File size: 1,783 Bytes
b384f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aecc64d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np
import torch

class NoiseSchedule:
  """
  Handles:
  - DDIM inference (with a ddim_mod to skip steps)
  - DDPM inference
  - Forward Noising
  - Linear beta schedule
  - Classifier Free Guidance (w is a hyperparameter for cfg schedule)
  """
  def __init__(self, T, std=1, shape=(4, 64, 64), ddim_mod=10, trainer_mode=False):
    self.T = T
    self.std = std
    self.ddim_mod = ddim_mod
    self.beta = torch.tensor(np.linspace(1e-4, 0.02, T), dtype=torch.float32, device='cpu' if trainer_mode else 'cuda')
    self.alpha = 1 - self.beta
    self.alpha_bar = self.alpha.cumprod(dim=0)
    self.w = torch.full((T,), 7.5, device='cpu' if trainer_mode else 'cuda')
    self.shape = shape

  def noise(self, x, t):
    eps = torch.randn_like(x) * self.std
    return (self.alpha_bar[t]**0.5) * x + ((1-self.alpha_bar[t])**0.5) * eps, eps

  def ddim_step(self, xt, t, eps):
    x0 = (xt - (1 - self.alpha_bar[t]).sqrt() * eps) / self.alpha_bar[t].sqrt()
    x0 = x0.clamp(-1, 1)
    # note that eps = (xt - sqrt(abar[t]) * x0) / sqrt(1 - abar[t])
    xt_1 = self.alpha_bar[max(0,t - self.ddim_mod)].sqrt() * x0 + (1 - self.alpha_bar[max(0,t - self.ddim_mod)]).sqrt() * eps
    return xt_1

  def ddpm_step(self, x, eps, t, var=None):
    var = self.beta[t] if var is None else var
    return (self.alpha[t]**-0.5) * (x - ((1 - self.alpha_bar[t])**0.5) * eps) + var * torch.randn_like(x)

  def generate(self, model, num_images=16, device="cuda"):
    with torch.no_grad():
      x = torch.randn((num_images, *self.shape), device=device) * self.std
      for t in range(self.T-1, -1, -self.ddim_mod):
        t_tensor = torch.full((num_images,),t, device=device)
        epsilons = model(x, t=t_tensor)
        x = self.ddim_step(x, t=t, eps=epsilons)
      return x