krystv commited on
Commit
d46a7a4
·
verified ·
1 Parent(s): bc982c7

Upload PMA_VAE_Colab_Training.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. PMA_VAE_Colab_Training.ipynb +218 -197
PMA_VAE_Colab_Training.ipynb CHANGED
@@ -17,7 +17,7 @@
17
  "cell_type": "markdown",
18
  "metadata": {},
19
  "source": [
20
- "# 🎨 PMA-VAE: Parallel Mobile Artistic Variational Autoencoder\n",
21
  "\n",
22
  "**A novel attention-free architecture for image generation, super-resolution, artifact removal, and artistic style transfer.**\n",
23
  "\n",
@@ -31,17 +31,17 @@
31
  "\n",
32
  "## Architecture\n",
33
  "```\n",
34
- "Image PixelUnshuffle stem MobileConv stages Parallel 2D Mamba blocks\n",
35
- " Multi-scale latent (z_base H/16, z_detail H/8, z_style global)\n",
36
- " Light parallel decoder with FiLM style modulation Reconstructed image\n",
37
  "```\n",
38
  "\n",
39
  "## Key Design Decisions\n",
40
- "- **Parallel scan SSM** (Blelloch algorithm) pure PyTorch, no CUDA kernels needed\n",
41
- "- **Cross-scan 2D** (VMamba-style) 4 directional scans for global context without attention\n",
42
- "- **PixelShuffle upsampling** efficient sub-pixel convolution for mobile\n",
43
- "- **Taming-transformers loss recipe** adaptive discriminator weight balancing\n",
44
- "- **Progressive resolution training** start small, scale up\n",
45
  "\n",
46
  "---\n",
47
  "**Trainable on free Colab T4 GPU (15GB VRAM) in ~2-4 hours for meaningful results.**"
@@ -68,7 +68,7 @@
68
  " print(f'GPU: {torch.cuda.get_device_name(0)}')\n",
69
  " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB')\n",
70
  "else:\n",
71
- " print('⚠️ No GPU detected! Go to Runtime Change runtime type T4 GPU')"
72
  ]
73
  },
74
  {
@@ -90,12 +90,12 @@
90
  "The full model is defined below in a single cell for easy Colab use.\n",
91
  "\n",
92
  "### Component breakdown:\n",
93
- "1. **Parallel Scan (PScan)** Blelloch parallel prefix scan in pure PyTorch\n",
94
- "2. **Selective SSM (S6)** Mamba's core mechanism, input-dependent state space\n",
95
- "3. **2D Cross-Scan** VMamba-style 4-directional scanning for 2D feature maps\n",
96
- "4. **Mobile Conv Blocks** Depthwise separable + SE + FiLM conditioning\n",
97
- "5. **Encoder** Progressive downsampling with hybrid MobileConv + Mamba stages\n",
98
- "6. **Decoder** Lightweight with FiLM style modulation, PixelShuffle upsampling"
99
  ]
100
  },
101
  {
@@ -112,7 +112,7 @@
112
  "\n",
113
  "\n",
114
  "# ============================================================================\n",
115
- "# Parallel Scan (Blelloch) Pure PyTorch, no CUDA kernels\n",
116
  "# ============================================================================\n",
117
  "\n",
118
  "class PScan(torch.autograd.Function):\n",
@@ -379,13 +379,13 @@
379
  " nn.Conv2d(in_channels * 4, stage_channels[0], 3, padding=1, bias=False),\n",
380
  " nn.BatchNorm2d(stage_channels[0]), nn.SiLU(inplace=True))\n",
381
  "\n",
382
- " # Stage 1: H/2 H/4 (MobileConv only)\n",
383
  " s1 = [MobileConvBlock(stage_channels[0], stage_channels[1], stride=2)]\n",
384
  " for _ in range(stage_blocks[0] - 1):\n",
385
  " s1.append(MobileConvBlock(stage_channels[1], stage_channels[1]))\n",
386
  " self.stage1 = nn.Sequential(*s1)\n",
387
  "\n",
388
- " # Stage 2: H/4 H/8 (hybrid MobileConv + Mamba)\n",
389
  " s2 = nn.ModuleList()\n",
390
  " s2.append(MobileConvBlock(stage_channels[1], stage_channels[2], stride=2))\n",
391
  " n_mamba = max(1, (stage_blocks[1] - 1) // 2)\n",
@@ -398,7 +398,7 @@
398
  " self.detail_head_mu = nn.Conv2d(stage_channels[2], latent_detail_dim, 1)\n",
399
  " self.detail_head_logvar = nn.Conv2d(stage_channels[2], latent_detail_dim, 1)\n",
400
  "\n",
401
- " # Stage 3: H/8 H/16 (Mamba-heavy)\n",
402
  " s3 = nn.ModuleList()\n",
403
  " s3.append(MobileConvBlock(stage_channels[2], stage_channels[3], stride=2))\n",
404
  " n_mamba3 = max(1, int((stage_blocks[2] - 1) * 0.75))\n",
@@ -533,18 +533,18 @@
533
  "# ============================================================================\n",
534
  "\n",
535
  "def pmavae_small(use_parallel_scan=True):\n",
536
- " \"\"\"~6M params fast training on free Colab T4\"\"\"\n",
537
  " return PMAVAE(enc_channels=(48, 96, 144, 192), dec_channels=(192, 144, 96, 72, 48),\n",
538
  " enc_blocks=(2, 2, 3, 3), latent_base_dim=24, latent_detail_dim=6,\n",
539
  " latent_style_dim=96, d_state=16, use_parallel_scan=use_parallel_scan)\n",
540
  "\n",
541
  "def pmavae_base(use_parallel_scan=True):\n",
542
- " \"\"\"~15M params high quality, needs more VRAM\"\"\"\n",
543
  " return PMAVAE(enc_channels=(64, 128, 192, 256), dec_channels=(256, 192, 128, 96, 64),\n",
544
  " enc_blocks=(2, 2, 4, 4), latent_base_dim=32, latent_detail_dim=8,\n",
545
  " latent_style_dim=128, d_state=16, use_parallel_scan=use_parallel_scan)\n",
546
  "\n",
547
- "print(' PMA-VAE architecture defined!')"
548
  ]
549
  },
