File size: 5,843 Bytes
9f5a022
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import torch
import torch.nn as nn
from enum import Enum
from tqdm import trange





Schedule = Enum('Schedule', ['LINEAR', 'COSINE'])

class DiffusionManager(nn.Module):
    def __init__(self, model: nn.Module, noise_steps=1000, start=0.0001, end=0.02, device="cpu", **kwargs ) -> None:
        super().__init__(**kwargs)

        self.model = model

        self.noise_steps = noise_steps

        self.start = start
        self.end = end
        self.device = device

        self.schedule = None

        self.set_schedule()

        #model.set_parent(self)


    def _get_schedule(self, schedule_type: Schedule = Schedule.LINEAR):
        if schedule_type == Schedule.LINEAR:
            return torch.linspace(self.start, self.end, self.noise_steps)
        elif schedule_type == Schedule.COSINE:
            # https://arxiv.org/pdf/2102.09672 page 4
            #https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 
            #line 18
            def get_alphahat_at(t):
                def f(t):
                    s=self.start
                    return torch.cos((t/self.noise_steps + s)/(1+s) * torch.pi/2) ** 2
                
                return f(t)/f(torch.zeros_like(t))

            t = torch.Tensor(range(self.noise_steps))

            t = 1-(get_alphahat_at(t + 1)/get_alphahat_at(t))
            
            t = torch.minimum(t, torch.ones_like(t) * 0.999) #"In practice, we clip β_t to be no larger than 0.999 to prevent singularities at the end of the diffusion process n"

            return t
    
    def set_schedule(self, schedule: Schedule = Schedule.LINEAR):
        self.schedule = self._get_schedule(schedule).to(self.device)
    
    def get_schedule_at(self, step):
        beta = self.schedule
        alpha = 1 - beta
        alpha_hat = torch.cumprod(alpha, dim=0)

        return self._unsqueezify(beta.data[step]), self._unsqueezify(alpha.data[step]), self._unsqueezify(alpha_hat.data[step])
    
    @staticmethod
    def _unsqueezify(value):
        return value.view(-1, 1, 1, 1)#.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        
    def noise_image(self, image, step):

        
        image = image.to(self.device)

        beta, alpha, alpha_hat = self.get_schedule_at(step)

        epsilon = torch.randn_like(image)

        # print(alpha_hat)

        # print(alpha_hat.size())
        # print(image.size())

        noised_img = torch.sqrt(alpha_hat) * image  + torch.sqrt(1 - alpha_hat) * epsilon

        return noised_img, epsilon
    
    def random_timesteps(self, amt=1):

        return torch.randint(low=1, high=self.noise_steps, size=(amt,))
    
    

    
    def sample(self, img_size, condition, amt=5, use_tqdm=True):

        if tuple(condition.shape)[0] < amt:
            condition = condition.repeat(amt, 1)

        self.model.eval()

        condition = condition.to(self.device)

        my_trange = lambda x, y, z: trange(x,y, z, leave=False,dynamic_ncols=True)
        fn = my_trange if use_tqdm else range
        with torch.no_grad():
            
            cur_img = torch.randn((amt, 3, img_size, img_size)).to(self.device)
            for i in fn(self.noise_steps-1, 0, -1):

                timestep = torch.ones(amt) * (i)

                timestep = timestep.to(self.device)



                predicted_noise = self.model(cur_img, timestep, condition)

                beta, alpha, alpha_hat = self.get_schedule_at(i)

                cur_img = (1/torch.sqrt(alpha))*(cur_img - (beta/torch.sqrt(1-alpha_hat))*predicted_noise)
                if i > 1:
                    cur_img = cur_img + torch.sqrt(beta)*torch.randn_like(cur_img)


        self.model.train()




    
        return cur_img
    def sample_multicond(self, img_size, condition, use_tqdm=True):
        num_conditions = condition.shape[0]

        
        
        amt = num_conditions

        self.model.eval()

        condition = condition.to(self.device)

        my_trange = lambda x, y, z: trange(x, y, z, leave=False, dynamic_ncols=True)
        fn = my_trange if use_tqdm else range
        
        with torch.no_grad():

            cur_img = torch.randn((amt, 3, img_size, img_size)).to(self.device)
            
            for i in fn(self.noise_steps-1, 0, -1):
                timestep = torch.ones(amt) * i
                timestep = timestep.to(self.device)


                predicted_noise = self.model(cur_img, timestep, condition)

                beta, alpha, alpha_hat = self.get_schedule_at(i)

                cur_img = (1 / torch.sqrt(alpha)) * (cur_img - (beta / torch.sqrt(1 - alpha_hat)) * predicted_noise)
                if i > 1:
                    cur_img = cur_img + torch.sqrt(beta) * torch.randn_like(cur_img)

        self.model.train()

        # Return images sampled for each condition
        return cur_img
    
    def training_loop_iteration(self, optimizer, batch, label, criterion):

        def print_(string):
            for i in range(10):
                print(string)
        batch = batch.to(self.device)

        #label = label.long() # uncomment for nn.Embedding
        label = label.to(self.device)

        timesteps = self.random_timesteps(batch.shape[0]).to(self.device)

        noisy_batch, real_noise = self.noise_image(batch, timesteps)

        if torch.isnan(noisy_batch).any() or torch.isnan(real_noise).any():
            print_("NaNs detected in the noisy batch or real noise")


        pred_noise = self.model(noisy_batch, timesteps, label)

        if torch.isnan(pred_noise).any():
            print_("NaNs detected in the predicted noise")

        loss = criterion(real_noise, pred_noise)

        if torch.isnan(loss).any():
            print_("NaNs detected in the loss")

        loss.backward()
        optimizer.step()

        return loss.item()