Spaces:
Sleeping
Sleeping
File size: 1,645 Bytes
e0552b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | import torch
from torch import nn, Tensor
class DiffusionSchedule(nn.Module):
def __init__(
self,
*,
timesteps: int = 1000,
beta_start: float = 1e-4,
beta_end: float = 2e-2,
):
super().__init__()
betas = torch.linspace(beta_start, beta_end, timesteps)
alphas = 1.0 - betas
alpha_bars = torch.cumprod(alphas, dim=0)
self.timesteps = timesteps
self.register_buffer("betas", betas)
self.register_buffer("alphas", alphas)
self.register_buffer("alpha_bars", alpha_bars)
self.register_buffer("sqrt_alpha_bars", torch.sqrt(alpha_bars))
self.register_buffer(
"sqrt_one_minus_alpha_bars",
torch.sqrt(1.0 - alpha_bars),
)
def extract(self, values: Tensor, t: Tensor, x: Tensor) -> Tensor:
out = values.gather(0, t)
while out.ndim < x.ndim:
out = out.unsqueeze(-1)
return out
def q_sample(self, x0: Tensor, t: Tensor, noise: Tensor) -> Tensor:
sqrt_ab = self.extract(self.sqrt_alpha_bars, t, x0)
sqrt_1mab = self.extract(
self.sqrt_one_minus_alpha_bars,
t,
x0,
)
return sqrt_ab * x0 + sqrt_1mab * noise
def predict_x0_from_eps(
self,
xt: Tensor,
t: Tensor,
eps_pred: Tensor,
) -> Tensor:
sqrt_ab = self.extract(self.sqrt_alpha_bars, t, xt)
sqrt_1mab = self.extract(
self.sqrt_one_minus_alpha_bars,
t,
xt,
)
return (xt - sqrt_1mab * eps_pred) / sqrt_ab.clamp_min(1e-8)
|