550
  {
@@ -568,7 +568,7 @@
568
  " print(f' {k}: {v.shape}')\n",
569
  "\n",
570
  "params = model.count_parameters()\n",
571
- "print(f'\\n📊 Parameters: {params[\"total_M\"]:.2f}M total')\n",
572
  "print(f' Encoder: {params[\"enc_M\"]:.2f}M | Decoder: {params[\"dec_M\"]:.2f}M')\n",
573
  "\n",
574
  "del model, x, recon\n",
@@ -582,12 +582,12 @@
582
  "## 3. Loss Functions\n",
583
  "\n",
584
  "Our loss combines:\n",
585
- "- **L1 reconstruction** pixel-level fidelity\n",
586
- "- **VGG perceptual** semantic/structural similarity \n",
587
- "- **PatchGAN discriminator** sharp, realistic textures\n",
588
- "- **KL with free bits** prevents posterior collapse\n",
589
- "- **Edge preservation** high-frequency detail via Sobel filters\n",
590
- "- **Adaptive discriminator weight** taming-transformers trick"
591
  ]
592
  },
593
  {
@@ -701,7 +701,7 @@
701
  " d = hinge_d_loss(self.discriminator(inputs.detach()), self.discriminator(recon.detach()))\n",
702
  " return d, {'d_loss': d.item()}\n",
703
  "\n",
704
- "print(' Loss functions defined!')"
705
  ]
706
  },
707
  {
@@ -711,8 +711,8 @@
711
  "## 4. Dataset Setup\n",
712
  "\n",
713
  "We use a HuggingFace dataset for training. Options:\n",
714
- "- `huggan/wikiart` artistic images (great for style learning)\n",
715
- "- `ILSVRC/imagenet-1k` diverse natural images\n",
716
  "- Any folder of images\n",
717
  "\n",
718
  "For free Colab, we use a moderate-sized art dataset."
@@ -724,35 +724,61 @@
724
  "metadata": {},
725
  "outputs": [],
726
  "source": [
727
- "from torch.utils.data import DataLoader, Dataset\n",
728
  "from torchvision import transforms\n",
729
  "from PIL import Image\n",
730
  "import os\n",
731
  "\n",
732
- "# ======== Option A: HuggingFace Dataset ========\n",
733
- "class HFImageDataset(Dataset):\n",
734
- " def __init__(self, hf_dataset, image_col='image', resolution=256):\n",
735
- " self.ds = hf_dataset\n",
 
 
 
 
 
 
 
 
 
 
 
736
  " self.col = image_col\n",
737
  " self.transform = transforms.Compose([\n",
738
- " transforms.Resize(int(resolution * 1.15), interpolation=transforms.InterpolationMode.LANCZOS, antialias=True),\n",
 
 
739
  " transforms.RandomCrop(resolution),\n",
740
  " transforms.RandomHorizontalFlip(),\n",
741
  " transforms.ToTensor(),\n",
742
  " transforms.Normalize([0.5]*3, [0.5]*3)])\n",
743
- " def __len__(self): return len(self.ds)\n",
744
- " def __getitem__(self, idx):\n",
745
- " img = self.ds[idx][self.col]\n",
746
- " if not isinstance(img, Image.Image): img = Image.fromarray(img)\n",
747
- " return self.transform(img.convert('RGB'))\n",
748
  "\n",
749
- "# ======== Option B: Local folder ========\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750
  "class FolderDataset(Dataset):\n",
751
  " def __init__(self, root, resolution=256):\n",
752
  " exts = {'.jpg','.jpeg','.png','.bmp','.webp'}\n",
753
- " self.files = [os.path.join(dp,f) for dp,_,fns in os.walk(root) for f in fns if os.path.splitext(f)[1].lower() in exts]\n",
 
754
  " self.transform = transforms.Compose([\n",
755
- " transforms.Resize(int(resolution * 1.15), interpolation=transforms.InterpolationMode.LANCZOS, antialias=True),\n",
 
 
756
  " transforms.RandomCrop(resolution),\n",
757
  " transforms.RandomHorizontalFlip(),\n",
758
  " transforms.ToTensor(),\n",
@@ -761,7 +787,7 @@
761
  " def __getitem__(self, idx):\n",
762
  " return self.transform(Image.open(self.files[idx]).convert('RGB'))\n",
763
  "\n",
764
- "print(' Dataset classes defined!')"
765
  ]
766
  },
767
  {
@@ -770,33 +796,58 @@
770
  "metadata": {},
771
  "outputs": [],
772
  "source": [
773
- "# Load dataset\n",
774
  "from datasets import load_dataset\n",
775
  "\n",
776
- "# === Choose your dataset ===\n",
777
- "DATASET_NAME = 'huggan/wikiart' # Art images - great for style learning\n",
 
 
778
  "IMAGE_COLUMN = 'image'\n",
779
- "RESOLUTION = 256 # Start with 256, can increase to 512 later\n",
780
- "BATCH_SIZE = 8 # Fits on T4 with small model\n",
781
- "NUM_WORKERS = 2\n",
782
  "\n",
783
- "print(f'Loading {DATASET_NAME}...')\n",
784
- "raw_dataset = load_dataset(DATASET_NAME, split='train', streaming=False)\n",
785
- "# For very large datasets, use streaming=True and take a subset:\n",
786
- "# raw_dataset = load_dataset(DATASET_NAME, split='train', streaming=True)\n",
787
- "# raw_dataset = list(raw_dataset.take(50000))\n",
 
 
 
 
 
 
788
  "\n",
789
- "dataset = HFImageDataset(raw_dataset, IMAGE_COLUMN, RESOLUTION)\n",
790
- "dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
791
- " num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)\n",
792
  "\n",
793
- "print(f'Dataset: {len(dataset)} images')\n",
794
- "print(f'Batches per epoch: {len(dataloader)}')\n",
795
  "\n",
796
- "# Quick check\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
797
  "sample = next(iter(dataloader))\n",
798
- "print(f'Batch shape: {sample.shape}')\n",
799
- "print(f'Value range: [{sample.min():.2f}, {sample.max():.2f}]')"
 
 
 
 
800
  ]
801
  },
