H-Liu1997 commited on
Commit
d378637
·
verified ·
1 Parent(s): f6b9ddc

Upload models/tools/wan_vae.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/tools/wan_vae.py +1117 -0
models/tools/wan_vae.py ADDED
@@ -0,0 +1,1117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This module uses modified code from Alibaba Wan Team
2
+ # Original source: https://github.com/Wan-Video/Wan2.2
3
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
+ # Modified to support 1d, 2d, 3d features with (B, C, T, 1, 1), (B, C, T, L, 1), (B, C, T, H, W) respectively.
5
+
6
+ import logging
7
+
8
+ import torch
9
+ import torch.cuda.amp as amp
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+
14
+ CACHE_T = 2
15
+
16
+
17
+ class CausalConv3d(nn.Conv3d):
18
+ """
19
+ Causal 3d convolusion.
20
+ """
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self._padding = (
25
+ self.padding[2],
26
+ self.padding[2],
27
+ self.padding[1],
28
+ self.padding[1],
29
+ 2 * self.padding[0],
30
+ 0,
31
+ )
32
+ self.padding = (0, 0, 0)
33
+
34
+ def forward(self, x, cache_x=None):
35
+ padding = list(self._padding)
36
+ if cache_x is not None and self._padding[4] > 0:
37
+ cache_x = cache_x.to(x.device)
38
+ x = torch.cat([cache_x, x], dim=2)
39
+ padding[4] -= cache_x.shape[2]
40
+ x = F.pad(x, padding)
41
+
42
+ return super().forward(x)
43
+
44
+
45
+ class RMS_norm(nn.Module):
46
+ def __init__(self, dim, bias=False):
47
+ super().__init__()
48
+ broadcastable_dims = (1, 1, 1)
49
+ shape = (dim, *broadcastable_dims)
50
+
51
+ self.scale = dim**0.5
52
+ self.gamma = nn.Parameter(torch.ones(shape))
53
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
54
+
55
+ def forward(self, x):
56
+ return F.normalize(x, dim=(1)) * self.scale * self.gamma + self.bias
57
+
58
+
59
+ class Upsample(nn.Upsample):
60
+ def forward(self, x):
61
+ """
62
+ Fix bfloat16 support for nearest neighbor interpolation.
63
+ """
64
+ return super().forward(x.float()).type_as(x)
65
+
66
+
67
+ class Resample(nn.Module):
68
+ def __init__(self, dim, mode, spatial_dim=2):
69
+ assert mode in (
70
+ "none",
71
+ "upsample_temporal",
72
+ "upsample_spatial",
73
+ "upsample_temporal_spatial",
74
+ "downsample_temporal",
75
+ "downsample_spatial",
76
+ "downsample_temporal_spatial",
77
+ )
78
+ super().__init__()
79
+ self.dim = dim
80
+ self.mode = mode
81
+
82
+ # layers
83
+ if mode == "upsample_temporal":
84
+ self.resample = nn.Identity()
85
+ self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
86
+ elif mode == "upsample_spatial" and spatial_dim == 2:
87
+ self.resample = nn.Sequential(
88
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
89
+ nn.Conv2d(dim, dim, 3, padding=1),
90
+ )
91
+ elif mode == "upsample_spatial" and spatial_dim == 1:
92
+ self.resample = nn.Sequential(
93
+ Upsample(scale_factor=(2.0, 1.0), mode="nearest-exact"),
94
+ nn.Conv2d(dim, dim, (3, 1), padding=(1, 0)),
95
+ )
96
+ elif mode == "upsample_temporal_spatial" and spatial_dim == 2:
97
+ self.resample = nn.Sequential(
98
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
99
+ nn.Conv2d(dim, dim, 3, padding=1),
100
+ )
101
+ self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
102
+ elif mode == "upsample_temporal_spatial" and spatial_dim == 1:
103
+ self.resample = nn.Sequential(
104
+ Upsample(scale_factor=(2.0, 1.0), mode="nearest-exact"),
105
+ nn.Conv2d(dim, dim, (3, 1), padding=(1, 0)),
106
+ )
107
+ self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
108
+ elif mode == "downsample_temporal":
109
+ self.resample = nn.Identity()
110
+ self.time_conv = CausalConv3d(
111
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
112
+ )
113
+ elif mode == "downsample_spatial" and spatial_dim == 2:
114
+ self.resample = nn.Sequential(
115
+ nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
116
+ )
117
+ elif mode == "downsample_spatial" and spatial_dim == 1:
118
+ self.resample = nn.Sequential(
119
+ nn.ZeroPad2d((0, 0, 0, 1)), nn.Conv2d(dim, dim, (3, 1), stride=(2, 1))
120
+ )
121
+ elif mode == "downsample_temporal_spatial" and spatial_dim == 2:
122
+ self.resample = nn.Sequential(
123
+ nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
124
+ )
125
+ self.time_conv = CausalConv3d(
126
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
127
+ )
128
+ elif mode == "downsample_temporal_spatial" and spatial_dim == 1:
129
+ self.resample = nn.Sequential(
130
+ nn.ZeroPad2d((0, 0, 0, 1)), nn.Conv2d(dim, dim, (3, 1), stride=(2, 1))
131
+ )
132
+ self.time_conv = CausalConv3d(
133
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
134
+ )
135
+ else:
136
+ self.resample = nn.Identity()
137
+
138
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
139
+ b, c, t, h, w = x.size()
140
+ if self.mode == "upsample_temporal_spatial" or self.mode == "upsample_temporal":
141
+ if feat_cache is not None:
142
+ idx = feat_idx[0]
143
+ if feat_cache[idx] is None:
144
+ feat_cache[idx] = "Rep"
145
+ feat_idx[0] += 1
146
+ else:
147
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
148
+ if (
149
+ cache_x.shape[2] < 2
150
+ and feat_cache[idx] is not None
151
+ and feat_cache[idx] != "Rep"
152
+ ):
153
+ # cache last frame of last two chunk
154
+ cache_x = torch.cat(
155
+ [
156
+ feat_cache[idx][:, :, -1, :, :]
157
+ .unsqueeze(2)
158
+ .to(cache_x.device),
159
+ cache_x,
160
+ ],
161
+ dim=2,
162
+ )
163
+ if (
164
+ cache_x.shape[2] < 2
165
+ and feat_cache[idx] is not None
166
+ and feat_cache[idx] == "Rep"
167
+ ):
168
+ cache_x = torch.cat(
169
+ [torch.zeros_like(cache_x).to(cache_x.device), cache_x],
170
+ dim=2,
171
+ )
172
+ if feat_cache[idx] == "Rep":
173
+ x = self.time_conv(x)
174
+ else:
175
+ x = self.time_conv(x, feat_cache[idx])
176
+ feat_cache[idx] = cache_x
177
+ feat_idx[0] += 1
178
+ x = x.reshape(b, 2, c, t, h, w)
179
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
180
+ x = x.reshape(b, c, t * 2, h, w)
181
+ t = x.shape[2]
182
+ x = rearrange(x, "b c t h w -> (b t) c h w")
183
+ x = self.resample(x)
184
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
185
+
186
+ if (
187
+ self.mode == "downsample_temporal_spatial"
188
+ or self.mode == "downsample_temporal"
189
+ ):
190
+ if feat_cache is not None:
191
+ idx = feat_idx[0]
192
+ if feat_cache[idx] is None:
193
+ feat_cache[idx] = x.clone()
194
+ feat_idx[0] += 1
195
+ else:
196
+ cache_x = x[:, :, -1:, :, :].clone()
197
+ x = self.time_conv(
198
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)
199
+ )
200
+ feat_cache[idx] = cache_x
201
+ feat_idx[0] += 1
202
+ return x
203
+
204
+ def init_weight(self, conv):
205
+ conv_weight = conv.weight.detach().clone()
206
+ nn.init.zeros_(conv_weight)
207
+ c1, c2, t, h, w = conv_weight.size()
208
+ one_matrix = torch.eye(c1, c2)
209
+ init_matrix = one_matrix
210
+ nn.init.zeros_(conv_weight)
211
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
212
+ conv.weight = nn.Parameter(conv_weight)
213
+ nn.init.zeros_(conv.bias.data)
214
+
215
+ def init_weight2(self, conv):
216
+ conv_weight = conv.weight.data.detach().clone()
217
+ nn.init.zeros_(conv_weight)
218
+ c1, c2, t, h, w = conv_weight.size()
219
+ init_matrix = torch.eye(c1 // 2, c2)
220
+ conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
221
+ conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
222
+ conv.weight = nn.Parameter(conv_weight)
223
+ nn.init.zeros_(conv.bias.data)
224
+
225
+
226
+ class ResidualBlock(nn.Module):
227
+ def __init__(self, in_dim, out_dim, spatial_dim=2, dropout=0.0):
228
+ super().__init__()
229
+ self.in_dim = in_dim
230
+ self.out_dim = out_dim
231
+ self.spatial_dim = spatial_dim
232
+
233
+ if spatial_dim == 2:
234
+ kernel_size = (3, 3, 3)
235
+ padding = (1, 1, 1)
236
+ elif spatial_dim == 1:
237
+ kernel_size = (3, 3, 1)
238
+ padding = (1, 1, 0)
239
+ elif spatial_dim == 0:
240
+ kernel_size = (3, 1, 1)
241
+ padding = (1, 0, 0)
242
+ else:
243
+ kernel_size = (3, 3, 3)
244
+ padding = (1, 1, 1)
245
+
246
+ # layers
247
+ self.residual = nn.Sequential(
248
+ RMS_norm(in_dim),
249
+ nn.SiLU(),
250
+ CausalConv3d(in_dim, out_dim, kernel_size, padding=padding),
251
+ RMS_norm(out_dim),
252
+ nn.SiLU(),
253
+ nn.Dropout(dropout),
254
+ CausalConv3d(out_dim, out_dim, kernel_size, padding=padding),
255
+ )
256
+ self.shortcut = (
257
+ CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
258
+ )
259
+
260
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
261
+ h = self.shortcut(x)
262
+ for layer in self.residual:
263
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
264
+ idx = feat_idx[0]
265
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
266
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
267
+ # cache last frame of last two chunk
268
+ cache_x = torch.cat(
269
+ [
270
+ feat_cache[idx][:, :, -1, :, :]
271
+ .unsqueeze(2)
272
+ .to(cache_x.device),
273
+ cache_x,
274
+ ],
275
+ dim=2,
276
+ )
277
+ x = layer(x, feat_cache[idx])
278
+ feat_cache[idx] = cache_x
279
+ feat_idx[0] += 1
280
+ else:
281
+ x = layer(x)
282
+ return x + h
283
+
284
+
285
+ class AttentionBlock(nn.Module):
286
+ """
287
+ Causal self-attention with a single head.
288
+ """
289
+
290
+ def __init__(self, dim):
291
+ super().__init__()
292
+ self.dim = dim
293
+
294
+ # layers
295
+ self.norm = RMS_norm(dim)
296
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
297
+ self.proj = nn.Conv2d(dim, dim, 1)
298
+
299
+ # zero out the last layer params
300
+ nn.init.zeros_(self.proj.weight)
301
+
302
+ def forward(self, x):
303
+ identity = x
304
+ b, c, t, h, w = x.size()
305
+ x = self.norm(x)
306
+ x = rearrange(x, "b c t h w -> (b t) c h w")
307
+ # compute query, key, value
308
+ q, k, v = (
309
+ self.to_qkv(x)
310
+ .reshape(b * t, 1, c * 3, -1)
311
+ .permute(0, 1, 3, 2)
312
+ .contiguous()
313
+ .chunk(3, dim=-1)
314
+ )
315
+
316
+ q = q.contiguous()
317
+ k = k.contiguous()
318
+ v = v.contiguous()
319
+
320
+ # apply attention
321
+ x = F.scaled_dot_product_attention(
322
+ q,
323
+ k,
324
+ v,
325
+ )
326
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
327
+
328
+ # output
329
+ x = self.proj(x)
330
+ x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
331
+ return x + identity
332
+
333
+
334
+ def patchify(x, patch_size):
335
+ if patch_size == 1:
336
+ return x
337
+ if x.dim() == 4:
338
+ x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
339
+ elif x.dim() == 5:
340
+ x = rearrange(
341
+ x,
342
+ "b c f (h q) (w r) -> b (c r q) f h w",
343
+ q=patch_size,
344
+ r=patch_size,
345
+ )
346
+ else:
347
+ raise ValueError(f"Invalid input shape: {x.shape}")
348
+
349
+ return x
350
+
351
+
352
+ def unpatchify(x, patch_size):
353
+ if patch_size == 1:
354
+ return x
355
+
356
+ if x.dim() == 4:
357
+ x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
358
+ elif x.dim() == 5:
359
+ x = rearrange(
360
+ x,
361
+ "b (c r q) f h w -> b c f (h q) (w r)",
362
+ q=patch_size,
363
+ r=patch_size,
364
+ )
365
+ return x
366
+
367
+
368
+ class AvgDown3D(nn.Module):
369
+ def __init__(
370
+ self,
371
+ in_channels,
372
+ out_channels,
373
+ factor_t,
374
+ factor_h=1,
375
+ factor_w=1,
376
+ ):
377
+ super().__init__()
378
+ self.in_channels = in_channels
379
+ self.out_channels = out_channels
380
+ self.factor_t = factor_t
381
+ self.factor_h = factor_h
382
+ self.factor_w = factor_w
383
+ self.factor = self.factor_t * self.factor_h * self.factor_w
384
+
385
+ assert in_channels * self.factor % out_channels == 0
386
+ self.group_size = in_channels * self.factor // out_channels
387
+
388
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
389
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
390
+ pad = (0, 0, 0, 0, pad_t, 0)
391
+ x = F.pad(x, pad)
392
+ B, C, T, H, W = x.shape
393
+ x = x.view(
394
+ B,
395
+ C,
396
+ T // self.factor_t,
397
+ self.factor_t,
398
+ H // self.factor_h,
399
+ self.factor_h,
400
+ W // self.factor_w,
401
+ self.factor_w,
402
+ )
403
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
404
+ x = x.view(
405
+ B,
406
+ C * self.factor,
407
+ T // self.factor_t,
408
+ H // self.factor_h,
409
+ W // self.factor_w,
410
+ )
411
+ x = x.view(
412
+ B,
413
+ self.out_channels,
414
+ self.group_size,
415
+ T // self.factor_t,
416
+ H // self.factor_h,
417
+ W // self.factor_w,
418
+ )
419
+ x = x.mean(dim=2)
420
+ return x
421
+
422
+
423
+ class DupUp3D(nn.Module):
424
+ def __init__(
425
+ self,
426
+ in_channels: int,
427
+ out_channels: int,
428
+ factor_t,
429
+ factor_h=1,
430
+ factor_w=1,
431
+ ):
432
+ super().__init__()
433
+ self.in_channels = in_channels
434
+ self.out_channels = out_channels
435
+
436
+ self.factor_t = factor_t
437
+ self.factor_h = factor_h
438
+ self.factor_w = factor_w
439
+ self.factor = self.factor_t * self.factor_h * self.factor_w
440
+
441
+ assert out_channels * self.factor % in_channels == 0
442
+ self.repeats = out_channels * self.factor // in_channels
443
+
444
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
445
+ x = x.repeat_interleave(self.repeats, dim=1)
446
+ x = x.view(
447
+ x.size(0),
448
+ self.out_channels,
449
+ self.factor_t,
450
+ self.factor_h,
451
+ self.factor_w,
452
+ x.size(2),
453
+ x.size(3),
454
+ x.size(4),
455
+ )
456
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
457
+ x = x.view(
458
+ x.size(0),
459
+ self.out_channels,
460
+ x.size(2) * self.factor_t,
461
+ x.size(4) * self.factor_h,
462
+ x.size(6) * self.factor_w,
463
+ )
464
+ if first_chunk:
465
+ x = x[:, :, self.factor_t - 1 :, :, :]
466
+ return x
467
+
468
+
469
+ class Down_ResidualBlock(nn.Module):
470
+ def __init__(
471
+ self,
472
+ in_dim,
473
+ out_dim,
474
+ dropout,
475
+ mult,
476
+ temperal_downsample=False,
477
+ spatial_downsample=False,
478
+ spatial_dim=2,
479
+ ):
480
+ super().__init__()
481
+
482
+ # Determine spatial factors based on spatial_downsample
483
+ down_flag = temperal_downsample or spatial_downsample
484
+ factor_h, factor_w = 1, 1
485
+ if spatial_downsample:
486
+ if spatial_dim == 2:
487
+ factor_h, factor_w = 2, 2
488
+ elif spatial_dim == 1:
489
+ factor_h, factor_w = 2, 1
490
+
491
+ # Shortcut path with downsample
492
+ self.avg_shortcut = AvgDown3D(
493
+ in_dim,
494
+ out_dim,
495
+ factor_t=2 if temperal_downsample else 1,
496
+ factor_h=factor_h,
497
+ factor_w=factor_w,
498
+ )
499
+
500
+ # Main path with residual blocks and downsample
501
+ downsamples = []
502
+ for _ in range(mult):
503
+ downsamples.append(ResidualBlock(in_dim, out_dim, spatial_dim, dropout))
504
+ in_dim = out_dim
505
+
506
+ # Add the final downsample block
507
+ if down_flag:
508
+ if temperal_downsample and spatial_downsample and spatial_dim > 0:
509
+ mode = "downsample_temporal_spatial"
510
+ elif temperal_downsample:
511
+ mode = "downsample_temporal"
512
+ elif spatial_downsample and spatial_dim > 0:
513
+ mode = "downsample_spatial"
514
+ else:
515
+ mode = "none"
516
+ downsamples.append(Resample(out_dim, mode=mode, spatial_dim=spatial_dim))
517
+
518
+ self.downsamples = nn.Sequential(*downsamples)
519
+
520
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
521
+ x_copy = x.clone()
522
+ for module in self.downsamples:
523
+ x = module(x, feat_cache, feat_idx)
524
+
525
+ return x + self.avg_shortcut(x_copy)
526
+
527
+
528
+ class Up_ResidualBlock(nn.Module):
529
+ def __init__(
530
+ self,
531
+ in_dim,
532
+ out_dim,
533
+ dropout,
534
+ mult,
535
+ temperal_upsample=False,
536
+ spatial_upsample=False,
537
+ spatial_dim=2,
538
+ ):
539
+ super().__init__()
540
+
541
+ # Determine spatial factors based on spatial_upsample
542
+ up_flag = temperal_upsample or spatial_upsample
543
+ factor_h, factor_w = 1, 1
544
+ if spatial_upsample:
545
+ if spatial_dim == 2:
546
+ factor_h, factor_w = 2, 2
547
+ elif spatial_dim == 1:
548
+ factor_h, factor_w = 2, 1
549
+
550
+ # Shortcut path with upsample
551
+ if up_flag:
552
+ self.avg_shortcut = DupUp3D(
553
+ in_dim,
554
+ out_dim,
555
+ factor_t=2 if temperal_upsample else 1,
556
+ factor_h=factor_h,
557
+ factor_w=factor_w,
558
+ )
559
+ else:
560
+ self.avg_shortcut = None
561
+
562
+ # Main path with residual blocks and upsample
563
+ upsamples = []
564
+ for _ in range(mult):
565
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
566
+ in_dim = out_dim
567
+
568
+ # Add the final upsample block
569
+ if up_flag:
570
+ if temperal_upsample and spatial_upsample and spatial_dim > 0:
571
+ mode = "upsample_temporal_spatial"
572
+ elif temperal_upsample:
573
+ mode = "upsample_temporal"
574
+ elif spatial_upsample and spatial_dim > 0:
575
+ mode = "upsample_spatial"
576
+ else:
577
+ mode = "none"
578
+ upsamples.append(Resample(out_dim, mode=mode, spatial_dim=spatial_dim))
579
+
580
+ self.upsamples = nn.Sequential(*upsamples)
581
+
582
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
583
+ x_main = x.clone()
584
+ for module in self.upsamples:
585
+ x_main = module(x_main, feat_cache, feat_idx)
586
+ if self.avg_shortcut is not None:
587
+ x_shortcut = self.avg_shortcut(x, first_chunk)
588
+ return x_main + x_shortcut
589
+ else:
590
+ return x_main
591
+
592
+
593
+ class Encoder3d(nn.Module):
594
+ def __init__(
595
+ self,
596
+ input_dim=12,
597
+ dim=128,
598
+ z_dim=4,
599
+ dim_mult=[1, 2, 4, 4],
600
+ num_res_blocks=2,
601
+ attn_scales=[],
602
+ temperal_downsample=[True, True, False],
603
+ spatial_downsample=[True, True, True],
604
+ spatial_dim=2,
605
+ dropout=0.0,
606
+ ):
607
+ super().__init__()
608
+ self.dim = dim
609
+ self.z_dim = z_dim
610
+ self.dim_mult = dim_mult
611
+ self.num_res_blocks = num_res_blocks
612
+ self.attn_scales = attn_scales
613
+ self.temperal_downsample = temperal_downsample
614
+
615
+ # dimensions
616
+ dims = [dim * u for u in [1] + dim_mult]
617
+ scale = 1.0
618
+
619
+ if spatial_dim == 2:
620
+ kernel_size = (3, 3, 3)
621
+ padding = (1, 1, 1)
622
+ elif spatial_dim == 1:
623
+ kernel_size = (3, 3, 1)
624
+ padding = (1, 1, 0)
625
+ elif spatial_dim == 0:
626
+ kernel_size = (3, 1, 1)
627
+ padding = (1, 0, 0)
628
+ else:
629
+ kernel_size = (3, 3, 3)
630
+ padding = (1, 1, 1)
631
+
632
+ # init block
633
+ self.conv1 = CausalConv3d(input_dim, dims[0], kernel_size, padding=padding)
634
+
635
+ # downsample blocks
636
+ downsamples = []
637
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
638
+ t_down_flag = (
639
+ temperal_downsample[i] if i < len(temperal_downsample) else False
640
+ )
641
+ spatial_down_flag = (
642
+ spatial_downsample[i] if i < len(spatial_downsample) else False
643
+ )
644
+ downsamples.append(
645
+ Down_ResidualBlock(
646
+ in_dim=in_dim,
647
+ out_dim=out_dim,
648
+ dropout=dropout,
649
+ mult=num_res_blocks,
650
+ temperal_downsample=t_down_flag,
651
+ spatial_downsample=spatial_down_flag,
652
+ spatial_dim=spatial_dim,
653
+ )
654
+ )
655
+ scale /= 2.0
656
+ self.downsamples = nn.Sequential(*downsamples)
657
+
658
+ # middle blocks
659
+ if spatial_dim > 0:
660
+ self.middle = nn.Sequential(
661
+ ResidualBlock(
662
+ out_dim, out_dim, spatial_dim=spatial_dim, dropout=dropout
663
+ ),
664
+ AttentionBlock(out_dim),
665
+ ResidualBlock(
666
+ out_dim, out_dim, spatial_dim=spatial_dim, dropout=dropout
667
+ ),
668
+ )
669
+ else:
670
+ self.middle = nn.Sequential(
671
+ ResidualBlock(
672
+ out_dim, out_dim, spatial_dim=spatial_dim, dropout=dropout
673
+ ),
674
+ RMS_norm(out_dim),
675
+ CausalConv3d(out_dim, out_dim, 1),
676
+ ResidualBlock(
677
+ out_dim, out_dim, spatial_dim=spatial_dim, dropout=dropout
678
+ ),
679
+ )
680
+
681
+ # # output blocks
682
+ self.head = nn.Sequential(
683
+ RMS_norm(out_dim),
684
+ nn.SiLU(),
685
+ CausalConv3d(out_dim, z_dim, kernel_size, padding=padding),
686
+ )
687
+
688
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
689
+ if feat_cache is not None:
690
+ idx = feat_idx[0]
691
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
692
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
693
+ cache_x = torch.cat(
694
+ [
695
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
696
+ cache_x,
697
+ ],
698
+ dim=2,
699
+ )
700
+ x = self.conv1(x, feat_cache[idx])
701
+ feat_cache[idx] = cache_x
702
+ feat_idx[0] += 1
703
+ else:
704
+ x = self.conv1(x)
705
+
706
+ ## downsamples
707
+ for layer in self.downsamples:
708
+ if feat_cache is not None:
709
+ x = layer(x, feat_cache, feat_idx)
710
+ else:
711
+ x = layer(x)
712
+
713
+ ## middle
714
+ for layer in self.middle:
715
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
716
+ x = layer(x, feat_cache, feat_idx)
717
+ else:
718
+ x = layer(x)
719
+
720
+ ## head
721
+ for layer in self.head:
722
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
723
+ idx = feat_idx[0]
724
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
725
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
726
+ cache_x = torch.cat(
727
+ [
728
+ feat_cache[idx][:, :, -1, :, :]
729
+ .unsqueeze(2)
730
+ .to(cache_x.device),
731
+ cache_x,
732
+ ],
733
+ dim=2,
734
+ )
735
+ x = layer(x, feat_cache[idx])
736
+ feat_cache[idx] = cache_x
737
+ feat_idx[0] += 1
738
+ else:
739
+ x = layer(x)
740
+
741
+ return x
742
+
743
+
744
+ class Decoder3d(nn.Module):
745
+ def __init__(
746
+ self,
747
+ output_dim=12,
748
+ dim=128,
749
+ z_dim=4,
750
+ dim_mult=[1, 2, 4, 4],
751
+ num_res_blocks=2,
752
+ attn_scales=[],
753
+ temperal_upsample=[False, True, True],
754
+ spatial_upsample=[True, True, True],
755
+ spatial_dim=2,
756
+ dropout=0.0,
757
+ ):
758
+ super().__init__()
759
+ self.dim = dim
760
+ self.z_dim = z_dim
761
+ self.dim_mult = dim_mult
762
+ self.num_res_blocks = num_res_blocks
763
+ self.attn_scales = attn_scales
764
+ self.temperal_upsample = temperal_upsample
765
+ self.spatial_upsample = spatial_upsample
766
+
767
+ # dimensions
768
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
769
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
770
+ if spatial_dim == 2:
771
+ kernel_size = (3, 3, 3)
772
+ padding = (1, 1, 1)
773
+ elif spatial_dim == 1:
774
+ kernel_size = (3, 3, 1)
775
+ padding = (1, 1, 0)
776
+ elif spatial_dim == 0:
777
+ kernel_size = (3, 1, 1)
778
+ padding = (1, 0, 0)
779
+ else:
780
+ kernel_size = (3, 3, 3)
781
+ padding = (1, 1, 1)
782
+ # init block
783
+ self.conv1 = CausalConv3d(z_dim, dims[0], kernel_size, padding=padding)
784
+
785
+ # middle blocks
786
+ if spatial_dim > 0:
787
+ self.middle = nn.Sequential(
788
+ ResidualBlock(
789
+ dims[0], dims[0], spatial_dim=spatial_dim, dropout=dropout
790
+ ),
791
+ AttentionBlock(dims[0]),
792
+ ResidualBlock(
793
+ dims[0], dims[0], spatial_dim=spatial_dim, dropout=dropout
794
+ ),
795
+ )
796
+ else:
797
+ self.middle = nn.Sequential(
798
+ ResidualBlock(
799
+ dims[0], dims[0], spatial_dim=spatial_dim, dropout=dropout
800
+ ),
801
+ RMS_norm(dims[0]),
802
+ CausalConv3d(dims[0], dims[0], 1),
803
+ ResidualBlock(
804
+ dims[0], dims[0], spatial_dim=spatial_dim, dropout=dropout
805
+ ),
806
+ )
807
+
808
+ # upsample blocks
809
+ upsamples = []
810
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
811
+ t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
812
+ spatial_up_flag = (
813
+ spatial_upsample[i] if i < len(spatial_upsample) else False
814
+ )
815
+ upsamples.append(
816
+ Up_ResidualBlock(
817
+ in_dim=in_dim,
818
+ out_dim=out_dim,
819
+ dropout=dropout,
820
+ mult=num_res_blocks + 1,
821
+ temperal_upsample=t_up_flag,
822
+ spatial_upsample=spatial_up_flag,
823
+ spatial_dim=spatial_dim,
824
+ )
825
+ )
826
+ self.upsamples = nn.Sequential(*upsamples)
827
+
828
+ # output blocks
829
+ self.head = nn.Sequential(
830
+ RMS_norm(out_dim),
831
+ nn.SiLU(),
832
+ CausalConv3d(out_dim, output_dim, kernel_size, padding=padding),
833
+ )
834
+
835
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
836
+ if feat_cache is not None:
837
+ idx = feat_idx[0]
838
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
839
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
840
+ cache_x = torch.cat(
841
+ [
842
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
843
+ cache_x,
844
+ ],
845
+ dim=2,
846
+ )
847
+ x = self.conv1(x, feat_cache[idx])
848
+ feat_cache[idx] = cache_x
849
+ feat_idx[0] += 1
850
+ else:
851
+ x = self.conv1(x)
852
+
853
+ for layer in self.middle:
854
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
855
+ x = layer(x, feat_cache, feat_idx)
856
+ else:
857
+ x = layer(x)
858
+
859
+ ## upsamples
860
+ for layer in self.upsamples:
861
+ if feat_cache is not None:
862
+ x = layer(x, feat_cache, feat_idx, first_chunk)
863
+ else:
864
+ x = layer(x)
865
+
866
+ ## head
867
+ for layer in self.head:
868
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
869
+ idx = feat_idx[0]
870
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
871
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
872
+ cache_x = torch.cat(
873
+ [
874
+ feat_cache[idx][:, :, -1, :, :]
875
+ .unsqueeze(2)
876
+ .to(cache_x.device),
877
+ cache_x,
878
+ ],
879
+ dim=2,
880
+ )
881
+ x = layer(x, feat_cache[idx])
882
+ feat_cache[idx] = cache_x
883
+ feat_idx[0] += 1
884
+ else:
885
+ x = layer(x)
886
+ return x
887
+
888
+
889
+ def count_conv3d(model):
890
+ count = 0
891
+ for m in model.modules():
892
+ if isinstance(m, CausalConv3d):
893
+ count += 1
894
+ return count
895
+
896
+
897
+ class WanVAE_(nn.Module):
898
+ def __init__(
899
+ self,
900
+ input_dim=12,
901
+ dim=160,
902
+ dec_dim=256,
903
+ z_dim=16,
904
+ dim_mult=[1, 2, 4, 4],
905
+ num_res_blocks=2,
906
+ attn_scales=[],
907
+ temperal_downsample=[True, True, False],
908
+ spatial_downsample=[True, True, True],
909
+ spatial_dim=2,
910
+ dropout=0.0,
911
+ ):
912
+ super().__init__()
913
+ self.dim = dim
914
+ self.z_dim = z_dim
915
+ self.dim_mult = dim_mult
916
+ self.num_res_blocks = num_res_blocks
917
+ self.attn_scales = attn_scales
918
+ self.temperal_downsample = temperal_downsample
919
+ self.spatial_downsample = spatial_downsample
920
+ self.temperal_upsample = temperal_downsample[::-1]
921
+ self.spatial_upsample = spatial_downsample[::-1]
922
+
923
+ # modules
924
+ self.encoder = Encoder3d(
925
+ input_dim,
926
+ dim,
927
+ z_dim * 2,
928
+ dim_mult,
929
+ num_res_blocks,
930
+ attn_scales,
931
+ self.temperal_downsample,
932
+ self.spatial_downsample,
933
+ spatial_dim,
934
+ dropout,
935
+ )
936
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
937
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
938
+ self.decoder = Decoder3d(
939
+ input_dim,
940
+ dec_dim,
941
+ z_dim,
942
+ dim_mult,
943
+ num_res_blocks,
944
+ attn_scales,
945
+ self.temperal_upsample,
946
+ self.spatial_upsample,
947
+ spatial_dim,
948
+ dropout,
949
+ )
950
+
951
+ def forward(self, x, scale=[0, 1]):
952
+ mu = self.encode(x, scale)
953
+ x_recon = self.decode(mu, scale)
954
+ return x_recon, mu
955
+
956
+ def encode(self, x, scale, patch_size=1, return_dist=False):
957
+ self.clear_cache()
958
+ x = patchify(x, patch_size=patch_size)
959
+ t = x.shape[2]
960
+ iter_ = 1 + (t - 1) // 4
961
+ for i in range(iter_):
962
+ self._enc_conv_idx = [0]
963
+ if i == 0:
964
+ out = self.encoder(
965
+ x[:, :, :1, :, :],
966
+ feat_cache=self._enc_feat_map,
967
+ feat_idx=self._enc_conv_idx,
968
+ )
969
+ else:
970
+ out_ = self.encoder(
971
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
972
+ feat_cache=self._enc_feat_map,
973
+ feat_idx=self._enc_conv_idx,
974
+ )
975
+ out = torch.cat([out, out_], 2)
976
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
977
+ if isinstance(scale[0], torch.Tensor):
978
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
979
+ 1, self.z_dim, 1, 1, 1
980
+ )
981
+ else:
982
+ mu = (mu - scale[0]) * scale[1]
983
+ self.clear_cache()
984
+
985
+ if return_dist:
986
+ return mu, log_var
987
+ else:
988
+ return mu
989
+
990
+ def decode(self, z, scale, patch_size=1):
991
+ self.clear_cache()
992
+ if isinstance(scale[0], torch.Tensor):
993
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
994
+ 1, self.z_dim, 1, 1, 1
995
+ )
996
+ else:
997
+ z = z / scale[1] + scale[0]
998
+ iter_ = z.shape[2]
999
+ x = self.conv2(z)
1000
+ for i in range(iter_):
1001
+ self._conv_idx = [0]
1002
+ if i == 0:
1003
+ out = self.decoder(
1004
+ x[:, :, i : i + 1, :, :],
1005
+ feat_cache=self._feat_map,
1006
+ feat_idx=self._conv_idx,
1007
+ first_chunk=True,
1008
+ )
1009
+ else:
1010
+ out_ = self.decoder(
1011
+ x[:, :, i : i + 1, :, :],
1012
+ feat_cache=self._feat_map,
1013
+ feat_idx=self._conv_idx,
1014
+ )
1015
+ out = torch.cat([out, out_], 2)
1016
+ out = unpatchify(out, patch_size=patch_size)
1017
+ self.clear_cache()
1018
+ return out
1019
+
1020
+ @torch.no_grad()
1021
+ def stream_encode(self, x, first_chunk, scale, patch_size=1, return_dist=False):
1022
+ x = patchify(x, patch_size=patch_size)
1023
+ t = x.shape[2]
1024
+ if first_chunk:
1025
+ iter_ = 1 + (t - 1) // 4
1026
+ else:
1027
+ iter_ = t // 4
1028
+ for i in range(iter_):
1029
+ self._enc_conv_idx = [0]
1030
+ if i == 0:
1031
+ if first_chunk:
1032
+ out = self.encoder(
1033
+ x[:, :, :1, :, :],
1034
+ feat_cache=self._enc_feat_map,
1035
+ feat_idx=self._enc_conv_idx,
1036
+ )
1037
+ else:
1038
+ out = self.encoder(
1039
+ x[:, :, :4, :, :],
1040
+ feat_cache=self._enc_feat_map,
1041
+ feat_idx=self._enc_conv_idx,
1042
+ )
1043
+ else:
1044
+ if first_chunk:
1045
+ out_ = self.encoder(
1046
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
1047
+ feat_cache=self._enc_feat_map,
1048
+ feat_idx=self._enc_conv_idx,
1049
+ )
1050
+ else:
1051
+ out_ = self.encoder(
1052
+ x[:, :, 4 * i : 4 * (i + 1), :, :],
1053
+ feat_cache=self._enc_feat_map,
1054
+ feat_idx=self._enc_conv_idx,
1055
+ )
1056
+ out = torch.cat([out, out_], 2)
1057
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
1058
+ if isinstance(scale[0], torch.Tensor):
1059
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1060
+ 1, self.z_dim, 1, 1, 1
1061
+ )
1062
+ else:
1063
+ mu = (mu - scale[0]) * scale[1]
1064
+ self.clear_cache()
1065
+
1066
+ if return_dist:
1067
+ return mu, log_var
1068
+ else:
1069
+ return mu
1070
+
1071
+ @torch.no_grad()
1072
+ def stream_decode(self, z, first_chunk, scale, patch_size=1):
1073
+ if isinstance(scale[0], torch.Tensor):
1074
+ z = z / scale[1].view(1, self.z_dim, 1) + scale[0].view(1, self.z_dim, 1)
1075
+ else:
1076
+ z = z / scale[1] + scale[0]
1077
+ iter_ = z.shape[2]
1078
+ x = self.conv2(z)
1079
+ for i in range(iter_):
1080
+ self._conv_idx = [0]
1081
+ if i == 0:
1082
+ out = self.decoder(
1083
+ x[:, :, i : i + 1, :, :],
1084
+ feat_cache=self._feat_map,
1085
+ feat_idx=self._conv_idx,
1086
+ first_chunk=first_chunk,
1087
+ )
1088
+ else:
1089
+ out_ = self.decoder(
1090
+ x[:, :, i : i + 1, :, :],
1091
+ feat_cache=self._feat_map,
1092
+ feat_idx=self._conv_idx,
1093
+ )
1094
+ out = torch.cat([out, out_], 2)
1095
+ out = unpatchify(out, patch_size=patch_size)
1096
+ return out
1097
+
1098
+ def reparameterize(self, mu, log_var):
1099
+ std = torch.exp(0.5 * log_var)
1100
+ eps = torch.randn_like(std)
1101
+ return eps * std + mu
1102
+
1103
+ def sample(self, features, deterministic=False):
1104
+ mu, log_var = self.encode(features, return_dist=True)
1105
+ if deterministic:
1106
+ return mu
1107
+ else:
1108
+ return self.reparameterize(mu, log_var)
1109
+
1110
+ def clear_cache(self):
1111
+ self._conv_num = count_conv3d(self.decoder)
1112
+ self._conv_idx = [0]
1113
+ self._feat_map = [None] * self._conv_num
1114
+ # cache encode
1115
+ self._enc_conv_num = count_conv3d(self.encoder)
1116
+ self._enc_conv_idx = [0]
1117
+ self._enc_feat_map = [None] * self._enc_conv_num