sqfoo commited on
Commit
cdb543e
·
verified ·
1 Parent(s): 159e29a

Upload Hugging Face related file

Browse files
Files changed (1) hide show
  1. stldm/stldm_hf.py +620 -0
stldm/stldm_hf.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, random
2
+ from torch import nn
3
+ from einops import rearrange
4
+
5
+ from stldm.submodules import *
6
+
7
+ class Down_Block(nn.Module):
8
+ def __init__(self, in_ch, hid_ch, out_ch, time_dim, is_last, patch_size=None, num_groups=8, heads=4, dim_head=32):
9
+ super(Down_Block, self).__init__()
10
+ self.block1 = ResnetBlock(dim=in_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
11
+ self.attn_spatial = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head))) if patch_size is None else Residual(PreNorm(hid_ch, Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
12
+ self.block2 = ResnetBlock(dim=hid_ch, dim_out=hid_ch, groups=num_groups)
13
+ # self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention_Pos(dim=hid_ch, heads=heads, dim_head=dim_head)))
14
+ self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention(dim=hid_ch, heads=heads, dim_head=dim_head)))
15
+ self.last = Downsample2D(dim_in=hid_ch, dim_out=out_ch) if not is_last else ChannelConversion(hid_ch, out_ch)
16
+
17
+ def forward(self, x, time_emb, cond=None, relative_pos=None):
18
+ assert x.ndim==5
19
+ B, T, C, H, W = x.shape
20
+
21
+ x = x.reshape(B*T, C, H, W)
22
+ if cond is None:
23
+ cond = torch.zeros_like(x) # -> Unconditioning
24
+
25
+ time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
26
+ time_emb = time_emb.repeat(1, T, 1)
27
+ time_emb = time_emb.reshape(B*T, -1)
28
+
29
+ out = torch.cat((x, cond), dim=1) # BT, 2C, H, W
30
+ out = self.block1(out, time_emb)
31
+
32
+ spatial_attn = self.attn_spatial(out)
33
+ out = self.block2(spatial_attn, time_emb)
34
+ *_, c, h, w = out.shape
35
+ out = out.reshape(B,T,c,h,w)
36
+
37
+ # temporal_attn = self.attn_temporal(out, relative_pos)
38
+ temporal_attn = self.attn_temporal(out)
39
+ temporal_attn = temporal_attn.reshape(B*T,c,h,w)
40
+
41
+ out = self.last(temporal_attn)
42
+ *_, c, h, w = out.shape
43
+
44
+ return out.reshape(B, T, c, h, w), spatial_attn, temporal_attn
45
+
46
+ class MidBlock(nn.Module):
47
+ def __init__(self, in_ch, time_dim, num_groups=8, heads=4, dim_head=32):
48
+ super(MidBlock, self).__init__()
49
+ self.block1 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
50
+ self.qattn_spatial = Residual(PreNorm(in_ch, Quadratic_SpatialAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
51
+ self.block2 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
52
+ # self.qattn_time = Residual(PreNorm(in_ch, TemporalAttention_Pos(dim=in_ch, heads=heads, dim_head=dim_head)))
53
+ self.qattn_time = Residual(PreNorm(in_ch, TemporalAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
54
+ self.block3 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
55
+
56
+ def forward(self, x, time_emb, relative_pos=None):
57
+ assert x.ndim==5
58
+ B, T, C, H, W = x.shape
59
+ x = x.reshape(B*T, C, H, W)
60
+
61
+ time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
62
+ time_emb = time_emb.repeat(1, T, 1)
63
+ time_emb = time_emb.reshape(B*T, -1)
64
+
65
+ out = self.block1(x, time_emb)
66
+ out = self.qattn_spatial(out)
67
+ out = self.block2(out, time_emb) # a little bit difference here
68
+
69
+ out = out.reshape((B, T, C, H, W))
70
+ # out = self.qattn_time(out, relative_pos).reshape(B*T, C, H, W)
71
+ out = self.qattn_time(out).reshape(B*T, C, H, W)
72
+ out = self.block3(out, time_emb)
73
+
74
+ *_, c, h, w = out.shape
75
+ return out.reshape(B, T, c, h, w)
76
+
77
+ class Up_Block(nn.Module):
78
+ def __init__(self, in_chs, hid_ch, out_ch, is_last, time_dim, patch_size=None, num_groups=8, heads=4, dim_head=32):
79
+ super(Up_Block, self).__init__()
80
+ in_ch, skip_ch = in_chs
81
+ self.up = Upsample2D(dim_in=in_ch, dim_out=hid_ch) if not is_last else ChannelConversion(in_ch, hid_ch)
82
+ self.attn_spatial = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head) if patch_size is None else Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
83
+ self.block1 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
84
+ # self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention_Pos(dim=hid_ch, heads=heads, dim_head=dim_head)))
85
+ self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention(dim=hid_ch, heads=heads, dim_head=dim_head)))
86
+ self.block2 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=out_ch, time_emb_dim=time_dim, groups=num_groups)
87
+
88
+ def forward(self, x, time_emb, spatialattn_skip, tempattn_skip, relative_pos=None):
89
+ assert x.ndim==5
90
+ B, T, C, H, W = x.shape
91
+ x = x.reshape(B*T, C, H, W)
92
+
93
+ time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
94
+ time_emb = time_emb.repeat(1, T, 1)
95
+ time_emb = time_emb.reshape(B*T, -1)
96
+
97
+ out = self.up(x)
98
+ *_, c, h, w = out.shape
99
+ out = out.reshape(-1, T, c, h, w)
100
+
101
+ # out = self.attn_temporal(out, relative_pos).reshape(B*T, c, h, w)
102
+ out = self.attn_temporal(out).reshape(B*T, c, h, w)
103
+
104
+ out = torch.cat((out, tempattn_skip), dim=1)
105
+ out = self.block1(out, time_emb)
106
+
107
+ out = self.attn_spatial(out)
108
+
109
+ out = torch.cat((out, spatialattn_skip), dim=1)
110
+ out = self.block2(out, time_emb)
111
+ *_, c, h, w = out.shape
112
+ return out.reshape(B, T, c, h, w)
113
+
114
+ class LDM(nn.Module):
115
+ def __init__(self, in_ch, chs_mult:tuple, patch_size=None, num_groups=8, heads=4, dim_head=32, base_ch=64):
116
+ super(LDM, self).__init__()
117
+ # Time Embedding MLP
118
+ time_dim = 4*base_ch
119
+ fourier_dim = base_ch
120
+ self.time_mlp = Time_MLP(dim=base_ch, time_dim=time_dim, fourier_dim=fourier_dim)
121
+
122
+ ups, downs = [], []
123
+ conditions = []
124
+
125
+ layer_no = len(chs_mult)
126
+ chs = [in_ch, *map(lambda m: base_ch*m, chs_mult)]
127
+ ch_in, ch_out = chs[:-1], chs[1:]
128
+ up_in, up_out = list(reversed(ch_out)), list(reversed(ch_in))
129
+
130
+ patches = None if patch_size is None else [patch_size//(2**n) for n in range(layer_no)] # Patch Size should be 2^N
131
+ for n in range(layer_no):
132
+ downs.append(
133
+ Down_Block(in_ch=2*ch_in[n], hid_ch=ch_in[n], out_ch=ch_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[n], is_last=(n==layer_no-1), num_groups=num_groups, heads=heads, dim_head=dim_head)
134
+ )
135
+ ups.append(
136
+ Up_Block(in_chs=(up_in[n], ch_in[-n-1]), hid_ch=up_in[n], out_ch=up_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[layer_no-n-1], is_last=(n==0), num_groups=num_groups, heads=heads, dim_head=dim_head)
137
+ )
138
+ if n != -1:
139
+ conditions.append(
140
+ Downsample2D(dim_in=ch_in[n], dim_out=ch_out[n])
141
+ )
142
+
143
+ self.downs = nn.ModuleList(downs)
144
+ self.ups = nn.ModuleList(ups)
145
+ self.conditions = nn.ModuleList(conditions)
146
+ self.mid = MidBlock(in_ch=ch_out[-1], time_dim=time_dim, num_groups=num_groups, heads=heads, dim_head=dim_head)
147
+ # self.relative_pos = RelativePositionBias(heads=heads)
148
+
149
+ def forward(self, x, time, conds=None):
150
+ t = self.time_mlp(time)
151
+
152
+ hid_spatial = []
153
+ hid_temporal = []
154
+
155
+ # relative_position = self.relative_pos(x.shape[1], x.device) # Calculate The Relative Position
156
+
157
+ for n, down_block in enumerate(self.downs):
158
+ # print(x.shape)
159
+ # x, spatial_attn, time_attn = down_block(x, t, conds, relative_position)
160
+ x, spatial_attn, time_attn = down_block(x, t, conds)
161
+ hid_spatial.append(spatial_attn)
162
+ hid_temporal.append(time_attn)
163
+ if conds is not None:
164
+ conds = self.conditions[n](conds)
165
+
166
+ # out = self.mid(x, t, relative_position)
167
+ out = self.mid(x, t)
168
+
169
+ for up_block in self.ups:
170
+ # out = up_block(out, t, hid_spatial.pop(), hid_temporal.pop(), relative_position)
171
+ out = up_block(out, t, hid_spatial.pop(), hid_temporal.pop())
172
+
173
+ return out
174
+
175
+ # constants
176
+ from collections import namedtuple
177
+ from torch.cuda.amp import autocast
178
+ import torch.nn.functional as F
179
+ from einops import reduce
180
+ from tqdm.auto import tqdm
181
+
182
+ ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
183
+
184
+ def identity(t, *args, **kwargs):
185
+ return t
186
+
187
+ def extract(a, t, x_shape):
188
+ b, *_ = t.shape
189
+ out = a.gather(-1, t)
190
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
191
+
192
+ def default(val, d):
193
+ if exists(val):
194
+ return val
195
+ return d() if callable(d) else d
196
+
197
+ def exists(x):
198
+ return x is not None
199
+
200
+ def guidance_scheduler(sampling_step: int, const: float):
201
+ return const*torch.ones(sampling_step)
202
+
203
+ from huggingface_hub import PyTorchModelHubMixin
204
+
205
+ class GaussianDiffusion(
206
+ nn.Module,
207
+ PyTorchModelHubMixin,
208
+ # optionally, you can add metadata which gets pushed to the model card
209
+ repo_url="https://github.com/sqfoo/stldm_official",
210
+ pipeline_tag="Precipitation_Nowcasting",
211
+ license="mit"):
212
+ def __init__(
213
+ self,
214
+ vp_param: dict,
215
+ stldm_param: dict,
216
+ timesteps = 1000,
217
+ sampling_timesteps = None,
218
+ objective = 'pred_v',
219
+ beta_schedule = 'sigmoid',
220
+ schedule_fn_kwargs = dict(),
221
+ ddim_sampling_eta = 0.,
222
+ offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
223
+ min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
224
+ min_snr_gamma = 5
225
+ ):
226
+ super(GaussianDiffusion, self).__init__()
227
+
228
+ self.backbone = SimVPV2_Model(**vp_param)
229
+ self.diff_unet = LDM(**stldm_param)
230
+
231
+ self.objective = objective
232
+ assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
233
+
234
+ if beta_schedule == 'linear':
235
+ beta_schedule_fn = linear_beta_schedule
236
+ elif beta_schedule == 'cosine':
237
+ beta_schedule_fn = cosine_beta_schedule
238
+ elif beta_schedule == 'sigmoid':
239
+ beta_schedule_fn = sigmoid_beta_schedule
240
+ else:
241
+ raise ValueError(f'unknown beta schedule {beta_schedule}')
242
+
243
+ betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
244
+
245
+ alphas = 1. - betas
246
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
247
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
248
+
249
+ timesteps, = betas.shape
250
+ self.num_timesteps = int(timesteps)
251
+
252
+ # sampling related parameters
253
+
254
+ self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
255
+
256
+ assert self.sampling_timesteps <= timesteps
257
+ self.is_ddim_sampling = self.sampling_timesteps < timesteps
258
+ self.ddim_sampling_eta = ddim_sampling_eta
259
+
260
+ # helper function to register buffer from float64 to float32
261
+
262
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
263
+
264
+ register_buffer('betas', betas)
265
+ register_buffer('alphas_cumprod', alphas_cumprod)
266
+ register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
267
+
268
+ # calculations for diffusion q(x_t | x_{t-1}) and others
269
+
270
+ register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
271
+ register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
272
+ register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
273
+ register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
274
+ register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
275
+
276
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
277
+
278
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
279
+
280
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
281
+
282
+ register_buffer('posterior_variance', posterior_variance)
283
+
284
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
285
+
286
+ register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
287
+ register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
288
+ register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
289
+
290
+ # offset noise strength - in blogpost, they claimed 0.1 was ideal
291
+
292
+ self.offset_noise_strength = offset_noise_strength
293
+
294
+ # derive loss weight
295
+ # snr - signal noise ratio
296
+
297
+ snr = alphas_cumprod / (1 - alphas_cumprod)
298
+
299
+ # https://arxiv.org/abs/2303.09556
300
+
301
+ maybe_clipped_snr = snr.clone()
302
+ if min_snr_loss_weight:
303
+ maybe_clipped_snr.clamp_(max = min_snr_gamma)
304
+
305
+ if objective == 'pred_noise':
306
+ register_buffer('loss_weight', maybe_clipped_snr / snr)
307
+ elif objective == 'pred_x0':
308
+ register_buffer('loss_weight', maybe_clipped_snr)
309
+ elif objective == 'pred_v':
310
+ register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))
311
+
312
+ @property
313
+ def device(self):
314
+ return self.betas.device
315
+
316
+ # CFG schdeuler => by taking pre-setting scheduler
317
+ def setup_guidance(self, scheduler):
318
+ if exists(scheduler):
319
+ self.CFG_sch = scheduler.to(self.device)
320
+ else:
321
+ self.CFG_sch = scheduler
322
+
323
+ def predict_start_from_noise(self, x_t, t, noise):
324
+ return (
325
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
326
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
327
+ )
328
+
329
+ def predict_noise_from_start(self, x_t, t, x0):
330
+ return (
331
+ (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
332
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
333
+ )
334
+
335
+ def predict_v(self, x_start, t, noise):
336
+ return (
337
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
338
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
339
+ )
340
+
341
+ def predict_start_from_v(self, x_t, t, v):
342
+ return (
343
+ extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
344
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
345
+ )
346
+
347
+ def q_posterior(self, x_start, x_t, t):
348
+ posterior_mean = (
349
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
350
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
351
+ )
352
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
353
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
354
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
355
+
356
+ def model_predictions(self, x, t, cond, clip_x_start = False, rederive_pred_noise = False):
357
+ # print(t.device)
358
+ if exists(self.CFG_sch):
359
+ uncond = self.diff_unet(x, t, conds=None) #conds=torch.zeros_like(cond))
360
+ model_output = self.diff_unet(x, t, conds=cond)
361
+ time = int(t[0])
362
+ model_output = model_output - self.CFG_sch[time] * (uncond - model_output)
363
+ else:
364
+ model_output = self.diff_unet(x, t, conds=cond)
365
+ maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
366
+
367
+ if self.objective == 'pred_noise':
368
+ pred_noise = model_output
369
+ x_start = self.predict_start_from_noise(x, t, pred_noise)
370
+ x_start = maybe_clip(x_start)
371
+
372
+ if clip_x_start and rederive_pred_noise:
373
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
374
+
375
+ elif self.objective == 'pred_x0':
376
+ x_start = model_output
377
+ x_start = maybe_clip(x_start)
378
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
379
+
380
+ elif self.objective == 'pred_v':
381
+ v = model_output
382
+ x_start = self.predict_start_from_v(x, t, v)
383
+ x_start = maybe_clip(x_start)
384
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
385
+
386
+ return ModelPrediction(pred_noise, x_start)
387
+
388
+ def p_mean_variance(self, x, t, cond=None, clip_denoised = True):
389
+ preds = self.model_predictions(x, t, cond=cond, clip_x_start=False,)
390
+ x_start = preds.pred_x_start
391
+
392
+ if clip_denoised:
393
+ x_start.clamp_(-1., 1.)
394
+
395
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
396
+ return model_mean, posterior_variance, posterior_log_variance, x_start
397
+
398
+ @torch.no_grad()
399
+ def p_sample(self, x, t: int, cond=None):
400
+ b, *_, device = *x.shape, self.device
401
+ batched_times = torch.full((b,), t, device = device, dtype = torch.long)
402
+ model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, cond=cond, clip_denoised = False)
403
+ noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
404
+ pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
405
+ return pred_img, x_start
406
+
407
+ @torch.no_grad()
408
+ def p_sample_loop(self, shape, cond=None, return_all_timesteps = False):
409
+ batch, device = shape[0], self.device
410
+
411
+ frames_pred = torch.randn(shape, device = device)
412
+ imgs = [frames_pred]
413
+
414
+ for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps, disable=True):
415
+ frames_pred, _ = self.p_sample(frames_pred, t, cond=cond)
416
+ imgs.append(frames_pred)
417
+
418
+ ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
419
+ return ret
420
+
421
+ @torch.no_grad()
422
+ def ddim_sample(self, shape, cond=None, return_all_timesteps = False):
423
+ batch, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
424
+ device = self.device
425
+ times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
426
+ times = list(reversed(times.int().tolist()))
427
+ time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
428
+
429
+ frames_pred = torch.randn(shape, device = device)
430
+ imgs = [frames_pred]
431
+
432
+ for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', disable=True):
433
+ time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
434
+ pred_noise, x_start, *_ = self.model_predictions(
435
+ frames_pred,
436
+ time_cond,
437
+ cond = cond, #cond.copy(),
438
+ clip_x_start = False,
439
+ rederive_pred_noise = True
440
+ )
441
+
442
+ if time_next < 0:
443
+ frames_pred = x_start
444
+ imgs.append(frames_pred)
445
+ continue
446
+
447
+ alpha = self.alphas_cumprod[time]
448
+ alpha_next = self.alphas_cumprod[time_next]
449
+
450
+ sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
451
+ c = (1 - alpha_next - sigma ** 2).sqrt()
452
+
453
+ noise = torch.randn_like(frames_pred)
454
+
455
+ frames_pred = x_start * alpha_next.sqrt() + \
456
+ c * pred_noise + \
457
+ sigma * noise
458
+
459
+ imgs.append(frames_pred)
460
+
461
+ ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
462
+ return ret
463
+
464
+ @torch.no_grad()
465
+ def sample(self, frames_in, return_all_timesteps = False):
466
+ assert frames_in.ndim == 5
467
+ B, T_in, C, H, W = frames_in.shape
468
+ device = self.device
469
+
470
+ backbone_output, conds, *_ = self.backbone(frames_in)
471
+ sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
472
+
473
+ *_, c, h, w = conds.shape
474
+ tgt_shape = conds.reshape(B, -1, c, h, w).shape
475
+ ldm_pred = sample_fn(
476
+ tgt_shape,
477
+ cond=conds,
478
+ return_all_timesteps = return_all_timesteps
479
+ )
480
+
481
+ ldm_pred = rearrange(ldm_pred, 'b t c h w -> (b t) c h w')
482
+ frames_pred = self.backbone.vae.decode(ldm_pred)
483
+ frames_pred = rearrange(frames_pred, '(b t) c h w -> b t c h w', b=B)
484
+ return frames_pred, backbone_output
485
+
486
+ def predict(self, frames_in, compute_loss=False, **kwargs):
487
+ pred, mu = self.sample(frames_in=frames_in)
488
+ return pred, mu
489
+
490
+ def compute_loss(self, frames_in, frames_gt, validate=False):
491
+ compute_loss = True and (not validate)
492
+ B, T_in, C, H, W = frames_in.shape
493
+ T_out = frames_gt.shape[1]
494
+ device = frames_in.device
495
+
496
+ """
497
+ Diffusion Loss
498
+ """
499
+ backbone_output, conds = self.backbone(frames_in)
500
+ hid_gt, _ = self.backbone.vae.encode(
501
+ rearrange(frames_gt, 'b t c h w -> (b t) c h w')
502
+ )
503
+ hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
504
+ t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
505
+ if random.random() > 0.85: # Unconditional
506
+ conds = None
507
+ diff_loss = self.p_losses(hid_gt.detach(), t, cond=conds)
508
+
509
+ """
510
+ Backbone Loss
511
+ """
512
+ mu_loss = self.backbone._losses_(frames_in, frames_gt)
513
+
514
+ """
515
+ VAE Loss
516
+ """
517
+ ae_loss, kl_loss = self.backbone.vae._losses_(
518
+ rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w'),
519
+ rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w')
520
+ )
521
+ kl_weight = 1E-6
522
+ recon_loss = ae_loss + kl_weight*kl_loss
523
+
524
+ """
525
+ Prior Loss at t=T [Noisy]
526
+ """
527
+ hid_gt, _ = self.backbone.vae.encode(
528
+ rearrange(frames_gt, 'b t c h w -> (b t) c h w')
529
+ )
530
+ hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
531
+ T = torch.ones((B,), device=self.device).long() * (self.num_timesteps - 1)
532
+ mu_noisy = extract(self.sqrt_alphas_cumprod, T, hid_gt.shape) * hid_gt
533
+ sigma_noisy = extract(self.sqrt_one_minus_alphas_cumprod, T, hid_gt.shape)
534
+ log_var_noisy = 2*torch.log(sigma_noisy)
535
+ prior_loss = self.kl_from_standard_normal(mu_noisy, log_var_noisy)
536
+
537
+ return recon_loss, mu_loss, diff_loss, prior_loss
538
+
539
+
540
+ def kl_from_standard_normal(self, mean, log_var):
541
+ kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var)
542
+ return kl.mean()
543
+
544
+ @autocast(enabled = False)
545
+ def q_sample(self, x_start, t, noise = None):
546
+ noise = default(noise, lambda: torch.randn_like(x_start))
547
+
548
+ return (
549
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
550
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
551
+ )
552
+
553
+ def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, cond=None):
554
+ b, T, c, h, w = x_start.shape
555
+
556
+ noise = default(noise, lambda: torch.randn_like(x_start))
557
+
558
+ # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
559
+ offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
560
+
561
+ if offset_noise_strength > 0.:
562
+ offset_noise = torch.randn(x_start.shape[:2], device = self.device)
563
+ noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
564
+
565
+ # noise sample
566
+ x = self.q_sample(x_start=x_start, t=t, noise=noise) # Use q_sample here for updating: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L763
567
+
568
+ model_out = self.diff_unet(x, t, conds=cond)
569
+
570
+ if self.objective == 'pred_noise':
571
+ target = noise
572
+ elif self.objective == 'pred_x0':
573
+ target = x_start
574
+ elif self.objective == 'pred_v':
575
+ v = self.predict_v(x_start, t, noise)
576
+ target = v
577
+ else:
578
+ raise ValueError(f'unknown objective {self.objective}')
579
+
580
+ loss = F.mse_loss(model_out, target, reduction = 'none') # (B, T, C, H, W)
581
+ loss = reduce(loss, 'b ... -> b', 'mean')
582
+
583
+ loss = loss * extract(self.loss_weight, t, loss.shape)
584
+ return loss.mean()
585
+
586
+ @torch.no_grad()
587
+ def forward(self, input_x, include_mu=False, **kwargs):
588
+ pred, mu = self.predict(input_x, compute_loss=False)
589
+ if include_mu:
590
+ return pred, mu
591
+ else:
592
+ return pred
593
+
594
+ from stldm.modules import SimVPV2_Model, VAE
595
+ def model_setup(model_config, print_info=False, cfg_str=None):
596
+ if print_info:
597
+ print('Setup the model with considering temporal attention be (BHW, T, C) ... ...')
598
+ print('Train it from end to end')
599
+ vp_config = model_config['vp_param']
600
+ ldm_config = model_config['stldm_param']
601
+
602
+ vpm = SimVPV2_Model(**vp_config)
603
+ ldm = LDM(**ldm_config)
604
+ model = GaussianDiffusion(vp_model=vpm, diffusion=ldm, **model_config['param'])
605
+
606
+ scheduler = guidance_scheduler(sampling_step=model_config['param']['timesteps'], const=cfg_str) if cfg_str is not None else None
607
+ model.setup_guidance(scheduler)
608
+
609
+ return model
610
+
611
+ def ae_setup(model_config):
612
+ vp_config = model_config['vp_param']
613
+ vpm = SimVPV2_Model(**vp_config)
614
+ ae = vpm.vae
615
+ return ae
616
+
617
+ def backbone_setup(model_config):
618
+ vp_config = model_config['vp_param']
619
+ vpm = SimVPV2_Model(**vp_config)
620
+ return vpm