802
  {
@@ -806,12 +857,12 @@
806
  "## 5. Training\n",
807
  "\n",
808
  "### Training recipe:\n",
809
- "- **Phase 1** (256×256): Learn structure and composition\n",
810
- "- **Phase 2** (384×384): Refine texture details\n",
811
- "- **Phase 3** (512×512): Fine-tune for high-res quality\n",
812
  "\n",
813
  "### Anti-collapse measures:\n",
814
- "1. **KL warmup**: β goes from 0 target over first 5000 steps\n",
815
  "2. **Free bits**: Each latent dimension must use at least 0.25 nats\n",
816
  "3. **Discriminator cold start**: Only activates after 10000 steps\n",
817
  "4. **Adaptive disc weight**: Balances recon vs adversarial gradients\n",
@@ -921,128 +972,98 @@
921
  "outputs": [],
922
  "source": [
923
  "# ============================================================================\n",
924
- "# Training Loop with Live Visualization\n",
925
  "# ============================================================================\n",
 
 
926
  "\n",
927
- "def visualize_reconstruction(model, batch, step):\n",
928
- " \"\"\"Show original vs reconstructed images.\"\"\"\n",
929
- " model.eval()\n",
930
- " with torch.no_grad():\n",
931
- " recon, _ = model(batch[:4].to(device))\n",
932
- " model.train()\n",
933
- " \n",
934
- " fig, axes = plt.subplots(2, 4, figsize=(16, 8))\n",
935
- " for i in range(4):\n",
936
- " orig = batch[i].permute(1,2,0).cpu().numpy() * 0.5 + 0.5\n",
937
- " rec = recon[i].permute(1,2,0).cpu().numpy() * 0.5 + 0.5\n",
938
- " axes[0,i].imshow(orig.clip(0,1))\n",
939
- " axes[0,i].set_title('Original')\n",
940
- " axes[0,i].axis('off')\n",
941
- " axes[1,i].imshow(rec.clip(0,1))\n",
942
- " axes[1,i].set_title(f'Recon (step {step})')\n",
943
- " axes[1,i].axis('off')\n",
944
- " plt.tight_layout()\n",
945
- " plt.show()\n",
946
- "\n",
947
- "def plot_losses(history):\n",
948
- " \"\"\"Plot training loss curves.\"\"\"\n",
949
- " if len(history) < 10: return\n",
950
- " fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
951
- " steps = [h['step'] for h in history]\n",
952
- " axes[0].plot(steps, [h['l1'] for h in history], label='L1')\n",
953
- " axes[0].plot(steps, [h.get('perc',0) for h in history], label='Perceptual')\n",
954
- " axes[0].set_title('Reconstruction Losses'); axes[0].legend(); axes[0].set_xlabel('Step')\n",
955
- " axes[1].plot(steps, [h.get('kl_base',0) for h in history], label='KL base')\n",
956
- " axes[1].plot(steps, [h.get('kl_detail',0) for h in history], label='KL detail')\n",
957
- " axes[1].set_title('KL Losses'); axes[1].legend(); axes[1].set_xlabel('Step')\n",
958
- " axes[2].plot(steps, [h.get('d_loss',0) for h in history], label='Disc')\n",
959
- " axes[2].plot(steps, [h.get('g_loss',0) for h in history], label='Gen')\n",
960
- " axes[2].set_title('GAN Losses'); axes[2].legend(); axes[2].set_xlabel('Step')\n",
961
- " plt.tight_layout(); plt.show()\n",
962
- "\n",
963
- "# === TRAINING LOOP ===\n",
964
  "global_step = 0\n",
965
  "history = []\n",
966
  "start_time = time.time()\n",
967
- "vis_batch = next(iter(dataloader)) # Fixed batch for visualization\n",
968
  "\n",
969
- "print(f'\\n🚀 Starting training! Target: {CONFIG[\"max_steps\"]} steps')\n",
970
- "print(f' KL warmup: 0 → {CONFIG[\"kl_weight\"]} over {CONFIG[\"kl_warmup_steps\"]} steps')\n",
 
 
 
971
  "print(f' Discriminator starts at step {CONFIG[\"disc_start\"]}\\n')\n",
972
  "\n",
973
  "model.train()\n",
974
- "for epoch in range(CONFIG['num_epochs']):\n",
975
- " for batch_idx, batch in enumerate(dataloader):\n",
976
- " batch = batch.to(device)\n",
977
- " \n",
978
- " # KL warmup\n",
979
- " kl_w = CONFIG['kl_weight'] * min(1.0, global_step / max(1, CONFIG['kl_warmup_steps']))\n",
980
- " criterion.kl_weight = kl_w\n",
981
- " \n",
982
- " # === VAE update ===\n",
983
- " opt_vae.zero_grad()\n",
984
- " with autocast('cuda', enabled=device=='cuda'):\n",
985
- " recon, posteriors = model(batch)\n",
986
- " loss_vae, log_vae = criterion(batch, recon, posteriors, 0, global_step,\n",
987
- " model.get_last_decoder_layer())\n",
988
- " scaler_vae.scale(loss_vae).backward()\n",
989
- " scaler_vae.unscale_(opt_vae)\n",
990
- " gn = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
991
- " scaler_vae.step(opt_vae)\n",
992
- " scaler_vae.update()\n",
993
- " \n",
994
- " # === Discriminator update ===\n",
995
- " opt_disc.zero_grad()\n",
996
- " with autocast('cuda', enabled=device=='cuda'):\n",
997
- " with torch.no_grad():\n",
998
- " recon_d, _ = model(batch)\n",
999
- " loss_disc, log_disc = criterion(batch, recon_d, posteriors, 1, global_step)\n",
1000
- " if global_step >= CONFIG['disc_start']:\n",
1001
- " scaler_disc.scale(loss_disc).backward()\n",
1002
- " scaler_disc.unscale_(opt_disc)\n",
1003
- " torch.nn.utils.clip_grad_norm_(criterion.discriminator.parameters(), 1.0)\n",
1004
- " scaler_disc.step(opt_disc)\n",
1005
- " scaler_disc.update()\n",
1006
- " \n",
1007
- " global_step += 1\n",
1008
- " \n",
1009
- " # Logging\n",
1010
- " log = {**log_vae, **log_disc, 'step': global_step, 'grad_norm': gn.item(), 'kl_w': kl_w}\n",
1011
- " \n",
1012
- " if global_step % CONFIG['log_every'] == 0:\n",
1013
- " history.append(log)\n",
1014
- " elapsed = (time.time() - start_time) / 60\n",
1015
- " print(f\"Step {global_step:6d} | L1:{log['l1']:.4f} | Perc:{log.get('perc',0):.4f} | \"\n",
1016
- " f\"KL:{log.get('kl_base',0):.1f}/{log.get('kl_detail',0):.1f}/{log.get('kl_style',0):.1f} | \"\n",
1017
- " f\"D:{log.get('d_loss',0):.4f} | G:{log.get('g_loss',0):.4f} | \"\n",
1018
- " f\"GN:{log['grad_norm']:.2f} | {elapsed:.1f}min\")\n",
1019
- " \n",
1020
- " if global_step % CONFIG['vis_every'] == 0:\n",
1021
- " clear_output(wait=True)\n",
1022
- " visualize_reconstruction(model, vis_batch, global_step)\n",
1023
- " plot_losses(history)\n",
1024
- " \n",
1025
- " if global_step % CONFIG['save_every'] == 0:\n",
1026
- " os.makedirs('checkpoints', exist_ok=True)\n",
1027
- " torch.save({'model': model.state_dict(),\n",
1028
- " 'disc': criterion.discriminator.state_dict(),\n",
1029
- " 'opt_vae': opt_vae.state_dict(),\n",
1030
- " 'opt_disc': opt_disc.state_dict(),\n",
1031
- " 'step': global_step, 'config': CONFIG},\n",
1032
- " f'checkpoints/pma_vae_step{global_step}.pt')\n",
1033
- " print(f'💾 Saved checkpoint at step {global_step}')\n",
1034
- " \n",
1035
- " if global_step >= CONFIG['max_steps']:\n",
1036
- " break\n",
1037
- " \n",
1038
- " if global_step >= CONFIG['max_steps']:\n",
1039
- " break\n",
 
 
 
1040
  "\n",
1041
  "# Final save\n",
1042
  "torch.save({'model': model.state_dict(), 'config': CONFIG}, 'checkpoints/pma_vae_final.pt')\n",
1043
  "total_time = (time.time() - start_time) / 60\n",
1044
- "print(f'\\n Training complete! {global_step} steps in {total_time:.1f} minutes')\n",
1045
- "print(f'💾 Final model saved to checkpoints/pma_vae_final.pt')"
1046
  ]
1047
  },
