javirk1 commited on
Commit
86039d9
·
verified ·
1 Parent(s): 8561a5f

Upload folder using huggingface_hub

Browse files
1080p/decoder_1080p_cs_discrete8_wan_patch2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30db9e31f490cc9a7487f9f50b076eb6c75a03ee4cebbd76f0c131df577e9764
3
+ size 27441527
1080p/encoder_1080p_cs_discrete8_wan_patch2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9811f1df9b8aed0fcff820e833b2a70cec95b89db99a1e0c578d79d5a2fc6af9
3
+ size 23172295
1080p/quantizer_1080p_cs_discrete8_wan_patch2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8c5fef2505d1854984c1560e6485b06e083328ab9b159a0ac245ccca8dbfbf7
3
+ size 8633
540p/decoder_540p_cs_discrete8_wan_patch2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:145e78ddfbce5357170e48ebf118cd5b546f97e67c5b25d7116893dc825e8a79
3
+ size 27441525
540p/encoder_540p_cs_discrete8_wan_patch2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6710c84c362d067cb780ee7906a13272f8ae37cd1342dc26b3f0209cbd5df123
3
+ size 23172295
540p/quantizer_540p_cs_discrete8_wan_patch2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96120f29521563219267edd904d3166cf113a8d726e9e38362e70ef962222e0a
3
+ size 8629
720p/decoder_720p_cs_discrete8_wan_patch2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a80c5899556f517fbe4203d5000c17ec7ee9be9609a5430a588da1919cc9b17b
3
+ size 27441526
720p/encoder_720p_cs_discrete8_wan_patch2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f083d0b791983d4cb5fc918cff6d49bc01433cb15174ff8c811adaef574377a
3
+ size 23172295
720p/quantizer_720p_cs_discrete8_wan_patch2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96126464259e93be1cf0c6d034ca73a1057b4563dc9f57cc3fa177a2573597d0
3
+ size 8631
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa319aa3982c258ee22d2f9096668b9cd42e0f3104b3a37488232edc01776ac0
3
+ size 50493554
config.json CHANGED
@@ -1,31 +1,28 @@
1
  {
2
- "project_name": "wan8_csquant_bs8_20251022_134201",
3
  "teacher_config": {
4
- "dim": "96",
5
- "z_dim": "16",
6
- "dim_mult": "[1, 2, 4, 4]",
7
- "num_res_blocks": "2",
8
- "attn_scales": "[]",
9
- "temporal_downsample": "[False, True, True]",
10
- "dropout": "0.0",
11
- "cls": "<class 'models.temporal_wan.WanVAE_'>"
12
- },
13
  "student_config": {
14
- "dim": "64",
15
- "z_dim": "16",
16
- "dim_mult": "[1, 2, 4, 4]",
17
- "num_res_blocks": "2",
18
- "attn_scales": "[]",
19
- "dropout": "0.0",
20
- "cls": "<class 'models.image_vae.DiscreteImageVAE'>",
21
- "z_channels": "256",
22
- "z_factor": "1",
23
- "embedding_dim": "16",
24
- "levels": "[8, 8, 8, 5, 5, 5]",
25
- "dtype": "torch.float32",
26
- "model_type": "wan_2_1",
27
- "quantizer_cls": "<class 'models.quantizers.ChannelSplitFSQ'>",
28
- "num_codebooks": "1",
 
29
  "K": "2"
30
  }
31
  }
 
