File size: 3,652 Bytes
e6092ac
 
 
 
 
 
 
 
 
 
 
e9c60a4
 
e6092ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9c60a4
e6092ac
 
 
 
 
 
 
 
 
 
 
 
 
 
e9c60a4
e6092ac
 
 
 
 
 
 
 
 
 
 
 
 
 
95dc5d6
e6092ac
 
 
 
 
 
95dc5d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9c60a4
95dc5d6
 
 
 
 
e9c60a4
95dc5d6
e9c60a4
95dc5d6
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch 
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


def extract(v, t, x_shape):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    v = v.to("cuda")
    out = torch.gather(v, index=t, dim=0).float().to("cuda")
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))


class GaussianDiffusionTrainer(nn.Module):
    def __init__(self, beta_1, beta_T, T, model) -> None:
        super().__init__()
        self.model = model
        self.register_buffer(
            'betas',
            torch.linspace(beta_1,beta_T,T).double()
        )

        self.T = T
        self.alphas = 1 - self.betas
        self.beta_alphas =  torch.cumprod(self.alphas,dim=0)

        # Calculation for Algorithm 1 := sqrt(alpha_bar), sqrt(1-alpha_bar)
        self.register_buffer(
            "sqrt_beta_alphas",
            torch.sqrt(self.beta_alphas)
        )
        self.register_buffer(
            "sqrt_one_minus_beta_alphas",
            torch.sqrt(1 - self.beta_alphas)
        )
    
    def forward(self,x_0):
        t = torch.randint(self.T,size=(x_0.shape[0],),device=x_0.device)
        noise = torch.randn_like(x_0)
        x_t = (
            extract(self.sqrt_beta_alphas,t,x_0.shape) * x_0 + 
            extract(self.sqrt_one_minus_beta_alphas,t,x_0.shape) * noise
        )
        loss = F.mse_loss(self.model(x_t,t),noise,reduction='mean')
        return loss
    

class GaussianDiffusionSampler(nn.Module):
    def __init__(self,beta_1,beta_t,model, T) -> None:
        super().__init__()
        self.model = model 
        self.T = T
        self.register_buffer(
            "betas",
            torch.linspace(beta_1,beta_t,self.T).double()
        )
        self.alphas = 1 - self.betas
        self.beta_alphas = torch.cumprod(self.alphas,dim=0)

        """
         This line of code pads the tensor self.beta_alphas by adding a single element with the value 1 to the beginning of the tensor. 
         The resulting tensor is stored in self.beta_alphas_prev.
        """
        self.beta_alphas_prev = F.pad(self.beta_alphas,[1,0],value=1)[:T]

        self.register_buffer(
            "coeff1",
            (1 / torch.sqrt(self.alphas))
        )

        self.register_buffer(
            "coeff2",
            self.coeff1 * ((1- self.alphas) / (torch.sqrt(1-self.beta_alphas)))
        )

        self.register_buffer(
            "posterior_coeff",
            (1 - self.beta_alphas_prev) / (1-self.beta_alphas) * self.betas
        )

    def pred_xt_prev_mean_from_eps(self,x_t,t,eps):
        return (
            extract(self.coeff1,t,x_t.shape) * x_t - 
            extract(self.coeff2,t,x_t.shape) * eps
        )
    
    def p_mean_variance(self,x_t,t):
        var = torch.cat([self.posterior_coeff[1:2],self.betas[1:]])
        var = extract(var,t,x_t.shape)

        eps = self.model(x_t,t)
        xt_prev_mean = self.pred_xt_prev_mean_from_eps(x_t,t,eps)
        return xt_prev_mean,var
    
    def forward(self,x_T):
        x_t=x_T.to("cuda")
        for timestep in reversed(range(self.T)):
            print(f"Sampling timestep: {timestep}")

            t = x_t.new_ones([x_t.shape[0],], dtype=torch.long) * timestep
            mean, var = self.p_mean_variance(x_t,t)
            mean , var = mean.to("cuda"), var.to("cuda")
            if timestep > 0:
                noise = torch.randn_like(x_t).to("cuda")
            else:
                noise = 0
            x_t = mean + torch.sqrt(var) * noise
        
        x_0 = x_t
        return torch.clip(x_0,-1,1)