1048
  {
@@ -1080,7 +1101,7 @@
1080
  " psnr = -10 * math.log10(mse + 1e-8)\n",
1081
  " psnrs.append(psnr)\n",
1082
  "\n",
1083
- "print(f'\\n📊 Evaluation Results:')\n",
1084
  "print(f' Average PSNR: {sum(psnrs)/len(psnrs):.2f} dB')\n",
1085
  "print(f' Min PSNR: {min(psnrs):.2f} dB')\n",
1086
  "print(f' Max PSNR: {max(psnrs):.2f} dB')"
@@ -1144,7 +1165,7 @@
1144
  " out = model.decoder(pa['base_mu'], pa['detail_mu'], z_style)\n",
1145
  " img = out[0].cpu().permute(1,2,0).numpy() * 0.5 + 0.5\n",
1146
  " axes[i].imshow(img.clip(0,1))\n",
1147
- " axes[i].set_title(f'α={alpha:.2f}')\n",
1148
  " axes[i].axis('off')\n",
1149
  "plt.suptitle('Style Interpolation (structure fixed, style varies)', fontsize=14)\n",
1150
  "plt.tight_layout()\n",
@@ -1197,7 +1218,7 @@
1197
  "model.eval()\n",
1198
  "\n",
1199
  "# Dummy inputs matching the latent shapes\n",
1200
- "dummy_base = torch.randn(1, 24, 16, 16, device=device) # For 256×256 input\n",
1201
  "dummy_detail = torch.randn(1, 6, 32, 32, device=device)\n",
1202
  "dummy_style = torch.randn(1, 96, device=device)\n",
1203
  "\n",
@@ -1215,7 +1236,7 @@
1215
  ")\n",
1216
  "\n",
1217
  "onnx_size = os.path.getsize('pma_vae_decoder.onnx') / 1024**2\n",
1218
- "print(f'\\n📱 ONNX decoder exported!')\n",
1219
  "print(f' Size: {onnx_size:.1f} MB')\n",
1220
  "print(f' Ready for: Core ML, TFLite, ONNX Runtime Mobile')\n",
1221
  "\n",
@@ -1228,7 +1249,7 @@
1228
  "source": [
1229
  "## 9. Progressive Resolution Training\n",
1230
  "\n",
1231
- "After initial training at 256×256, progressively increase resolution.\n",
1232
  "The model handles variable resolutions thanks to the convolutional architecture."
1233
  ]
1234
  },
@@ -1253,7 +1274,7 @@
1253
  "# for pg in opt_disc.param_groups: pg['lr'] *= 0.5\n",
1254
  "# \n",
1255
  "# # Continue training (copy the training loop above with dataloader_hr)\n",
1256
- "# print(f'Phase 2: Training at {NEW_RESOLUTION}×{NEW_RESOLUTION}')\n",
1257
  "# print(f'Batches per epoch: {len(dataloader_hr)}')"
1258
  ]
1259
  },
@@ -1276,8 +1297,8 @@
1276
  "model.eval()\n",
1277
  "\n",
1278
  "# Take a high-res image and downsample it\n",
1279
- "hr_img = test_batch[0:1] # 256×256\n",
1280
- "lr_img = F.interpolate(hr_img, scale_factor=0.5, mode='bilinear', align_corners=False) # 128×128\n",
1281
  "lr_upscaled = F.interpolate(lr_img, size=(256, 256), mode='bilinear', align_corners=False)\n",
1282
  "\n",
1283
  "with torch.no_grad():\n",
@@ -1313,7 +1334,7 @@
1313
  "| Component | Choice | Why |\n",
1314
  "|---|---|---|\n",
1315
  "| Backbone | MobileConv + Parallel 2D Mamba | Fast, efficient, attention-free |\n",
1316
- "| Downsampling | PixelUnshuffle stride-2 conv | Lossless initial features |\n",
1317
  "| Upsampling | PixelShuffle (sub-pixel) | Mobile-friendly, no checkerboard |\n",
1318
  "| Latent | Multi-scale (base/detail/style) | Controllable, prevents collapse |\n",
1319
  "| Style control | FiLM conditioning | Lightweight, multiplicative |\n",
@@ -1326,13 +1347,13 @@
1326
  "\n",
1327
  "| Feature | PMA-VAE | SD-VAE | NVAE |\n",
1328
  "|---|---|---|---|\n",
1329
- "| Attention-free | | | |\n",
1330
- "| Mobile-friendly decoder | | | |\n",
1331
- "| Multi-scale latent | | | |\n",
1332
- "| Style control built-in | | | |\n",
1333
  "| Decoder params | ~4-8M | ~50M | ~100M+ |\n",
1334
- "| Parallel training | | | |\n",
1335
- "| Free Colab trainable | | | |"
1336
  ]
1337
  },
1338
  {
@@ -1341,7 +1362,7 @@
1341
  "source": [
1342
  "---\n",
1343
  "\n",
1344
- "## 📚 References\n",
1345
  "\n",
1346
  "- **Mamba**: Gu & Dao, 2023. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752)\n",
1347
  "- **VMamba**: Liu et al., 2024. [VMamba: Visual State Space Model](https://arxiv.org/abs/2401.10166)\n",
@@ -1355,4 +1376,4 @@
1355
  ]
1356
  }
1357
  ]
1358
- }
 
17
  "cell_type": "markdown",
18
  "metadata": {},