1
  {
2
+ "project_name": "wan8_csquant_patch2_bs8_720p",
3
  "teacher_config": {
4
+ "dim": "96",
5
+ "z_dim": "16",
6
+ "dim_mult": "[1, 2, 4, 4]",
7
+ "num_res_blocks": "2", "attn_scales": "[]", "temperal_downsample": "[False, True, True]", "dropout": "0.0", "cls": "<class 'models.temporal_wan.WanVAE_'>"
8
+ },
 
 
 
 
9
  "student_config": {
10
+ "dim": "64",
11
+ "z_dim": "16",
12
+ "dim_mult": "[1, 2, 4]",
13
+ "patch_size": "2",
14
+ "num_res_blocks": "3",
15
+ "attn_scales": "[]",
16
+ "dropout": "0.0",
17
+ "cls": "<class 'models.image_vae.DiscreteImageVAE'>",
18
+ "z_channels": "256",
19
+ "z_factor": "1",
20
+ "embedding_dim": "16",
21
+ "levels": "[8, 8, 8, 5, 5, 5]",
22
+ "dtype": "torch.float32",
23
+ "model_type": "wan_2_1",
24
+ "quantizer_cls": "<class 'models.quantizers.ChannelSplitFSQ'>",
25
+ "num_codebooks": "1",
26
  "K": "2"
27
  }
28
  }
python/simple_sample_vae.py CHANGED
@@ -8,6 +8,40 @@ from einops import rearrange, pack, unpack
8
  _PERSISTENT = True
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def exists(v):
12
  return v is not None
13
 
@@ -239,20 +273,26 @@ class RMS_norm(nn.Module):
239
  self.gamma = nn.Parameter(torch.ones(shape))
240
  self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
241
 
 
 
 
 
 
 
 
 
242
  def forward(self, x):
243
- return (
244
- F.normalize(x, dim=(1 if self.channel_first else -1))
245
- * self.scale
246
- * self.gamma
247
- + self.bias
248
- )
249
 
250
 
251
  class Upsample(nn.Upsample):
252
 
253
  def forward(self, x):
254
  # Fix bfloat16 support for nearest neighbor interpolation.
255
- return super().forward(x.float()).type_as(x)
 
256
 
257
 
258
  class ResidualBlock2d(nn.Module):
@@ -291,21 +331,77 @@ class AttentionBlock2d(nn.Module):
291
  self.proj = nn.Conv2d(dim, dim, 1)
292
  nn.init.zeros_(self.proj.weight)
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  def forward(self, x):
295
  identity = x
296
  b, c, h, w = x.size()
 
 
 
297
  x = self.norm(x)
298
- q, k, v = (
299
- self.to_qkv(x)
300
- .reshape(b, 1, c * 3, -1)
301
- .permute(0, 1, 3, 2)
302
- .contiguous()
303
- .chunk(3, dim=-1)
304
- )
305
- x = F.scaled_dot_product_attention(q, k, v)
306
- x = x.squeeze(1).permute(0, 2, 1).reshape(b, c, h, w)
307
- x = self.proj(x)
308
- return x + identity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
 
311
  class Resample2d(nn.Module):
@@ -344,6 +440,7 @@ class Encoder2d(nn.Module):
344
  attn_scales=[],
345
  patch_size=1,
346
  in_channels=3,
 
347
  ):
348
  super().__init__()
349
  self.dim = dim
@@ -354,6 +451,8 @@ class Encoder2d(nn.Module):
354
  self.patch_size = patch_size
355
  self.in_channels = in_channels
356
 
 
 
357
  # dimensions
358
  dims = [dim * u for u in [1] + dim_mult]
359
  scale = 1.0
@@ -370,7 +469,7 @@ class Encoder2d(nn.Module):
370
  for _ in range(num_res_blocks):
371
  downsamples.append(ResidualBlock2d(in_dim, out_dim, dropout))
372
  if scale in self.attn_scales:
373
- downsamples.append(AttentionBlock2d(out_dim))
374
  in_dim = out_dim
375
  if i != len(dim_mult) - 1:
376
  downsamples.append(Resample2d(out_dim, mode="downsample2d"))
@@ -386,6 +485,7 @@ class Encoder2d(nn.Module):
386
  )
387
 
388
  def forward(self, x):
 
389
  x = self.conv1(x)
390
  x = self.downsamples(x)
391
  x = self.middle(x)
@@ -404,6 +504,8 @@ class Decoder2d(nn.Module):
404
  dropout=0.0,
