Fabrice-TIERCELIN commited on
Commit
a532a56
·
verified ·
1 Parent(s): 0ba18b1

Delete video_to_video

Browse files
Files changed (30) hide show
  1. video_to_video/__init__.py +0 -0
  2. video_to_video/__pycache__/__init__.cpython-39.pyc +0 -0
  3. video_to_video/__pycache__/video_to_video_model.cpython-39.pyc +0 -0
  4. video_to_video/diffusion/__init__.py +0 -0
  5. video_to_video/diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
  6. video_to_video/diffusion/__pycache__/diffusion_sdedit.cpython-39.pyc +0 -0
  7. video_to_video/diffusion/__pycache__/schedules_sdedit.cpython-39.pyc +0 -0
  8. video_to_video/diffusion/__pycache__/solvers_sdedit.cpython-39.pyc +0 -0
  9. video_to_video/diffusion/diffusion_sdedit.py +0 -443
  10. video_to_video/diffusion/schedules_sdedit.py +0 -85
  11. video_to_video/diffusion/solvers_sdedit.py +0 -204
  12. video_to_video/modules/__init__.py +0 -3
  13. video_to_video/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  14. video_to_video/modules/__pycache__/embedder.cpython-39.pyc +0 -0
  15. video_to_video/modules/__pycache__/t5.cpython-39.pyc +0 -0
  16. video_to_video/modules/__pycache__/unet_v2v.cpython-39.pyc +0 -0
  17. video_to_video/modules/__pycache__/unet_v2v_LocalConv.cpython-39.pyc +0 -0
  18. video_to_video/modules/__pycache__/unet_v2v_deform.cpython-39.pyc +0 -0
  19. video_to_video/modules/embedder.py +0 -75
  20. video_to_video/modules/t5.py +0 -335
  21. video_to_video/modules/unet_v2v.py +0 -2332
  22. video_to_video/utils/__init__.py +0 -0
  23. video_to_video/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  24. video_to_video/utils/__pycache__/config.cpython-39.pyc +0 -0
  25. video_to_video/utils/__pycache__/logger.cpython-39.pyc +0 -0
  26. video_to_video/utils/__pycache__/seed.cpython-39.pyc +0 -0
  27. video_to_video/utils/config.py +0 -169
  28. video_to_video/utils/logger.py +0 -94
  29. video_to_video/utils/seed.py +0 -14
  30. video_to_video/video_to_video_model.py +0 -237
video_to_video/__init__.py DELETED
File without changes
video_to_video/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (153 Bytes)
 
video_to_video/__pycache__/video_to_video_model.cpython-39.pyc DELETED
Binary file (6.97 kB)
 
video_to_video/diffusion/__init__.py DELETED
File without changes
video_to_video/diffusion/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (163 Bytes)
 
video_to_video/diffusion/__pycache__/diffusion_sdedit.cpython-39.pyc DELETED
Binary file (10.4 kB)
 
video_to_video/diffusion/__pycache__/schedules_sdedit.cpython-39.pyc DELETED
Binary file (2.68 kB)
 
video_to_video/diffusion/__pycache__/solvers_sdedit.cpython-39.pyc DELETED
Binary file (6.18 kB)
 