19
  "source": [
20
+ "# \ud83c\udfa8 PMA-VAE: Parallel Mobile Artistic Variational Autoencoder\n",
21
  "\n",
22
  "**A novel attention-free architecture for image generation, super-resolution, artifact removal, and artistic style transfer.**\n",
23
  "\n",
 
31
  "\n",
32
  "## Architecture\n",
33
  "```\n",
34
+ "Image \u2192 PixelUnshuffle stem \u2192 MobileConv stages \u2192 Parallel 2D Mamba blocks\n",
35
+ " \u2192 Multi-scale latent (z_base H/16, z_detail H/8, z_style global)\n",
36
+ " \u2192 Light parallel decoder with FiLM style modulation \u2192 Reconstructed image\n",
37
  "```\n",
38
  "\n",
39
  "## Key Design Decisions\n",
40
+ "- **Parallel scan SSM** (Blelloch algorithm) \u2014 pure PyTorch, no CUDA kernels needed\n",
41
+ "- **Cross-scan 2D** (VMamba-style) \u2014 4 directional scans for global context without attention\n",
42
+ "- **PixelShuffle upsampling** \u2014 efficient sub-pixel convolution for mobile\n",
43
+ "- **Taming-transformers loss recipe** \u2014 adaptive discriminator weight balancing\n",
44
+ "- **Progressive resolution training** \u2014 start small, scale up\n",
45
  "\n",
46
  "---\n",
47
  "**Trainable on free Colab T4 GPU (15GB VRAM) in ~2-4 hours for meaningful results.**"
 
68
  " print(f'GPU: {torch.cuda.get_device_name(0)}')\n",
69
  " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB')\n",
70
  "else:\n",
71
+ " print('\u26a0\ufe0f No GPU detected! Go to Runtime \u2192 Change runtime type \u2192 T4 GPU')"
72
  ]
73
  },
74
  {
 
90
  "The full model is defined below in a single cell for easy Colab use.\n",
91
  "\n",
92
  "### Component breakdown:\n",
93
+ "1. **Parallel Scan (PScan)** \u2014 Blelloch parallel prefix scan in pure PyTorch\n",
94
+ "2. **Selective SSM (S6)** \u2014 Mamba's core mechanism, input-dependent state space\n",
95
+ "3. **2D Cross-Scan** \u2014 VMamba-style 4-directional scanning for 2D feature maps\n",
96
+ "4. **Mobile Conv Blocks** \u2014 Depthwise separable + SE + FiLM conditioning\n",
97
+ "5. **Encoder** \u2014 Progressive downsampling with hybrid MobileConv + Mamba stages\n",
98
+ "6. **Decoder** \u2014 Lightweight with FiLM style modulation, PixelShuffle upsampling"
99
  ]
100
  },
101
  {
 
112
  "\n",
113
  "\n",
114
  "# ============================================================================\n",
115
+ "# Parallel Scan (Blelloch) \u2014 Pure PyTorch, no CUDA kernels\n",
116
  "# ============================================================================\n",
117
  "\n",
118
  "class PScan(torch.autograd.Function):\n",
 
379
  " nn.Conv2d(in_channels * 4, stage_channels[0], 3, padding=1, bias=False),\n",
380
  " nn.BatchNorm2d(stage_channels[0]), nn.SiLU(inplace=True))\n",
381
  "\n",
382
+ " # Stage 1: H/2 \u2192 H/4 (MobileConv only)\n",
383
  " s1 = [MobileConvBlock(stage_channels[0], stage_channels[1], stride=2)]\n",
384
  " for _ in range(stage_blocks[0] - 1):\n",
385
  " s1.append(MobileConvBlock(stage_channels[1], stage_channels[1]))\n",
386
  " self.stage1 = nn.Sequential(*s1)\n",
387
  "\n",
388
+ " # Stage 2: H/4 \u2192 H/8 (hybrid MobileConv + Mamba)\n",
389
  " s2 = nn.ModuleList()\n",
390
  " s2.append(MobileConvBlock(stage_channels[1], stage_channels[2], stride=2))\n",
391
  " n_mamba = max(1, (stage_blocks[1] - 1) // 2)\n",
 
398
  " self.detail_head_mu = nn.Conv2d(stage_channels[2], latent_detail_dim, 1)\n",
399
  " self.detail_head_logvar = nn.Conv2d(stage_channels[2], latent_detail_dim, 1)\n",
400
  "\n",
401
+ " # Stage 3: H/8 \u2192 H/16 (Mamba-heavy)\n",
402
  " s3 = nn.ModuleList()\n",
403
  " s3.append(MobileConvBlock(stage_channels[2], stage_channels[3], stride=2))\n",
404
  " n_mamba3 = max(1, int((stage_blocks[2] - 1) * 0.75))\n",
 
533
  "# ============================================================================\n",
534
  "\n",
535
  "def pmavae_small(use_parallel_scan=True):\n",
536
+ " \"\"\"~6M params \u2014 fast training on free Colab T4\"\"\"\n",
537
  " return PMAVAE(enc_channels=(48, 96, 144, 192), dec_channels=(192, 144, 96, 72, 48),\n",
538
  " enc_blocks=(2, 2, 3, 3), latent_base_dim=24, latent_detail_dim=6,\n",
539
  " latent_style_dim=96, d_state=16, use_parallel_scan=use_parallel_scan)\n",
540
  "\n",
541
  "def pmavae_base(use_parallel_scan=True):\n",
542
+ " \"\"\"~15M params \u2014 high quality, needs more VRAM\"\"\"\n",
543
  " return PMAVAE(enc_channels=(64, 128, 192, 256), dec_channels=(256, 192, 128, 96, 64),\n",
544
  " enc_blocks=(2, 2, 4, 4), latent_base_dim=32, latent_detail_dim=8,\n",
545
  " latent_style_dim=128, d_state=16, use_parallel_scan=use_parallel_scan)\n",
546
  "\n",
547
+ "print('\u2705 PMA-VAE architecture defined!')"
548
  ]
549
  },
550
  {
 
568
  " print(f' {k}: {v.shape}')\n",
569
  "\n",
570
  "params = model.count_parameters()\n",
571
+ "print(f'\\n\ud83d\udcca Parameters: {params[\"total_M\"]:.2f}M total')\n",
572
  "print(f' Encoder: {params[\"enc_M\"]:.2f}M | Decoder: {params[\"dec_M\"]:.2f}M')\n",
573
  "\n",
574
  "del model, x, recon\n",
 
582
  "## 3. Loss Functions\n",
583
  "\n",
584
  "Our loss combines:\n",
585
+ "- **L1 reconstruction** \u2014 pixel-level fidelity\n",
586
+ "- **VGG perceptual** \u2014 semantic/structural similarity \n",
587
+ "- **PatchGAN discriminator** \u2014 sharp, realistic textures\n",
588
+ "- **KL with free bits** \u2014 prevents posterior collapse\n",
589
+ "- **Edge preservation** \u2014 high-frequency detail via Sobel filters\n",
590
+ "- **Adaptive discriminator weight** \u2014 taming-transformers trick"
591
  ]