405
  attn_scales=[],
406
  out_channels=3,
 
 
407
  ):
408
  super().__init__()
409
  self.dim = dim
@@ -412,12 +514,15 @@ class Decoder2d(nn.Module):
412
  self.num_res_blocks = num_res_blocks
413
  self.attn_scales = attn_scales
414
  self.out_channels = out_channels
 
 
 
415
 
416
  # dimensions (mirror of encoder)
417
  base = dim * dim_mult[-1]
418
  dims = [base] + [dim * u for u in dim_mult[::-1]]
419
  scale = 1.0 / (2 ** (len(dim_mult) - 2)) if len(dim_mult) >= 2 else 1.0
420
- output_channels = self.out_channels
421
 
422
  # init block
423
  self.conv1 = nn.Conv2d(z_dim, dims[0], kernel_size=3, padding=1)
@@ -432,7 +537,7 @@ class Decoder2d(nn.Module):
432
  for _ in range(num_res_blocks):
433
  upsamples.append(ResidualBlock2d(in_dim, out_dim, dropout))
434
  if scale in self.attn_scales:
435
- upsamples.append(AttentionBlock2d(out_dim))
436
  in_dim = out_dim
437
  if i != len(dim_mult) - 1:
438
  upsamples.append(Resample2d(out_dim, mode="upsample2d"))
@@ -451,6 +556,7 @@ class Decoder2d(nn.Module):
451
  x = self.middle(x)
452
  x = self.upsamples(x)
453
  x = self.head(x)
 
454
  return x
455
 
456
 
@@ -468,6 +574,8 @@ class DiscreteImageVAE(nn.Module):
468
  out_channels=3,
469
  embedding_dim=128,
470
  scale=None,
 
 
471
  *args,
472
  **kwargs,
473
  ):
@@ -486,6 +594,8 @@ class DiscreteImageVAE(nn.Module):
486
  dropout=dropout,
487
  attn_scales=attn_scales,
488
  in_channels=in_channels,
 
 
489
  )
490
  self.decoder = Decoder2d(
491
  dim=dim,
@@ -495,6 +605,8 @@ class DiscreteImageVAE(nn.Module):
495
  dropout=dropout,
496
  attn_scales=attn_scales,
497
  out_channels=out_channels,
 
 
498
  )
499
  self.embedding_dim = embedding_dim
500
 
@@ -598,7 +710,7 @@ if __name__ == "__main__":
598
  from PIL import Image
599
  import numpy as np
600
 
601
- def load_image(path, size=(848, 480)):
602
  if not os.path.exists(path):
603
  print(
604
  f"Image not found at {path}, generating random noise. Warning: The tokenizer might to work properly."
@@ -636,7 +748,7 @@ if __name__ == "__main__":
636
  "--checkpoint", type=str, default=None, help="Path to model checkpoint"
637
  )
638
  parser.add_argument(
639
- "--image", type=str, default="assets/demo1.png", help="Path to input image"
640
  )