video_to_video/diffusion/diffusion_sdedit.py DELETED
@@ -1,443 +0,0 @@
1
- import random
2
-
3
- import torch
4
-
5
- from .schedules_sdedit import karras_schedule
6
- from .solvers_sdedit import sample_dpmpp_2m_sde, sample_heun
7
-
8
- from video_to_video.utils.logger import get_logger
9
-
10
- logger = get_logger()
11
-
12
- __all__ = ['GaussianDiffusion']
13
-
14
-
15
- def _i(tensor, t, x):
16
- shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
17
- return tensor[t.to(tensor.device)].view(shape).to(x.device)
18
-
19
- class GaussianDiffusion(object):
20
-
21
- def __init__(self, sigmas):
22
- self.sigmas = sigmas
23
- self.alphas = torch.sqrt(1 - sigmas**2)
24
- self.num_timesteps = len(sigmas)
25
-
26
- def diffuse(self, x0, t, noise=None):
27
- noise = torch.randn_like(x0) if noise is None else noise
28
- xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise
29
-
30
- return xt
31
-
32
- def get_velocity(self, x0, xt, t):
33
- sigmas = _i(self.sigmas, t, xt)
34
- alphas = _i(self.alphas, t, xt)
35
- velocity = (alphas * xt - x0) / sigmas
36
- return velocity
37
-
38
- def get_x0(self, v, xt, t):
39
- sigmas = _i(self.sigmas, t, xt)
40
- alphas = _i(self.alphas, t, xt)
41
- x0 = alphas * xt - sigmas * v
42
- return x0
43
-
44
- def denoise(self,
45
- xt,
46
- t,
47
- s,
48
- model,
49
- model_kwargs={},
50
- guide_scale=None,
51
- guide_rescale=None,
52
- clamp=None,
53
- percentile=None,
54
- variant_info=None,):
55
- s = t - 1 if s is None else s
56
-
57
- # hyperparams
58
- sigmas = _i(self.sigmas, t, xt)
59
- alphas = _i(self.alphas, t, xt)
60
- alphas_s = _i(self.alphas, s.clamp(0), xt)
61
- alphas_s[s < 0] = 1.
62
- sigmas_s = torch.sqrt(1 - alphas_s**2)
63
-
64
- # precompute variables
65
- betas = 1 - (alphas / alphas_s)**2
66
- coef1 = betas * alphas_s / sigmas**2
67
- coef2 = (alphas * sigmas_s**2) / (alphas_s * sigmas**2)
68
- var = betas * (sigmas_s / sigmas)**2
69
- log_var = torch.log(var).clamp_(-20, 20)
70
-
71
- # prediction
72
- if guide_scale is None:
73
- assert isinstance(model_kwargs, dict)
74
- out = model(xt, t=t, **model_kwargs)
75
- else:
76
- # classifier-free guidance
77
- assert isinstance(model_kwargs, list)
78
- if len(model_kwargs) > 3:
79
- y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5])
80
- else:
81
- y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], variant_info=variant_info)
82
- if guide_scale == 1.:
83
- out = y_out
84
- else:
85
- if len(model_kwargs) > 3:
86
- u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5])
87
- else:
88
- u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], variant_info=variant_info)
89
- out = u_out + guide_scale * (y_out - u_out)
90
-
91
- if guide_rescale is not None:
92
- assert guide_rescale >= 0 and guide_rescale <= 1
93
- ratio = (
94
- y_out.flatten(1).std(dim=1) / # noqa
95
- (out.flatten(1).std(dim=1) + 1e-12)
96
- ).view((-1, ) + (1, ) * (y_out.ndim - 1))
97
- out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
98
-
99
- x0 = alphas * xt - sigmas * out
100
-
101
- # restrict the range of x0
102
- if percentile is not None:
103
- assert percentile > 0 and percentile <= 1
104
- s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1)
105
- s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1))
106
- x0 = torch.min(s, torch.max(-s, x0)) / s
107
- elif clamp is not None:
108
- x0 = x0.clamp(-clamp, clamp)
109
-
110
- # recompute eps using the restricted x0
111
- eps = (xt - alphas * x0) / sigmas
112
-
113
- # compute mu (mean of posterior distribution) using the restricted x0
114
- mu = coef1 * x0 + coef2 * xt
115
- return mu, var, log_var, x0, eps
116
-
117
-
118
- @torch.no_grad()
119
- def sample(self,
120
- noise,
121
- model,
122
- model_kwargs={},
123
- condition_fn=None,
124
- guide_scale=None,
125
- guide_rescale=None,
126
- clamp=None,
127
- percentile=None,
128
- solver='euler_a',
129
- solver_mode='fast',
130
- steps=20,
131
- t_max=None,
132
- t_min=None,
133
- discretization=None,
134
- discard_penultimate_step=None,
135
- return_intermediate=None,
136
- show_progress=False,
137
- seed=-1,
138
- chunk_inds=None,
139
- **kwargs):
140
- # sanity check
141
- assert isinstance(steps, (int, torch.LongTensor))
142
- assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
143
- assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
144
- assert discretization in (None, 'leading', 'linspace', 'trailing')
145
- assert discard_penultimate_step in (None, True, False)
146
- assert return_intermediate in (None, 'x0', 'xt')
147
-
148
- # function of diffusion solver
149
- solver_fn = {
150
- 'heun': sample_heun,
151
- 'dpmpp_2m_sde': sample_dpmpp_2m_sde
152
- }[solver]
153
-
154
- # options
155
- schedule = 'karras' if 'karras' in solver else None
156
- discretization = discretization or 'linspace'
157
- seed = seed if seed >= 0 else random.randint(0, 2**31)
158
- if isinstance(steps, torch.LongTensor):
159
- discard_penultimate_step = False
160
- if discard_penultimate_step is None:
161
- discard_penultimate_step = True if solver in (
162
- 'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
163
- 'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
164
-
165
- # function for denoising xt to get x0
166
- intermediates = []
167
-
168
- def model_fn(xt, sigma):
169
- # denoising
170
- t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
171
- x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
172
- guide_rescale, clamp, percentile)[-2]
173
-
174
- # collect intermediate outputs
175
- if return_intermediate == 'xt':
176
- intermediates.append(xt)
177
- elif return_intermediate == 'x0':
178
- intermediates.append(x0)
179
- return x0
180
-
181
- mask_cond = model_kwargs[3]['mask_cond']
182
- def model_chunk_fn(xt, sigma):
183
- # denoising
184
- t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
185
- O_LEN = chunk_inds[0][-1]-chunk_inds[1][0]
186
- cut_f_ind = O_LEN//2
187
-
188
- results_list = []
189
- for i in range(len(chunk_inds)):
190
- ind_start, ind_end = chunk_inds[i]
191
- xt_chunk = xt[:,:,ind_start:ind_end].clone()
192
- cur_f = xt_chunk.size(2)
193
- model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone()
194
- x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale,
195
- guide_rescale, clamp, percentile)[-2]
196
- if i == 0:
197
- results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN])
198
- elif i == len(chunk_inds)-1:
199
- results_list.append(x0_chunk[:,:,cut_f_ind:])
200
- else:
201
- results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN])
202
- x0 = torch.concat(results_list, dim=2)
203
- torch.cuda.empty_cache()
204
- return x0
205
-
206
- # get timesteps
207
- if isinstance(steps, int):
208
- steps += 1 if discard_penultimate_step else 0
209
- t_max = self.num_timesteps - 1 if t_max is None else t_max
210
- t_min = 0 if t_min is None else t_min
211
-
212
- # discretize timesteps
213
- if discretization == 'leading':
214
- steps = torch.arange(t_min, t_max + 1,
215
- (t_max - t_min + 1) / steps).flip(0)
216
- elif discretization == 'linspace':
217
- steps = torch.linspace(t_max, t_min, steps)
218
- elif discretization == 'trailing':
219
- steps = torch.arange(t_max, t_min - 1,
220
- -((t_max - t_min + 1) / steps))
221
- if solver_mode == 'fast':
222
- t_mid = 500
223
- steps1 = torch.arange(t_max, t_mid - 1,
224
- -((t_max - t_mid + 1) / 4))
225
- steps2 = torch.arange(t_mid, t_min - 1,
226
- -((t_mid - t_min + 1) / 11))
227
- steps = torch.concat([steps1, steps2])
228
- else:
229
- raise NotImplementedError(
230
- f'{discretization} discretization not implemented')
231
- steps = steps.clamp_(t_min, t_max)
232
- steps = torch.as_tensor(
233
- steps, dtype=torch.float32, device=noise.device)
234
-
235
- # get sigmas
236
- sigmas = self._t_to_sigma(steps)
237
- sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
238
- if schedule == 'karras':
239
- if sigmas[0] == float('inf'):
240
- sigmas = karras_schedule(
241
- n=len(steps) - 1,
242
- sigma_min=sigmas[sigmas > 0].min().item(),
243
- sigma_max=sigmas[sigmas < float('inf')].max().item(),
244
- rho=7.).to(sigmas)
245
- sigmas = torch.cat([
246
- sigmas.new_tensor([float('inf')]), sigmas,
247
- sigmas.new_zeros([1])
248
- ])
249
- else:
250
- sigmas = karras_schedule(
251
- n=len(steps),
252
- sigma_min=sigmas[sigmas > 0].min().item(),
253
- sigma_max=sigmas.max().item(),
254
- rho=7.).to(sigmas)
255
- sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
256
- if discard_penultimate_step:
257
- sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
258
-
259
- fn = model_chunk_fn if chunk_inds is not None else model_fn
260
- x0 = solver_fn(
261
- noise, fn, sigmas, show_progress=show_progress, **kwargs)
262
- return (x0, intermediates) if return_intermediate is not None else x0
263
-
264
- @torch.no_grad()
265
- def sample_sr(self,
266
- noise,
267
- model,
268
- model_kwargs={},
269
- condition_fn=None,
270
- guide_scale=None,
271
- guide_rescale=None,
272
- clamp=None,
273
- percentile=None,
274
- solver='euler_a',
275
- solver_mode='fast',
276
- steps=20,
277
- t_max=None,
278
- t_min=None,
279
- discretization=None,
280
- discard_penultimate_step=None,
281
- return_intermediate=None,
282
- show_progress=False,
283
- seed=-1,
284
- chunk_inds=None,
285
- variant_info=None,
286
- **kwargs):
287
- # sanity check
288
- assert isinstance(steps, (int, torch.LongTensor))
289
- assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
290
- assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
291
- assert discretization in (None, 'leading', 'linspace', 'trailing')
292
- assert discard_penultimate_step in (None, True, False)
293
- assert return_intermediate in (None, 'x0', 'xt')
294
-
295
- # function of diffusion solver
296
- solver_fn = {
297
- 'heun': sample_heun,
298
- 'dpmpp_2m_sde': sample_dpmpp_2m_sde
299
- }[solver]
300
-
301
- # options
302
- schedule = 'karras' if 'karras' in solver else None
303
- discretization = discretization or 'linspace'
304
- seed = seed if seed >= 0 else random.randint(0, 2**31)
305
- if isinstance(steps, torch.LongTensor):
306
- discard_penultimate_step = False
307
- if discard_penultimate_step is None:
308
- discard_penultimate_step = True if solver in (
309
- 'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
310
- 'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
311
-
312
- # function for denoising xt to get x0
313
- intermediates = []
314
-
315
- def model_fn(xt, sigma, variant_info=None):
316
- # denoising
317
- t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
318
- x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
319
- guide_rescale, clamp, percentile, variant_info=variant_info)[-2]
320
-
321
- # collect intermediate outputs
322
- if return_intermediate == 'xt':
323
- intermediates.append(xt)
324
- elif return_intermediate == 'x0':
325
- print('add intermediate outputs x0')
326
- intermediates.append(x0)
327
- return x0
328
-
329
- # mask_cond = model_kwargs[3]['mask_cond']
330
- def model_chunk_fn(xt, sigma, variant_info=None):
331
- # denoising
332
- t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
333
- O_LEN = chunk_inds[0][-1]-chunk_inds[1][0]
334
- cut_f_ind = O_LEN//2
335
-
336
- results_list = []
337
- for i in range(len(chunk_inds)):
338
- ind_start, ind_end = chunk_inds[i]
339
- xt_chunk = xt[:,:,ind_start:ind_end].clone()
340
- model_kwargs[2]['hint_chunk'] = model_kwargs[2]['hint'][:,:,ind_start:ind_end].clone() # new added
341
- cur_f = xt_chunk.size(2)
342
- # model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone()
343
- x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale,
344
- guide_rescale, clamp, percentile, variant_info=variant_info)[-2]
345
- if i == 0:
346
- results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN])
347
- elif i == len(chunk_inds)-1:
348
- results_list.append(x0_chunk[:,:,cut_f_ind:])
349
- else:
350
- results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN])
351
- x0 = torch.concat(results_list, dim=2)
352
- torch.cuda.empty_cache()
353
- return x0
354
-
355
- # get timesteps
356
- if isinstance(steps, int):
357
- steps += 1 if discard_penultimate_step else 0
358
- t_max = self.num_timesteps - 1 if t_max is None else t_max
359
- t_min = 0 if t_min is None else t_min
360
-
361
- # discretize timesteps
362
- if discretization == 'leading':
363
- steps = torch.arange(t_min, t_max + 1,
364
- (t_max - t_min + 1) / steps).flip(0)
365
- elif discretization == 'linspace':
366
- steps = torch.linspace(t_max, t_min, steps)
367
- elif discretization == 'trailing':
368
- steps = torch.arange(t_max, t_min - 1,
369
- -((t_max - t_min + 1) / steps))
370
- if solver_mode == 'fast':
371
- t_mid = 500
372
- steps1 = torch.arange(t_max, t_mid - 1,
373
- -((t_max - t_mid + 1) / 4))
374
- steps2 = torch.arange(t_mid, t_min - 1,
375
- -((t_mid - t_min + 1) / 11))
376
- steps = torch.concat([steps1, steps2])
377
- else:
378
- raise NotImplementedError(
379
- f'{discretization} discretization not implemented')
380
- steps = steps.clamp_(t_min, t_max)
381
- steps = torch.as_tensor(
382
- steps, dtype=torch.float32, device=noise.device)
383
-
384
- # get sigmas
385
- sigmas = self._t_to_sigma(steps)
386
- sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
387
- if schedule == 'karras':
388
- if sigmas[0] == float('inf'):
389
- sigmas = karras_schedule(
390
- n=len(steps) - 1,
391
- sigma_min=sigmas[sigmas > 0].min().item(),
392
- sigma_max=sigmas[sigmas < float('inf')].max().item(),
393
- rho=7.).to(sigmas)
394
- sigmas = torch.cat([
395
- sigmas.new_tensor([float('inf')]), sigmas,
396
- sigmas.new_zeros([1])
397
- ])
398
- else:
399
- sigmas = karras_schedule(
400
- n=len(steps),
401
- sigma_min=sigmas[sigmas > 0].min().item(),
402
- sigma_max=sigmas.max().item(),
403
- rho=7.).to(sigmas)
404
- sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
405
- if discard_penultimate_step:
406
- sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
407
-
408
-
409
- fn = model_chunk_fn if chunk_inds is not None else model_fn
410
- x0 = solver_fn(
411
- noise, fn, sigmas, variant_info=variant_info, show_progress=show_progress, **kwargs)
412
- return (x0, intermediates) if return_intermediate is not None else x0
413
-
414
-
415
- def _sigma_to_t(self, sigma):
416
- if sigma == float('inf'):
417
- t = torch.full_like(sigma, len(self.sigmas) - 1)
418
- else:
419
- log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
420
- (1 - self.sigmas**2)).log().to(sigma)
421
- log_sigma = sigma.log()
422
- dists = log_sigma - log_sigmas[:, None]
423
- low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(
424
- max=log_sigmas.shape[0] - 2)
425
- high_idx = low_idx + 1
426
- low, high = log_sigmas[low_idx], log_sigmas[high_idx]
427
- w = (low - log_sigma) / (low - high)
428
- w = w.clamp(0, 1)
429
- t = (1 - w) * low_idx + w * high_idx
430
- t = t.view(sigma.shape)
431
- if t.ndim == 0:
432
- t = t.unsqueeze(0)
433
- return t
434
-
435
- def _t_to_sigma(self, t):
436
- t = t.float()
437
- low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
438
- log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
439
- (1 - self.sigmas**2)).log().to(t)
440
- log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
441
- log_sigma[torch.isnan(log_sigma)
442
- | torch.isinf(log_sigma)] = float('inf')
443
- return log_sigma.exp()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video_to_video/diffusion/schedules_sdedit.py DELETED
@@ -1,85 +0,0 @@
1
- # Copyright (c) Alibaba, Inc. and its affiliates.
2
-
3
- import math
4
-
5
- import torch
6
-
7
-
8
- def betas_to_sigmas(betas):
9
- return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
10
-
11
-
12
- def sigmas_to_betas(sigmas):
13
- square_alphas = 1 - sigmas**2
14
- betas = 1 - torch.cat(
15
- [square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
16
- return betas
17
-
18
-
19
- def logsnrs_to_sigmas(logsnrs):
20
- return torch.sqrt(torch.sigmoid(-logsnrs))
21
-
22
-
23
- def sigmas_to_logsnrs(sigmas):
24
- square_sigmas = sigmas**2
25
- return torch.log(square_sigmas / (1 - square_sigmas))
26
-
27
-
28
- def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
29
- t_min = math.atan(math.exp(-0.5 * logsnr_min))
30
- t_max = math.atan(math.exp(-0.5 * logsnr_max))
31
- t = torch.linspace(1, 0, n)
32
- logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
33
- return logsnrs
34
-
35
-
36
- def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
37
- logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
38
- logsnrs += 2 * math.log(1 / scale)
39
- return logsnrs
40
-
41
-
42
- def _logsnr_cosine_interp(n,
43
- logsnr_min=-15,
44
- logsnr_max=15,
45
- scale_min=2,
46
- scale_max=4):
47
- t = torch.linspace(1, 0, n)
48
- logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
49
- logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
50
- logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
51
- return logsnrs
52
-
53
-
54
- def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
55
- ramp = torch.linspace(1, 0, n)
56
- min_inv_rho = sigma_min**(1 / rho)
57
- max_inv_rho = sigma_max**(1 / rho)
58
- sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
59
- sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
60
- return sigmas
61
-
62
-
63
- def logsnr_cosine_interp_schedule(n,
64
- logsnr_min=-15,
65
- logsnr_max=15,
66
- scale_min=2,
67
- scale_max=4):
68
- return logsnrs_to_sigmas(
69
- _logsnr_cosine_interp(n, logsnr_min, logsnr_max, scale_min, scale_max))
70
-
71
-
72
- def noise_schedule(schedule='logsnr_cosine_interp',
73
- n=1000,
74
- zero_terminal_snr=False,
75
- **kwargs):
76
- # compute sigmas
77
- sigmas = {
78
- 'logsnr_cosine_interp': logsnr_cosine_interp_schedule
79
- }[schedule](n, **kwargs)
80
-
81
- # post-processing
82
- if zero_terminal_snr and sigmas.max() != 1.0:
83
- scale = (1.0 - sigmas.min()) / (sigmas.max() - sigmas.min())
84
- sigmas = sigmas.min() + scale * (sigmas - sigmas.min())
85
- return sigmas
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video_to_video/diffusion/solvers_sdedit.py DELETED
@@ -1,204 +0,0 @@
1
- # Copyright (c) Alibaba, Inc. and its affiliates.
2
-
3
- import torch
4
- import torchsde
5
- from tqdm.auto import trange
6
-
7
- from video_to_video.utils.logger import get_logger
8
-
9
- logger = get_logger()
10
-
11
- def get_ancestral_step(sigma_from, sigma_to, eta=1.):
12
- """
13
- Calculates the noise level (sigma_down) to step down to and the amount
14
- of noise to add (sigma_up) when doing an ancestral sampling step.
15
- """
16
- if not eta:
17
- return sigma_to, 0.
18
- sigma_up = min(
19
- sigma_to,
20
- eta * (
21
- sigma_to**2 * # noqa
22
- (sigma_from**2 - sigma_to**2) / sigma_from**2)**0.5)
23
- sigma_down = (sigma_to**2 - sigma_up**2)**0.5
24
- return sigma_down, sigma_up
25
-
26
-
27
- def get_scalings(sigma):
28
- c_out = -sigma
29
- c_in = 1 / (sigma**2 + 1.**2)**0.5
30
- return c_out, c_in
31
-
32
-
33
- @torch.no_grad()
34
- def sample_heun(noise,
35
- model,
36
- sigmas,
37
- s_churn=0.,
38
- s_tmin=0.,
39
- s_tmax=float('inf'),
40
- s_noise=1.,
41
- show_progress=True):
42
- """
43
- Implements Algorithm 2 (Heun steps) from Karras et al. (2022).
44
- """
45
- x = noise * sigmas[0]
46
- for i in trange(len(sigmas) - 1, disable=not show_progress):
47
- gamma = 0.
48
- if s_tmin <= sigmas[i] <= s_tmax and sigmas[i] < float('inf'):
49
- gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
50
- eps = torch.randn_like(x) * s_noise
51
- sigma_hat = sigmas[i] * (gamma + 1)
52
- if gamma > 0:
53
- x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5
54
- if sigmas[i] == float('inf'):
55
- # Euler method
56
- denoised = model(noise, sigma_hat)
57
- x = denoised + sigmas[i + 1] * (gamma + 1) * noise
58
- else:
59
- _, c_in = get_scalings(sigma_hat)
60
- denoised = model(x * c_in, sigma_hat)
61
- d = (x - denoised) / sigma_hat
62
- dt = sigmas[i + 1] - sigma_hat
63
- if sigmas[i + 1] == 0:
64
- # Euler method
65
- x = x + d * dt
66
- else:
67
- # Heun's method
68
- x_2 = x + d * dt
69
- _, c_in = get_scalings(sigmas[i + 1])
70
- denoised_2 = model(x_2 * c_in, sigmas[i + 1])
71
- d_2 = (x_2 - denoised_2) / sigmas[i + 1]
72
- d_prime = (d + d_2) / 2
73
- x = x + d_prime * dt
74
- return x
75
-
76
-
77
- class BatchedBrownianTree:
78
- """
79
- A wrapper around torchsde.BrownianTree that enables batches of entropy.
80
- """
81
-
82
- def __init__(self, x, t0, t1, seed=None, **kwargs):
83
- t0, t1, self.sign = self.sort(t0, t1)
84
- w0 = kwargs.get('w0', torch.zeros_like(x))
85
- if seed is None:
86
- seed = torch.randint(0, 2**63 - 1, []).item()
87
- self.batched = True
88
- try:
89
- assert len(seed) == x.shape[0]
90
- w0 = w0[0]
91
- except TypeError:
92
- seed = [seed]
93
- self.batched = False
94
- self.trees = [
95
- torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs)
96
- for s in seed
97
- ]
98
-
99
- @staticmethod
100
- def sort(a, b):
101
- return (a, b, 1) if a < b else (b, a, -1)
102
-
103
- def __call__(self, t0, t1):
104
- t0, t1, sign = self.sort(t0, t1)
105
- w = torch.stack([tree(t0, t1) for tree in self.trees]) * (
106
- self.sign * sign)
107
- return w if self.batched else w[0]
108
-
109
-
110
- class BrownianTreeNoiseSampler:
111
- """
112
- A noise sampler backed by a torchsde.BrownianTree.
113
-
114
- Args:
115
- x (Tensor): The tensor whose shape, device and dtype to use to generate
116
- random samples.
117
- sigma_min (float): The low end of the valid interval.
118
- sigma_max (float): The high end of the valid interval.
119
- seed (int or List[int]): The random seed. If a list of seeds is
120
- supplied instead of a single integer, then the noise sampler will
121
- use one BrownianTree per batch item, each with its own seed.
122
- transform (callable): A function that maps sigma to the sampler's
123
- internal timestep.
124
- """
125
-
126
- def __init__(self,
127
- x,
128
- sigma_min,
129
- sigma_max,
130
- seed=None,
131
- transform=lambda x: x):
132
- self.transform = transform
133
- t0 = self.transform(torch.as_tensor(sigma_min))
134
- t1 = self.transform(torch.as_tensor(sigma_max))
135
- self.tree = BatchedBrownianTree(x, t0, t1, seed)
136
-
137
- def __call__(self, sigma, sigma_next):
138
- t0 = self.transform(torch.as_tensor(sigma))
139
- t1 = self.transform(torch.as_tensor(sigma_next))
140
- return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
141
-
142
-
143
- @torch.no_grad()
144
- def sample_dpmpp_2m_sde(noise,
145
- model,
146
- sigmas,
147
- eta=1.,
148
- s_noise=1.,
149
- solver_type='midpoint',
150
- show_progress=True,
151
- variant_info=None):
152
- """
153
- DPM-Solver++ (2M) SDE.
154
- """
155
- assert solver_type in {'heun', 'midpoint'}
156
-
157
- x = noise * sigmas[0]
158
- sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[
159
- sigmas < float('inf')].max()
160
- noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max)
161
- old_denoised = None
162
- h_last = None
163
-
164
- for i in trange(len(sigmas) - 1, disable=not show_progress):
165
- logger.info(f'step: {i}')
166
- if sigmas[i] == float('inf'):
167
- # Euler method
168
- denoised = model(noise, sigmas[i], variant_info=variant_info)
169
- x = denoised + sigmas[i + 1] * noise
170
- else:
171
- _, c_in = get_scalings(sigmas[i])
172
- denoised = model(x * c_in, sigmas[i], variant_info=variant_info)
173
- if sigmas[i + 1] == 0:
174
- # Denoising step
175
- x = denoised
176
- else:
177
- # DPM-Solver++(2M) SDE
178
- t, s = -sigmas[i].log(), -sigmas[i + 1].log()
179
- h = s - t
180
- eta_h = eta * h
181
-
182
- x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \
183
- (-h - eta_h).expm1().neg() * denoised
184
-
185
- if old_denoised is not None:
186
- r = h_last / h
187
- if solver_type == 'heun':
188
- x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \
189
- (1 / r) * (denoised - old_denoised)
190
- elif solver_type == 'midpoint':
191
- x = x + 0.5 * (-h - eta_h).expm1().neg() * \
192
- (1 / r) * (denoised - old_denoised)
193
-
194
- x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[
195
- i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
196
-
197
- old_denoised = denoised
198
- h_last = h
199
-
200
- if variant_info is not None and variant_info.get('type') == 'variant1':
201
- x_long, x_short = x.chunk(2, dim=0)
202
- x = x_long * (1-variant_info['alpha']) + x_short * variant_info['alpha']
203
-
204
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video_to_video/modules/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .embedder import *
2
- from .unet_v2v import *
3
- # from .unet_v2v_deform import *
 
 
 
 
video_to_video/modules/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (206 Bytes)
 
video_to_video/modules/__pycache__/embedder.cpython-39.pyc DELETED
Binary file (2.58 kB)
 
video_to_video/modules/__pycache__/t5.cpython-39.pyc DELETED
Binary file (7.07 kB)
 
video_to_video/modules/__pycache__/unet_v2v.cpython-39.pyc DELETED
Binary file (47.6 kB)
 
video_to_video/modules/__pycache__/unet_v2v_LocalConv.cpython-39.pyc DELETED
Binary file (47.8 kB)
 
video_to_video/modules/__pycache__/unet_v2v_deform.cpython-39.pyc DELETED
Binary file (48.2 kB)
 
video_to_video/modules/embedder.py DELETED
@@ -1,75 +0,0 @@
1
- # Copyright (c) Alibaba, Inc. and its affiliates.
2
-
3
- import os
4
-
5
- import numpy as np
6
- import open_clip
7
- import torch
8
- import torch.nn as nn
9
- import torchvision.transforms as T
10
-
11
-
12
- class FrozenOpenCLIPEmbedder(nn.Module):
13
- """
14
- Uses the OpenCLIP transformer encoder for text
15
- """
16
- LAYERS = ['last', 'penultimate']
17
-
18
- def __init__(self,
19
- pretrained='laion2b_s32b_b79k',
20
- arch='ViT-H-14',
21
- device='cuda',
22
- max_length=77,
23
- freeze=True,
24
- layer='penultimate'):
25
- super().__init__()
26
- assert layer in self.LAYERS
27
- model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained)
28
-
29
- del model.visual
30
- self.model = model
31
- self.device = device
32
- self.max_length = max_length
33
-
34
- if freeze:
35
- self.freeze()
36
- self.layer = layer
37
- if self.layer == 'last':
38
- self.layer_idx = 0
39
- elif self.layer == 'penultimate':
40
- self.layer_idx = 1
41
- else:
42
- raise NotImplementedError()
43
-
44
- def freeze(self):
45
- self.model = self.model.eval()
46
- for param in self.parameters():
47
- param.requires_grad = False
48
-
49
- def forward(self, text):
50
- tokens = open_clip.tokenize(text)
51
- z = self.encode_with_transformer(tokens.to(self.device))
52
- return z
53
-
54
- def encode_with_transformer(self, text):
55
- x = self.model.token_embedding(text)
56
- x = x + self.model.positional_embedding
57
- x = x.permute(1, 0, 2)
58
- x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
59
- x = x.permute(1, 0, 2)
60
- x = self.model.ln_final(x)
61
- return x
62
-
63
- def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
64
- for i, r in enumerate(self.model.transformer.resblocks):
65
- if i == len(self.model.transformer.resblocks) - self.layer_idx:
66
- break
67
- if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
68
- ):
69
- x = checkpoint(r, x, attn_mask)
70
- else:
71
- x = r(x, attn_mask=attn_mask)
72
- return x
73
-
74
- def encode(self, text):
75
- return self(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video_to_video/modules/t5.py DELETED
@@ -1,335 +0,0 @@
1
- # Adapted from PixArt
2
- #
3
- # Copyright (C) 2023 PixArt-alpha/PixArt-alpha
4
- #
5
- # This program is free software: you can redistribute it and/or modify
6
- # it under the terms of the GNU Affero General Public License as published
7
- # by the Free Software Foundation, either version 3 of the License, or
8
- # (at your option) any later version.
9
- #
10
- # This program is distributed in the hope that it will be useful,
11
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
- # GNU Affero General Public License for more details.
14
- #
15
- #
16
- # This source code is licensed under the license found in the
17
- # LICENSE file in the root directory of this source tree.
18
- # --------------------------------------------------------
19
- # References:
20
- # PixArt: https://github.com/PixArt-alpha/PixArt-alpha
21
- # T5: https://github.com/google-research/text-to-text-transfer-transformer
22
- # --------------------------------------------------------
23
-
24
- import html
25
- import re
26
-
27
- import ftfy
28
- import torch
29
- from transformers import AutoTokenizer, T5EncoderModel
30
-
31
- # from opensora.registry import MODELS
32
-
33
-
34
- class T5Embedder:
35
- def __init__(
36
- self,
37
- device,
38
- from_pretrained=None,
39
- *,
40
- cache_dir=None,
41
- hf_token=None,
42
- use_text_preprocessing=True,
43
- t5_model_kwargs=None,
44
- torch_dtype=None,
45
- use_offload_folder=None,
46
- model_max_length=120,
47
- local_files_only=False,
48
- ):
49
- self.device = torch.device(device)
50
- self.torch_dtype = torch_dtype or torch.bfloat16
51
- self.cache_dir = cache_dir
52
-
53
- if t5_model_kwargs is None:
54
- t5_model_kwargs = {
55
- "low_cpu_mem_usage": True,
56
- "torch_dtype": self.torch_dtype,
57
- }
58
-
59
- if use_offload_folder is not None:
60
- t5_model_kwargs["offload_folder"] = use_offload_folder
61
- t5_model_kwargs["device_map"] = {
62
- "shared": self.device,
63
- "encoder.embed_tokens": self.device,
64
- "encoder.block.0": self.device,
65
- "encoder.block.1": self.device,
66
- "encoder.block.2": self.device,
67
- "encoder.block.3": self.device,
68
- "encoder.block.4": self.device,
69
- "encoder.block.5": self.device,
70
- "encoder.block.6": self.device,
71
- "encoder.block.7": self.device,
72
- "encoder.block.8": self.device,
73
- "encoder.block.9": self.device,
74
- "encoder.block.10": self.device,
75
- "encoder.block.11": self.device,
76
- "encoder.block.12": "disk",
77
- "encoder.block.13": "disk",
78
- "encoder.block.14": "disk",
79
- "encoder.block.15": "disk",
80
- "encoder.block.16": "disk",
81
- "encoder.block.17": "disk",
82
- "encoder.block.18": "disk",
83
- "encoder.block.19": "disk",
84
- "encoder.block.20": "disk",
85
- "encoder.block.21": "disk",
86
- "encoder.block.22": "disk",
87
- "encoder.block.23": "disk",
88
- "encoder.final_layer_norm": "disk",
89
- "encoder.dropout": "disk",
90
- }
91
- else:
92
- t5_model_kwargs["device_map"] = {
93
- "shared": self.device,
94
- "encoder": self.device,
95
- }
96
-
97
- self.use_text_preprocessing = use_text_preprocessing
98
- self.hf_token = hf_token
99
-
100
- self.tokenizer = AutoTokenizer.from_pretrained(
101
- from_pretrained,
102
- cache_dir=cache_dir,
103
- local_files_only=local_files_only,
104
- )
105
- self.model = T5EncoderModel.from_pretrained(
106
- from_pretrained,
107
- cache_dir=cache_dir,
108
- local_files_only=local_files_only,
109
- **t5_model_kwargs,
110
- ).eval()
111
- self.model_max_length = model_max_length
112
-
113
- def get_text_embeddings(self, texts):
114
- text_tokens_and_mask = self.tokenizer(
115
- texts,
116
- max_length=self.model_max_length,
117
- padding="max_length",
118
- truncation=True,
119
- return_attention_mask=True,
120
- add_special_tokens=True,
121
- return_tensors="pt",
122
- )
123
-
124
- input_ids = text_tokens_and_mask["input_ids"].to(self.device)
125
- attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
126
- with torch.no_grad():
127
- text_encoder_embs = self.model(
128
- input_ids=input_ids,
129
- attention_mask=attention_mask,
130
- )["last_hidden_state"].detach()
131
- return text_encoder_embs, attention_mask
132
-
133
-
134
- # @MODELS.register_module("t5")
135
- class T5Encoder:
136
- def __init__(
137
- self,
138
- from_pretrained=None,
139
- model_max_length=120,
140
- device="cuda",
141
- dtype=torch.float,
142
- cache_dir=None,
143
- shardformer=False,
144
- local_files_only=False,
145
- ):
146
- assert from_pretrained is not None, "Please specify the path to the T5 model"
147
-
148
- self.t5 = T5Embedder(
149
- device=device,
150
- torch_dtype=dtype,
151
- from_pretrained=from_pretrained,
152
- cache_dir=cache_dir,
153
- model_max_length=model_max_length,
154
- local_files_only=local_files_only,
155
- )
156
- self.t5.model.to(dtype=dtype)
157
- self.y_embedder = None
158
-
159
- self.model_max_length = model_max_length
160
- self.output_dim = self.t5.model.config.d_model
161
- self.dtype = dtype
162
-
163
- if shardformer:
164
- self.shardformer_t5()
165
-
166
- def shardformer_t5(self):
167
- from colossalai.shardformer import ShardConfig, ShardFormer
168
-
169
- from opensora.acceleration.shardformer.policy.t5_encoder import T5EncoderPolicy
170
- from opensora.utils.misc import requires_grad
171
-
172
- shard_config = ShardConfig(
173
- tensor_parallel_process_group=None,
174
- pipeline_stage_manager=None,
175
- enable_tensor_parallelism=False,
176
- enable_fused_normalization=False,
177
- enable_flash_attention=False,
178
- enable_jit_fused=True,
179
- enable_sequence_parallelism=False,
180
- enable_sequence_overlap=False,
181
- )
182
- shard_former = ShardFormer(shard_config=shard_config)
183
- optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy())
184
- self.t5.model = optim_model.to(self.dtype)
185
-
186
- # ensure the weights are frozen
187
- requires_grad(self.t5.model, False)
188
-
189
- def encode(self, text):
190
- caption_embs, emb_masks = self.t5.get_text_embeddings(text)
191
- caption_embs = caption_embs[:, None]
192
- return dict(y=caption_embs, mask=emb_masks)
193
-
194
- def null(self, n):
195
- null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
196
- return null_y
197
-
198
-
199
- def basic_clean(text):
200
- text = ftfy.fix_text(text)
201
- text = html.unescape(html.unescape(text))
202
- return text.strip()
203
-
204
-
205
- BAD_PUNCT_REGEX = re.compile(
206
- r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
207
- ) # noqa
208
-
209
-
210
- def clean_caption(caption):
211
- import urllib.parse as ul
212
-
213
- from bs4 import BeautifulSoup
214
-
215
- caption = str(caption)
216
- caption = ul.unquote_plus(caption)
217
- caption = caption.strip().lower()
218
- caption = re.sub("<person>", "person", caption)
219
- # urls:
220
- caption = re.sub(
221
- r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
222
- "",
223
- caption,
224
- ) # regex for urls
225
- caption = re.sub(
226
- r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
227
- "",
228
- caption,
229
- ) # regex for urls
230
- # html:
231
- caption = BeautifulSoup(caption, features="html.parser").text
232
-
233
- # @<nickname>
234
- caption = re.sub(r"@[\w\d]+\b", "", caption)
235
-
236
- # 31C0—31EF CJK Strokes
237
- # 31F0—31FF Katakana Phonetic Extensions
238
- # 3200—32FF Enclosed CJK Letters and Months
239
- # 3300—33FF CJK Compatibility
240
- # 3400—4DBF CJK Unified Ideographs Extension A
241
- # 4DC0—4DFF Yijing Hexagram Symbols
242
- # 4E00—9FFF CJK Unified Ideographs
243
- caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
244
- caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
245
- caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
246
- caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
247
- caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
248
- caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
249
- caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
250
- #######################################################
251
-
252
- # все виды тире / all types of dash --> "-"
253
- caption = re.sub(
254
- r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
255
- "-",
256
- caption,
257
- )
258
-
259
- # кавычки к одному стандарту
260
- caption = re.sub(r"[`´«»“”¨]", '"', caption)
261
- caption = re.sub(r"[‘’]", "'", caption)
262
-
263
- # &quot;
264
- caption = re.sub(r"&quot;?", "", caption)
265
- # &amp
266
- caption = re.sub(r"&amp", "", caption)
267
-
268
- # ip adresses:
269
- caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
270
-
271
- # article ids:
272
- caption = re.sub(r"\d:\d\d\s+$", "", caption)
273
-
274
- # \n
275
- caption = re.sub(r"\\n", " ", caption)
276
-
277
- # "#123"
278
- caption = re.sub(r"#\d{1,3}\b", "", caption)
279
- # "#12345.."
280
- caption = re.sub(r"#\d{5,}\b", "", caption)
281
- # "123456.."
282
- caption = re.sub(r"\b\d{6,}\b", "", caption)
283
- # filenames:
284
- caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
285
-
286
- #
287
- caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
288
- caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
289
-
290
- caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
291
- caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
292
-
293
- # this-is-my-cute-cat / this_is_my_cute_cat
294
- regex2 = re.compile(r"(?:\-|\_)")
295
- if len(re.findall(regex2, caption)) > 3:
296
- caption = re.sub(regex2, " ", caption)
297
-
298
- caption = basic_clean(caption)
299
-
300
- caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
301
- caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
302
- caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
303
-
304
- caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
305
- caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
306
- caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
307
- caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
308
- caption = re.sub(r"\bpage\s+\d+\b", "", caption)
309
-
310
- caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
311
-
312
- caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
313
-
314
- caption = re.sub(r"\b\s+\:\s+", r": ", caption)
315
- caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
316
- caption = re.sub(r"\s+", " ", caption)
317
-
318
- caption.strip()
319
-
320
- caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
321
- caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
322
- caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
323
- caption = re.sub(r"^\.\S+$", "", caption)
324
-
325
- return caption.strip()
326
-
327
-
328
- def text_preprocessing(text, use_text_preprocessing: bool = True):
329
- if use_text_preprocessing:
330
- # The exact text cleaning as was in the training stage:
331
- text = clean_caption(text)
332
- text = clean_caption(text)
333
- return text
334
- else:
335
- return text.lower().strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video_to_video/modules/unet_v2v.py DELETED
@@ -1,2332 +0,0 @@
1
- # Copyright (c) Alibaba, Inc. and its affiliates.
2
-
3
- import math
4
- import os
5
- from abc import abstractmethod
6
-
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- import xformers
11
- import xformers.ops
12
- from einops import rearrange
13
- from fairscale.nn.checkpoint import checkpoint_wrapper
14
- from timm.models.vision_transformer import Mlp
15
-
16
-
17
- USE_TEMPORAL_TRANSFORMER = True
18
-
19
-
20
- class CaptionEmbedder(nn.Module):
21
- """
22
- Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
23
- """
24
-
25
- def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120):
26
- super().__init__()
27
- self.y_proj = Mlp(
28
- in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0
29
- )
30
- self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5))
31
- self.uncond_prob = uncond_prob
32
-
33
- def token_drop(self, caption, force_drop_ids=None):
34
- """
35
- Drops labels to enable classifier-free guidance.
36
- """
37
- if force_drop_ids is None:
38
- drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
39
- else:
40
- drop_ids = force_drop_ids == 1
41
- caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
42
- return caption
43
-
44
- def forward(self, caption, train, force_drop_ids=None):
45
- if train:
46
- assert caption.shape[2:] == self.y_embedding.shape
47
- use_dropout = self.uncond_prob > 0
48
- if (train and use_dropout) or (force_drop_ids is not None):
49
- caption = self.token_drop(caption, force_drop_ids)
50
- caption = self.y_proj(caption)
51
- return caption
52
-
53
-
54
- class DropPath(nn.Module):
55
- r"""DropPath but without rescaling and supports optional all-zero and/or all-keep.
56
- """
57
-
58
- def __init__(self, p):
59
- super(DropPath, self).__init__()
60
- self.p = p
61
-
62
- def forward(self, *args, zero=None, keep=None):
63
- if not self.training:
64
- return args[0] if len(args) == 1 else args
65
-
66
- # params
67
- x = args[0]
68
- b = x.size(0)
69
- n = (torch.rand(b) < self.p).sum()
70
-
71
- # non-zero and non-keep mask
72
- mask = x.new_ones(b, dtype=torch.bool)
73
- if keep is not None:
74
- mask[keep] = False
75
- if zero is not None:
76
- mask[zero] = False
77
-
78
- # drop-path index
79
- index = torch.where(mask)[0]
80
- index = index[torch.randperm(len(index))[:n]]
81
- if zero is not None:
82
- index = torch.cat([index, torch.where(zero)[0]], dim=0)
83
-
84
- # drop-path multiplier
85
- multiplier = x.new_ones(b)
86
- multiplier[index] = 0.0
87
- output = tuple(u * self.broadcast(multiplier, u) for u in args)
88
- return output[0] if len(args) == 1 else output
89
-
90
- def broadcast(self, src, dst):
91
- assert src.size(0) == dst.size(0)
92
- shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1)
93
- return src.view(shape)
94
-
95
-
96
- def sinusoidal_embedding(timesteps, dim):
97
- # check input
98
- half = dim // 2
99
- timesteps = timesteps.float()
100
-
101
- # compute sinusoidal embedding
102
- sinusoid = torch.outer(
103
- timesteps, torch.pow(10000,
104
- -torch.arange(half).to(timesteps).div(half)))
105
- x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
106
- if dim % 2 != 0:
107
- x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
108
- return x
109
-
110
-
111
- def exists(x):
112
- return x is not None
113
-
114
-
115
- def default(val, d):
116
- if exists(val):
117
- return val
118
- return d() if callable(d) else d
119
-
120
-
121
- def prob_mask_like(shape, prob, device):
122
- if prob == 1:
123
- return torch.ones(shape, device=device, dtype=torch.bool)
124
- elif prob == 0:
125
- return torch.zeros(shape, device=device, dtype=torch.bool)
126
- else:
127
- mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
128
- # aviod mask all, which will cause find_unused_parameters error
129
- if mask.all():
130
- mask[0] = False
131
- return mask
132
-
133
-
134
- class MemoryEfficientCrossAttention(nn.Module):
135
-
136
- def __init__(self,
137
- query_dim,
138
- context_dim=None,
139
- heads=8,
140
- dim_head=64,
141
- max_bs=16384,
142
- dropout=0.0):
143
- super().__init__()
144
- inner_dim = dim_head * heads
145
- context_dim = default(context_dim, query_dim)
146
-
147
- self.max_bs = max_bs
148
- self.heads = heads
149
- self.dim_head = dim_head
150
-
151
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
152
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
153
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
154
- self.to_out = nn.Sequential(
155
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
156
- self.attention_op: Optional[Any] = None
157
-
158
- def forward(self, x, context=None, mask=None):
159
- q = self.to_q(x)
160
- context = default(context, x)
161
- k = self.to_k(context)
162
- v = self.to_v(context)
163
-
164
- b, _, _ = q.shape
165
- q, k, v = map(
166
- lambda t: t.unsqueeze(3).reshape(b, t.shape[
167
- 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
168
- b * self.heads, t.shape[1], self.dim_head).contiguous(),
169
- (q, k, v),
170
- )
171
-
172
- # actually compute the attention, what we cannot get enough of.
173
- if q.shape[0] > self.max_bs:
174
- q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0)
175
- k_list = torch.chunk(k, k.shape[0] // self.max_bs, dim=0)
176
- v_list = torch.chunk(v, v.shape[0] // self.max_bs, dim=0)
177
- out_list = []
178
- for q_1, k_1, v_1 in zip(q_list, k_list, v_list):
179
- out = xformers.ops.memory_efficient_attention(
180
- q_1, k_1, v_1, attn_bias=None, op=self.attention_op)
181
- out_list.append(out)
182
- out = torch.cat(out_list, dim=0)
183
- else:
184
- out = xformers.ops.memory_efficient_attention(
185
- q, k, v, attn_bias=None, op=self.attention_op)
186
-
187
- if exists(mask):
188
- raise NotImplementedError
189
- out = (
190
- out.unsqueeze(0).reshape(
191
- b, self.heads, out.shape[1],
192
- self.dim_head).permute(0, 2, 1,
193
- 3).reshape(b, out.shape[1],
194
- self.heads * self.dim_head))
195
- return self.to_out(out)
196
-
197
-
198
- class RelativePositionBias(nn.Module):
199
-
200
- def __init__(self, heads=8, num_buckets=32, max_distance=128):
201
- super().__init__()
202
- self.num_buckets = num_buckets
203
- self.max_distance = max_distance
204
- self.relative_attention_bias = nn.Embedding(num_buckets, heads)
205
-
206
- @staticmethod
207
- def _relative_position_bucket(relative_position,
208
- num_buckets=32,
209
- max_distance=128):
210
- ret = 0
211
- n = -relative_position
212
-
213
- num_buckets //= 2
214
- ret += (n < 0).long() * num_buckets
215
- n = torch.abs(n)
216
-
217
- max_exact = num_buckets // 2
218
- is_small = n < max_exact
219
-
220
- val_if_large = max_exact + (
221
- torch.log(n.float() / max_exact)
222
- / math.log(max_distance / max_exact) * # noqa
223
- (num_buckets - max_exact)).long()
224
- val_if_large = torch.min(
225
- val_if_large, torch.full_like(val_if_large, num_buckets - 1))
226
-
227
- ret += torch.where(is_small, n, val_if_large)
228
- return ret
229
-
230
- def forward(self, n, device):
231
- q_pos = torch.arange(n, dtype=torch.long, device=device)
232
- k_pos = torch.arange(n, dtype=torch.long, device=device)
233
- rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
234
- rp_bucket = self._relative_position_bucket(
235
- rel_pos,
236
- num_buckets=self.num_buckets,
237
- max_distance=self.max_distance)
238
- values = self.relative_attention_bias(rp_bucket)
239
- return rearrange(values, 'i j h -> h i j')
240
-
241
-
242
- class SpatialTransformer(nn.Module):
243
- """
244
- Transformer block for image-like data.
245
- First, project the input (aka embedding)
246
- and reshape to b, t, d.
247
- Then apply standard transformer action.
248
- Finally, reshape to image
249
- NEW: use_linear for more efficiency instead of the 1x1 convs
250
- """
251
-
252
- def __init__(self,
253
- in_channels,
254
- n_heads,
255
- d_head,
256
- depth=1,
257
- dropout=0.,
258
- context_dim=None,
259
- disable_self_attn=False,
260
- use_linear=False,
261
- use_checkpoint=True,
262
- is_ctrl=False):
263
- super().__init__()
264
- if exists(context_dim) and not isinstance(context_dim, list):
265
- context_dim = [context_dim]
266
- self.in_channels = in_channels
267
- inner_dim = n_heads * d_head
268
- self.norm = torch.nn.GroupNorm(
269
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
270
- if not use_linear:
271
- self.proj_in = nn.Conv2d(
272
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
273
- else:
274
- self.proj_in = nn.Linear(in_channels, inner_dim)
275
-
276
- self.transformer_blocks = nn.ModuleList([
277
- BasicTransformerBlock(
278
- inner_dim,
279
- n_heads,
280
- d_head,
281
- dropout=dropout,
282
- context_dim=context_dim[d],
283
- disable_self_attn=disable_self_attn,
284
- checkpoint=use_checkpoint,
285
- local_type='space',
286
- is_ctrl=is_ctrl) for d in range(depth)
287
- ])
288
- if not use_linear:
289
- self.proj_out = zero_module(
290
- nn.Conv2d(
291
- inner_dim, in_channels, kernel_size=1, stride=1,
292
- padding=0))
293
- else:
294
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
295
- self.use_linear = use_linear
296
-
297
- def forward(self, x, context=None):
298
- # note: if no context is given, cross-attention defaults to self-attention
299
- if not isinstance(context, list):
300
- context = [context]
301
- _, _, h, w = x.shape
302
- # print('x shape:', x.shape) # [64, 320, 90, 160]
303
- x_in = x
304
- x = self.norm(x)
305
- if not self.use_linear:
306
- x = self.proj_in(x)
307
- x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
308
- if self.use_linear:
309
- x = self.proj_in(x)
310
- for i, block in enumerate(self.transformer_blocks):
311
- x = block(x, context=context[i], h=h, w=w)
312
- if self.use_linear:
313
- x = self.proj_out(x)
314
- x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
315
- if not self.use_linear:
316
- x = self.proj_out(x)
317
- return x + x_in
318
-
319
-
320
- _ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32')
321
-
322
-
323
- class CrossAttention(nn.Module):
324
-
325
- def __init__(self,
326
- query_dim,
327
- context_dim=None,
328
- heads=8,
329
- dim_head=64,
330
- dropout=0.):
331
- super().__init__()
332
- inner_dim = dim_head * heads
333
- context_dim = default(context_dim, query_dim)
334
-
335
- self.scale = dim_head**-0.5
336
- self.heads = heads
337
-
338
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
339
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
340
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
341
-
342
- self.to_out = nn.Sequential(
343
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
344
-
345
- def forward(self, x, context=None, mask=None):
346
- h = self.heads
347
-
348
- q = self.to_q(x)
349
- context = default(context, x)
350
- k = self.to_k(context)
351
- v = self.to_v(context)
352
-
353
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
354
- (q, k, v))
355
-
356
- # force cast to fp32 to avoid overflowing
357
- if _ATTN_PRECISION == 'fp32':
358
- with torch.autocast(enabled=False, device_type='cuda'):
359
- q, k = q.float(), k.float()
360
- sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
361
- else:
362
- sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
363
-
364
- del q, k
365
-
366
- if exists(mask):
367
- mask = rearrange(mask, 'b ... -> b (...)')
368
- max_neg_value = -torch.finfo(sim.dtype).max
369
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
370
- sim.masked_fill_(~mask, max_neg_value)
371
-
372
- # attention, what we cannot get enough of
373
- sim = sim.softmax(dim=-1)
374
-
375
- out = torch.einsum('b i j, b j d -> b i d', sim, v)
376
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
377
- return self.to_out(out)
378
-
379
-
380
-
381
-
382
- class SpatialAttention(nn.Module):
383
- def __init__(self):
384
- super(SpatialAttention, self).__init__()
385
- self.conv1 = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, padding=7 // 2, bias=False)
386
- self.sigmoid = nn.Sigmoid()
387
- def forward(self, x):
388
-
389
- max_out, _ = torch.max(x, dim=1, keepdim=True)
390
- avg_out = torch.mean(x, dim=1, keepdim=True)
391
-
392
- weight = torch.cat([max_out, avg_out], dim=1)
393
- weight = self.conv1(weight)
394
-
395
- out = self.sigmoid(weight) * x
396
- return out
397
-
398
- class TemporalLocalAttention(nn.Module): # b c t h w
399
- def __init__(self, dim, kernel_size=7):
400
- super(TemporalLocalAttention, self).__init__()
401
- self.conv1 = nn.Linear(in_features=2, out_features=1, bias=False)
402
- self.sigmoid = nn.Sigmoid()
403
-
404
- def forward(self, x):
405
-
406
- max_out, _ = torch.max(x, dim=-1, keepdim=True)
407
- avg_out = torch.mean(x, dim=-1, keepdim=True)
408
-
409
- weight = torch.cat([max_out, avg_out], dim=-1)
410
- weight = self.conv1(weight)
411
-
412
- out = self.sigmoid(weight) * x
413
- return out
414
-
415
-
416
- class BasicTransformerBlock(nn.Module):
417
-
418
- def __init__(self,
419
- dim,
420
- n_heads,
421
- d_head,
422
- dropout=0.,
423
- context_dim=None,
424
- gated_ff=True,
425
- checkpoint=True,
426
- disable_self_attn=False,
427
- local_type=None,
428
- is_ctrl=False):
429
- super().__init__()
430
- self.local_type = local_type
431
- self.is_ctrl = is_ctrl
432
- attn_cls = MemoryEfficientCrossAttention
433
- self.disable_self_attn = disable_self_attn
434
- self.attn1 = attn_cls( # self-attn
435
- query_dim=dim,
436
- heads=n_heads,
437
- dim_head=d_head,
438
- dropout=dropout,
439
- context_dim=context_dim if self.disable_self_attn else None)
440
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
441
-
442
- attn_cls2 = MemoryEfficientCrossAttention
443
-
444
- self.attn2 = attn_cls2(
445
- query_dim=dim,
446
- context_dim=context_dim,
447
- heads=n_heads,
448
- dim_head=d_head,
449
- dropout=dropout)
450
- self.norm1 = nn.LayerNorm(dim)
451
- self.norm2 = nn.LayerNorm(dim)
452
- self.norm3 = nn.LayerNorm(dim)
453
- self.checkpoint = checkpoint
454
-
455
- if self.local_type == 'space' and self.is_ctrl:
456
- self.local1 = SpatialAttention()
457
-
458
- if self.local_type == 'temp' and self.is_ctrl:
459
- self.local1 = TemporalLocalAttention(dim=dim)
460
- self.local2 = TemporalLocalAttention(dim=dim)
461
-
462
- def forward_(self, x, context=None):
463
- return checkpoint(self._forward, (x, context), self.parameters(),
464
- self.checkpoint)
465
-
466
- def forward(self, x, context=None, h=None, w=None):
467
-
468
- if self.local_type == 'space' and self.is_ctrl: # [b*t,(hw), c]
469
-
470
- x_local = rearrange(x, 'b (h w) c -> b c h w', h=h)
471
- x_local = self.local1(x_local)
472
- x_local = rearrange(x_local, 'b c h w -> b (h w) c')
473
-
474
- x = self.attn1(
475
- self.norm1(x_local),
476
- context=context if self.disable_self_attn else None) + x
477
-
478
- x = self.attn2(self.norm2(x), context=context) + x # cross attention or self-attention
479
- x = self.ff(self.norm3(x)) + x
480
-
481
- if self.local_type == 'temp' and self.is_ctrl:
482
-
483
- # x_local = rearrange(x, '(b h w) t c -> b c t h w', h=h, w=w)
484
- x_local = self.local1(x)
485
-
486
- x = self.attn1(
487
- self.norm1(x_local),
488
- context=context if self.disable_self_attn else None) + x
489
-
490
- # x_local = rearrange(x, '(b h w) t c -> b c t h w', h=h, w=w)
491
- x_local = self.local2(x)
492
-
493
- x = self.attn2(self.norm2(x_local), context=context) + x
494
- x = self.ff(self.norm3(x)) + x
495
-
496
- # elif self.local_type == 'space' and self.is_ctrl:
497
- # # print('*** use original attention ***')
498
- # x = self.attn1(
499
- # self.norm1(x),
500
- # context=context if self.disable_self_attn else None) + x # self-attention
501
-
502
- # x = self.attn2(self.norm2(x), context=context) + x # cross attention or self-attention
503
- # x = self.ff(self.norm3(x)) + x
504
-
505
- return x
506
-
507
-
508
- # feedforward
509
- class GEGLU(nn.Module):
510
-
511
- def __init__(self, dim_in, dim_out):
512
- super().__init__()
513
- self.proj = nn.Linear(dim_in, dim_out * 2)
514
-
515
- def forward(self, x):
516
- x, gate = self.proj(x).chunk(2, dim=-1)
517
- return x * F.gelu(gate)
518
-
519
-
520
- def zero_module(module):
521
- """
522
- Zero out the parameters of a module and return it.
523
- """
524
- for p in module.parameters():
525
- p.detach().zero_()
526
- return module
527
-
528
-
529
- class FeedForward(nn.Module):
530
-
531
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
532
- super().__init__()
533
- inner_dim = int(dim * mult)
534
- dim_out = default(dim_out, dim)
535
- project_in = nn.Sequential(nn.Linear(
536
- dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
537
-
538
- self.net = nn.Sequential(project_in, nn.Dropout(dropout),
539
- nn.Linear(inner_dim, dim_out))
540
-
541
- def forward(self, x):
542
- return self.net(x)
543
-
544
-
545
- class Upsample(nn.Module):
546
- """
547
- An upsampling layer with an optional convolution.
548
- :param channels: channels in the inputs and outputs.
549
- :param use_conv: a bool determining if a convolution is applied.
550
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
551
- upsampling occurs in the inner-two dimensions.
552
- """
553
-
554
- def __init__(self,
555
- channels,
556
- use_conv,
557
- dims=2,
558
- out_channels=None,
559
- padding=1):
560
- super().__init__()
561
- self.channels = channels
562
- self.out_channels = out_channels or channels
563
- self.use_conv = use_conv
564
- self.dims = dims
565
- if use_conv:
566
- self.conv = nn.Conv2d(
567
- self.channels, self.out_channels, 3, padding=padding)
568
-
569
- def forward(self, x):
570
- assert x.shape[1] == self.channels
571
- if self.dims == 3:
572
- x = F.interpolate(
573
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
574
- mode='nearest')
575
- else:
576
- x = F.interpolate(x, scale_factor=2, mode='nearest')
577
- x = x[..., 1:-1, :]
578
- if self.use_conv:
579
- x = self.conv(x)
580
- return x
581
-
582
-
583
- class ResBlock(nn.Module):
584
- """
585
- A residual block that can optionally change the number of channels.
586
- :param channels: the number of input channels.
587
- :param emb_channels: the number of timestep embedding channels.
588
- :param dropout: the rate of dropout.
589
- :param out_channels: if specified, the number of out channels.
590
- :param use_conv: if True and out_channels is specified, use a spatial
591
- convolution instead of a smaller 1x1 convolution to change the
592
- channels in the skip connection.
593
- :param dims: determines if the signal is 1D, 2D, or 3D.
594
- :param use_checkpoint: if True, use gradient checkpointing on this module.
595
- :param up: if True, use this block for upsampling.
596
- :param down: if True, use this block for downsampling.
597
- """
598
-
599
- def __init__(
600
- self,
601
- channels,
602
- emb_channels,
603
- dropout,
604
- out_channels=None,
605
- use_conv=False,
606
- use_scale_shift_norm=False,
607
- dims=2,
608
- up=False,
609
- down=False,
610
- use_temporal_conv=True,
611
- use_image_dataset=False,
612
- ):
613
- super().__init__()
614
- self.channels = channels
615
- self.emb_channels = emb_channels
616
- self.dropout = dropout
617
- self.out_channels = out_channels or channels
618
- self.use_conv = use_conv
619
- self.use_scale_shift_norm = use_scale_shift_norm
620
- self.use_temporal_conv = use_temporal_conv
621
-
622
- self.in_layers = nn.Sequential(
623
- nn.GroupNorm(32, channels),
624
- nn.SiLU(),
625
- nn.Conv2d(channels, self.out_channels, 3, padding=1),
626
- )
627
-
628
- self.updown = up or down
629
-
630
- if up:
631
- self.h_upd = Upsample(channels, False, dims)
632
- self.x_upd = Upsample(channels, False, dims)
633
- elif down:
634
- self.h_upd = Downsample(channels, False, dims)
635
- self.x_upd = Downsample(channels, False, dims)
636
- else:
637
- self.h_upd = self.x_upd = nn.Identity()
638
-
639
- self.emb_layers = nn.Sequential(
640
- nn.SiLU(),
641
- nn.Linear(
642
- emb_channels,
643
- 2 * self.out_channels
644
- if use_scale_shift_norm else self.out_channels,
645
- ),
646
- )
647
- self.out_layers = nn.Sequential(
648
- nn.GroupNorm(32, self.out_channels),
649
- nn.SiLU(),
650
- nn.Dropout(p=dropout),
651
- zero_module(
652
- nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
653
- )
654
-
655
- if self.out_channels == channels:
656
- self.skip_connection = nn.Identity()
657
- elif use_conv:
658
- self.skip_connection = conv_nd(
659
- dims, channels, self.out_channels, 3, padding=1)
660
- else:
661
- self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
662
-
663
- if self.use_temporal_conv:
664
- self.temopral_conv = TemporalConvBlock_v2(
665
- self.out_channels,
666
- self.out_channels,
667
- dropout=0.1,
668
- use_image_dataset=use_image_dataset)
669
-
670
- def forward(self, x, emb, batch_size, variant_info=None):
671
- """
672
- Apply the block to a Tensor, conditioned on a timestep embedding.
673
- :param x: an [N x C x ...] Tensor of features.
674
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
675
- :return: an [N x C x ...] Tensor of outputs.
676
- """
677
- return self._forward(x, emb, batch_size, variant_info)
678
-
679
- def _forward(self, x, emb, batch_size, variant_info):
680
- if self.updown:
681
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
682
- h = in_rest(x)
683
- h = self.h_upd(h)
684
- x = self.x_upd(x)
685
- h = in_conv(h)
686
- else:
687
- h = self.in_layers(x)
688
- emb_out = self.emb_layers(emb).type(h.dtype)
689
- while len(emb_out.shape) < len(h.shape):
690
- emb_out = emb_out[..., None]
691
- if self.use_scale_shift_norm:
692
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
693
- scale, shift = th.chunk(emb_out, 2, dim=1)
694
- h = out_norm(h) * (1 + scale) + shift
695
- h = out_rest(h)
696
- else:
697
- h = h + emb_out
698
- h = self.out_layers(h)
699
- h = self.skip_connection(x) + h
700
-
701
- if self.use_temporal_conv:
702
- h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size)
703
- h = self.temopral_conv(h, variant_info=variant_info)
704
- h = rearrange(h, 'b c f h w -> (b f) c h w')
705
- return h
706
-
707
-
708
- class Downsample(nn.Module):
709
- """
710
- A downsampling layer with an optional convolution.
711
- :param channels: channels in the inputs and outputs.
712
- :param use_conv: a bool determining if a convolution is applied.
713
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
714
- downsampling occurs in the inner-two dimensions.
715
- """
716
-
717
- def __init__(self,
718
- channels,
719
- use_conv,
720
- dims=2,
721
- out_channels=None,
722
- padding=(2, 1)):
723
- super().__init__()
724
- self.channels = channels
725
- self.out_channels = out_channels or channels
726
- self.use_conv = use_conv
727
- self.dims = dims
728
- stride = 2 if dims != 3 else (1, 2, 2)
729
- if use_conv:
730
- self.op = nn.Conv2d(
731
- self.channels,
732
- self.out_channels,
733
- 3,
734
- stride=stride,
735
- padding=padding)
736
- else:
737
- assert self.channels == self.out_channels
738
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
739
-
740
- def forward(self, x):
741
- assert x.shape[1] == self.channels
742
- return self.op(x)
743
-
744
-
745
- class Resample(nn.Module):
746
-
747
- def __init__(self, in_dim, out_dim, mode):
748
- assert mode in ['none', 'upsample', 'downsample']
749
- super(Resample, self).__init__()
750
- self.in_dim = in_dim
751
- self.out_dim = out_dim
752
- self.mode = mode
753
-
754
- def forward(self, x, reference=None):
755
- if self.mode == 'upsample':
756
- assert reference is not None
757
- x = F.interpolate(x, size=reference.shape[-2:], mode='nearest')
758
- elif self.mode == 'downsample':
759
- x = F.adaptive_avg_pool2d(
760
- x, output_size=tuple(u // 2 for u in x.shape[-2:]))
761
- return x
762
-
763
-
764
- class ResidualBlock(nn.Module):
765
-
766
- def __init__(self,
767
- in_dim,
768
- embed_dim,
769
- out_dim,
770
- use_scale_shift_norm=True,
771
- mode='none',
772
- dropout=0.0):
773
- super(ResidualBlock, self).__init__()
774
- self.in_dim = in_dim
775
- self.embed_dim = embed_dim
776
- self.out_dim = out_dim
777
- self.use_scale_shift_norm = use_scale_shift_norm
778
- self.mode = mode
779
-
780
- # layers
781
- self.layer1 = nn.Sequential(
782
- nn.GroupNorm(32, in_dim), nn.SiLU(),
783
- nn.Conv2d(in_dim, out_dim, 3, padding=1))
784
- self.resample = Resample(in_dim, in_dim, mode)
785
- self.embedding = nn.Sequential(
786
- nn.SiLU(),
787
- nn.Linear(embed_dim,
788
- out_dim * 2 if use_scale_shift_norm else out_dim))
789
- self.layer2 = nn.Sequential(
790
- nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
791
- nn.Conv2d(out_dim, out_dim, 3, padding=1))
792
- self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
793
- in_dim, out_dim, 1)
794
-
795
- # zero out the last layer params
796
- nn.init.zeros_(self.layer2[-1].weight)
797
-
798
- def forward(self, x, e, reference=None):
799
- identity = self.resample(x, reference)
800
- x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference))
801
- e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
802
- if self.use_scale_shift_norm:
803
- scale, shift = e.chunk(2, dim=1)
804
- x = self.layer2[0](x) * (1 + scale) + shift
805
- x = self.layer2[1:](x)
806
- else:
807
- x = x + e
808
- x = self.layer2(x)
809
- x = x + self.shortcut(identity)
810
- return x
811
-
812
-
813
- class AttentionBlock(nn.Module):
814
-
815
- def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
816
- # consider head_dim first, then num_heads
817
- num_heads = dim // head_dim if head_dim else num_heads
818
- head_dim = dim // num_heads
819
- assert num_heads * head_dim == dim
820
- super(AttentionBlock, self).__init__()
821
- self.dim = dim
822
- self.context_dim = context_dim
823
- self.num_heads = num_heads
824
- self.head_dim = head_dim
825
- self.scale = math.pow(head_dim, -0.25)
826
-
827
- # layers
828
- self.norm = nn.GroupNorm(32, dim)
829
- self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
830
- if context_dim is not None:
831
- self.context_kv = nn.Linear(context_dim, dim * 2)
832
- self.proj = nn.Conv2d(dim, dim, 1)
833
-
834
- # zero out the last layer params
835
- nn.init.zeros_(self.proj.weight)
836
-
837
- def forward(self, x, context=None):
838
- r"""x: [B, C, H, W].
839
- context: [B, L, C] or None.
840
- """
841
- identity = x
842
- b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
843
-
844
- # compute query, key, value
845
- x = self.norm(x)
846
- q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
847
- if context is not None:
848
- ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
849
- d).permute(0, 2, 3,
850
- 1).chunk(
851
- 2, dim=1)
852
- k = torch.cat([ck, k], dim=-1)
853
- v = torch.cat([cv, v], dim=-1)
854
-
855
- # compute attention
856
- attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
857
- attn = F.softmax(attn, dim=-1)
858
-
859
- # gather context
860
- x = torch.matmul(v, attn.transpose(-1, -2))
861
- x = x.reshape(b, c, h, w)
862
-
863
- # output
864
- x = self.proj(x)
865
- return x + identity
866
-
867
-
868
- class TemporalAttentionBlock(nn.Module):
869
-
870
- def __init__(self,
871
- dim,
872
- heads=4,
873
- dim_head=32,
874
- rotary_emb=None,
875
- use_image_dataset=False,
876
- use_sim_mask=False):
877
- super().__init__()
878
- # consider num_heads first, as pos_bias needs fixed num_heads
879
- dim_head = dim // heads
880
- assert heads * dim_head == dim
881
- self.use_image_dataset = use_image_dataset
882
- self.use_sim_mask = use_sim_mask
883
-
884
- self.scale = dim_head**-0.5
885
- self.heads = heads
886
- hidden_dim = dim_head * heads
887
-
888
- self.norm = nn.GroupNorm(32, dim)
889
- self.rotary_emb = rotary_emb
890
- self.to_qkv = nn.Linear(dim, hidden_dim * 3)
891
- self.to_out = nn.Linear(hidden_dim, dim)
892
-
893
- def forward(self,
894
- x,
895
- pos_bias=None,
896
- focus_present_mask=None,
897
- video_mask=None):
898
-
899
- identity = x
900
- n, height, device = x.shape[2], x.shape[-2], x.device
901
-
902
- x = self.norm(x)
903
- x = rearrange(x, 'b c f h w -> b (h w) f c')
904
-
905
- qkv = self.to_qkv(x).chunk(3, dim=-1)
906
-
907
- if exists(focus_present_mask) and focus_present_mask.all():
908
- # if all batch samples are focusing on present
909
- # it would be equivalent to passing that token's values (v=qkv[-1]) through to the output
910
- values = qkv[-1]
911
- out = self.to_out(values)
912
- out = rearrange(out, 'b (h w) f c -> b c f h w', h=height)
913
-
914
- return out + identity
915
-
916
- # split out heads
917
- q = rearrange(qkv[0], '... n (h d) -> ... h n d', h=self.heads)
918
- k = rearrange(qkv[1], '... n (h d) -> ... h n d', h=self.heads)
919
- v = rearrange(qkv[2], '... n (h d) -> ... h n d', h=self.heads)
920
-
921
- # scale
922
-
923
- q = q * self.scale
924
-
925
- # rotate positions into queries and keys for time attention
926
- if exists(self.rotary_emb):
927
- q = self.rotary_emb.rotate_queries_or_keys(q)
928
- k = self.rotary_emb.rotate_queries_or_keys(k)
929
-
930
- # similarity
931
- # shape [b (hw) h n n], n=f
932
- sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k)
933
-
934
- # relative positional bias
935
-
936
- if exists(pos_bias):
937
- sim = sim + pos_bias
938
-
939
- if (focus_present_mask is None and video_mask is not None):
940
- # video_mask: [B, n]
941
- mask = video_mask[:, None, :] * video_mask[:, :, None]
942
- mask = mask.unsqueeze(1).unsqueeze(1)
943
- sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
944
- elif exists(focus_present_mask) and not (~focus_present_mask).all():
945
- attend_all_mask = torch.ones((n, n),
946
- device=device,
947
- dtype=torch.bool)
948
- attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)
949
-
950
- mask = torch.where(
951
- rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
952
- rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
953
- rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
954
- )
955
-
956
- sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
957
-
958
- if self.use_sim_mask:
959
- sim_mask = torch.tril(
960
- torch.ones((n, n), device=device, dtype=torch.bool),
961
- diagonal=0)
962
- sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max)
963
-
964
- # numerical stability
965
- sim = sim - sim.amax(dim=-1, keepdim=True).detach()
966
- attn = sim.softmax(dim=-1)
967
-
968
- # aggregate values
969
-
970
- out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v)
971
- out = rearrange(out, '... h n d -> ... n (h d)')
972
- out = self.to_out(out)
973
-
974
- out = rearrange(out, 'b (h w) f c -> b c f h w', h=height)
975
-
976
- if self.use_image_dataset:
977
- out = identity + 0 * out
978
- else:
979
- out = identity + out
980
- return out
981
-
982
-
983
- class TemporalTransformer(nn.Module):
984
- """
985
- Transformer block for image-like data.
986
- First, project the input (aka embedding)
987
- and reshape to b, t, d.
988
- Then apply standard transformer action.
989
- Finally, reshape to image
990
- """
991
-
992
- def __init__(self,
993
- in_channels,
994
- n_heads,
995
- d_head,
996
- depth=1,
997
- dropout=0.,
998
- context_dim=None,
999
- disable_self_attn=False,
1000
- use_linear=False,
1001
- use_checkpoint=True,
1002
- only_self_att=True,
1003
- multiply_zero=False,
1004
- is_ctrl=False):
1005
- super().__init__()
1006
- self.multiply_zero = multiply_zero
1007
- self.only_self_att = only_self_att
1008
- self.use_adaptor = False
1009
- if self.only_self_att:
1010
- context_dim = None
1011
- if not isinstance(context_dim, list):
1012
- context_dim = [context_dim]
1013
- self.in_channels = in_channels
1014
- inner_dim = n_heads * d_head
1015
- self.norm = torch.nn.GroupNorm(
1016
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
1017
- if not use_linear:
1018
- self.proj_in = nn.Conv1d(
1019
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
1020
- else:
1021
- self.proj_in = nn.Linear(in_channels, inner_dim)
1022
- if self.use_adaptor:
1023
- self.adaptor_in = nn.Linear(frames, frames)
1024
-
1025
- self.transformer_blocks = nn.ModuleList([
1026
- BasicTransformerBlock(
1027
- inner_dim,
1028
- n_heads,
1029
- d_head,
1030
- dropout=dropout,
1031
- context_dim=context_dim[d],
1032
- checkpoint=use_checkpoint,
1033
- local_type='temp',
1034
- is_ctrl=is_ctrl) for d in range(depth)
1035
- ])
1036
- if not use_linear:
1037
- self.proj_out = zero_module(
1038
- nn.Conv1d(
1039
- inner_dim, in_channels, kernel_size=1, stride=1,
1040
- padding=0))
1041
- else:
1042
- self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
1043
- if self.use_adaptor:
1044
- self.adaptor_out = nn.Linear(frames, frames)
1045
- self.use_linear = use_linear
1046
-
1047
- def forward(self, x, context=None):
1048
- # note: if no context is given, cross-attention defaults to self-attention
1049
- if self.only_self_att:
1050
- context = None
1051
- if not isinstance(context, list):
1052
- context = [context]
1053
- b, _, _, h, w = x.shape
1054
- x_in = x
1055
- x = self.norm(x)
1056
-
1057
- if not self.use_linear:
1058
- x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
1059
- x = self.proj_in(x)
1060
- if self.use_linear:
1061
- x = rearrange(
1062
- x, 'b c f h w -> (b h w) f c').contiguous()
1063
- x = self.proj_in(x)
1064
- x = rearrange(
1065
- x, 'bhw f c -> bhw c f').contiguous()
1066
-
1067
- # print('x shape:', x.shape) # [28800, 512, 32]
1068
- if self.only_self_att: # no cross-attention
1069
- x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
1070
- for i, block in enumerate(self.transformer_blocks):
1071
- x = block(x, h=h, w=w)
1072
- # print('x shape:', x.shape) # [43200, 32, 512]
1073
- x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
1074
- else:
1075
- x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
1076
- for i, block in enumerate(self.transformer_blocks):
1077
- context[i] = rearrange(
1078
- context[i], '(b f) l con -> b f l con',
1079
- f=self.frames).contiguous()
1080
- # calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
1081
- for j in range(b):
1082
- context_i_j = repeat(
1083
- context[i][j],
1084
- 'f l con -> (f r) l con',
1085
- r=(h * w) // self.frames,
1086
- f=self.frames).contiguous()
1087
- x[j] = block(x[j], context=context_i_j)
1088
-
1089
- if self.use_linear:
1090
- x = rearrange(x, 'b hw f c -> (b hw) f c').contiguous()
1091
- x = self.proj_out(x)
1092
- x = rearrange(
1093
- x, '(b h w) f c -> b c f h w', b=b, h=h, w=w).contiguous()
1094
- if not self.use_linear:
1095
- # print('x shape:', x.shape) # [2, 21600, 32, 512]
1096
- x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
1097
- x = self.proj_out(x)
1098
- x = rearrange(
1099
- x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
1100
-
1101
- if self.multiply_zero:
1102
- x = 0.0 * x + x_in
1103
- else:
1104
- x = x + x_in
1105
- return x
1106
-
1107
-
1108
- class TemporalAttentionMultiBlock(nn.Module):
1109
-
1110
- def __init__(
1111
- self,
1112
- dim,
1113
- heads=4,
1114
- dim_head=32,
1115
- rotary_emb=None,
1116
- use_image_dataset=False,
1117
- use_sim_mask=False,
1118
- temporal_attn_times=1,
1119
- ):
1120
- super().__init__()
1121
- self.att_layers = nn.ModuleList([
1122
- TemporalAttentionBlock(dim, heads, dim_head, rotary_emb,
1123
- use_image_dataset, use_sim_mask)
1124
- for _ in range(temporal_attn_times)
1125
- ])
1126
-
1127
- def forward(self,
1128
- x,
1129
- pos_bias=None,
1130
- focus_present_mask=None,
1131
- video_mask=None):
1132
- for layer in self.att_layers:
1133
- x = layer(x, pos_bias, focus_present_mask, video_mask)
1134
- return x
1135
-
1136
-
1137
- class InitTemporalConvBlock(nn.Module):
1138
-
1139
- def __init__(self,
1140
- in_dim,
1141
- out_dim=None,
1142
- dropout=0.0,
1143
- use_image_dataset=False):
1144
- super(InitTemporalConvBlock, self).__init__()
1145
- if out_dim is None:
1146
- out_dim = in_dim
1147
- self.in_dim = in_dim
1148
- self.out_dim = out_dim
1149
- self.use_image_dataset = use_image_dataset
1150
-
1151
- # conv layers
1152
- self.conv = nn.Sequential(
1153
- nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
1154
- nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
1155
-
1156
- # zero out the last layer params,so the conv block is identity
1157
- nn.init.zeros_(self.conv[-1].weight)
1158
- nn.init.zeros_(self.conv[-1].bias)
1159
-
1160
- def forward(self, x):
1161
- identity = x
1162
- x = self.conv(x)
1163
- if self.use_image_dataset:
1164
- x = identity + 0 * x
1165
- else:
1166
- x = identity + x
1167
- return x
1168
-
1169
-
1170
- class TemporalConvBlock(nn.Module):
1171
-
1172
- def __init__(self,
1173
- in_dim,
1174
- out_dim=None,
1175
- dropout=0.0,
1176
- use_image_dataset=False):
1177
- super(TemporalConvBlock, self).__init__()
1178
- if out_dim is None:
1179
- out_dim = in_dim
1180
- self.in_dim = in_dim
1181
- self.out_dim = out_dim
1182
- self.use_image_dataset = use_image_dataset
1183
-
1184
- # conv layers
1185
- self.conv1 = nn.Sequential(
1186
- nn.GroupNorm(32, in_dim), nn.SiLU(),
1187
- nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)))
1188
- self.conv2 = nn.Sequential(
1189
- nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
1190
- nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
1191
-
1192
- # zero out the last layer params,so the conv block is identity
1193
- nn.init.zeros_(self.conv2[-1].weight)
1194
- nn.init.zeros_(self.conv2[-1].bias)
1195
-
1196
- def forward(self, x):
1197
- identity = x
1198
- x = self.conv1(x)
1199
- x = self.conv2(x)
1200
- if self.use_image_dataset:
1201
- x = identity + 0 * x
1202
- else:
1203
- x = identity + x
1204
- return x
1205
-
1206
-
1207
- class TemporalConvBlock_v2(nn.Module):
1208
-
1209
- def __init__(self,
1210
- in_dim,
1211
- out_dim=None,
1212
- dropout=0.0,
1213
- use_image_dataset=False):
1214
- super(TemporalConvBlock_v2, self).__init__()
1215
- if out_dim is None:
1216
- out_dim = in_dim
1217
- self.in_dim = in_dim
1218
- self.out_dim = out_dim
1219
- self.use_image_dataset = use_image_dataset
1220
-
1221
- # conv layers
1222
- self.conv1 = nn.Sequential(
1223
- nn.GroupNorm(32, in_dim), nn.SiLU(),
1224
- nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)))
1225
- self.conv2 = nn.Sequential(
1226
- nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
1227
- nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
1228
- self.conv3 = nn.Sequential(
1229
- nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
1230
- nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
1231
- self.conv4 = nn.Sequential(
1232
- nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
1233
- nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
1234
-
1235
- # zero out the last layer params,so the conv block is identity
1236
- nn.init.zeros_(self.conv4[-1].weight)
1237
- nn.init.zeros_(self.conv4[-1].bias)
1238
-
1239
- def forward(self, x, variant_info=None):
1240
- if variant_info is not None and variant_info.get('type') == 'variant2':
1241
- # print(x.shape) # torch.Size([1, 320, 32, 90, 160])
1242
- _, _, f, _, _ = x.shape
1243
- assert f % 4 == 0, "f must be divisible by 4"
1244
- x_short = rearrange(x, "b c (n s) h w -> (n b) c s h w", n=4)
1245
- x_short = self.conv1(x_short)
1246
- x_short = self.conv2(x_short)
1247
- x_short = self.conv3(x_short)
1248
- x_short = self.conv4(x_short)
1249
- x_short = rearrange(x_short, "(n b) c s h w -> b c (n s) h w", n=4)
1250
-
1251
- identity = x
1252
- x = self.conv1(x)
1253
- x = self.conv2(x)
1254
- x = self.conv3(x)
1255
- x = self.conv4(x)
1256
-
1257
- x = x * (1-variant_info['alpha']) + x_short * variant_info['alpha']
1258
-
1259
-
1260
- elif variant_info is not None and variant_info.get('type') == 'variant1':
1261
- identity = x
1262
- x_long, x_short = x.chunk(2, dim=0)
1263
-
1264
- x_short = rearrange(x_short, "b c (n s) h w -> (n b) c s h w", n=4)
1265
- x_short = self.conv1(x_short)
1266
- x_short = self.conv2(x_short)
1267
- x_short = self.conv3(x_short)
1268
- x_short = self.conv4(x_short)
1269
- x_short = rearrange(x_short, "(n b) c s h w -> b c (n s) h w", n=4)
1270
-
1271
- x_long = self.conv1(x_long)
1272
- x_long = self.conv2(x_long)
1273
- x_long = self.conv3(x_long)
1274
- x_long = self.conv4(x_long)
1275
-
1276
- x = torch.cat([x_long, x_short], dim=0)
1277
-
1278
-
1279
- elif variant_info is None:
1280
- identity = x
1281
- x = self.conv1(x)
1282
- x = self.conv2(x)
1283
- x = self.conv3(x)
1284
- x = self.conv4(x)
1285
-
1286
-
1287
- if self.use_image_dataset:
1288
- x = identity + 0.0 * x
1289
- else:
1290
- x = identity + x
1291
- return x
1292
-
1293
-
1294
- class Vid2VidSDUNet(nn.Module):
1295
-
1296
- def __init__(self,
1297
- in_dim=4,
1298
- dim=320,
1299
- y_dim=1024,
1300
- context_dim=1024,
1301
- out_dim=4,
1302
- dim_mult=[1, 2, 4, 4],
1303
- num_heads=8,
1304
- head_dim=64,
1305
- num_res_blocks=2,
1306
- attn_scales=[1 / 1, 1 / 2, 1 / 4],
1307
- use_scale_shift_norm=True,
1308
- dropout=0.1,
1309
- temporal_attn_times=1,
1310
- temporal_attention=True,
1311
- use_checkpoint=True,
1312
- use_image_dataset=False,
1313
- use_fps_condition=False,
1314
- use_sim_mask=False,
1315
- training=False,
1316
- inpainting=True):
1317
- embed_dim = dim * 4
1318
- num_heads = num_heads if num_heads else dim // 32
1319
- super(Vid2VidSDUNet, self).__init__()
1320
- self.in_dim = in_dim
1321
- self.dim = dim
1322
- self.y_dim = y_dim
1323
- self.context_dim = context_dim
1324
- self.embed_dim = embed_dim
1325
- self.out_dim = out_dim
1326
- self.dim_mult = dim_mult
1327
- # for temporal attention
1328
- self.num_heads = num_heads
1329
- # for spatial attention
1330
- self.head_dim = head_dim
1331
- self.num_res_blocks = num_res_blocks
1332
- self.attn_scales = attn_scales
1333
- self.use_scale_shift_norm = use_scale_shift_norm
1334
- self.temporal_attn_times = temporal_attn_times
1335
- self.temporal_attention = temporal_attention
1336
- self.use_checkpoint = use_checkpoint
1337
- self.use_image_dataset = use_image_dataset
1338
- self.use_fps_condition = use_fps_condition
1339
- self.use_sim_mask = use_sim_mask
1340
- self.training = training
1341
- self.inpainting = inpainting
1342
-
1343
- use_linear_in_temporal = False
1344
- transformer_depth = 1
1345
- disabled_sa = False
1346
- # params
1347
- enc_dims = [dim * u for u in [1] + dim_mult]
1348
- dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
1349
- shortcut_dims = []
1350
- scale = 1.0
1351
-
1352
- # embeddings
1353
- self.time_embed = nn.Sequential(
1354
- nn.Linear(dim, embed_dim), nn.SiLU(),
1355
- nn.Linear(embed_dim, embed_dim))
1356
-
1357
- if self.use_fps_condition:
1358
- self.fps_embedding = nn.Sequential(
1359
- nn.Linear(dim, embed_dim), nn.SiLU(),
1360
- nn.Linear(embed_dim, embed_dim))
1361
- nn.init.zeros_(self.fps_embedding[-1].weight)
1362
- nn.init.zeros_(self.fps_embedding[-1].bias)
1363
-
1364
- # encoder
1365
- self.input_blocks = nn.ModuleList()
1366
- init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
1367
- # need an initial temporal attention?
1368
- if temporal_attention:
1369
- if USE_TEMPORAL_TRANSFORMER:
1370
- init_block.append(
1371
- TemporalTransformer(
1372
- dim,
1373
- num_heads,
1374
- head_dim,
1375
- depth=transformer_depth,
1376
- context_dim=context_dim,
1377
- disable_self_attn=disabled_sa,
1378
- use_linear=use_linear_in_temporal,
1379
- multiply_zero=use_image_dataset,
1380
- is_ctrl=True
1381
- ))
1382
- else:
1383
- init_block.append(
1384
- TemporalAttentionMultiBlock(
1385
- dim,
1386
- num_heads,
1387
- head_dim,
1388
- rotary_emb=self.rotary_emb,
1389
- temporal_attn_times=temporal_attn_times,
1390
- use_image_dataset=use_image_dataset))
1391
- self.input_blocks.append(init_block)
1392
- shortcut_dims.append(dim)
1393
- for i, (in_dim,
1394
- out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
1395
- for j in range(num_res_blocks):
1396
- block = nn.ModuleList([
1397
- ResBlock(
1398
- in_dim,
1399
- embed_dim,
1400
- dropout,
1401
- out_channels=out_dim,
1402
- use_scale_shift_norm=False,
1403
- use_image_dataset=use_image_dataset,
1404
- )
1405
- ])
1406
- if scale in attn_scales:
1407
- block.append(
1408
- SpatialTransformer(
1409
- out_dim,
1410
- out_dim // head_dim,
1411
- head_dim,
1412
- depth=1,
1413
- context_dim=self.context_dim,
1414
- disable_self_attn=False,
1415
- use_linear=True,
1416
- is_ctrl=True
1417
- ))
1418
- if self.temporal_attention:
1419
- if USE_TEMPORAL_TRANSFORMER:
1420
- block.append(
1421
- TemporalTransformer(
1422
- out_dim,
1423
- out_dim // head_dim,
1424
- head_dim,
1425
- depth=transformer_depth,
1426
- context_dim=context_dim,
1427
- disable_self_attn=disabled_sa,
1428
- use_linear=use_linear_in_temporal,
1429
- multiply_zero=use_image_dataset,
1430
- is_ctrl=True
1431
- ))
1432
- else:
1433
- block.append(
1434
- TemporalAttentionMultiBlock(
1435
- out_dim,
1436
- num_heads,
1437
- head_dim,
1438
- rotary_emb=self.rotary_emb,
1439
- use_image_dataset=use_image_dataset,
1440
- use_sim_mask=use_sim_mask,
1441
- temporal_attn_times=temporal_attn_times))
1442
- in_dim = out_dim
1443
- self.input_blocks.append(block)
1444
- shortcut_dims.append(out_dim)
1445
-
1446
- # downsample
1447
- if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
1448
- downsample = Downsample(
1449
- out_dim, True, dims=2, out_channels=out_dim)
1450
- shortcut_dims.append(out_dim)
1451
- scale /= 2.0
1452
- self.input_blocks.append(downsample)
1453
-
1454
- self.middle_block = nn.ModuleList([
1455
- ResBlock(
1456
- out_dim,
1457
- embed_dim,
1458
- dropout,
1459
- use_scale_shift_norm=False,
1460
- use_image_dataset=use_image_dataset,
1461
- ),
1462
- SpatialTransformer(
1463
- out_dim,
1464
- out_dim // head_dim,
1465
- head_dim,
1466
- depth=1,
1467
- context_dim=self.context_dim,
1468
- disable_self_attn=False,
1469
- use_linear=True,
1470
- is_ctrl=True
1471
- )
1472
- ])
1473
-
1474
- if self.temporal_attention:
1475
- if USE_TEMPORAL_TRANSFORMER:
1476
- self.middle_block.append(
1477
- TemporalTransformer(
1478
- out_dim,
1479
- out_dim // head_dim,
1480
- head_dim,
1481
- depth=transformer_depth,
1482
- context_dim=context_dim,
1483
- disable_self_attn=disabled_sa,
1484
- use_linear=use_linear_in_temporal,
1485
- multiply_zero=use_image_dataset,
1486
- is_ctrl=True
1487
-
1488
- ))
1489
- else:
1490
- self.middle_block.append(
1491
- TemporalAttentionMultiBlock(
1492
- out_dim,
1493
- num_heads,
1494
- head_dim,
1495
- rotary_emb=self.rotary_emb,
1496
- use_image_dataset=use_image_dataset,
1497
- use_sim_mask=use_sim_mask,
1498
- temporal_attn_times=temporal_attn_times))
1499
-
1500
- self.middle_block.append(
1501
- ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
1502
-
1503
- # decoder
1504
- self.output_blocks = nn.ModuleList()
1505
- for i, (in_dim,
1506
- out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
1507
- for j in range(num_res_blocks + 1):
1508
- block = nn.ModuleList([
1509
- ResBlock(
1510
- in_dim + shortcut_dims.pop(),
1511
- embed_dim,
1512
- dropout,
1513
- out_dim,
1514
- use_scale_shift_norm=False,
1515
- use_image_dataset=use_image_dataset,
1516
- )
1517
- ])
1518
- if scale in attn_scales:
1519
- block.append(
1520
- SpatialTransformer(
1521
- out_dim,
1522
- out_dim // head_dim,
1523
- head_dim,
1524
- depth=1,
1525
- context_dim=1024,
1526
- disable_self_attn=False,
1527
- use_linear=True,
1528
- is_ctrl=True))
1529
- if self.temporal_attention:
1530
- if USE_TEMPORAL_TRANSFORMER:
1531
- block.append(
1532
- TemporalTransformer(
1533
- out_dim,
1534
- out_dim // head_dim,
1535
- head_dim,
1536
- depth=transformer_depth,
1537
- context_dim=context_dim,
1538
- disable_self_attn=disabled_sa,
1539
- use_linear=use_linear_in_temporal,
1540
- multiply_zero=use_image_dataset,
1541
- is_ctrl=True))
1542
- else:
1543
- block.append(
1544
- TemporalAttentionMultiBlock(
1545
- out_dim,
1546
- num_heads,
1547
- head_dim,
1548
- rotary_emb=self.rotary_emb,
1549
- use_image_dataset=use_image_dataset,
1550
- use_sim_mask=use_sim_mask,
1551
- temporal_attn_times=temporal_attn_times))
1552
- in_dim = out_dim
1553
-
1554
- # upsample
1555
- if i != len(dim_mult) - 1 and j == num_res_blocks:
1556
- upsample = Upsample(
1557
- out_dim, True, dims=2.0, out_channels=out_dim)
1558
- scale *= 2.0
1559
- block.append(upsample)
1560
- self.output_blocks.append(block)
1561
-
1562
- # head
1563
- self.out = nn.Sequential(
1564
- nn.GroupNorm(32, out_dim), nn.SiLU(),
1565
- nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
1566
-
1567
- # zero out the last layer params
1568
- nn.init.zeros_(self.out[-1].weight)
1569
-
1570
- def forward(self,
1571
- x,
1572
- t,
1573
- y,
1574
- x_lr=None,
1575
- fps=None,
1576
- video_mask=None,
1577
- focus_present_mask=None,
1578
- prob_focus_present=0.,
1579
- mask_last_frame_num=0):
1580
-
1581
- batch, c, f, h, w = x.shape
1582
- device = x.device
1583
- self.batch = batch
1584
-
1585
- # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
1586
- if mask_last_frame_num > 0:
1587
- focus_present_mask = None
1588
- video_mask[-mask_last_frame_num:] = False
1589
- else:
1590
- focus_present_mask = default(
1591
- focus_present_mask, lambda: prob_mask_like(
1592
- (batch, ), prob_focus_present, device=device))
1593
-
1594
- if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
1595
- time_rel_pos_bias = self.time_rel_pos_bias(
1596
- x.shape[2], device=x.device)
1597
- else:
1598
- time_rel_pos_bias = None
1599
-
1600
- # embeddings
1601
- e = self.time_embed(sinusoidal_embedding(t, self.dim))
1602
- context = y
1603
-
1604
- # repeat f times for spatial e and context
1605
- e = e.repeat_interleave(repeats=f, dim=0)
1606
- context = context.repeat_interleave(repeats=f, dim=0)
1607
-
1608
- # always in shape (b f) c h w, except for temporal layer
1609
- x = rearrange(x, 'b c f h w -> (b f) c h w')
1610
- # encoder
1611
- xs = []
1612
- for ind, block in enumerate(self.input_blocks):
1613
- x = self._forward_single(block, x, e, context, time_rel_pos_bias,
1614
- focus_present_mask, video_mask)
1615
- xs.append(x)
1616
-
1617
- # middle
1618
- for block in self.middle_block:
1619
- x = self._forward_single(block, x, e, context, time_rel_pos_bias,
1620
- focus_present_mask, video_mask)
1621
-
1622
- # decoder
1623
- for block in self.output_blocks:
1624
- x = torch.cat([x, xs.pop()], dim=1)
1625
- x = self._forward_single(
1626
- block,
1627
- x,
1628
- e,
1629
- context,
1630
- time_rel_pos_bias,
1631
- focus_present_mask,
1632
- video_mask,
1633
- reference=xs[-1] if len(xs) > 0 else None)
1634
-
1635
- # head
1636
- x = self.out(x)
1637
-
1638
- # reshape back to (b c f h w)
1639
- x = rearrange(x, '(b f) c h w -> b c f h w', b=batch)
1640
- return x
1641
-
1642
- def _forward_single(self,
1643
- module,
1644
- x,
1645
- e,
1646
- context,
1647
- time_rel_pos_bias,
1648
- focus_present_mask,
1649
- video_mask,
1650
- reference=None):
1651
- if isinstance(module, ResidualBlock):
1652
- module = checkpoint_wrapper(
1653
- module) if self.use_checkpoint else module
1654
- x = x.contiguous()
1655
- x = module(x, e, reference)
1656
- elif isinstance(module, ResBlock):
1657
- module = checkpoint_wrapper(
1658
- module) if self.use_checkpoint else module
1659
- x = x.contiguous()
1660
- x = module(x, e, self.batch)
1661
- elif isinstance(module, SpatialTransformer):
1662
- module = checkpoint_wrapper(
1663
- module) if self.use_checkpoint else module
1664
- x = module(x, context)
1665
- elif isinstance(module, TemporalTransformer):
1666
- module = checkpoint_wrapper(
1667
- module) if self.use_checkpoint else module
1668
- x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1669
- x = module(x, context)
1670
- x = rearrange(x, 'b c f h w -> (b f) c h w')
1671
- elif isinstance(module, CrossAttention):
1672
- module = checkpoint_wrapper(
1673
- module) if self.use_checkpoint else module
1674
- x = module(x, context)
1675
- elif isinstance(module, MemoryEfficientCrossAttention):
1676
- module = checkpoint_wrapper(
1677
- module) if self.use_checkpoint else module
1678
- x = module(x, context)
1679
- elif isinstance(module, BasicTransformerBlock):
1680
- module = checkpoint_wrapper(
1681
- module) if self.use_checkpoint else module
1682
- x = module(x, context)
1683
- elif isinstance(module, FeedForward):
1684
- x = module(x, context)
1685
- elif isinstance(module, Upsample):
1686
- x = module(x)
1687
- elif isinstance(module, Downsample):
1688
- x = module(x)
1689
- elif isinstance(module, Resample):
1690
- x = module(x, reference)
1691
- elif isinstance(module, TemporalAttentionBlock):
1692
- module = checkpoint_wrapper(
1693
- module) if self.use_checkpoint else module
1694
- x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1695
- x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
1696
- x = rearrange(x, 'b c f h w -> (b f) c h w')
1697
- elif isinstance(module, TemporalAttentionMultiBlock):
1698
- module = checkpoint_wrapper(
1699
- module) if self.use_checkpoint else module
1700
- x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1701
- x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
1702
- x = rearrange(x, 'b c f h w -> (b f) c h w')
1703
- elif isinstance(module, InitTemporalConvBlock):
1704
- module = checkpoint_wrapper(
1705
- module) if self.use_checkpoint else module
1706
- x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1707
- x = module(x)
1708
- x = rearrange(x, 'b c f h w -> (b f) c h w')
1709
- elif isinstance(module, TemporalConvBlock):
1710
- module = checkpoint_wrapper(
1711
- module) if self.use_checkpoint else module
1712
- x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1713
- x = module(x)
1714
- x = rearrange(x, 'b c f h w -> (b f) c h w')
1715
- elif isinstance(module, nn.ModuleList):
1716
- for block in module:
1717
- x = self._forward_single(block, x, e, context,
1718
- time_rel_pos_bias, focus_present_mask,
1719
- video_mask, reference)
1720
- else:
1721
- x = module(x)
1722
- return x
1723
-
1724
-
1725
- class ControlledV2VUNet(Vid2VidSDUNet):
1726
- def __init__(self):
1727
- super(ControlledV2VUNet, self).__init__()
1728
- self.VideoControlNet = VideoControlNet()
1729
-
1730
- def forward(self,
1731
- x,
1732
- t,
1733
- y,
1734
- hint=None,
1735
- variant_info=None,
1736
- hint_chunk=None,
1737
- t_hint=None,
1738
- s_cond=None,
1739
- mask_cond=None,
1740
- x_lr=None,
1741
- fps=None,
1742
- mask=None,
1743
- video_mask=None,
1744
- focus_present_mask=None,
1745
- prob_focus_present=0.,
1746
- mask_last_frame_num=0,
1747
- ):
1748
-
1749
- batch, _, f, _, _= x.shape
1750
- device = x.device
1751
- self.batch = batch
1752
-
1753
- # Process text (new added for t5 encoder)
1754
- # y = self.VideoControlNet.y_embedder(y, self.training).squeeze(1) # [1, 1, 120, 4096] -> [B, 1, 120, 1024].squeeze(1) -> [B, 120, 1024]
1755
-
1756
- if hint_chunk is not None:
1757
- hint = hint_chunk
1758
-
1759
- control = self.VideoControlNet(x, t, y, hint=hint, t_hint=t_hint, \
1760
- mask_cond=mask_cond, s_cond=s_cond, \
1761
- variant_info=variant_info)
1762
-
1763
- # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
1764
- if mask_last_frame_num > 0:
1765
- focus_present_mask = None
1766
- video_mask[-mask_last_frame_num:] = False
1767
- else:
1768
- focus_present_mask = default(
1769
- focus_present_mask, lambda: prob_mask_like(
1770
- (batch, ), prob_focus_present, device=device))
1771
-
1772
- if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
1773
- time_rel_pos_bias = self.time_rel_pos_bias(
1774
- x.shape[2], device=x.device)
1775
- else:
1776
- time_rel_pos_bias = None
1777
-
1778
- e = self.time_embed(sinusoidal_embedding(t, self.dim))
1779
- e = e.repeat_interleave(repeats=f, dim=0)
1780
-
1781
- # context = y
1782
- context = y.repeat_interleave(repeats=f, dim=0)
1783
-
1784
- # always in shape (b f) c h w, except for temporal layer
1785
- x = rearrange(x, 'b c f h w -> (b f) c h w')
1786
- # encoder
1787
- xs = []
1788
- for block in self.input_blocks:
1789
- x = self._forward_single(block, x, e, context, time_rel_pos_bias,
1790
- focus_present_mask, video_mask, variant_info=variant_info)
1791
- xs.append(x)
1792
- # middle
1793
- for block in self.middle_block:
1794
- x = self._forward_single(block, x, e, context, time_rel_pos_bias,
1795
- focus_present_mask, video_mask, variant_info=variant_info)
1796
-
1797
- if control is not None:
1798
- x = control.pop() + x
1799
-
1800
- # decoder
1801
- for block in self.output_blocks:
1802
- if control is None:
1803
- x = torch.cat([x, xs.pop()], dim=1)
1804
- else:
1805
- x = torch.cat([x, xs.pop() + control.pop()], dim=1)
1806
- x = self._forward_single(
1807
- block,
1808
- x,
1809
- e,
1810
- context,
1811
- time_rel_pos_bias,
1812
- focus_present_mask,
1813
- video_mask,
1814
- reference=xs[-1] if len(xs) > 0 else None,
1815
- variant_info=variant_info)
1816
-
1817
- # head
1818
- x = self.out(x)
1819
-
1820
- # reshape back to (b c f h w)
1821
- x = rearrange(x, '(b f) c h w -> b c f h w', b=batch)
1822
- return x
1823
-
1824
- def _forward_single(self,
1825
- module,
1826
- x,
1827
- e,
1828
- context,
1829
- time_rel_pos_bias,
1830
- focus_present_mask,
1831
- video_mask,
1832
- reference=None,
1833
- variant_info=None):
1834
- variant_info = None # For Debug
1835
- if isinstance(module, ResidualBlock):
1836
- module = checkpoint_wrapper(
1837
- module) if self.use_checkpoint else module
1838
- x = x.contiguous()
1839
- x = module(x, e, reference)
1840
- elif isinstance(module, ResBlock):
1841
- module = checkpoint_wrapper(
1842
- module) if self.use_checkpoint else module
1843
- x = x.contiguous()
1844
- x = module(x, e, self.batch, variant_info)
1845
- elif isinstance(module, SpatialTransformer):
1846
- module = checkpoint_wrapper(
1847
- module) if self.use_checkpoint else module
1848
- x = module(x, context)
1849
- elif isinstance(module, TemporalTransformer):
1850
- module = checkpoint_wrapper(
1851
- module) if self.use_checkpoint else module
1852
- x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1853
- x = module(x, context)
1854
- x = rearrange(x, 'b c f h w -> (b f) c h w')
1855
- elif isinstance(module, CrossAttention):
1856
- module = checkpoint_wrapper(
1857
- module) if self.use_checkpoint else module
1858
- x = module(x, context)
1859
- elif isinstance(module, MemoryEfficientCrossAttention):
1860
- module = checkpoint_wrapper(
1861
- module) if self.use_checkpoint else module
1862
- x = module(x, context)
1863
- elif isinstance(module, BasicTransformerBlock):
1864
- module = checkpoint_wrapper(
1865
- module) if self.use_checkpoint else module
1866
- x = module(x, context)
1867
- elif isinstance(module, FeedForward):
1868
- x = module(x, context)
1869
- elif isinstance(module, Upsample):
1870
- x = module(x)
1871
- elif isinstance(module, Downsample):
1872
- x = module(x)
1873
- elif isinstance(module, Resample):
1874
- x = module(x, reference)
1875
- elif isinstance(module, TemporalAttentionBlock):
1876
- module = checkpoint_wrapper(
1877
- module) if self.use_checkpoint else module
1878
- x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1879
- x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
1880
- x = rearrange(x, 'b c f h w -> (b f) c h w')
1881
- elif isinstance(module, TemporalAttentionMultiBlock):
1882
- module = checkpoint_wrapper(
1883
- module) if self.use_checkpoint else module
1884
- x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1885
- x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
1886
- x = rearrange(x, 'b c f h w -> (b f) c h w')
1887
- elif isinstance(module, InitTemporalConvBlock):
1888
- module = checkpoint_wrapper(
1889
- module) if self.use_checkpoint else module
1890
- x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1891
- x = module(x)
1892
- x = rearrange(x, 'b c f h w -> (b f) c h w')
1893
- elif isinstance(module, TemporalConvBlock):
1894
- module = checkpoint_wrapper(
1895
- module) if self.use_checkpoint else module
1896
- x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1897
- x = module(x)
1898
- x = rearrange(x, 'b c f h w -> (b f) c h w')
1899
- elif isinstance(module, nn.ModuleList):
1900
- for block in module:
1901
- x = self._forward_single(block, x, e, context,
1902
- time_rel_pos_bias, focus_present_mask,
1903
- video_mask, reference, variant_info)
1904
- else:
1905
- x = module(x)
1906
- return x
1907
-
1908
-
1909
- class VideoControlNet(nn.Module):
1910
-
1911
- def __init__(self,
1912
- in_dim=4,
1913
- dim=320,
1914
- y_dim=1024,
1915
- context_dim=1024,
1916
- out_dim=4,
1917
- dim_mult=[1, 2, 4, 4],
1918
- num_heads=8,
1919
- head_dim=64,
1920
- num_res_blocks=2,
1921
- attn_scales=[1 / 1, 1 / 2, 1 / 4],
1922
- use_scale_shift_norm=True,
1923
- dropout=0.1,
1924
- temporal_attn_times=1,
1925
- temporal_attention=True,
1926
- use_checkpoint=True,
1927
- use_image_dataset=False,
1928
- use_fps_condition=False,
1929
- use_sim_mask=False,
1930
- training=False,
1931
- inpainting=True):
1932
- embed_dim = dim * 4
1933
- num_heads = num_heads if num_heads else dim // 32
1934
- super(VideoControlNet, self).__init__()
1935
- self.in_dim = in_dim
1936
- self.dim = dim
1937
- self.y_dim = y_dim
1938
- self.context_dim = context_dim
1939
- self.embed_dim = embed_dim
1940
- self.out_dim = out_dim
1941
- self.dim_mult = dim_mult
1942
- # for temporal attention
1943
- self.num_heads = num_heads
1944
- # for spatial attention
1945
- self.head_dim = head_dim
1946
- self.num_res_blocks = num_res_blocks
1947
- self.attn_scales = attn_scales
1948
- self.use_scale_shift_norm = use_scale_shift_norm
1949
- self.temporal_attn_times = temporal_attn_times
1950
- self.temporal_attention = temporal_attention
1951
- self.use_checkpoint = use_checkpoint
1952
- self.use_image_dataset = use_image_dataset
1953
- self.use_fps_condition = use_fps_condition
1954
- self.use_sim_mask = use_sim_mask
1955
- self.training = training
1956
- self.inpainting = inpainting
1957
-
1958
- use_linear_in_temporal = False
1959
- transformer_depth = 1
1960
- disabled_sa = False
1961
- # params
1962
- enc_dims = [dim * u for u in [1] + dim_mult]
1963
- dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
1964
- shortcut_dims = []
1965
- scale = 1.0
1966
-
1967
- # CaptionEmbedder (new add)
1968
- # approx_gelu = lambda: nn.GELU(approximate="tanh")
1969
- # self.y_embedder = CaptionEmbedder(
1970
- # in_channels=4096,
1971
- # hidden_size=1024,
1972
- # uncond_prob=0.1,
1973
- # act_layer=approx_gelu,
1974
- # token_num=120,
1975
- # )
1976
-
1977
- # embeddings
1978
- self.time_embed = nn.Sequential(
1979
- nn.Linear(dim, embed_dim), nn.SiLU(),
1980
- nn.Linear(embed_dim, embed_dim))
1981
-
1982
- # self.hint_time_zero_linear = zero_module(nn.Linear(embed_dim, embed_dim))
1983
-
1984
- # scale prompt
1985
- # self.scale_cond = nn.Sequential(
1986
- # nn.Linear(dim, embed_dim), nn.SiLU(),
1987
- # zero_module(nn.Linear(embed_dim, embed_dim)))
1988
-
1989
- if self.use_fps_condition:
1990
- self.fps_embedding = nn.Sequential(
1991
- nn.Linear(dim, embed_dim), nn.SiLU(),
1992
- nn.Linear(embed_dim, embed_dim))
1993
- nn.init.zeros_(self.fps_embedding[-1].weight)
1994
- nn.init.zeros_(self.fps_embedding[-1].bias)
1995
-
1996
- # encoder
1997
- self.input_blocks = nn.ModuleList()
1998
- init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
1999
- # need an initial temporal attention?
2000
- if temporal_attention:
2001
- if USE_TEMPORAL_TRANSFORMER:
2002
- init_block.append(
2003
- TemporalTransformer(
2004
- dim,
2005
- num_heads,
2006
- head_dim,
2007
- depth=transformer_depth,
2008
- context_dim=context_dim,
2009
- disable_self_attn=disabled_sa,
2010
- use_linear=use_linear_in_temporal,
2011
- multiply_zero=use_image_dataset,
2012
- is_ctrl=True,))
2013
- else:
2014
- init_block.append(
2015
- TemporalAttentionMultiBlock(
2016
- dim,
2017
- num_heads,
2018
- head_dim,
2019
- rotary_emb=self.rotary_emb,
2020
- temporal_attn_times=temporal_attn_times,
2021
- use_image_dataset=use_image_dataset))
2022
- self.input_blocks.append(init_block)
2023
- self.zero_convs = nn.ModuleList([self.make_zero_conv(dim)])
2024
- shortcut_dims.append(dim)
2025
- for i, (in_dim,
2026
- out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
2027
- for j in range(num_res_blocks):
2028
- block = nn.ModuleList([
2029
- ResBlock(
2030
- in_dim,
2031
- embed_dim,
2032
- dropout,
2033
- out_channels=out_dim,
2034
- use_scale_shift_norm=False,
2035
- use_image_dataset=use_image_dataset,
2036
- )
2037
- ])
2038
- if scale in attn_scales:
2039
- block.append(
2040
- SpatialTransformer(
2041
- out_dim,
2042
- out_dim // head_dim,
2043
- head_dim,
2044
- depth=1,
2045
- context_dim=self.context_dim,
2046
- disable_self_attn=False,
2047
- use_linear=True,
2048
- is_ctrl=True))
2049
- if self.temporal_attention:
2050
- if USE_TEMPORAL_TRANSFORMER:
2051
- block.append(
2052
- TemporalTransformer(
2053
- out_dim,
2054
- out_dim // head_dim,
2055
- head_dim,
2056
- depth=transformer_depth,
2057
- context_dim=context_dim,
2058
- disable_self_attn=disabled_sa,
2059
- use_linear=use_linear_in_temporal,
2060
- multiply_zero=use_image_dataset,
2061
- is_ctrl=True,))
2062
- else:
2063
- block.append(
2064
- TemporalAttentionMultiBlock(
2065
- out_dim,
2066
- num_heads,
2067
- head_dim,
2068
- rotary_emb=self.rotary_emb,
2069
- use_image_dataset=use_image_dataset,
2070
- use_sim_mask=use_sim_mask,
2071
- temporal_attn_times=temporal_attn_times))
2072
- in_dim = out_dim
2073
- self.input_blocks.append(block)
2074
- self.zero_convs.append(self.make_zero_conv(out_dim))
2075
- shortcut_dims.append(out_dim)
2076
-
2077
- # downsample
2078
- if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
2079
- downsample = Downsample(
2080
- out_dim, True, dims=2, out_channels=out_dim)
2081
- shortcut_dims.append(out_dim)
2082
- scale /= 2.0
2083
- self.input_blocks.append(downsample)
2084
- self.zero_convs.append(self.make_zero_conv(out_dim))
2085
-
2086
- self.middle_block = nn.ModuleList([
2087
- ResBlock(
2088
- out_dim,
2089
- embed_dim,
2090
- dropout,
2091
- use_scale_shift_norm=False,
2092
- use_image_dataset=use_image_dataset,
2093
- ),
2094
- SpatialTransformer(
2095
- out_dim,
2096
- out_dim // head_dim,
2097
- head_dim,
2098
- depth=1,
2099
- context_dim=self.context_dim,
2100
- disable_self_attn=False,
2101
- use_linear=True,
2102
- is_ctrl=True)
2103
- ])
2104
-
2105
- if self.temporal_attention:
2106
- if USE_TEMPORAL_TRANSFORMER:
2107
- self.middle_block.append(
2108
- TemporalTransformer(
2109
- out_dim,
2110
- out_dim // head_dim,
2111
- head_dim,
2112
- depth=transformer_depth,
2113
- context_dim=context_dim,
2114
- disable_self_attn=disabled_sa,
2115
- use_linear=use_linear_in_temporal,
2116
- multiply_zero=use_image_dataset,
2117
- is_ctrl=True,
2118
- ))
2119
- else:
2120
- self.middle_block.append(
2121
- TemporalAttentionMultiBlock(
2122
- out_dim,
2123
- num_heads,
2124
- head_dim,
2125
- rotary_emb=self.rotary_emb,
2126
- use_image_dataset=use_image_dataset,
2127
- use_sim_mask=use_sim_mask,
2128
- temporal_attn_times=temporal_attn_times))
2129
-
2130
- self.middle_block.append(
2131
- ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
2132
-
2133
- self.middle_block_out = self.make_zero_conv(embed_dim)
2134
-
2135
- '''
2136
- add prompt
2137
- '''
2138
- add_dim = 320
2139
- self.add_dim = add_dim
2140
-
2141
- self.input_hint_block = zero_module(nn.Conv2d(4, add_dim, 3, padding=1))
2142
-
2143
- def make_zero_conv(self, in_channels, out_channels=None):
2144
- out_channels = in_channels if out_channels is None else out_channels
2145
- return TimestepEmbedSequential(zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)))
2146
-
2147
- def forward(self,
2148
- x,
2149
- t,
2150
- y,
2151
- s_cond=None,
2152
- hint=None,
2153
- variant_info=None,
2154
- t_hint=None,
2155
- mask_cond=None,
2156
- fps=None,
2157
- video_mask=None,
2158
- focus_present_mask=None,
2159
- prob_focus_present=0.,
2160
- mask_last_frame_num=0):
2161
-
2162
- batch, _, f, _, _ = x.shape
2163
- device = x.device
2164
- self.batch = batch
2165
-
2166
- # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
2167
- if mask_last_frame_num > 0:
2168
- focus_present_mask = None
2169
- video_mask[-mask_last_frame_num:] = False
2170
- else:
2171
- focus_present_mask = default(
2172
- focus_present_mask, lambda: prob_mask_like(
2173
- (batch, ), prob_focus_present, device=device))
2174
-
2175
- if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
2176
- time_rel_pos_bias = self.time_rel_pos_bias(
2177
- x.shape[2], device=x.device)
2178
- else:
2179
- time_rel_pos_bias = None
2180
-
2181
- if hint is not None:
2182
- # add = x.new_zeros(batch, self.add_dim, f, h, w)
2183
- hint = rearrange(hint, 'b c f h w -> (b f) c h w')
2184
- hint = self.input_hint_block(hint)
2185
- # hint = rearrange(hint, '(b f) c h w -> b c f h w', b = batch)
2186
-
2187
- e = self.time_embed(sinusoidal_embedding(t, self.dim))
2188
- e = e.repeat_interleave(repeats=f, dim=0)
2189
-
2190
- context = y.repeat_interleave(repeats=f, dim=0)
2191
-
2192
- # always in shape (b f) c h w, except for temporal layer
2193
- x = rearrange(x, 'b c f h w -> (b f) c h w')
2194
- # print('before x shape:', x.shape) [64, 320, 90, 160]
2195
- # print('hint shape:', hint.shape) [32, 320, 90, 160]
2196
-
2197
- # encoder
2198
- xs = []
2199
- for module, zero_conv in zip(self.input_blocks, self.zero_convs):
2200
- if hint is not None:
2201
- for block in module:
2202
- x = self._forward_single(block, x, e, context, time_rel_pos_bias,
2203
- focus_present_mask, video_mask, variant_info=variant_info)
2204
- if not isinstance(block, TemporalTransformer):
2205
- if hint is not None:
2206
- x += hint
2207
- hint = None
2208
- else:
2209
- x = self._forward_single(module, x, e, context, time_rel_pos_bias,
2210
- focus_present_mask, video_mask, variant_info=variant_info)
2211
- xs.append(zero_conv(x, e, context))
2212
-
2213
- # middle
2214
- for block in self.middle_block:
2215
- x = self._forward_single(block, x, e, context, time_rel_pos_bias,
2216
- focus_present_mask, video_mask, variant_info=variant_info)
2217
- xs.append(self.middle_block_out(x, e, context))
2218
-
2219
- return xs
2220
-
2221
- def _forward_single(self,
2222
- module,
2223
- x,
2224
- e,
2225
- context,
2226
- time_rel_pos_bias,
2227
- focus_present_mask,
2228
- video_mask,
2229
- reference=None,
2230
- variant_info=None,):
2231
- # variant_info = None # For Debug
2232
- if isinstance(module, ResidualBlock):
2233
- module = checkpoint_wrapper(
2234
- module) if self.use_checkpoint else module
2235
- x = x.contiguous()
2236
- x = module(x, e, reference)
2237
- elif isinstance(module, ResBlock):
2238
- module = checkpoint_wrapper(
2239
- module) if self.use_checkpoint else module
2240
- x = x.contiguous()
2241
- x = module(x, e, self.batch, variant_info)
2242
- elif isinstance(module, SpatialTransformer):
2243
- module = checkpoint_wrapper(
2244
- module) if self.use_checkpoint else module
2245
- x = module(x, context)
2246
- elif isinstance(module, TemporalTransformer):
2247
- module = checkpoint_wrapper(
2248
- module) if self.use_checkpoint else module
2249
- x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
2250
- # print("x shape:", x.shape) # [2, 320, 32, 90, 160]
2251
- x = module(x, context)
2252
- x = rearrange(x, 'b c f h w -> (b f) c h w')
2253
- elif isinstance(module, CrossAttention):
2254
- module = checkpoint_wrapper(
2255
- module) if self.use_checkpoint else module
2256
- x = module(x, context)
2257
- elif isinstance(module, MemoryEfficientCrossAttention):
2258
- module = checkpoint_wrapper(
2259
- module) if self.use_checkpoint else module
2260
- x = module(x, context)
2261
- elif isinstance(module, BasicTransformerBlock):
2262
- module = checkpoint_wrapper(
2263
- module) if self.use_checkpoint else module
2264
- x = module(x, context)
2265
- elif isinstance(module, FeedForward):
2266
- x = module(x, context)
2267
- elif isinstance(module, Upsample):
2268
- x = module(x)
2269
- elif isinstance(module, Downsample):
2270
- x = module(x)
2271
- elif isinstance(module, Resample):
2272
- x = module(x, reference)
2273
- elif isinstance(module, TemporalAttentionBlock):
2274
- module = checkpoint_wrapper(
2275
- module) if self.use_checkpoint else module
2276
- x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
2277
- x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
2278
- x = rearrange(x, 'b c f h w -> (b f) c h w')
2279
- elif isinstance(module, TemporalAttentionMultiBlock):
2280
- module = checkpoint_wrapper(
2281
- module) if self.use_checkpoint else module
2282
- x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
2283
- x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
2284
- x = rearrange(x, 'b c f h w -> (b f) c h w')
2285
- elif isinstance(module, InitTemporalConvBlock):
2286
- module = checkpoint_wrapper(
2287
- module) if self.use_checkpoint else module
2288
- x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
2289
- x = module(x)
2290
- x = rearrange(x, 'b c f h w -> (b f) c h w')
2291
- elif isinstance(module, TemporalConvBlock):
2292
- module = checkpoint_wrapper(
2293
- module) if self.use_checkpoint else module
2294
- x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
2295
- x = module(x)
2296
- x = rearrange(x, 'b c f h w -> (b f) c h w')
2297
- elif isinstance(module, nn.ModuleList):
2298
- for block in module:
2299
- x = self._forward_single(block, x, e, context,
2300
- time_rel_pos_bias, focus_present_mask,
2301
- video_mask, reference, variant_info)
2302
- else:
2303
- x = module(x)
2304
- return x
2305
-
2306
-
2307
- class TimestepBlock(nn.Module):
2308
- """
2309
- Any module where forward() takes timestep embeddings as a second argument.
2310
- """
2311
-
2312
- @abstractmethod
2313
- def forward(self, x, emb):
2314
- """
2315
- Apply the module to `x` given `emb` timestep embeddings.
2316
- """
2317
-
2318
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
2319
- """
2320
- A sequential module that passes timestep embeddings to the children that
2321
- support it as an extra input.
2322
- """
2323
-
2324
- def forward(self, x, emb, context=None):
2325
- for layer in self:
2326
- if isinstance(layer, TimestepBlock):
2327
- x = layer(x, emb)
2328
- elif isinstance(layer, SpatialTransformer):
2329
- x = layer(x, context)
2330
- else:
2331
- x = layer(x)
2332
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video_to_video/utils/__init__.py DELETED
File without changes
video_to_video/utils/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (159 Bytes)
 
video_to_video/utils/__pycache__/config.cpython-39.pyc DELETED
Binary file (3.44 kB)
 
video_to_video/utils/__pycache__/logger.cpython-39.pyc DELETED
Binary file (2.14 kB)
 
video_to_video/utils/__pycache__/seed.cpython-39.pyc DELETED
Binary file (467 Bytes)
 
video_to_video/utils/config.py DELETED
@@ -1,169 +0,0 @@
1
- # Copyright (c) Alibaba, Inc. and its affiliates.
2
-
3
- import logging
4
- import os
5
- import os.path as osp
6
- from datetime import datetime
7
-
8
- import torch
9
- from easydict import EasyDict
10
-
11
- cfg = EasyDict(__name__='Config: VideoLDM Decoder')
12
-
13
- # ---------------------------work dir--------------------------
14
- cfg.work_dir = 'workspace/'
15
-
16
- # ---------------------------Global Variable-----------------------------------
17
- cfg.resolution = [448, 256]
18
- cfg.max_frames = 32
19
- # -----------------------------------------------------------------------------
20
-
21
- # ---------------------------Dataset Parameter---------------------------------
22
- cfg.mean = [0.5, 0.5, 0.5]
23
- cfg.std = [0.5, 0.5, 0.5]
24
- cfg.max_words = 1000
25
-
26
- # PlaceHolder
27
- cfg.vit_out_dim = 1024
28
- cfg.vit_resolution = [224, 224]
29
- cfg.depth_clamp = 10.0
30
- cfg.misc_size = 384
31
- cfg.depth_std = 20.0
32
-
33
- cfg.frame_lens = 32
34
- cfg.sample_fps = 8
35
-
36
- cfg.batch_sizes = 1
37
- # -----------------------------------------------------------------------------
38
-
39
- # ---------------------------Mode Parameters-----------------------------------
40
- # Diffusion
41
- cfg.schedule = 'cosine'
42
- cfg.num_timesteps = 1000
43
- cfg.mean_type = 'v'
44
- cfg.var_type = 'fixed_small'
45
- cfg.loss_type = 'mse'
46
- cfg.ddim_timesteps = 50
47
- cfg.ddim_eta = 0.0
48
- cfg.clamp = 1.0
49
- cfg.share_noise = False
50
- cfg.use_div_loss = False
51
- cfg.noise_strength = 0.1
52
-
53
- # classifier-free guidance
54
- cfg.p_zero = 0.1
55
- cfg.guide_scale = 3.0
56
-
57
- # clip vision encoder
58
- cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
59
- cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
60
-
61
- # Model
62
- cfg.scale_factor = 0.18215
63
- cfg.use_fp16 = True
64
- cfg.temporal_attention = True
65
- cfg.decoder_bs = 8
66
-
67
- cfg.UNet = {
68
- 'type': 'Vid2VidSDUNet',
69
- 'in_dim': 4,
70
- 'dim': 320,
71
- 'y_dim': cfg.vit_out_dim,
72
- 'context_dim': 1024,
73
- 'out_dim': 8 if cfg.var_type.startswith('learned') else 4,
74
- 'dim_mult': [1, 2, 4, 4],
75
- 'num_heads': 8,
76
- 'head_dim': 64,
77
- 'num_res_blocks': 2,
78
- 'attn_scales': [1 / 1, 1 / 2, 1 / 4],
79
- 'dropout': 0.1,
80
- 'temporal_attention': cfg.temporal_attention,
81
- 'temporal_attn_times': 1,
82
- 'use_checkpoint': False,
83
- 'use_fps_condition': False,
84
- 'use_sim_mask': False,
85
- 'num_tokens': 4,
86
- 'default_fps': 8,
87
- 'input_dim': 1024
88
- }
89
-
90
- cfg.guidances = []
91
-
92
- # auotoencoder from stabel diffusion
93
- cfg.auto_encoder = {
94
- 'type': 'AutoencoderKL',
95
- 'ddconfig': {
96
- 'double_z': True,
97
- 'z_channels': 4,
98
- 'resolution': 256,
99
- 'in_channels': 3,
100
- 'out_ch': 3,
101
- 'ch': 128,
102
- 'ch_mult': [1, 2, 4, 4],
103
- 'num_res_blocks': 2,
104
- 'attn_resolutions': [],
105
- 'dropout': 0.0
106
- },
107
- 'embed_dim': 4,
108
- 'pretrained': 'models/v2-1_512-ema-pruned.ckpt'
109
- }
110
- # clip embedder
111
- cfg.embedder = {
112
- 'type': 'FrozenOpenCLIPEmbedder',
113
- 'layer': 'penultimate',
114
- 'vit_resolution': [224, 224],
115
- 'pretrained': 'open_clip_pytorch_model.bin'
116
- }
117
- # -----------------------------------------------------------------------------
118
-
119
- # ---------------------------Training Settings---------------------------------
120
- # training and optimizer
121
- cfg.ema_decay = 0.9999
122
- cfg.num_steps = 600000
123
- cfg.lr = 5e-5
124
- cfg.weight_decay = 0.0
125
- cfg.betas = (0.9, 0.999)
126
- cfg.eps = 1.0e-8
127
- cfg.chunk_size = 16
128
- cfg.alpha = 0.7
129
- cfg.save_ckp_interval = 1000
130
- # -----------------------------------------------------------------------------
131
-
132
- # ----------------------------Pretrain Settings---------------------------------
133
- # Default: load 2d pretrain
134
- cfg.fix_weight = False
135
- cfg.load_match = False
136
- cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt'
137
- cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json'
138
- cfg.resume_checkpoint = 'img2video_ldm_0779000.pth'
139
- # -----------------------------------------------------------------------------
140
-
141
- # -----------------------------Visual-------------------------------------------
142
- # Visual videos
143
- cfg.viz_interval = 1000
144
- cfg.visual_train = {
145
- 'type': 'VisualVideoTextDuringTrain',
146
- }
147
- cfg.visual_inference = {
148
- 'type': 'VisualGeneratedVideos',
149
- }
150
- cfg.inference_list_path = ''
151
-
152
- # logging
153
- cfg.log_interval = 100
154
-
155
- # Default log_dir
156
- cfg.log_dir = 'workspace/output_data'
157
- # -----------------------------------------------------------------------------
158
-
159
- # ---------------------------Others--------------------------------------------
160
- # seed
161
- cfg.seed = 8888
162
-
163
- cfg.negative_prompt = 'painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \
164
- CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \
165
- signature, jpeg artifacts, deformed, lowres, over-smooth'
166
-
167
- cfg.positive_prompt = 'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \
168
- hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \
169
- skin pore detailing, hyper sharpness, perfect without deformations.'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video_to_video/utils/logger.py DELETED
@@ -1,94 +0,0 @@
1
- # Copyright (c) Alibaba, Inc. and its affiliates.
2
-
3
- import importlib
4
- import logging
5
- from typing import Optional
6
- from torch import distributed as dist
7
-
8
- init_loggers = {}
9
-
10
- formatter = logging.Formatter(
11
- '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
12
-
13
-
14
- def get_logger(log_file: Optional[str] = None,
15
- log_level: int = logging.INFO,
16
- file_mode: str = 'w'):
17
- """ Get logging logger
18
-
19
- Args:
20
- log_file: Log filename, if specified, file handler will be added to
21
- logger
22
- log_level: Logging level.
23
- file_mode: Specifies the mode to open the file, if filename is
24
- specified (if filemode is unspecified, it defaults to 'w').
25
- """
26
-
27
- logger_name = __name__.split('.')[0]
28
- logger = logging.getLogger(logger_name)
29
- logger.propagate = False
30
- if logger_name in init_loggers:
31
- add_file_handler_if_needed(logger, log_file, file_mode, log_level)
32
- return logger
33
-
34
- # handle duplicate logs to the console
35
- # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
36
- # to the root logger. As logger.propagate is True by default, this root
37
- # level handler causes logging messages from rank>0 processes to
38
- # unexpectedly show up on the console, creating much unwanted clutter.
39
- # To fix this issue, we set the root logger's StreamHandler, if any, to log
40
- # at the ERROR level.
41
- for handler in logger.root.handlers:
42
- if type(handler) is logging.StreamHandler:
43
- handler.setLevel(logging.ERROR)
44
-
45
- stream_handler = logging.StreamHandler()
46
- handlers = [stream_handler]
47
-
48
- if importlib.util.find_spec('torch') is not None:
49
- is_worker0 = is_master()
50
- else:
51
- is_worker0 = True
52
-
53
- if is_worker0 and log_file is not None:
54
- file_handler = logging.FileHandler(log_file, file_mode)
55
- handlers.append(file_handler)
56
-
57
- for handler in handlers:
58
- handler.setFormatter(formatter)
59
- handler.setLevel(log_level)
60
- logger.addHandler(handler)
61
-
62
- if is_worker0:
63
- logger.setLevel(log_level)
64
- else:
65
- logger.setLevel(logging.ERROR)
66
-
67
- init_loggers[logger_name] = True
68
-
69
- return logger
70
-
71
-
72
- def add_file_handler_if_needed(logger, log_file, file_mode, log_level):
73
- for handler in logger.handlers:
74
- if isinstance(handler, logging.FileHandler):
75
- return
76
-
77
- if importlib.util.find_spec('torch') is not None:
78
- is_worker0 = is_master()
79
- else:
80
- is_worker0 = True
81
-
82
- if is_worker0 and log_file is not None:
83
- file_handler = logging.FileHandler(log_file, file_mode)
84
- file_handler.setFormatter(formatter)
85
- file_handler.setLevel(log_level)
86
- logger.addHandler(file_handler)
87
-
88
-
89
- def is_master(group=None):
90
- return dist.get_rank(group) == 0 if is_dist() else True
91
-
92
-
93
- def is_dist():
94
- return dist.is_available() and dist.is_initialized()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video_to_video/utils/seed.py DELETED
@@ -1,14 +0,0 @@
1
- # Copyright (c) Alibaba, Inc. and its affiliates.
2
-
3
- import random
4
-
5
- import numpy as np
6
- import torch
7
-
8
-
9
- def setup_seed(seed):
10
- torch.manual_seed(seed)
11
- torch.cuda.manual_seed_all(seed)
12
- np.random.seed(seed)
13
- random.seed(seed)
14
- torch.backends.cudnn.deterministic = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video_to_video/video_to_video_model.py DELETED
@@ -1,237 +0,0 @@
1
- import os
2
- import os.path as osp
3
- import random
4
- from typing import Any, Dict
5
-
6
- import torch
7
- import torch.cuda.amp as amp
8
- import torch.nn.functional as F
9
-
10
- from video_to_video.modules import *
11
- from video_to_video.utils.config import cfg
12
- from video_to_video.diffusion.diffusion_sdedit import GaussianDiffusion
13
- from video_to_video.diffusion.schedules_sdedit import noise_schedule
14
- from video_to_video.utils.logger import get_logger
15
-
16
- from diffusers import AutoencoderKLTemporalDecoder
17
- import requests
18
-
19
- def download_model(url, model_path):
20
- if not os.path.exists(os.path.join(model_path, 'heavy_deg.pt')):
21
- print(f"Model not found at {model_path}, downloading...")
22
- response = requests.get(url, stream=True)
23
- with open(os.path.join(model_path, 'heavy_deg.pt'), 'wb') as f:
24
- for chunk in response.iter_content(chunk_size=1024):
25
- if chunk:
26
- f.write(chunk)
27
- print(f"Model downloaded to {model_path}")
28
- else:
29
- print(f"Model found at {model_path}, skipping download.")
30
-
31
-
32
- logger = get_logger()
33
-
34
- class VideoToVideo_sr():
35
- def __init__(self, opt, device=torch.device(f'cuda:0')):
36
- self.opt = opt
37
- self.device = device # torch.device(f'cuda:0')
38
-
39
- # text_encoder
40
- text_encoder = FrozenOpenCLIPEmbedder(device=self.device, pretrained="laion2b_s32b_b79k")
41
- text_encoder.model.to(self.device)
42
- self.text_encoder = text_encoder
43
- logger.info(f'Build encoder with FrozenOpenCLIPEmbedder')
44
-
45
- # U-Net with ControlNet
46
- generator = ControlledV2VUNet()
47
- generator = generator.to(self.device)
48
- generator.eval()
49
-
50
- # 确保 cfg.model_path 是文件夹路径,不要加上文件名
51
- cfg.model_path = opt.model_path
52
- # download weight
53
- model_url = 'https://huggingface.co/SherryX/STAR/resolve/main/I2VGen-XL-based/heavy_deg.pt'
54
- download_model(model_url, cfg.model_path)
55
-
56
- # 拼接完整路径
57
- model_file_path = os.path.join('pretrained_weight', 'I2VGen-XL-based', 'heavy_deg.pt')
58
- print('model_file_path:', model_file_path)
59
-
60
- # 加载模型
61
- load_dict = torch.load(model_file_path, map_location='cpu')
62
-
63
- if 'state_dict' in load_dict:
64
- load_dict = load_dict['state_dict']
65
- ret = generator.load_state_dict(load_dict, strict=False)
66
-
67
- self.generator = generator.half()
68
- logger.info('Load model path {}, with local status {}'.format(cfg.model_path, ret))
69
-
70
- # Noise scheduler
71
- sigmas = noise_schedule(
72
- schedule='logsnr_cosine_interp',
73
- n=1000,
74
- zero_terminal_snr=True,
75
- scale_min=2.0,
76
- scale_max=4.0)
77
- diffusion = GaussianDiffusion(sigmas=sigmas)
78
- self.diffusion = diffusion
79
- logger.info('Build diffusion with GaussianDiffusion')
80
-
81
- # Temporal VAE
82
- vae = AutoencoderKLTemporalDecoder.from_pretrained(
83
- "stabilityai/stable-video-diffusion-img2vid", subfolder="vae", variant="fp16"
84
- )
85
- vae.eval()
86
- vae.requires_grad_(False)
87
- vae.to(self.device)
88
- self.vae = vae
89
- logger.info('Build Temporal VAE')
90
-
91
- torch.cuda.empty_cache()
92
-
93
- self.negative_prompt = cfg.negative_prompt
94
- self.positive_prompt = cfg.positive_prompt
95
-
96
- negative_y = text_encoder(self.negative_prompt).detach()
97
- self.negative_y = negative_y
98
-
99
- self.chunk_size = opt.chunk_size
100
-
101
-
102
- def test(self, input: Dict[str, Any], total_noise_levels=1000, \
103
- steps=50, solver_mode='fast', guide_scale=7.5, max_chunk_len=32):
104
- video_data = input['video_data']
105
- y = input['y']
106
- (target_h, target_w) = input['target_res']
107
-
108
- video_data = F.interpolate(video_data, [target_h,target_w], mode='bilinear')
109
-
110
- logger.info(f'video_data shape: {video_data.shape}')
111
- frames_num, _, h, w = video_data.shape
112
-
113
- padding = pad_to_fit(h, w)
114
- video_data = F.pad(video_data, padding, 'constant', 1)
115
-
116
- video_data = video_data.unsqueeze(0)
117
- bs = 1
118
- video_data = video_data.to(self.device)
119
-
120
- video_data_feature = self.vae_encode(video_data)
121
- torch.cuda.empty_cache()
122
-
123
- y = self.text_encoder(y).detach()
124
-
125
- with amp.autocast(enabled=True):
126
-
127
- t = torch.LongTensor([total_noise_levels-1]).to(self.device)
128
- noised_lr = self.diffusion.diffuse(video_data_feature, t)
129
-
130
- model_kwargs = [{'y': y}, {'y': self.negative_y}]
131
- model_kwargs.append({'hint': video_data_feature})
132
-
133
- torch.cuda.empty_cache()
134
- chunk_inds = make_chunks(frames_num, interp_f_num=0, max_chunk_len=max_chunk_len) if frames_num > max_chunk_len else None
135
-
136
- solver = 'dpmpp_2m_sde' # 'heun' | 'dpmpp_2m_sde'
137
- gen_vid = self.diffusion.sample_sr(
138
- noise=noised_lr,
139
- model=self.generator,
140
- model_kwargs=model_kwargs,
141
- guide_scale=guide_scale,
142
- guide_rescale=0.2,
143
- solver=solver,
144
- solver_mode=solver_mode,
145
- return_intermediate=None,
146
- steps=steps,
147
- t_max=total_noise_levels - 1,
148
- t_min=0,
149
- discretization='trailing',
150
- chunk_inds=chunk_inds,)
151
- torch.cuda.empty_cache()
152
-
153
- logger.info(f'sampling, finished.')
154
- vid_tensor_gen = self.vae_decode_chunk(gen_vid, chunk_size=self.chunk_size)
155
-
156
- logger.info(f'temporal vae decoding, finished.')
157
-
158
- w1, w2, h1, h2 = padding
159
- vid_tensor_gen = vid_tensor_gen[:,:,h1:h+h1,w1:w+w1]
160
-
161
- gen_video = rearrange(
162
- vid_tensor_gen, '(b f) c h w -> b c f h w', b=bs)
163
-
164
- torch.cuda.empty_cache()
165
-
166
- return gen_video.type(torch.float32).cpu()
167
-
168
- def temporal_vae_decode(self, z, num_f):
169
- return self.vae.decode(z/self.vae.config.scaling_factor, num_frames=num_f).sample
170
-
171
- def vae_decode_chunk(self, z, chunk_size=3):
172
- z = rearrange(z, "b c f h w -> (b f) c h w")
173
- video = []
174
- for ind in range(0, z.shape[0], chunk_size):
175
- num_f = z[ind:ind+chunk_size].shape[0]
176
- video.append(self.temporal_vae_decode(z[ind:ind+chunk_size],num_f))
177
- video = torch.cat(video)
178
- return video
179
-
180
- def vae_encode(self, t, chunk_size=1):
181
- num_f = t.shape[1]
182
- t = rearrange(t, "b f c h w -> (b f) c h w")
183
- z_list = []
184
- for ind in range(0,t.shape[0],chunk_size):
185
- z_list.append(self.vae.encode(t[ind:ind+chunk_size]).latent_dist.sample())
186
- z = torch.cat(z_list, dim=0)
187
- z = rearrange(z, "(b f) c h w -> b c f h w", f=num_f)
188
- return z * self.vae.config.scaling_factor
189
-
190
-
191
- def pad_to_fit(h, w):
192
- BEST_H, BEST_W = 720, 1280
193
-
194
- if h < BEST_H:
195
- h1, h2 = _create_pad(h, BEST_H)
196
- elif h == BEST_H:
197
- h1 = h2 = 0
198
- else:
199
- h1 = 0
200
- h2 = int((h + 48) // 64 * 64) + 64 - 48 - h
201
-
202
- if w < BEST_W:
203
- w1, w2 = _create_pad(w, BEST_W)
204
- elif w == BEST_W:
205
- w1 = w2 = 0
206
- else:
207
- w1 = 0
208
- w2 = int(w // 64 * 64) + 64 - w
209
- return (w1, w2, h1, h2)
210
-
211
- def _create_pad(h, max_len):
212
- h1 = int((max_len - h) // 2)
213
- h2 = max_len - h1 - h
214
- return h1, h2
215
-
216
-
217
- def make_chunks(f_num, interp_f_num, max_chunk_len, chunk_overlap_ratio=0.5):
218
- MAX_CHUNK_LEN = max_chunk_len
219
- MAX_O_LEN = MAX_CHUNK_LEN * chunk_overlap_ratio
220
- chunk_len = int((MAX_CHUNK_LEN-1)//(1+interp_f_num)*(interp_f_num+1)+1)
221
- o_len = int((MAX_O_LEN-1)//(1+interp_f_num)*(interp_f_num+1)+1)
222
- chunk_inds = sliding_windows_1d(f_num, chunk_len, o_len)
223
- return chunk_inds
224
-
225
-
226
- def sliding_windows_1d(length, window_size, overlap_size):
227
- stride = window_size - overlap_size
228
- ind = 0
229
- coords = []
230
- while ind<length:
231
- if ind+window_size*1.25>=length:
232
- coords.append((ind,length))
233
- break
234
- else:
235
- coords.append((ind,ind+window_size))
236
- ind += stride
237
- return coords