592
  },
593
  {
 
701
  " d = hinge_d_loss(self.discriminator(inputs.detach()), self.discriminator(recon.detach()))\n",
702
  " return d, {'d_loss': d.item()}\n",
703
  "\n",
704
+ "print('\u2705 Loss functions defined!')"
705
  ]
706
  },
707
  {
 
711
  "## 4. Dataset Setup\n",
712
  "\n",
713
  "We use a HuggingFace dataset for training. Options:\n",
714
+ "- `huggan/wikiart` \u2014 artistic images (great for style learning)\n",
715
+ "- `ILSVRC/imagenet-1k` \u2014 diverse natural images\n",
716
  "- Any folder of images\n",
717
  "\n",
718
  "For free Colab, we use a moderate-sized art dataset."
 
724
  "metadata": {},
725
  "outputs": [],
726
  "source": [
727
+ "from torch.utils.data import DataLoader, Dataset, IterableDataset\n",
728
  "from torchvision import transforms\n",
729
  "from PIL import Image\n",
730
  "import os\n",
731
  "\n",
732
+ "# ======== Streaming HF Dataset (RAM-safe) ========\n",
733
+ "# This NEVER loads the full dataset into RAM.\n",
734
+ "# Images are decoded one-at-a-time from Parquet shards.\n",
735
+ "\n",
736
+ "class StreamingHFDataset(IterableDataset):\n",
737
+ " \"\"\"\n",
738
+ " Wraps a HuggingFace streaming dataset for PyTorch.\n",
739
+ " RAM usage: ~50-100MB regardless of dataset size.\n",
740
+ " \n",
741
+ " Key: we use datasets streaming mode which reads Parquet\n",
742
+ " files chunk-by-chunk from HF Hub, never materializing\n",
743
+ " the full dataset in memory.\n",
744
+ " \"\"\"\n",
745
+ " def __init__(self, hf_iterable_dataset, image_col='image', resolution=256):\n",
746
+ " self.ds = hf_iterable_dataset\n",
747
  " self.col = image_col\n",
748
  " self.transform = transforms.Compose([\n",
749
+ " transforms.Resize(int(resolution * 1.15),\n",
750
+ " interpolation=transforms.InterpolationMode.LANCZOS,\n",
751
+ " antialias=True),\n",
752
  " transforms.RandomCrop(resolution),\n",
753
  " transforms.RandomHorizontalFlip(),\n",
754
  " transforms.ToTensor(),\n",
755
  " transforms.Normalize([0.5]*3, [0.5]*3)])\n",
 
 
 
 
 
756
  "\n",
757
+ " def __iter__(self):\n",
758
+ " for sample in self.ds:\n",
759
+ " img = sample[self.col]\n",
760
+ " if not isinstance(img, Image.Image):\n",
761
+ " img = Image.fromarray(img)\n",
762
+ " img = img.convert('RGB')\n",
763
+ " # Ensure minimum size for crop\n",
764
+ " w, h = img.size\n",
765
+ " if w < 64 or h < 64:\n",
766
+ " continue # skip tiny images\n",
767
+ " try:\n",
768
+ " yield self.transform(img)\n",
769
+ " except Exception:\n",
770
+ " continue # skip corrupt images\n",
771
+ "\n",
772
+ "# ======== Local folder (non-streaming) ========\n",
773
  "class FolderDataset(Dataset):\n",
774
  " def __init__(self, root, resolution=256):\n",
775
  " exts = {'.jpg','.jpeg','.png','.bmp','.webp'}\n",
776
+ " self.files = [os.path.join(dp,f) for dp,_,fns in os.walk(root)\n",
777
+ " for f in fns if os.path.splitext(f)[1].lower() in exts]\n",
778
  " self.transform = transforms.Compose([\n",
779
+ " transforms.Resize(int(resolution * 1.15),\n",
780
+ " interpolation=transforms.InterpolationMode.LANCZOS,\n",
781
+ " antialias=True),\n",
782
  " transforms.RandomCrop(resolution),\n",
783
  " transforms.RandomHorizontalFlip(),\n",
784
  " transforms.ToTensor(),\n",
 
787
  " def __getitem__(self, idx):\n",
788
  " return self.transform(Image.open(self.files[idx]).convert('RGB'))\n",
789
  "\n",
790
+ "print('\u2705 Dataset classes defined!')"
791
  ]
792
  },
793
  {
 
796
  "metadata": {},
797
  "outputs": [],
798
  "source": [
 
799
  "from datasets import load_dataset\n",
800
  "\n",
801
+ "# ============================================================================\n",
802
+ "# Dataset Configuration\n",
803
+ "# ============================================================================\n",
804
+ "DATASET_NAME = 'huggan/wikiart' # 80K art images (~5GB)\n",
805
  "IMAGE_COLUMN = 'image'\n",
806
+ "RESOLUTION = 256\n",
807
+ "BATCH_SIZE = 8 # Fits T4 15GB with pmavae_small\n",
 
808
  "\n",
809
+ "# ============================================================================\n",
810
+ "# CRITICAL: Use streaming=True to avoid RAM crash!\n",
811
+ "# \n",
812
+ "# Without streaming: HF downloads ALL 5GB of images \u2192 decodes to PIL \u2192\n",
813
+ "# stores in RAM \u2192 Colab's 12GB RAM is exhausted \u2192 kernel crash.\n",
814
+ "# \n",
815
+ "# With streaming: HF reads Parquet shards on-the-fly \u2192 decodes one\n",
816
+ "# image at a time \u2192 constant ~100MB RAM usage.\n",
817
+ "# ============================================================================\n",
818
+ "print(f'Loading {DATASET_NAME} in streaming mode...')\n",
819
+ "raw_stream = load_dataset(DATASET_NAME, split='train', streaming=True)\n",
820
  "\n",
821
+ "# Shuffle with a buffer (keeps only 1000 samples in RAM at once)\n",
822
+ "raw_stream = raw_stream.shuffle(seed=42, buffer_size=1000)\n",
 
823
  "\n",
824
+ "dataset = StreamingHFDataset(raw_stream, IMAGE_COLUMN, RESOLUTION)\n",
 
825
  "\n",
826
+ "# ============================================================================\n",
827
+ "# DataLoader for streaming dataset\n",
828
+ "# \n",
829
+ "# IMPORTANT differences from map-style DataLoader:\n",
830
+ "# - num_workers=0 (streaming datasets handle their own I/O)\n",
831
+ "# - No shuffle (already shuffled in the stream buffer above)\n",
832
+ "# - drop_last=True (partial batches can cause issues)\n",
833
+ "# ============================================================================\n",
834
+ "dataloader = DataLoader(\n",
835
+ " dataset,\n",
836
+ " batch_size=BATCH_SIZE,\n",
837
+ " num_workers=0, # streaming handles I/O internally\n",
838
+ " pin_memory=True,\n",
839
+ " drop_last=True,\n",
840
+ ")\n",
841
+ "\n",
842
+ "# Quick sanity check \u2014 grab one batch\n",
843
+ "print('Fetching first batch...')\n",
844
  "sample = next(iter(dataloader))\n",
845
+ "print(f'\u2705 Batch shape: {sample.shape}')\n",
846
+ "print(f' Value range: [{sample.min():.2f}, {sample.max():.2f}]')\n",
847
+ "print(f' RAM usage: minimal (streaming mode)')\n",
848
+ "print()\n",
849
+ "print('NOTE: With streaming, len(dataloader) is unknown.')\n",
850
+ "print('Training runs by step count, not epoch count.')"
851
  ]
852
  },