641
  parser.add_argument(
642
  "--output",
@@ -653,16 +765,22 @@ if __name__ == "__main__":
653
 
654
  args = parser.parse_args()
655
 
656
- cs_discrete8_wan = {
657
  "dim": 64,
658
  "z_dim": 16,
659
- "dim_mult": [1, 2, 4, 4],
660
- "num_res_blocks": 2,
 
661
  "attn_scales": [],
662
  "dropout": 0.0,
 
 
 
663
  "embedding_dim": 16,
664
  "levels": [8, 8, 8, 5, 5, 5],
665
  "dtype": torch.float,
 
 
666
  "num_codebooks": 1,
667
  "K": 2,
668
  }
@@ -670,7 +788,7 @@ if __name__ == "__main__":
670
  device = args.device
671
  print(f"Running on {device}")
672
 
673
- vae = DiscreteImageVAE(**cs_discrete8_wan).to(device)
674
 
675
  if args.checkpoint and os.path.exists(args.checkpoint):
676
  print(f"Loading checkpoint from {args.checkpoint}")
 
8
  _PERSISTENT = True
9
 
10
 
11
+ def patchify(x, patch_size):
12
+ if patch_size == 1:
13
+ return x
14
+ if x.dim() == 4:
15
+ x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
16
+ elif x.dim() == 5:
17
+ x = rearrange(
18
+ x,
19
+ "b c f (h q) (w r) -> b (c r q) f h w",
20
+ q=patch_size,
21
+ r=patch_size,
22
+ )
23
+ else:
24
+ raise ValueError(f"Invalid input shape: {x.shape}")
25
+
26
+ return x
27
+
28
+
29
+ def unpatchify(x, patch_size):
30
+ if patch_size == 1:
31
+ return x
32
+
33
+ if x.dim() == 4:
34
+ x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
35
+ elif x.dim() == 5:
36
+ x = rearrange(
37
+ x,
38
+ "b (c r q) f h w -> b c f (h q) (w r)",
39
+ q=patch_size,
40
+ r=patch_size,
41
+ )
42
+ return x
43
+
44
+
45
  def exists(v):
46
  return v is not None
47
 
 
273
  self.gamma = nn.Parameter(torch.ones(shape))
274
  self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
275
 
276
+ # def forward(self, x):
277
+ # return (
278
+ # F.normalize(x, dim=(1 if self.channel_first else -1))
279
+ # * self.scale
280
+ # * self.gamma
281
+ # + self.bias
282
+ # )
283
+
284
  def forward(self, x):
285
+ dim = 1 if self.channel_first else -1
286
+ rms = x.pow(2).mean(dim=dim, keepdim=True).add(1e-6).rsqrt()
287
+ return x * rms * self.gamma + self.bias
 
 
 
288
 
289
 
290
  class Upsample(nn.Upsample):
291
 
292
  def forward(self, x):
293
  # Fix bfloat16 support for nearest neighbor interpolation.
294
+ # return super().forward(x.float()).type_as(x)
295
+ return super().forward(x)
296
 
297
 
298
  class ResidualBlock2d(nn.Module):
 
331
  self.proj = nn.Conv2d(dim, dim, 1)
332
  nn.init.zeros_(self.proj.weight)
333
 
334
+ # def forward(self, x):
335
+ # identity = x
336
+ # b, c, h, w = x.size()
337
+ # x = self.norm(x)
338
+ # q, k, v = (
339
+ # self.to_qkv(x)
340
+ # .reshape(b, 1, c * 3, -1)
341
+ # .permute(0, 1, 3, 2)
342
+ # .contiguous()
343
+ # .chunk(3, dim=-1)
344
+ # )
345
+ # x = F.scaled_dot_product_attention(q, k, v)
346
+ # x = x.squeeze(1).permute(0, 2, 1).reshape(b, c, h, w)
347
+ # x = self.proj(x)
348
+ # return x + identity
349
+
350
  def forward(self, x):
351
  identity = x
352
  b, c, h, w = x.size()
353
+ n_heads = 1 # or c // 64
354
+ head_dim = c // n_heads
355
+
356
  x = self.norm(x)
357
+ qkv = self.to_qkv(x).reshape(b, 3, n_heads, head_dim, h * w)
358
+ q, k, v = qkv.unbind(1) # Each: (b, n_heads, head_dim, h*w)
359
+ q, k, v = q.transpose(-1, -2), k.transpose(-1, -2), v.transpose(-1, -2)
360
+
361
+ x = F.scaled_dot_product_attention(q, k, v) # Flash attention
362
+ x = x.transpose(-1, -2).reshape(b, c, h, w)
363
+ return self.proj(x) + identity
364
+
365
+
366
+ class FlashAttentionBlock2d(nn.Module):
367
+ """Attention block using flash-attn's kernel directly."""
368
+
369
+ def __init__(self, dim, n_heads=8):
370
+ super().__init__()
371
+ assert dim % n_heads == 0, f"dim {dim} must be divisible by n_heads {n_heads}"
372
+ self.dim = dim
373
+ self.n_heads = n_heads
374
+ self.head_dim = dim // n_heads
375
+ self.norm = RMS_norm(dim)
376
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
377
+ self.proj = nn.Conv2d(dim, dim, 1)
378
+ nn.init.zeros_(self.proj.weight)
379
+
380
+ def forward(self, x):
381
+ from flash_attn import flash_attn_func
382
+
383
+ identity = x
384
+ b, c, h, w = x.size()
385
+
386
+ x = self.norm(x)
387
+ qkv = self.to_qkv(x) # (b, 3*c, h, w)
388
+
389
+ # flash_attn_func expects (b, seqlen, nheads, headdim)
390
+ qkv = qkv.reshape(b, 3, self.n_heads, self.head_dim, h * w)
391
+ qkv = qkv.permute(0, 4, 1, 2, 3) # (b, h*w, 3, n_heads, head_dim)
392
+ q, k, v = qkv.unbind(2) # each (b, h*w, n_heads, head_dim)
393
+
394
+ x = flash_attn_func(q, k, v) # (b, h*w, n_heads, head_dim)
395
+ x = x.reshape(b, h * w, c).permute(0, 2, 1).reshape(b, c, h, w)
396
+
397
+ return self.proj(x) + identity
398
+
399
+
400
+ # Custom conv with asymmetric padding
401
+ class AsymmetricConv2d(nn.Conv2d):
402
+ def forward(self, x):
403
+ x = F.pad(x, (0, 1, 0, 1)) # Fused with conv by torch.compile
404
+ return super().forward(x)
405
 
406
 
407
  class Resample2d(nn.Module):
 
440
  attn_scales=[],
441
  patch_size=1,
442
  in_channels=3,
443
+ attn_class=AttentionBlock2d,
444
  ):
