Upload PMA_VAE_Colab_Training.ipynb with huggingface_hub
Browse files- PMA_VAE_Colab_Training.ipynb +218 -197
PMA_VAE_Colab_Training.ipynb
CHANGED
|
@@ -17,7 +17,7 @@
|
|
| 17 |
"cell_type": "markdown",
|
| 18 |
"metadata": {},
|
| 19 |
"source": [
|
| 20 |
-
"#
|
| 21 |
"\n",
|
| 22 |
"**A novel attention-free architecture for image generation, super-resolution, artifact removal, and artistic style transfer.**\n",
|
| 23 |
"\n",
|
|
@@ -31,17 +31,17 @@
|
|
| 31 |
"\n",
|
| 32 |
"## Architecture\n",
|
| 33 |
"```\n",
|
| 34 |
-
"Image
|
| 35 |
-
"
|
| 36 |
-
"
|
| 37 |
"```\n",
|
| 38 |
"\n",
|
| 39 |
"## Key Design Decisions\n",
|
| 40 |
-
"- **Parallel scan SSM** (Blelloch algorithm)
|
| 41 |
-
"- **Cross-scan 2D** (VMamba-style)
|
| 42 |
-
"- **PixelShuffle upsampling**
|
| 43 |
-
"- **Taming-transformers loss recipe**
|
| 44 |
-
"- **Progressive resolution training**
|
| 45 |
"\n",
|
| 46 |
"---\n",
|
| 47 |
"**Trainable on free Colab T4 GPU (15GB VRAM) in ~2-4 hours for meaningful results.**"
|
|
@@ -68,7 +68,7 @@
|
|
| 68 |
" print(f'GPU: {torch.cuda.get_device_name(0)}')\n",
|
| 69 |
" print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB')\n",
|
| 70 |
"else:\n",
|
| 71 |
-
" print('
|
| 72 |
]
|
| 73 |
},
|
| 74 |
{
|
|
@@ -90,12 +90,12 @@
|
|
| 90 |
"The full model is defined below in a single cell for easy Colab use.\n",
|
| 91 |
"\n",
|
| 92 |
"### Component breakdown:\n",
|
| 93 |
-
"1. **Parallel Scan (PScan)**
|
| 94 |
-
"2. **Selective SSM (S6)**
|
| 95 |
-
"3. **2D Cross-Scan**
|
| 96 |
-
"4. **Mobile Conv Blocks**
|
| 97 |
-
"5. **Encoder**
|
| 98 |
-
"6. **Decoder**
|
| 99 |
]
|
| 100 |
},
|
| 101 |
{
|
|
@@ -112,7 +112,7 @@
|
|
| 112 |
"\n",
|
| 113 |
"\n",
|
| 114 |
"# ============================================================================\n",
|
| 115 |
-
"# Parallel Scan (Blelloch)
|
| 116 |
"# ============================================================================\n",
|
| 117 |
"\n",
|
| 118 |
"class PScan(torch.autograd.Function):\n",
|
|
@@ -379,13 +379,13 @@
|
|
| 379 |
" nn.Conv2d(in_channels * 4, stage_channels[0], 3, padding=1, bias=False),\n",
|
| 380 |
" nn.BatchNorm2d(stage_channels[0]), nn.SiLU(inplace=True))\n",
|
| 381 |
"\n",
|
| 382 |
-
" # Stage 1: H/2
|
| 383 |
" s1 = [MobileConvBlock(stage_channels[0], stage_channels[1], stride=2)]\n",
|
| 384 |
" for _ in range(stage_blocks[0] - 1):\n",
|
| 385 |
" s1.append(MobileConvBlock(stage_channels[1], stage_channels[1]))\n",
|
| 386 |
" self.stage1 = nn.Sequential(*s1)\n",
|
| 387 |
"\n",
|
| 388 |
-
" # Stage 2: H/4
|
| 389 |
" s2 = nn.ModuleList()\n",
|
| 390 |
" s2.append(MobileConvBlock(stage_channels[1], stage_channels[2], stride=2))\n",
|
| 391 |
" n_mamba = max(1, (stage_blocks[1] - 1) // 2)\n",
|
|
@@ -398,7 +398,7 @@
|
|
| 398 |
" self.detail_head_mu = nn.Conv2d(stage_channels[2], latent_detail_dim, 1)\n",
|
| 399 |
" self.detail_head_logvar = nn.Conv2d(stage_channels[2], latent_detail_dim, 1)\n",
|
| 400 |
"\n",
|
| 401 |
-
" # Stage 3: H/8
|
| 402 |
" s3 = nn.ModuleList()\n",
|
| 403 |
" s3.append(MobileConvBlock(stage_channels[2], stage_channels[3], stride=2))\n",
|
| 404 |
" n_mamba3 = max(1, int((stage_blocks[2] - 1) * 0.75))\n",
|
|
@@ -533,18 +533,18 @@
|
|
| 533 |
"# ============================================================================\n",
|
| 534 |
"\n",
|
| 535 |
"def pmavae_small(use_parallel_scan=True):\n",
|
| 536 |
-
" \"\"\"~6M params
|
| 537 |
" return PMAVAE(enc_channels=(48, 96, 144, 192), dec_channels=(192, 144, 96, 72, 48),\n",
|
| 538 |
" enc_blocks=(2, 2, 3, 3), latent_base_dim=24, latent_detail_dim=6,\n",
|
| 539 |
" latent_style_dim=96, d_state=16, use_parallel_scan=use_parallel_scan)\n",
|
| 540 |
"\n",
|
| 541 |
"def pmavae_base(use_parallel_scan=True):\n",
|
| 542 |
-
" \"\"\"~15M params
|
| 543 |
" return PMAVAE(enc_channels=(64, 128, 192, 256), dec_channels=(256, 192, 128, 96, 64),\n",
|
| 544 |
" enc_blocks=(2, 2, 4, 4), latent_base_dim=32, latent_detail_dim=8,\n",
|
| 545 |
" latent_style_dim=128, d_state=16, use_parallel_scan=use_parallel_scan)\n",
|
| 546 |
"\n",
|
| 547 |
-
"print('
|
| 548 |
]
|
| 549 |
},
|
| 550 |
{
|
|
@@ -568,7 +568,7 @@
|
|
| 568 |
" print(f' {k}: {v.shape}')\n",
|
| 569 |
"\n",
|
| 570 |
"params = model.count_parameters()\n",
|
| 571 |
-
"print(f'\\n
|
| 572 |
"print(f' Encoder: {params[\"enc_M\"]:.2f}M | Decoder: {params[\"dec_M\"]:.2f}M')\n",
|
| 573 |
"\n",
|
| 574 |
"del model, x, recon\n",
|
|
@@ -582,12 +582,12 @@
|
|
| 582 |
"## 3. Loss Functions\n",
|
| 583 |
"\n",
|
| 584 |
"Our loss combines:\n",
|
| 585 |
-
"- **L1 reconstruction**
|
| 586 |
-
"- **VGG perceptual**
|
| 587 |
-
"- **PatchGAN discriminator**
|
| 588 |
-
"- **KL with free bits**
|
| 589 |
-
"- **Edge preservation**
|
| 590 |
-
"- **Adaptive discriminator weight**
|
| 591 |
]
|
| 592 |
},
|
| 593 |
{
|
|
@@ -701,7 +701,7 @@
|
|
| 701 |
" d = hinge_d_loss(self.discriminator(inputs.detach()), self.discriminator(recon.detach()))\n",
|
| 702 |
" return d, {'d_loss': d.item()}\n",
|
| 703 |
"\n",
|
| 704 |
-
"print('
|
| 705 |
]
|
| 706 |
},
|
| 707 |
{
|
|
@@ -711,8 +711,8 @@
|
|
| 711 |
"## 4. Dataset Setup\n",
|
| 712 |
"\n",
|
| 713 |
"We use a HuggingFace dataset for training. Options:\n",
|
| 714 |
-
"- `huggan/wikiart`
|
| 715 |
-
"- `ILSVRC/imagenet-1k`
|
| 716 |
"- Any folder of images\n",
|
| 717 |
"\n",
|
| 718 |
"For free Colab, we use a moderate-sized art dataset."
|
|
@@ -724,35 +724,61 @@
|
|
| 724 |
"metadata": {},
|
| 725 |
"outputs": [],
|
| 726 |
"source": [
|
| 727 |
-
"from torch.utils.data import DataLoader, Dataset\n",
|
| 728 |
"from torchvision import transforms\n",
|
| 729 |
"from PIL import Image\n",
|
| 730 |
"import os\n",
|
| 731 |
"\n",
|
| 732 |
-
"# ========
|
| 733 |
-
"
|
| 734 |
-
"
|
| 735 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 736 |
" self.col = image_col\n",
|
| 737 |
" self.transform = transforms.Compose([\n",
|
| 738 |
-
" transforms.Resize(int(resolution * 1.15),
|
|
|
|
|
|
|
| 739 |
" transforms.RandomCrop(resolution),\n",
|
| 740 |
" transforms.RandomHorizontalFlip(),\n",
|
| 741 |
" transforms.ToTensor(),\n",
|
| 742 |
" transforms.Normalize([0.5]*3, [0.5]*3)])\n",
|
| 743 |
-
" def __len__(self): return len(self.ds)\n",
|
| 744 |
-
" def __getitem__(self, idx):\n",
|
| 745 |
-
" img = self.ds[idx][self.col]\n",
|
| 746 |
-
" if not isinstance(img, Image.Image): img = Image.fromarray(img)\n",
|
| 747 |
-
" return self.transform(img.convert('RGB'))\n",
|
| 748 |
"\n",
|
| 749 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
"class FolderDataset(Dataset):\n",
|
| 751 |
" def __init__(self, root, resolution=256):\n",
|
| 752 |
" exts = {'.jpg','.jpeg','.png','.bmp','.webp'}\n",
|
| 753 |
-
" self.files = [os.path.join(dp,f) for dp,_,fns in os.walk(root)
|
|
|
|
| 754 |
" self.transform = transforms.Compose([\n",
|
| 755 |
-
" transforms.Resize(int(resolution * 1.15),
|
|
|
|
|
|
|
| 756 |
" transforms.RandomCrop(resolution),\n",
|
| 757 |
" transforms.RandomHorizontalFlip(),\n",
|
| 758 |
" transforms.ToTensor(),\n",
|
|
@@ -761,7 +787,7 @@
|
|
| 761 |
" def __getitem__(self, idx):\n",
|
| 762 |
" return self.transform(Image.open(self.files[idx]).convert('RGB'))\n",
|
| 763 |
"\n",
|
| 764 |
-
"print('
|
| 765 |
]
|
| 766 |
},
|
| 767 |
{
|
|
@@ -770,33 +796,58 @@
|
|
| 770 |
"metadata": {},
|
| 771 |
"outputs": [],
|
| 772 |
"source": [
|
| 773 |
-
"# Load dataset\n",
|
| 774 |
"from datasets import load_dataset\n",
|
| 775 |
"\n",
|
| 776 |
-
"# ===
|
| 777 |
-
"
|
|
|
|
|
|
|
| 778 |
"IMAGE_COLUMN = 'image'\n",
|
| 779 |
-
"RESOLUTION = 256
|
| 780 |
-
"BATCH_SIZE = 8
|
| 781 |
-
"NUM_WORKERS = 2\n",
|
| 782 |
"\n",
|
| 783 |
-
"
|
| 784 |
-
"
|
| 785 |
-
"#
|
| 786 |
-
"#
|
| 787 |
-
"#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 788 |
"\n",
|
| 789 |
-
"
|
| 790 |
-
"
|
| 791 |
-
" num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)\n",
|
| 792 |
"\n",
|
| 793 |
-
"
|
| 794 |
-
"print(f'Batches per epoch: {len(dataloader)}')\n",
|
| 795 |
"\n",
|
| 796 |
-
"#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 797 |
"sample = next(iter(dataloader))\n",
|
| 798 |
-
"print(f'Batch shape: {sample.shape}')\n",
|
| 799 |
-
"print(f'Value range: [{sample.min():.2f}, {sample.max():.2f}]')"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 800 |
]
|
| 801 |
},
|
| 802 |
{
|
|
@@ -806,12 +857,12 @@
|
|
| 806 |
"## 5. Training\n",
|
| 807 |
"\n",
|
| 808 |
"### Training recipe:\n",
|
| 809 |
-
"- **Phase 1** (256
|
| 810 |
-
"- **Phase 2** (384
|
| 811 |
-
"- **Phase 3** (512
|
| 812 |
"\n",
|
| 813 |
"### Anti-collapse measures:\n",
|
| 814 |
-
"1. **KL warmup**:
|
| 815 |
"2. **Free bits**: Each latent dimension must use at least 0.25 nats\n",
|
| 816 |
"3. **Discriminator cold start**: Only activates after 10000 steps\n",
|
| 817 |
"4. **Adaptive disc weight**: Balances recon vs adversarial gradients\n",
|
|
@@ -921,128 +972,98 @@
|
|
| 921 |
"outputs": [],
|
| 922 |
"source": [
|
| 923 |
"# ============================================================================\n",
|
| 924 |
-
"# Training Loop
|
| 925 |
"# ============================================================================\n",
|
|
|
|
|
|
|
| 926 |
"\n",
|
| 927 |
-
"def visualize_reconstruction(model, batch, step):\n",
|
| 928 |
-
" \"\"\"Show original vs reconstructed images.\"\"\"\n",
|
| 929 |
-
" model.eval()\n",
|
| 930 |
-
" with torch.no_grad():\n",
|
| 931 |
-
" recon, _ = model(batch[:4].to(device))\n",
|
| 932 |
-
" model.train()\n",
|
| 933 |
-
" \n",
|
| 934 |
-
" fig, axes = plt.subplots(2, 4, figsize=(16, 8))\n",
|
| 935 |
-
" for i in range(4):\n",
|
| 936 |
-
" orig = batch[i].permute(1,2,0).cpu().numpy() * 0.5 + 0.5\n",
|
| 937 |
-
" rec = recon[i].permute(1,2,0).cpu().numpy() * 0.5 + 0.5\n",
|
| 938 |
-
" axes[0,i].imshow(orig.clip(0,1))\n",
|
| 939 |
-
" axes[0,i].set_title('Original')\n",
|
| 940 |
-
" axes[0,i].axis('off')\n",
|
| 941 |
-
" axes[1,i].imshow(rec.clip(0,1))\n",
|
| 942 |
-
" axes[1,i].set_title(f'Recon (step {step})')\n",
|
| 943 |
-
" axes[1,i].axis('off')\n",
|
| 944 |
-
" plt.tight_layout()\n",
|
| 945 |
-
" plt.show()\n",
|
| 946 |
-
"\n",
|
| 947 |
-
"def plot_losses(history):\n",
|
| 948 |
-
" \"\"\"Plot training loss curves.\"\"\"\n",
|
| 949 |
-
" if len(history) < 10: return\n",
|
| 950 |
-
" fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
|
| 951 |
-
" steps = [h['step'] for h in history]\n",
|
| 952 |
-
" axes[0].plot(steps, [h['l1'] for h in history], label='L1')\n",
|
| 953 |
-
" axes[0].plot(steps, [h.get('perc',0) for h in history], label='Perceptual')\n",
|
| 954 |
-
" axes[0].set_title('Reconstruction Losses'); axes[0].legend(); axes[0].set_xlabel('Step')\n",
|
| 955 |
-
" axes[1].plot(steps, [h.get('kl_base',0) for h in history], label='KL base')\n",
|
| 956 |
-
" axes[1].plot(steps, [h.get('kl_detail',0) for h in history], label='KL detail')\n",
|
| 957 |
-
" axes[1].set_title('KL Losses'); axes[1].legend(); axes[1].set_xlabel('Step')\n",
|
| 958 |
-
" axes[2].plot(steps, [h.get('d_loss',0) for h in history], label='Disc')\n",
|
| 959 |
-
" axes[2].plot(steps, [h.get('g_loss',0) for h in history], label='Gen')\n",
|
| 960 |
-
" axes[2].set_title('GAN Losses'); axes[2].legend(); axes[2].set_xlabel('Step')\n",
|
| 961 |
-
" plt.tight_layout(); plt.show()\n",
|
| 962 |
-
"\n",
|
| 963 |
-
"# === TRAINING LOOP ===\n",
|
| 964 |
"global_step = 0\n",
|
| 965 |
"history = []\n",
|
| 966 |
"start_time = time.time()\n",
|
| 967 |
-
"vis_batch = next(iter(dataloader)) # Fixed batch for visualization\n",
|
| 968 |
"\n",
|
| 969 |
-
"
|
| 970 |
-
"
|
|
|
|
|
|
|
|
|
|
| 971 |
"print(f' Discriminator starts at step {CONFIG[\"disc_start\"]}\\n')\n",
|
| 972 |
"\n",
|
| 973 |
"model.train()\n",
|
| 974 |
-
"
|
| 975 |
-
"
|
| 976 |
-
"
|
| 977 |
-
"
|
| 978 |
-
"
|
| 979 |
-
"
|
| 980 |
-
"
|
| 981 |
-
" \n",
|
| 982 |
-
"
|
| 983 |
-
"
|
| 984 |
-
"
|
| 985 |
-
"
|
| 986 |
-
"
|
| 987 |
-
"
|
| 988 |
-
"
|
| 989 |
-
"
|
| 990 |
-
"
|
| 991 |
-
"
|
| 992 |
-
"
|
| 993 |
-
"
|
| 994 |
-
"
|
| 995 |
-
"
|
| 996 |
-
"
|
| 997 |
-
"
|
| 998 |
-
"
|
| 999 |
-
"
|
| 1000 |
-
"
|
| 1001 |
-
"
|
| 1002 |
-
"
|
| 1003 |
-
"
|
| 1004 |
-
"
|
| 1005 |
-
"
|
| 1006 |
-
"
|
| 1007 |
-
"
|
| 1008 |
-
" \n",
|
| 1009 |
-
"
|
| 1010 |
-
"
|
| 1011 |
-
"
|
| 1012 |
-
"
|
| 1013 |
-
"
|
| 1014 |
-
"
|
| 1015 |
-
"
|
| 1016 |
-
"
|
| 1017 |
-
"
|
| 1018 |
-
"
|
| 1019 |
-
"
|
| 1020 |
-
"
|
| 1021 |
-
"
|
| 1022 |
-
"
|
| 1023 |
-
"
|
| 1024 |
-
" \n",
|
| 1025 |
-
"
|
| 1026 |
-
"
|
| 1027 |
-
"
|
| 1028 |
-
"
|
| 1029 |
-
"
|
| 1030 |
-
"
|
| 1031 |
-
"
|
| 1032 |
-
"
|
| 1033 |
-
"
|
| 1034 |
-
"
|
| 1035 |
-
"
|
| 1036 |
-
"
|
| 1037 |
-
"
|
| 1038 |
-
"
|
| 1039 |
-
"
|
|
|
|
|
|
|
|
|
|
| 1040 |
"\n",
|
| 1041 |
"# Final save\n",
|
| 1042 |
"torch.save({'model': model.state_dict(), 'config': CONFIG}, 'checkpoints/pma_vae_final.pt')\n",
|
| 1043 |
"total_time = (time.time() - start_time) / 60\n",
|
| 1044 |
-
"print(f'\\n
|
| 1045 |
-
"print(f'
|
| 1046 |
]
|
| 1047 |
},
|
| 1048 |
{
|
|
@@ -1080,7 +1101,7 @@
|
|
| 1080 |
" psnr = -10 * math.log10(mse + 1e-8)\n",
|
| 1081 |
" psnrs.append(psnr)\n",
|
| 1082 |
"\n",
|
| 1083 |
-
"print(f'\\n
|
| 1084 |
"print(f' Average PSNR: {sum(psnrs)/len(psnrs):.2f} dB')\n",
|
| 1085 |
"print(f' Min PSNR: {min(psnrs):.2f} dB')\n",
|
| 1086 |
"print(f' Max PSNR: {max(psnrs):.2f} dB')"
|
|
@@ -1144,7 +1165,7 @@
|
|
| 1144 |
" out = model.decoder(pa['base_mu'], pa['detail_mu'], z_style)\n",
|
| 1145 |
" img = out[0].cpu().permute(1,2,0).numpy() * 0.5 + 0.5\n",
|
| 1146 |
" axes[i].imshow(img.clip(0,1))\n",
|
| 1147 |
-
" axes[i].set_title(f'
|
| 1148 |
" axes[i].axis('off')\n",
|
| 1149 |
"plt.suptitle('Style Interpolation (structure fixed, style varies)', fontsize=14)\n",
|
| 1150 |
"plt.tight_layout()\n",
|
|
@@ -1197,7 +1218,7 @@
|
|
| 1197 |
"model.eval()\n",
|
| 1198 |
"\n",
|
| 1199 |
"# Dummy inputs matching the latent shapes\n",
|
| 1200 |
-
"dummy_base = torch.randn(1, 24, 16, 16, device=device) # For 256
|
| 1201 |
"dummy_detail = torch.randn(1, 6, 32, 32, device=device)\n",
|
| 1202 |
"dummy_style = torch.randn(1, 96, device=device)\n",
|
| 1203 |
"\n",
|
|
@@ -1215,7 +1236,7 @@
|
|
| 1215 |
")\n",
|
| 1216 |
"\n",
|
| 1217 |
"onnx_size = os.path.getsize('pma_vae_decoder.onnx') / 1024**2\n",
|
| 1218 |
-
"print(f'\\n
|
| 1219 |
"print(f' Size: {onnx_size:.1f} MB')\n",
|
| 1220 |
"print(f' Ready for: Core ML, TFLite, ONNX Runtime Mobile')\n",
|
| 1221 |
"\n",
|
|
@@ -1228,7 +1249,7 @@
|
|
| 1228 |
"source": [
|
| 1229 |
"## 9. Progressive Resolution Training\n",
|
| 1230 |
"\n",
|
| 1231 |
-
"After initial training at 256
|
| 1232 |
"The model handles variable resolutions thanks to the convolutional architecture."
|
| 1233 |
]
|
| 1234 |
},
|
|
@@ -1253,7 +1274,7 @@
|
|
| 1253 |
"# for pg in opt_disc.param_groups: pg['lr'] *= 0.5\n",
|
| 1254 |
"# \n",
|
| 1255 |
"# # Continue training (copy the training loop above with dataloader_hr)\n",
|
| 1256 |
-
"# print(f'Phase 2: Training at {NEW_RESOLUTION}
|
| 1257 |
"# print(f'Batches per epoch: {len(dataloader_hr)}')"
|
| 1258 |
]
|
| 1259 |
},
|
|
@@ -1276,8 +1297,8 @@
|
|
| 1276 |
"model.eval()\n",
|
| 1277 |
"\n",
|
| 1278 |
"# Take a high-res image and downsample it\n",
|
| 1279 |
-
"hr_img = test_batch[0:1] # 256
|
| 1280 |
-
"lr_img = F.interpolate(hr_img, scale_factor=0.5, mode='bilinear', align_corners=False) # 128
|
| 1281 |
"lr_upscaled = F.interpolate(lr_img, size=(256, 256), mode='bilinear', align_corners=False)\n",
|
| 1282 |
"\n",
|
| 1283 |
"with torch.no_grad():\n",
|
|
@@ -1313,7 +1334,7 @@
|
|
| 1313 |
"| Component | Choice | Why |\n",
|
| 1314 |
"|---|---|---|\n",
|
| 1315 |
"| Backbone | MobileConv + Parallel 2D Mamba | Fast, efficient, attention-free |\n",
|
| 1316 |
-
"| Downsampling | PixelUnshuffle
|
| 1317 |
"| Upsampling | PixelShuffle (sub-pixel) | Mobile-friendly, no checkerboard |\n",
|
| 1318 |
"| Latent | Multi-scale (base/detail/style) | Controllable, prevents collapse |\n",
|
| 1319 |
"| Style control | FiLM conditioning | Lightweight, multiplicative |\n",
|
|
@@ -1326,13 +1347,13 @@
|
|
| 1326 |
"\n",
|
| 1327 |
"| Feature | PMA-VAE | SD-VAE | NVAE |\n",
|
| 1328 |
"|---|---|---|---|\n",
|
| 1329 |
-
"| Attention-free |
|
| 1330 |
-
"| Mobile-friendly decoder |
|
| 1331 |
-
"| Multi-scale latent |
|
| 1332 |
-
"| Style control built-in |
|
| 1333 |
"| Decoder params | ~4-8M | ~50M | ~100M+ |\n",
|
| 1334 |
-
"| Parallel training |
|
| 1335 |
-
"| Free Colab trainable |
|
| 1336 |
]
|
| 1337 |
},
|
| 1338 |
{
|
|
@@ -1341,7 +1362,7 @@
|
|
| 1341 |
"source": [
|
| 1342 |
"---\n",
|
| 1343 |
"\n",
|
| 1344 |
-
"##
|
| 1345 |
"\n",
|
| 1346 |
"- **Mamba**: Gu & Dao, 2023. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752)\n",
|
| 1347 |
"- **VMamba**: Liu et al., 2024. [VMamba: Visual State Space Model](https://arxiv.org/abs/2401.10166)\n",
|
|
@@ -1355,4 +1376,4 @@
|
|
| 1355 |
]
|
| 1356 |
}
|
| 1357 |
]
|
| 1358 |
-
}
|
|
|
|
| 17 |
"cell_type": "markdown",
|
| 18 |
"metadata": {},
|
| 19 |
"source": [
|
| 20 |
+
"# \ud83c\udfa8 PMA-VAE: Parallel Mobile Artistic Variational Autoencoder\n",
|
| 21 |
"\n",
|
| 22 |
"**A novel attention-free architecture for image generation, super-resolution, artifact removal, and artistic style transfer.**\n",
|
| 23 |
"\n",
|
|
|
|
| 31 |
"\n",
|
| 32 |
"## Architecture\n",
|
| 33 |
"```\n",
|
| 34 |
+
"Image \u2192 PixelUnshuffle stem \u2192 MobileConv stages \u2192 Parallel 2D Mamba blocks\n",
|
| 35 |
+
" \u2192 Multi-scale latent (z_base H/16, z_detail H/8, z_style global)\n",
|
| 36 |
+
" \u2192 Light parallel decoder with FiLM style modulation \u2192 Reconstructed image\n",
|
| 37 |
"```\n",
|
| 38 |
"\n",
|
| 39 |
"## Key Design Decisions\n",
|
| 40 |
+
"- **Parallel scan SSM** (Blelloch algorithm) \u2014 pure PyTorch, no CUDA kernels needed\n",
|
| 41 |
+
"- **Cross-scan 2D** (VMamba-style) \u2014 4 directional scans for global context without attention\n",
|
| 42 |
+
"- **PixelShuffle upsampling** \u2014 efficient sub-pixel convolution for mobile\n",
|
| 43 |
+
"- **Taming-transformers loss recipe** \u2014 adaptive discriminator weight balancing\n",
|
| 44 |
+
"- **Progressive resolution training** \u2014 start small, scale up\n",
|
| 45 |
"\n",
|
| 46 |
"---\n",
|
| 47 |
"**Trainable on free Colab T4 GPU (15GB VRAM) in ~2-4 hours for meaningful results.**"
|
|
|
|
| 68 |
" print(f'GPU: {torch.cuda.get_device_name(0)}')\n",
|
| 69 |
" print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB')\n",
|
| 70 |
"else:\n",
|
| 71 |
+
" print('\u26a0\ufe0f No GPU detected! Go to Runtime \u2192 Change runtime type \u2192 T4 GPU')"
|
| 72 |
]
|
| 73 |
},
|
| 74 |
{
|
|
|
|
| 90 |
"The full model is defined below in a single cell for easy Colab use.\n",
|
| 91 |
"\n",
|
| 92 |
"### Component breakdown:\n",
|
| 93 |
+
"1. **Parallel Scan (PScan)** \u2014 Blelloch parallel prefix scan in pure PyTorch\n",
|
| 94 |
+
"2. **Selective SSM (S6)** \u2014 Mamba's core mechanism, input-dependent state space\n",
|
| 95 |
+
"3. **2D Cross-Scan** \u2014 VMamba-style 4-directional scanning for 2D feature maps\n",
|
| 96 |
+
"4. **Mobile Conv Blocks** \u2014 Depthwise separable + SE + FiLM conditioning\n",
|
| 97 |
+
"5. **Encoder** \u2014 Progressive downsampling with hybrid MobileConv + Mamba stages\n",
|
| 98 |
+
"6. **Decoder** \u2014 Lightweight with FiLM style modulation, PixelShuffle upsampling"
|
| 99 |
]
|
| 100 |
},
|
| 101 |
{
|
|
|
|
| 112 |
"\n",
|
| 113 |
"\n",
|
| 114 |
"# ============================================================================\n",
|
| 115 |
+
"# Parallel Scan (Blelloch) \u2014 Pure PyTorch, no CUDA kernels\n",
|
| 116 |
"# ============================================================================\n",
|
| 117 |
"\n",
|
| 118 |
"class PScan(torch.autograd.Function):\n",
|
|
|
|
| 379 |
" nn.Conv2d(in_channels * 4, stage_channels[0], 3, padding=1, bias=False),\n",
|
| 380 |
" nn.BatchNorm2d(stage_channels[0]), nn.SiLU(inplace=True))\n",
|
| 381 |
"\n",
|
| 382 |
+
" # Stage 1: H/2 \u2192 H/4 (MobileConv only)\n",
|
| 383 |
" s1 = [MobileConvBlock(stage_channels[0], stage_channels[1], stride=2)]\n",
|
| 384 |
" for _ in range(stage_blocks[0] - 1):\n",
|
| 385 |
" s1.append(MobileConvBlock(stage_channels[1], stage_channels[1]))\n",
|
| 386 |
" self.stage1 = nn.Sequential(*s1)\n",
|
| 387 |
"\n",
|
| 388 |
+
" # Stage 2: H/4 \u2192 H/8 (hybrid MobileConv + Mamba)\n",
|
| 389 |
" s2 = nn.ModuleList()\n",
|
| 390 |
" s2.append(MobileConvBlock(stage_channels[1], stage_channels[2], stride=2))\n",
|
| 391 |
" n_mamba = max(1, (stage_blocks[1] - 1) // 2)\n",
|
|
|
|
| 398 |
" self.detail_head_mu = nn.Conv2d(stage_channels[2], latent_detail_dim, 1)\n",
|
| 399 |
" self.detail_head_logvar = nn.Conv2d(stage_channels[2], latent_detail_dim, 1)\n",
|
| 400 |
"\n",
|
| 401 |
+
" # Stage 3: H/8 \u2192 H/16 (Mamba-heavy)\n",
|
| 402 |
" s3 = nn.ModuleList()\n",
|
| 403 |
" s3.append(MobileConvBlock(stage_channels[2], stage_channels[3], stride=2))\n",
|
| 404 |
" n_mamba3 = max(1, int((stage_blocks[2] - 1) * 0.75))\n",
|
|
|
|
| 533 |
"# ============================================================================\n",
|
| 534 |
"\n",
|
| 535 |
"def pmavae_small(use_parallel_scan=True):\n",
|
| 536 |
+
" \"\"\"~6M params \u2014 fast training on free Colab T4\"\"\"\n",
|
| 537 |
" return PMAVAE(enc_channels=(48, 96, 144, 192), dec_channels=(192, 144, 96, 72, 48),\n",
|
| 538 |
" enc_blocks=(2, 2, 3, 3), latent_base_dim=24, latent_detail_dim=6,\n",
|
| 539 |
" latent_style_dim=96, d_state=16, use_parallel_scan=use_parallel_scan)\n",
|
| 540 |
"\n",
|
| 541 |
"def pmavae_base(use_parallel_scan=True):\n",
|
| 542 |
+
" \"\"\"~15M params \u2014 high quality, needs more VRAM\"\"\"\n",
|
| 543 |
" return PMAVAE(enc_channels=(64, 128, 192, 256), dec_channels=(256, 192, 128, 96, 64),\n",
|
| 544 |
" enc_blocks=(2, 2, 4, 4), latent_base_dim=32, latent_detail_dim=8,\n",
|
| 545 |
" latent_style_dim=128, d_state=16, use_parallel_scan=use_parallel_scan)\n",
|
| 546 |
"\n",
|
| 547 |
+
"print('\u2705 PMA-VAE architecture defined!')"
|
| 548 |
]
|
| 549 |
},
|
| 550 |
{
|
|
|
|
| 568 |
" print(f' {k}: {v.shape}')\n",
|
| 569 |
"\n",
|
| 570 |
"params = model.count_parameters()\n",
|
| 571 |
+
"print(f'\\n\ud83d\udcca Parameters: {params[\"total_M\"]:.2f}M total')\n",
|
| 572 |
"print(f' Encoder: {params[\"enc_M\"]:.2f}M | Decoder: {params[\"dec_M\"]:.2f}M')\n",
|
| 573 |
"\n",
|
| 574 |
"del model, x, recon\n",
|
|
|
|
| 582 |
"## 3. Loss Functions\n",
|
| 583 |
"\n",
|
| 584 |
"Our loss combines:\n",
|
| 585 |
+
"- **L1 reconstruction** \u2014 pixel-level fidelity\n",
|
| 586 |
+
"- **VGG perceptual** \u2014 semantic/structural similarity \n",
|
| 587 |
+
"- **PatchGAN discriminator** \u2014 sharp, realistic textures\n",
|
| 588 |
+
"- **KL with free bits** \u2014 prevents posterior collapse\n",
|
| 589 |
+
"- **Edge preservation** \u2014 high-frequency detail via Sobel filters\n",
|
| 590 |
+
"- **Adaptive discriminator weight** \u2014 taming-transformers trick"
|
| 591 |
]
|
| 592 |
},
|
| 593 |
{
|
|
|
|
| 701 |
" d = hinge_d_loss(self.discriminator(inputs.detach()), self.discriminator(recon.detach()))\n",
|
| 702 |
" return d, {'d_loss': d.item()}\n",
|
| 703 |
"\n",
|
| 704 |
+
"print('\u2705 Loss functions defined!')"
|
| 705 |
]
|
| 706 |
},
|
| 707 |
{
|
|
|
|
| 711 |
"## 4. Dataset Setup\n",
|
| 712 |
"\n",
|
| 713 |
"We use a HuggingFace dataset for training. Options:\n",
|
| 714 |
+
"- `huggan/wikiart` \u2014 artistic images (great for style learning)\n",
|
| 715 |
+
"- `ILSVRC/imagenet-1k` \u2014 diverse natural images\n",
|
| 716 |
"- Any folder of images\n",
|
| 717 |
"\n",
|
| 718 |
"For free Colab, we use a moderate-sized art dataset."
|
|
|
|
| 724 |
"metadata": {},
|
| 725 |
"outputs": [],
|
| 726 |
"source": [
|
| 727 |
+
"from torch.utils.data import DataLoader, Dataset, IterableDataset\n",
|
| 728 |
"from torchvision import transforms\n",
|
| 729 |
"from PIL import Image\n",
|
| 730 |
"import os\n",
|
| 731 |
"\n",
|
| 732 |
+
"# ======== Streaming HF Dataset (RAM-safe) ========\n",
|
| 733 |
+
"# This NEVER loads the full dataset into RAM.\n",
|
| 734 |
+
"# Images are decoded one-at-a-time from Parquet shards.\n",
|
| 735 |
+
"\n",
|
| 736 |
+
"class StreamingHFDataset(IterableDataset):\n",
|
| 737 |
+
" \"\"\"\n",
|
| 738 |
+
" Wraps a HuggingFace streaming dataset for PyTorch.\n",
|
| 739 |
+
" RAM usage: ~50-100MB regardless of dataset size.\n",
|
| 740 |
+
" \n",
|
| 741 |
+
" Key: we use datasets streaming mode which reads Parquet\n",
|
| 742 |
+
" files chunk-by-chunk from HF Hub, never materializing\n",
|
| 743 |
+
" the full dataset in memory.\n",
|
| 744 |
+
" \"\"\"\n",
|
| 745 |
+
" def __init__(self, hf_iterable_dataset, image_col='image', resolution=256):\n",
|
| 746 |
+
" self.ds = hf_iterable_dataset\n",
|
| 747 |
" self.col = image_col\n",
|
| 748 |
" self.transform = transforms.Compose([\n",
|
| 749 |
+
" transforms.Resize(int(resolution * 1.15),\n",
|
| 750 |
+
" interpolation=transforms.InterpolationMode.LANCZOS,\n",
|
| 751 |
+
" antialias=True),\n",
|
| 752 |
" transforms.RandomCrop(resolution),\n",
|
| 753 |
" transforms.RandomHorizontalFlip(),\n",
|
| 754 |
" transforms.ToTensor(),\n",
|
| 755 |
" transforms.Normalize([0.5]*3, [0.5]*3)])\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 756 |
"\n",
|
| 757 |
+
" def __iter__(self):\n",
|
| 758 |
+
" for sample in self.ds:\n",
|
| 759 |
+
" img = sample[self.col]\n",
|
| 760 |
+
" if not isinstance(img, Image.Image):\n",
|
| 761 |
+
" img = Image.fromarray(img)\n",
|
| 762 |
+
" img = img.convert('RGB')\n",
|
| 763 |
+
" # Ensure minimum size for crop\n",
|
| 764 |
+
" w, h = img.size\n",
|
| 765 |
+
" if w < 64 or h < 64:\n",
|
| 766 |
+
" continue # skip tiny images\n",
|
| 767 |
+
" try:\n",
|
| 768 |
+
" yield self.transform(img)\n",
|
| 769 |
+
" except Exception:\n",
|
| 770 |
+
" continue # skip corrupt images\n",
|
| 771 |
+
"\n",
|
| 772 |
+
"# ======== Local folder (non-streaming) ========\n",
|
| 773 |
"class FolderDataset(Dataset):\n",
|
| 774 |
" def __init__(self, root, resolution=256):\n",
|
| 775 |
" exts = {'.jpg','.jpeg','.png','.bmp','.webp'}\n",
|
| 776 |
+
" self.files = [os.path.join(dp,f) for dp,_,fns in os.walk(root)\n",
|
| 777 |
+
" for f in fns if os.path.splitext(f)[1].lower() in exts]\n",
|
| 778 |
" self.transform = transforms.Compose([\n",
|
| 779 |
+
" transforms.Resize(int(resolution * 1.15),\n",
|
| 780 |
+
" interpolation=transforms.InterpolationMode.LANCZOS,\n",
|
| 781 |
+
" antialias=True),\n",
|
| 782 |
" transforms.RandomCrop(resolution),\n",
|
| 783 |
" transforms.RandomHorizontalFlip(),\n",
|
| 784 |
" transforms.ToTensor(),\n",
|
|
|
|
| 787 |
" def __getitem__(self, idx):\n",
|
| 788 |
" return self.transform(Image.open(self.files[idx]).convert('RGB'))\n",
|
| 789 |
"\n",
|
| 790 |
+
"print('\u2705 Dataset classes defined!')"
|
| 791 |
]
|
| 792 |
},
|
| 793 |
{
|
|
|
|
| 796 |
"metadata": {},
|
| 797 |
"outputs": [],
|
| 798 |
"source": [
|
|
|
|
| 799 |
"from datasets import load_dataset\n",
|
| 800 |
"\n",
|
| 801 |
+
"# ============================================================================\n",
|
| 802 |
+
"# Dataset Configuration\n",
|
| 803 |
+
"# ============================================================================\n",
|
| 804 |
+
"DATASET_NAME = 'huggan/wikiart' # 80K art images (~5GB)\n",
|
| 805 |
"IMAGE_COLUMN = 'image'\n",
|
| 806 |
+
"RESOLUTION = 256\n",
|
| 807 |
+
"BATCH_SIZE = 8 # Fits T4 15GB with pmavae_small\n",
|
|
|
|
| 808 |
"\n",
|
| 809 |
+
"# ============================================================================\n",
|
| 810 |
+
"# CRITICAL: Use streaming=True to avoid RAM crash!\n",
|
| 811 |
+
"# \n",
|
| 812 |
+
"# Without streaming: HF downloads ALL 5GB of images \u2192 decodes to PIL \u2192\n",
|
| 813 |
+
"# stores in RAM \u2192 Colab's 12GB RAM is exhausted \u2192 kernel crash.\n",
|
| 814 |
+
"# \n",
|
| 815 |
+
"# With streaming: HF reads Parquet shards on-the-fly \u2192 decodes one\n",
|
| 816 |
+
"# image at a time \u2192 constant ~100MB RAM usage.\n",
|
| 817 |
+
"# ============================================================================\n",
|
| 818 |
+
"print(f'Loading {DATASET_NAME} in streaming mode...')\n",
|
| 819 |
+
"raw_stream = load_dataset(DATASET_NAME, split='train', streaming=True)\n",
|
| 820 |
"\n",
|
| 821 |
+
"# Shuffle with a buffer (keeps only 1000 samples in RAM at once)\n",
|
| 822 |
+
"raw_stream = raw_stream.shuffle(seed=42, buffer_size=1000)\n",
|
|
|
|
| 823 |
"\n",
|
| 824 |
+
"dataset = StreamingHFDataset(raw_stream, IMAGE_COLUMN, RESOLUTION)\n",
|
|
|
|
| 825 |
"\n",
|
| 826 |
+
"# ============================================================================\n",
|
| 827 |
+
"# DataLoader for streaming dataset\n",
|
| 828 |
+
"# \n",
|
| 829 |
+
"# IMPORTANT differences from map-style DataLoader:\n",
|
| 830 |
+
"# - num_workers=0 (streaming datasets handle their own I/O)\n",
|
| 831 |
+
"# - No shuffle (already shuffled in the stream buffer above)\n",
|
| 832 |
+
"# - drop_last=True (partial batches can cause issues)\n",
|
| 833 |
+
"# ============================================================================\n",
|
| 834 |
+
"dataloader = DataLoader(\n",
|
| 835 |
+
" dataset,\n",
|
| 836 |
+
" batch_size=BATCH_SIZE,\n",
|
| 837 |
+
" num_workers=0, # streaming handles I/O internally\n",
|
| 838 |
+
" pin_memory=True,\n",
|
| 839 |
+
" drop_last=True,\n",
|
| 840 |
+
")\n",
|
| 841 |
+
"\n",
|
| 842 |
+
"# Quick sanity check \u2014 grab one batch\n",
|
| 843 |
+
"print('Fetching first batch...')\n",
|
| 844 |
"sample = next(iter(dataloader))\n",
|
| 845 |
+
"print(f'\u2705 Batch shape: {sample.shape}')\n",
|
| 846 |
+
"print(f' Value range: [{sample.min():.2f}, {sample.max():.2f}]')\n",
|
| 847 |
+
"print(f' RAM usage: minimal (streaming mode)')\n",
|
| 848 |
+
"print()\n",
|
| 849 |
+
"print('NOTE: With streaming, len(dataloader) is unknown.')\n",
|
| 850 |
+
"print('Training runs by step count, not epoch count.')"
|
| 851 |
]
|
| 852 |
},
|
| 853 |
{
|
|
|
|
| 857 |
"## 5. Training\n",
|
| 858 |
"\n",
|
| 859 |
"### Training recipe:\n",
|
| 860 |
+
"- **Phase 1** (256\u00d7256): Learn structure and composition\n",
|
| 861 |
+
"- **Phase 2** (384\u00d7384): Refine texture details\n",
|
| 862 |
+
"- **Phase 3** (512\u00d7512): Fine-tune for high-res quality\n",
|
| 863 |
"\n",
|
| 864 |
"### Anti-collapse measures:\n",
|
| 865 |
+
"1. **KL warmup**: \u03b2 goes from 0 \u2192 target over first 5000 steps\n",
|
| 866 |
"2. **Free bits**: Each latent dimension must use at least 0.25 nats\n",
|
| 867 |
"3. **Discriminator cold start**: Only activates after 10000 steps\n",
|
| 868 |
"4. **Adaptive disc weight**: Balances recon vs adversarial gradients\n",
|
|
|
|
| 972 |
"outputs": [],
|
| 973 |
"source": [
|
| 974 |
"# ============================================================================\n",
|
| 975 |
+
"# Training Loop \u2014 Streaming-Compatible\n",
|
| 976 |
"# ============================================================================\n",
|
| 977 |
+
"# Since streaming datasets don't have len(), we train by step count.\n",
|
| 978 |
+
"# The stream automatically loops when exhausted.\n",
|
| 979 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 980 |
"global_step = 0\n",
|
| 981 |
"history = []\n",
|
| 982 |
"start_time = time.time()\n",
|
|
|
|
| 983 |
"\n",
|
| 984 |
+
"# Get a fixed batch for visualization (detach from stream)\n",
|
| 985 |
+
"vis_batch = next(iter(dataloader)).clone()\n",
|
| 986 |
+
"\n",
|
| 987 |
+
"print(f'\\n\ud83d\ude80 Starting training! Target: {CONFIG[\"max_steps\"]} steps')\n",
|
| 988 |
+
"print(f' KL warmup: 0 \u2192 {CONFIG[\"kl_weight\"]} over {CONFIG[\"kl_warmup_steps\"]} steps')\n",
|
| 989 |
"print(f' Discriminator starts at step {CONFIG[\"disc_start\"]}\\n')\n",
|
| 990 |
"\n",
|
| 991 |
"model.train()\n",
|
| 992 |
+
"\n",
|
| 993 |
+
"# Infinite iterator over the streaming dataloader\n",
|
| 994 |
+
"data_iter = iter(dataloader)\n",
|
| 995 |
+
"\n",
|
| 996 |
+
"while global_step < CONFIG['max_steps']:\n",
|
| 997 |
+
" # Get next batch (re-create iterator if stream exhausted)\n",
|
| 998 |
+
" try:\n",
|
| 999 |
+
" batch = next(data_iter)\n",
|
| 1000 |
+
" except StopIteration:\n",
|
| 1001 |
+
" # Stream exhausted = 1 epoch done. Re-create.\n",
|
| 1002 |
+
" data_iter = iter(dataloader)\n",
|
| 1003 |
+
" batch = next(data_iter)\n",
|
| 1004 |
+
"\n",
|
| 1005 |
+
" batch = batch.to(device)\n",
|
| 1006 |
+
"\n",
|
| 1007 |
+
" # KL warmup\n",
|
| 1008 |
+
" kl_w = CONFIG['kl_weight'] * min(1.0, global_step / max(1, CONFIG['kl_warmup_steps']))\n",
|
| 1009 |
+
" criterion.kl_weight = kl_w\n",
|
| 1010 |
+
"\n",
|
| 1011 |
+
" # === VAE update ===\n",
|
| 1012 |
+
" opt_vae.zero_grad()\n",
|
| 1013 |
+
" with autocast('cuda', enabled=device=='cuda'):\n",
|
| 1014 |
+
" recon, posteriors = model(batch)\n",
|
| 1015 |
+
" loss_vae, log_vae = criterion(batch, recon, posteriors, 0, global_step,\n",
|
| 1016 |
+
" model.get_last_decoder_layer())\n",
|
| 1017 |
+
" scaler_vae.scale(loss_vae).backward()\n",
|
| 1018 |
+
" scaler_vae.unscale_(opt_vae)\n",
|
| 1019 |
+
" gn = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
| 1020 |
+
" scaler_vae.step(opt_vae)\n",
|
| 1021 |
+
" scaler_vae.update()\n",
|
| 1022 |
+
"\n",
|
| 1023 |
+
" # === Discriminator update ===\n",
|
| 1024 |
+
" opt_disc.zero_grad()\n",
|
| 1025 |
+
" with autocast('cuda', enabled=device=='cuda'):\n",
|
| 1026 |
+
" with torch.no_grad():\n",
|
| 1027 |
+
" recon_d, _ = model(batch)\n",
|
| 1028 |
+
" loss_disc, log_disc = criterion(batch, recon_d, posteriors, 1, global_step)\n",
|
| 1029 |
+
" if global_step >= CONFIG['disc_start']:\n",
|
| 1030 |
+
" scaler_disc.scale(loss_disc).backward()\n",
|
| 1031 |
+
" scaler_disc.unscale_(opt_disc)\n",
|
| 1032 |
+
" torch.nn.utils.clip_grad_norm_(criterion.discriminator.parameters(), 1.0)\n",
|
| 1033 |
+
" scaler_disc.step(opt_disc)\n",
|
| 1034 |
+
" scaler_disc.update()\n",
|
| 1035 |
+
"\n",
|
| 1036 |
+
" global_step += 1\n",
|
| 1037 |
+
" log = {**log_vae, **log_disc, 'step': global_step, 'grad_norm': gn.item(), 'kl_w': kl_w}\n",
|
| 1038 |
+
"\n",
|
| 1039 |
+
" if global_step % CONFIG['log_every'] == 0:\n",
|
| 1040 |
+
" history.append(log)\n",
|
| 1041 |
+
" elapsed = (time.time() - start_time) / 60\n",
|
| 1042 |
+
" print(f\"Step {global_step:6d} | L1:{log['l1']:.4f} | Perc:{log.get('perc',0):.4f} | \"\n",
|
| 1043 |
+
" f\"KL:{log.get('kl_base',0):.1f}/{log.get('kl_detail',0):.1f}/{log.get('kl_style',0):.1f} | \"\n",
|
| 1044 |
+
" f\"D:{log.get('d_loss',0):.4f} | G:{log.get('g_loss',0):.4f} | \"\n",
|
| 1045 |
+
" f\"GN:{log['grad_norm']:.2f} | {elapsed:.1f}min\")\n",
|
| 1046 |
+
"\n",
|
| 1047 |
+
" if global_step % CONFIG['vis_every'] == 0:\n",
|
| 1048 |
+
" clear_output(wait=True)\n",
|
| 1049 |
+
" visualize_reconstruction(model, vis_batch, global_step)\n",
|
| 1050 |
+
" plot_losses(history)\n",
|
| 1051 |
+
"\n",
|
| 1052 |
+
" if global_step % CONFIG['save_every'] == 0:\n",
|
| 1053 |
+
" os.makedirs('checkpoints', exist_ok=True)\n",
|
| 1054 |
+
" torch.save({'model': model.state_dict(),\n",
|
| 1055 |
+
" 'disc': criterion.discriminator.state_dict(),\n",
|
| 1056 |
+
" 'opt_vae': opt_vae.state_dict(),\n",
|
| 1057 |
+
" 'opt_disc': opt_disc.state_dict(),\n",
|
| 1058 |
+
" 'step': global_step, 'config': CONFIG},\n",
|
| 1059 |
+
" f'checkpoints/pma_vae_step{global_step}.pt')\n",
|
| 1060 |
+
" print(f'\ud83d\udcbe Saved checkpoint at step {global_step}')\n",
|
| 1061 |
"\n",
|
| 1062 |
"# Final save\n",
|
| 1063 |
"torch.save({'model': model.state_dict(), 'config': CONFIG}, 'checkpoints/pma_vae_final.pt')\n",
|
| 1064 |
"total_time = (time.time() - start_time) / 60\n",
|
| 1065 |
+
"print(f'\\n\u2705 Training complete! {global_step} steps in {total_time:.1f} minutes')\n",
|
| 1066 |
+
"print(f'\ud83d\udcbe Final model saved to checkpoints/pma_vae_final.pt')"
|
| 1067 |
]
|
| 1068 |
},
|
| 1069 |
{
|
|
|
|
| 1101 |
" psnr = -10 * math.log10(mse + 1e-8)\n",
|
| 1102 |
" psnrs.append(psnr)\n",
|
| 1103 |
"\n",
|
| 1104 |
+
"print(f'\\n\ud83d\udcca Evaluation Results:')\n",
|
| 1105 |
"print(f' Average PSNR: {sum(psnrs)/len(psnrs):.2f} dB')\n",
|
| 1106 |
"print(f' Min PSNR: {min(psnrs):.2f} dB')\n",
|
| 1107 |
"print(f' Max PSNR: {max(psnrs):.2f} dB')"
|
|
|
|
| 1165 |
" out = model.decoder(pa['base_mu'], pa['detail_mu'], z_style)\n",
|
| 1166 |
" img = out[0].cpu().permute(1,2,0).numpy() * 0.5 + 0.5\n",
|
| 1167 |
" axes[i].imshow(img.clip(0,1))\n",
|
| 1168 |
+
" axes[i].set_title(f'\u03b1={alpha:.2f}')\n",
|
| 1169 |
" axes[i].axis('off')\n",
|
| 1170 |
"plt.suptitle('Style Interpolation (structure fixed, style varies)', fontsize=14)\n",
|
| 1171 |
"plt.tight_layout()\n",
|
|
|
|
| 1218 |
"model.eval()\n",
|
| 1219 |
"\n",
|
| 1220 |
"# Dummy inputs matching the latent shapes\n",
|
| 1221 |
+
"dummy_base = torch.randn(1, 24, 16, 16, device=device) # For 256\u00d7256 input\n",
|
| 1222 |
"dummy_detail = torch.randn(1, 6, 32, 32, device=device)\n",
|
| 1223 |
"dummy_style = torch.randn(1, 96, device=device)\n",
|
| 1224 |
"\n",
|
|
|
|
| 1236 |
")\n",
|
| 1237 |
"\n",
|
| 1238 |
"onnx_size = os.path.getsize('pma_vae_decoder.onnx') / 1024**2\n",
|
| 1239 |
+
"print(f'\\n\ud83d\udcf1 ONNX decoder exported!')\n",
|
| 1240 |
"print(f' Size: {onnx_size:.1f} MB')\n",
|
| 1241 |
"print(f' Ready for: Core ML, TFLite, ONNX Runtime Mobile')\n",
|
| 1242 |
"\n",
|
|
|
|
| 1249 |
"source": [
|
| 1250 |
"## 9. Progressive Resolution Training\n",
|
| 1251 |
"\n",
|
| 1252 |
+
"After initial training at 256\u00d7256, progressively increase resolution.\n",
|
| 1253 |
"The model handles variable resolutions thanks to the convolutional architecture."
|
| 1254 |
]
|
| 1255 |
},
|
|
|
|
| 1274 |
"# for pg in opt_disc.param_groups: pg['lr'] *= 0.5\n",
|
| 1275 |
"# \n",
|
| 1276 |
"# # Continue training (copy the training loop above with dataloader_hr)\n",
|
| 1277 |
+
"# print(f'Phase 2: Training at {NEW_RESOLUTION}\u00d7{NEW_RESOLUTION}')\n",
|
| 1278 |
"# print(f'Batches per epoch: {len(dataloader_hr)}')"
|
| 1279 |
]
|
| 1280 |
},
|
|
|
|
| 1297 |
"model.eval()\n",
|
| 1298 |
"\n",
|
| 1299 |
"# Take a high-res image and downsample it\n",
|
| 1300 |
+
"hr_img = test_batch[0:1] # 256\u00d7256\n",
|
| 1301 |
+
"lr_img = F.interpolate(hr_img, scale_factor=0.5, mode='bilinear', align_corners=False) # 128\u00d7128\n",
|
| 1302 |
"lr_upscaled = F.interpolate(lr_img, size=(256, 256), mode='bilinear', align_corners=False)\n",
|
| 1303 |
"\n",
|
| 1304 |
"with torch.no_grad():\n",
|
|
|
|
| 1334 |
"| Component | Choice | Why |\n",
|
| 1335 |
"|---|---|---|\n",
|
| 1336 |
"| Backbone | MobileConv + Parallel 2D Mamba | Fast, efficient, attention-free |\n",
|
| 1337 |
+
"| Downsampling | PixelUnshuffle \u2192 stride-2 conv | Lossless initial features |\n",
|
| 1338 |
"| Upsampling | PixelShuffle (sub-pixel) | Mobile-friendly, no checkerboard |\n",
|
| 1339 |
"| Latent | Multi-scale (base/detail/style) | Controllable, prevents collapse |\n",
|
| 1340 |
"| Style control | FiLM conditioning | Lightweight, multiplicative |\n",
|
|
|
|
| 1347 |
"\n",
|
| 1348 |
"| Feature | PMA-VAE | SD-VAE | NVAE |\n",
|
| 1349 |
"|---|---|---|---|\n",
|
| 1350 |
+
"| Attention-free | \u2705 | \u274c | \u274c |\n",
|
| 1351 |
+
"| Mobile-friendly decoder | \u2705 | \u274c | \u274c |\n",
|
| 1352 |
+
"| Multi-scale latent | \u2705 | \u274c | \u2705 |\n",
|
| 1353 |
+
"| Style control built-in | \u2705 | \u274c | \u274c |\n",
|
| 1354 |
"| Decoder params | ~4-8M | ~50M | ~100M+ |\n",
|
| 1355 |
+
"| Parallel training | \u2705 | \u2705 | \u2705 |\n",
|
| 1356 |
+
"| Free Colab trainable | \u2705 | \u274c | \u274c |"
|
| 1357 |
]
|
| 1358 |
},
|
| 1359 |
{
|
|
|
|
| 1362 |
"source": [
|
| 1363 |
"---\n",
|
| 1364 |
"\n",
|
| 1365 |
+
"## \ud83d\udcda References\n",
|
| 1366 |
"\n",
|
| 1367 |
"- **Mamba**: Gu & Dao, 2023. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752)\n",
|
| 1368 |
"- **VMamba**: Liu et al., 2024. [VMamba: Visual State Space Model](https://arxiv.org/abs/2401.10166)\n",
|
|
|
|
| 1376 |
]
|
| 1377 |
}
|
| 1378 |
]
|
| 1379 |
+
}
|