Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn, Tensor | |
| class DiffusionSchedule(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| timesteps: int = 1000, | |
| beta_start: float = 1e-4, | |
| beta_end: float = 2e-2, | |
| ): | |
| super().__init__() | |
| betas = torch.linspace(beta_start, beta_end, timesteps) | |
| alphas = 1.0 - betas | |
| alpha_bars = torch.cumprod(alphas, dim=0) | |
| self.timesteps = timesteps | |
| self.register_buffer("betas", betas) | |
| self.register_buffer("alphas", alphas) | |
| self.register_buffer("alpha_bars", alpha_bars) | |
| self.register_buffer("sqrt_alpha_bars", torch.sqrt(alpha_bars)) | |
| self.register_buffer( | |
| "sqrt_one_minus_alpha_bars", | |
| torch.sqrt(1.0 - alpha_bars), | |
| ) | |
| def extract(self, values: Tensor, t: Tensor, x: Tensor) -> Tensor: | |
| out = values.gather(0, t) | |
| while out.ndim < x.ndim: | |
| out = out.unsqueeze(-1) | |
| return out | |
| def q_sample(self, x0: Tensor, t: Tensor, noise: Tensor) -> Tensor: | |
| sqrt_ab = self.extract(self.sqrt_alpha_bars, t, x0) | |
| sqrt_1mab = self.extract( | |
| self.sqrt_one_minus_alpha_bars, | |
| t, | |
| x0, | |
| ) | |
| return sqrt_ab * x0 + sqrt_1mab * noise | |
| def predict_x0_from_eps( | |
| self, | |
| xt: Tensor, | |
| t: Tensor, | |
| eps_pred: Tensor, | |
| ) -> Tensor: | |
| sqrt_ab = self.extract(self.sqrt_alpha_bars, t, xt) | |
| sqrt_1mab = self.extract( | |
| self.sqrt_one_minus_alpha_bars, | |
| t, | |
| xt, | |
| ) | |
| return (xt - sqrt_1mab * eps_pred) / sqrt_ab.clamp_min(1e-8) | |