445
  super().__init__()
446
  self.dim = dim
 
451
  self.patch_size = patch_size
452
  self.in_channels = in_channels
453
 
454
+ self.patcher = lambda x: patchify(x, patch_size=patch_size)
455
+
456
  # dimensions
457
  dims = [dim * u for u in [1] + dim_mult]
458
  scale = 1.0
 
469
  for _ in range(num_res_blocks):
470
  downsamples.append(ResidualBlock2d(in_dim, out_dim, dropout))
471
  if scale in self.attn_scales:
472
+ downsamples.append(attn_class(out_dim))
473
  in_dim = out_dim
474
  if i != len(dim_mult) - 1:
475
  downsamples.append(Resample2d(out_dim, mode="downsample2d"))
 
485
  )
486
 
487
  def forward(self, x):
488
+ x = self.patcher(x)
489
  x = self.conv1(x)
490
  x = self.downsamples(x)
491
  x = self.middle(x)
 
504
  dropout=0.0,
505
  attn_scales=[],
506
  out_channels=3,
507
+ attn_class=AttentionBlock2d,
508
+ patch_size=1,
509
  ):
510
  super().__init__()
511
  self.dim = dim
 
514
  self.num_res_blocks = num_res_blocks
515
  self.attn_scales = attn_scales
516
  self.out_channels = out_channels
517
+ self.patch_size = patch_size
518
+
519
+ self.unpatcher = lambda x: unpatchify(x, patch_size=patch_size)
520
 
521
  # dimensions (mirror of encoder)
522
  base = dim * dim_mult[-1]
523
  dims = [base] + [dim * u for u in dim_mult[::-1]]
524
  scale = 1.0 / (2 ** (len(dim_mult) - 2)) if len(dim_mult) >= 2 else 1.0
525
+ output_channels = self.out_channels * self.patch_size * self.patch_size
526
 
527
  # init block
528
  self.conv1 = nn.Conv2d(z_dim, dims[0], kernel_size=3, padding=1)
 
537
  for _ in range(num_res_blocks):