853
  {
 
857
  "## 5. Training\n",
858
  "\n",
859
  "### Training recipe:\n",
860
+ "- **Phase 1** (256\u00d7256): Learn structure and composition\n",
861
+ "- **Phase 2** (384\u00d7384): Refine texture details\n",
862
+ "- **Phase 3** (512\u00d7512): Fine-tune for high-res quality\n",
863
  "\n",
864
  "### Anti-collapse measures:\n",
865
+ "1. **KL warmup**: \u03b2 goes from 0 \u2192 target over first 5000 steps\n",
866
  "2. **Free bits**: Each latent dimension must use at least 0.25 nats\n",
867
  "3. **Discriminator cold start**: Only activates after 10000 steps\n",
868
  "4. **Adaptive disc weight**: Balances recon vs adversarial gradients\n",
 
972
  "outputs": [],
973
  "source": [
974
  "# ============================================================================\n",
975
+ "# Training Loop \u2014 Streaming-Compatible\n",
976
  "# ============================================================================\n",
977
+ "# Since streaming datasets don't have len(), we train by step count.\n",
978
+ "# The stream automatically loops when exhausted.\n",
979
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980
  "global_step = 0\n",
981
  "history = []\n",
982
  "start_time = time.time()\n",
 
983
  "\n",
984
+ "# Get a fixed batch for visualization (detach from stream)\n",
985
+ "vis_batch = next(iter(dataloader)).clone()\n",
986
+ "\n",
987
+ "print(f'\\n\ud83d\ude80 Starting training! Target: {CONFIG[\"max_steps\"]} steps')\n",
988
+ "print(f' KL warmup: 0 \u2192 {CONFIG[\"kl_weight\"]} over {CONFIG[\"kl_warmup_steps\"]} steps')\n",
989
  "print(f' Discriminator starts at step {CONFIG[\"disc_start\"]}\\n')\n",
990
  "\n",
991
  "model.train()\n",
992
+ "\n",
993
+ "# Infinite iterator over the streaming dataloader\n",
994
+ "data_iter = iter(dataloader)\n",
995
+ "\n",
996
+ "while global_step < CONFIG['max_steps']:\n",
997
+ " # Get next batch (re-create iterator if stream exhausted)\n",
998
+ " try:\n",
999
+ " batch = next(data_iter)\n",
1000
+ " except StopIteration:\n",
1001
+ " # Stream exhausted = 1 epoch done. Re-create.\n",
1002
+ " data_iter = iter(dataloader)\n",
1003
+ " batch = next(data_iter)\n",
1004
+ "\n",
1005
+ " batch = batch.to(device)\n",
1006
+ "\n",
1007
+ " # KL warmup\n",
1008
+ " kl_w = CONFIG['kl_weight'] * min(1.0, global_step / max(1, CONFIG['kl_warmup_steps']))\n",
1009
+ " criterion.kl_weight = kl_w\n",
1010
+ "\n",
1011
+ " # === VAE update ===\n",
1012
+ " opt_vae.zero_grad()\n",
1013
+ " with autocast('cuda', enabled=device=='cuda'):\n",
1014
+ " recon, posteriors = model(batch)\n",
1015
+ " loss_vae, log_vae = criterion(batch, recon, posteriors, 0, global_step,\n",
1016
+ " model.get_last_decoder_layer())\n",
1017
+ " scaler_vae.scale(loss_vae).backward()\n",
1018
+ " scaler_vae.unscale_(opt_vae)\n",
1019
+ " gn = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
1020
+ " scaler_vae.step(opt_vae)\n",
1021
+ " scaler_vae.update()\n",
1022
+ "\n",
1023
+ " # === Discriminator update ===\n",
1024
+ " opt_disc.zero_grad()\n",
1025
+ " with autocast('cuda', enabled=device=='cuda'):\n",
1026
+ " with torch.no_grad():\n",
1027
+ " recon_d, _ = model(batch)\n",
1028
+ " loss_disc, log_disc = criterion(batch, recon_d, posteriors, 1, global_step)\n",
1029
+ " if global_step >= CONFIG['disc_start']:\n",
1030
+ " scaler_disc.scale(loss_disc).backward()\n",
1031
+ " scaler_disc.unscale_(opt_disc)\n",
1032
+ " torch.nn.utils.clip_grad_norm_(criterion.discriminator.parameters(), 1.0)\n",
1033
+ " scaler_disc.step(opt_disc)\n",
1034
+ " scaler_disc.update()\n",
1035
+ "\n",
1036
+ " global_step += 1\n",
1037
+ " log = {**log_vae, **log_disc, 'step': global_step, 'grad_norm': gn.item(), 'kl_w': kl_w}\n",
1038
+ "\n",
1039
+ " if global_step % CONFIG['log_every'] == 0:\n",
1040
+ " history.append(log)\n",
1041
+ " elapsed = (time.time() - start_time) / 60\n",
1042
+ " print(f\"Step {global_step:6d} | L1:{log['l1']:.4f} | Perc:{log.get('perc',0):.4f} | \"\n",
1043
+ " f\"KL:{log.get('kl_base',0):.1f}/{log.get('kl_detail',0):.1f}/{log.get('kl_style',0):.1f} | \"\n",
1044
+ " f\"D:{log.get('d_loss',0):.4f} | G:{log.get('g_loss',0):.4f} | \"\n",
1045
+ " f\"GN:{log['grad_norm']:.2f} | {elapsed:.1f}min\")\n",
1046
+ "\n",
1047
+ " if global_step % CONFIG['vis_every'] == 0:\n",
1048
+ " clear_output(wait=True)\n",
1049
+ " visualize_reconstruction(model, vis_batch, global_step)\n",
1050
+ " plot_losses(history)\n",
1051
+ "\n",
1052
+ " if global_step % CONFIG['save_every'] == 0:\n",
1053
+ " os.makedirs('checkpoints', exist_ok=True)\n",
1054
+ " torch.save({'model': model.state_dict(),\n",
1055
+ " 'disc': criterion.discriminator.state_dict(),\n",
1056
+ " 'opt_vae': opt_vae.state_dict(),\n",
1057
+ " 'opt_disc': opt_disc.state_dict(),\n",
1058
+ " 'step': global_step, 'config': CONFIG},\n",
1059
+ " f'checkpoints/pma_vae_step{global_step}.pt')\n",
1060
+ " print(f'\ud83d\udcbe Saved checkpoint at step {global_step}')\n",
1061
  "\n",
1062
  "# Final save\n",
1063
  "torch.save({'model': model.state_dict(), 'config': CONFIG}, 'checkpoints/pma_vae_final.pt')\n",
1064
  "total_time = (time.time() - start_time) / 60\n",
1065
+ "print(f'\\n\u2705 Training complete! {global_step} steps in {total_time:.1f} minutes')\n",
1066
+ "print(f'\ud83d\udcbe Final model saved to checkpoints/pma_vae_final.pt')"
1067
  ]
1068
  },
