Yzy00518 commited on
Commit
ed8c297
·
1 Parent(s): 9632411

Upload src/model/gaussian_diffusion.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/model/gaussian_diffusion.py +211 -0
src/model/gaussian_diffusion.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ def linear_beta_schedule(timesteps):
7
+ scale = 1.0 # for 100 steps
8
+ beta_start = scale * 0.0001
9
+ beta_end = scale * 0.02
10
+ return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
11
+
12
+ class GaussianDiffusion:
13
+ def __init__(
14
+ self,
15
+ device,
16
+ fix_mode=False,
17
+ text_emb=False,
18
+ fixed_frames=2,
19
+ seq_len=16,
20
+ timesteps=100,
21
+ beta_schedule='linear',
22
+ ):
23
+ self.device = device
24
+ self.fix_mode = fix_mode # autoregressive
25
+ self.fixed_frames = fixed_frames # number of frames to fix
26
+ self.timesteps = timesteps
27
+ self.text_emb = text_emb
28
+ self.seq_len = seq_len
29
+
30
+
31
+ if beta_schedule == 'linear':
32
+ betas = linear_beta_schedule(timesteps)
33
+ elif beta_schedule == 'cosine':
34
+ raise NotImplementedError('cosine schedule is not implemented yet!')
35
+ else:
36
+ raise ValueError(f'unknown beta schedule {beta_schedule}')
37
+
38
+ self.betas = betas.to(self.device)
39
+ self.alphas = (1. - self.betas).to(self.device)
40
+ self.alphas_cumprod = torch.cumprod(self.alphas, axis=0).to(self.device)
41
+ self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.).to(self.device)
42
+
43
+ # calculations for diffusion q(x_t | x_{t-1}) and others
44
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod).to(self.device)
45
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod).to(self.device)
46
+ self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod).to(self.device)
47
+ self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod).to(self.device)
48
+ self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1).to(self.device)
49
+
50
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
51
+ self.posterior_variance = (
52
+ self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
53
+ ).to(self.device)
54
+ # below: log calculation clipped because the posterior variance is 0 at the beginning
55
+ # of the diffusion chain
56
+ self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min =1e-20)).to(self.device)
57
+
58
+ self.posterior_mean_coef1 = (
59
+ self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
60
+ ).to(self.device)
61
+ self.posterior_mean_coef2 = (
62
+ (1.0 - self.alphas_cumprod_prev)
63
+ * torch.sqrt(self.alphas)
64
+ / (1.0 - self.alphas_cumprod)
65
+ ).to(self.device)
66
+
67
+ # get the param of given timestep t
68
+ def _extract(self, a, t, x_shape):
69
+ batch_size = t.shape[0]
70
+ out = a.to(t.device).gather(0, t).float()
71
+ out = out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(self.device)
72
+ return out
73
+
74
+ # forward diffusion (using the nice property): q(x_t | x_0)
75
+ def q_sample(self, x_start, t, noise=None):
76
+ if noise is None:
77
+ noise = torch.randn_like(x_start)
78
+
79
+ sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
80
+ sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
81
+ return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
82
+
83
+ # Get the mean and variance of q(x_t | x_0).
84
+ def q_mean_variance(self, x_start, t):
85
+ mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
86
+ variance = self._extract(1.0 - self.alphas_cumprod, t, x_start.shape)
87
+ log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
88
+ return mean, variance, log_variance
89
+
90
+ # Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0)
91
+ def q_posterior_mean_variance(self, x_start, x_t, t):
92
+ posterior_mean = (
93
+ self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
94
+ + self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
95
+ )
96
+ posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
97
+ posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
98
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
99
+
100
+ # compute x_0 from x_t and pred noise: the reverse of `q_sample`
101
+ def predict_start_from_noise(self, x_t, t, noise):
102
+ return (
103
+ self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
104
+ self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
105
+ )
106
+
107
+ # compute predicted mean and variance of p(x_{t-1} | x_t)
108
+ def p_mean_variance(self, model, x_t, t, clip_denoised=True, **kwargs):
109
+ # predict noise using model
110
+ assert 'text' in kwargs, 'text is required'
111
+ assert 'prog_ind' in kwargs, 'prog_ind is required'
112
+ assert 'joints_orig' in kwargs, 'joints_orig is required'
113
+ pred_noise = model(x_t, t,
114
+ text_emb=kwargs['text'],
115
+ prog_ind=kwargs['prog_ind'],
116
+ joints_orig=kwargs['joints_orig'])
117
+
118
+ # use cfg for text embedding:
119
+ if kwargs['use_cfg']:
120
+ pred_noise_empty = model(x_t, t,
121
+ text_emb=torch.zeros_like(kwargs['text']),
122
+ prog_ind=kwargs['prog_ind'],
123
+ joints_orig=kwargs['joints_orig'])
124
+ pred_noise = pred_noise_empty + kwargs['cfg_alpha'] * (pred_noise - pred_noise_empty)
125
+
126
+ # get the predicted x_0: different from the algorithm2 in the paper
127
+ x_recon = self.predict_start_from_noise(x_t, t, pred_noise)
128
+
129
+ if clip_denoised:
130
+ x_recon = torch.clamp(x_recon, min=-1., max=1.)
131
+ model_mean, posterior_variance, posterior_log_variance = \
132
+ self.q_posterior_mean_variance(x_recon, x_t, t)
133
+ return model_mean, posterior_variance, posterior_log_variance
134
+
135
+ # denoise_step: sample x_{t-1} from x_t and pred_noise
136
+ # @torch.no_grad()
137
+ def p_sample(self, model, x_t, t, clip_denoised=True, **kwargs):
138
+ if 'disc_model' in kwargs:
139
+ disc_model = kwargs['disc_model']
140
+ try:
141
+ cg_alpha = kwargs['cg_alpha'] # default 1.0
142
+ cg_diffusion_steps = kwargs['cg_diffusion_steps']
143
+ except:
144
+ print("cg_alpha and cg_diffusion_steps are not provided!")
145
+ print("Using default values: cg_alpha=1.0, cg_diffusion_steps=5")
146
+ cg_alpha = 1.0
147
+ cg_diffusion_steps = 5
148
+ # predict mean and variance
149
+ model_mean, _, model_log_variance = self.p_mean_variance(model, x_t, t,
150
+ clip_denoised=clip_denoised, **kwargs)
151
+ model_mean = torch.tensor(model_mean, requires_grad=True)
152
+ noise = torch.randn_like(x_t)
153
+ # no noise when t == 0
154
+ nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))))
155
+ if t.item() < cg_diffusion_steps:
156
+ pred_syn = disc_model(model_mean, t) # y = f(theta, x) theta fixed
157
+ pred_syn.backward()
158
+
159
+ grad = model_mean.grad * cg_alpha
160
+ model_mean = model_mean - nonzero_mask * (0.5 * model_log_variance).exp() * grad
161
+
162
+ # compute x_{t-1}
163
+ pred_img = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
164
+ return pred_img
165
+
166
+ # denoise: reverse diffusion
167
+ # @torch.no_grad()
168
+ def p_sample_loop(self, model, shape, fixed_points=None, **kwargs):
169
+ batch_size = shape[0]
170
+ device = next(model.parameters()).device
171
+
172
+ # start from pure noise (for each example in the batch)
173
+ img = torch.randn(shape, device=device)
174
+ # notice that if we are in fixed mode, we need to fix the first 2 frames
175
+ if self.fix_mode:
176
+ assert not (fixed_points is None), 'fixed_points is required for fixed mode'
177
+ img[:, :self.fixed_frames, :] = fixed_points
178
+ imgs = []
179
+
180
+ for i in reversed(range(0, self.timesteps)):
181
+ img = self.p_sample(model, img, torch.full((batch_size,), i, device=device, dtype=torch.long), **kwargs)
182
+ if self.fix_mode:
183
+ img[:, :self.fixed_frames, :] = fixed_points
184
+ imgs.append(img)
185
+ return imgs
186
+
187
+
188
+ # sample new images
189
+ # @torch.no_grad()
190
+ def sample(self, model, batch_size=1, seq_len=16, channels=135,
191
+ fixed_points=None, clip_denoised=True, **kwargs):
192
+ return self.p_sample_loop(model, shape=(batch_size, seq_len, channels),
193
+ fixed_points=fixed_points, clip_denoised=clip_denoised, **kwargs)
194
+
195
+ # compute train losses
196
+ def train_losses(self, model, x_start, t, mask=None, **kwargs):
197
+ assert not (mask is None and self.fixed_mode), 'mask is required for fixed mode'
198
+ if mask is None:
199
+ mask = torch.zeros_like(x_start).to(dtype=torch.bool, device=self.device)
200
+
201
+ mask_inv = torch.logical_not(mask)
202
+ # generate random noise
203
+ noise = torch.randn_like(x_start).to(device=self.device)
204
+ noise[mask] = 0.
205
+
206
+ # get x_t
207
+ x_noisy = self.q_sample(x_start, t, noise=noise)
208
+ predicted_noise = model(x_noisy, t, text_emb=kwargs['text'], prog_ind=kwargs['prog_ind'], joints_orig=kwargs['joints_orig'])
209
+
210
+ loss = F.smooth_l1_loss(noise[mask_inv], predicted_noise[mask_inv])
211
+ return loss