538
  upsamples.append(ResidualBlock2d(in_dim, out_dim, dropout))
539
  if scale in self.attn_scales:
540
+ upsamples.append(attn_class(out_dim))
541
  in_dim = out_dim
542
  if i != len(dim_mult) - 1:
543
  upsamples.append(Resample2d(out_dim, mode="upsample2d"))
 
556
  x = self.middle(x)
557
  x = self.upsamples(x)
558
  x = self.head(x)
559
+ x = self.unpatcher(x)
560
  return x
561
 
562
 
 
574
  out_channels=3,
575
  embedding_dim=128,
576
  scale=None,
577
+ attn_class=AttentionBlock2d,
578
+ patch_size=1,
579
  *args,
580
  **kwargs,
581
  ):
 
594
  dropout=dropout,
595
  attn_scales=attn_scales,
596
  in_channels=in_channels,
597
+ attn_class=attn_class,
598
+ patch_size=patch_size,
599
  )
600
  self.decoder = Decoder2d(
601
  dim=dim,
 
605
  dropout=dropout,
606
  attn_scales=attn_scales,
607
  out_channels=out_channels,
608
+ attn_class=attn_class,
609
+ patch_size=patch_size,
610
  )
611
  self.embedding_dim = embedding_dim
612
 
 
710
  from PIL import Image
711
  import numpy as np
712
 
713
+ def load_image(path, size=(1920, 1080)):
714
  if not os.path.exists(path):
715
  print(
716
  f"Image not found at {path}, generating random noise. Warning: The tokenizer might to work properly."
 
748
  "--checkpoint", type=str, default=None, help="Path to model checkpoint"
749
  )
750
  parser.add_argument(
751
+ "--image", type=str, default="assets/00128.png", help="Path to input image"
752
  )
753
  parser.add_argument(
754
  "--output",
 
765
 
766
  args = parser.parse_args()
767
 
768
+ cs_discrete8_wan_patch2 = {
769
  "dim": 64,
770
  "z_dim": 16,
771
+ "dim_mult": [1, 2, 4],
772
+ "patch_size": 2,
773
+ "num_res_blocks": 3,
774
  "attn_scales": [],
775
  "dropout": 0.0,
776
+ "cls": DiscreteImageVAE,
777
+ "z_channels": 256,
778
+ "z_factor": 1,
779
  "embedding_dim": 16,
780
  "levels": [8, 8, 8, 5, 5, 5],
781
  "dtype": torch.float,
782
+ "model_type": "wan_2_1",
783
+ "quantizer_cls": ChannelSplitFSQ,
784
  "num_codebooks": 1,
785
  "K": 2,
786
  }
 
788
  device = args.device
789
  print(f"Running on {device}")
790
 
791
+ vae = DiscreteImageVAE(**cs_discrete8_wan_patch2).to(device)
792
 
793
  if args.checkpoint and os.path.exists(args.checkpoint):
794
  print(f"Loading checkpoint from {args.checkpoint}")
specs.txt CHANGED
@@ -1,13 +1,17 @@
1
- PSNR: 30.78 ± 3.49
2
- SSIM: 0.898 ± 0.063
3
- LPIPS: 0.123 ± 0.033
4
 
5
  Latent dims: [1, 2, H/8, W/8]
6
 
7
- [480p]
8
- height: 480
9
- width: 848
10
-
11
  [540p]
12
  height: 536
13
- width: 960
 
 
 
 
 
 
 
 
 
1
+ PSNR: 34.61 ± 3.18
2
+ SSIM: 0.961 ± 0.026
3
+ LPIPS: 0.105 ± 0.026
4
 
5
  Latent dims: [1, 2, H/8, W/8]
6
 
 
 
 
 
7
  [540p]
8
  height: 536
9
+ width: 960
10
+
11
+ [720p]
12
+ height: 720
13
+ width: 1280
14
+
15
+ [1080p]
16
+ height: 1080
17
+ width: 1280