1069
  {
 
1101
  " psnr = -10 * math.log10(mse + 1e-8)\n",
1102
  " psnrs.append(psnr)\n",
1103
  "\n",
1104
+ "print(f'\\n\ud83d\udcca Evaluation Results:')\n",
1105
  "print(f' Average PSNR: {sum(psnrs)/len(psnrs):.2f} dB')\n",
1106
  "print(f' Min PSNR: {min(psnrs):.2f} dB')\n",
1107
  "print(f' Max PSNR: {max(psnrs):.2f} dB')"
 
1165
  " out = model.decoder(pa['base_mu'], pa['detail_mu'], z_style)\n",
1166
  " img = out[0].cpu().permute(1,2,0).numpy() * 0.5 + 0.5\n",
1167
  " axes[i].imshow(img.clip(0,1))\n",
1168
+ " axes[i].set_title(f'\u03b1={alpha:.2f}')\n",
1169
  " axes[i].axis('off')\n",
1170
  "plt.suptitle('Style Interpolation (structure fixed, style varies)', fontsize=14)\n",
1171
  "plt.tight_layout()\n",
 
1218
  "model.eval()\n",
1219
  "\n",
1220
  "# Dummy inputs matching the latent shapes\n",
1221
+ "dummy_base = torch.randn(1, 24, 16, 16, device=device) # For 256\u00d7256 input\n",
1222
  "dummy_detail = torch.randn(1, 6, 32, 32, device=device)\n",
1223
  "dummy_style = torch.randn(1, 96, device=device)\n",
1224
  "\n",
 
1236
  ")\n",
1237
  "\n",
1238
  "onnx_size = os.path.getsize('pma_vae_decoder.onnx') / 1024**2\n",
1239
+ "print(f'\\n\ud83d\udcf1 ONNX decoder exported!')\n",
1240
  "print(f' Size: {onnx_size:.1f} MB')\n",
1241
  "print(f' Ready for: Core ML, TFLite, ONNX Runtime Mobile')\n",
1242
  "\n",
 
1249
  "source": [
1250
  "## 9. Progressive Resolution Training\n",
1251
  "\n",
1252
+ "After initial training at 256\u00d7256, progressively increase resolution.\n",
1253
  "The model handles variable resolutions thanks to the convolutional architecture."
1254
  ]
1255
  },
 
1274
  "# for pg in opt_disc.param_groups: pg['lr'] *= 0.5\n",
1275
  "# \n",
1276
  "# # Continue training (copy the training loop above with dataloader_hr)\n",
1277
+ "# print(f'Phase 2: Training at {NEW_RESOLUTION}\u00d7{NEW_RESOLUTION}')\n",
1278
  "# print(f'Batches per epoch: {len(dataloader_hr)}')"
1279
  ]
1280
  },
 
1297
  "model.eval()\n",
1298
  "\n",
1299
  "# Take a high-res image and downsample it\n",
1300
+ "hr_img = test_batch[0:1] # 256\u00d7256\n",
1301
+ "lr_img = F.interpolate(hr_img, scale_factor=0.5, mode='bilinear', align_corners=False) # 128\u00d7128\n",
1302
  "lr_upscaled = F.interpolate(lr_img, size=(256, 256), mode='bilinear', align_corners=False)\n",
1303
  "\n",
1304
  "with torch.no_grad():\n",
 
1334
  "| Component | Choice | Why |\n",
1335
  "|---|---|---|\n",
1336
  "| Backbone | MobileConv + Parallel 2D Mamba | Fast, efficient, attention-free |\n",
1337
+ "| Downsampling | PixelUnshuffle \u2192 stride-2 conv | Lossless initial features |\n",
1338
  "| Upsampling | PixelShuffle (sub-pixel) | Mobile-friendly, no checkerboard |\n",
1339
  "| Latent | Multi-scale (base/detail/style) | Controllable, prevents collapse |\n",
1340
  "| Style control | FiLM conditioning | Lightweight, multiplicative |\n",
 
1347
  "\n",
1348
  "| Feature | PMA-VAE | SD-VAE | NVAE |\n",
1349
  "|---|---|---|---|\n",
1350
+ "| Attention-free | \u2705 | \u274c | \u274c |\n",
1351
+ "| Mobile-friendly decoder | \u2705 | \u274c | \u274c |\n",
1352
+ "| Multi-scale latent | \u2705 | \u274c | \u2705 |\n",
1353
+ "| Style control built-in | \u2705 | \u274c | \u274c |\n",
1354
  "| Decoder params | ~4-8M | ~50M | ~100M+ |\n",
1355
+ "| Parallel training | \u2705 | \u2705 | \u2705 |\n",
1356
+ "| Free Colab trainable | \u2705 | \u274c | \u274c |"
1357
  ]
1358
  },
1359
  {
 
1362
  "source": [
1363
  "---\n",
1364
  "\n",
1365
+ "## \ud83d\udcda References\n",
1366
  "\n",
1367
  "- **Mamba**: Gu & Dao, 2023. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752)\n",
1368
  "- **VMamba**: Liu et al., 2024. [VMamba: Visual State Space Model](https://arxiv.org/abs/2401.10166)\n",
 
1376
  ]
1377
  }
1378
  ]
1379
+ }