{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# šŸ”Ø MicroForge: A Novel Mobile-First Image Generation Architecture\n", "\n", "**A genuinely new architecture combining Recurrent Latent Planning, SSM-Conv Hybrid Backbone, and Deep Compression VAE**\n", "\n", "This notebook demonstrates the complete MicroForge architecture:\n", "- Module-by-module construction and testing\n", "- End-to-end training pipeline (VAE + backbone + planner)\n", "- Inference for text-to-image generation\n", "- Memory and compute profiling\n", "- Staged training curriculum design\n", "\n", "## Architecture Overview\n", "\n", "```\n", "ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”\n", "│ MicroForge Pipeline │\n", "ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¤\n", "│ │\n", "│ Text ──→ [Text Encoder] ──→ text_emb, text_pooled │\n", "│ │ │\n", "│ ā–¼ │\n", "│ Noise ──→ [Recurrent Latent Planner] ◄── plan_t-1 │\n", "│ │ READ: plan ◄── z_t │\n", "│ │ REASON: plan self-attention │\n", "│ │ OUTPUT: planner_tokens │\n", "│ ā–¼ │\n", "│ z_t ──→ [SSM-Conv Backbone] ◄── planner_tokens │\n", "│ │ Per-block: │\n", "│ │ AdaLN-Group conditioning │\n", "│ │ Bidirectional SSM (zigzag scan) │\n", "│ │ Cross-attention to text+plan │\n", "│ │ FFN (expansion=3) │\n", "│ │ Global: Shared MQA attention │\n", "│ ā–¼ │\n", "│ v_pred ──→ [Euler ODE Step] ──→ z_{t-1} │\n", "│ │\n", "│ z_0 ──→ [DC-VAE Decoder] ──→ Image │\n", "│ │\n", "ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜\n", "```\n", "\n", "## Key Innovations\n", "\n", "1. **Recurrent Latent Planner (RLP)**: A compact set of 32 latent tokens that iteratively reason about the image before committing to pixel changes. Inspired by RIN but adapted for diffusion.\n", "\n", "2. **SSM-Conv Hybrid Backbone**: Bidirectional state-space model with zigzag scanning + local DWConv + one globally-shared attention block. O(N) complexity vs O(N²) for transformers.\n", "\n", "3. **Deep Compression VAE**: 32Ɨ spatial compression with residual space-to-channel shortcuts. 512px → 16Ɨ16Ɨ32 latent (only 256 spatial tokens).\n", "\n", "4. **Editing-Ready Architecture**: DreamLite-style spatial concatenation for unified generation + editing with zero extra parameters." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Setup & Installation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install dependencies\n", "!pip install -q torch torchvision einops timm matplotlib" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import time\n", "import os\n", "\n", "# Auto-detect device\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f'Using device: {device}')\n", "if device == 'cuda':\n", " print(f'GPU: {torch.cuda.get_device_name()}')\n", " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Architecture Module Tests" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from microforge.vae import MicroForgeVAE\n", "from microforge.backbone import MicroForgeBackbone\n", "from microforge.planner import RecurrentLatentPlanner\n", "from microforge.pipeline import MicroForgePipeline, SimpleTextEncoder\n", "from microforge.training import MicroForgeTrainer, FlowMatchingScheduler, MicroForgeLoss\n", "\n", "print('All modules imported successfully!')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1 Deep Compression VAE\n", "\n", "The VAE compresses images by 32Ɨ spatially using residual space-to-channel shortcuts (DC-AE technique).\n", "\n", "- **Input**: `[B, 3, H, W]` images\n", "- **Latent**: `[B, C_latent, H/32, W/32]` — for 256px: `[B, 16, 8, 8]` (tiny) or `[B, 32, 8, 8]` (small)\n", "- **Key**: Space-to-channel rearrangement as non-parametric skip connection" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Test each VAE configuration\n", "for config in ['tiny', 'small', 'base']:\n", " vae = MicroForgeVAE(config=config)\n", " params = sum(p.numel() for p in vae.parameters())\n", " \n", " x = torch.randn(1, 3, 256, 256)\n", " x_recon, mu, logvar = vae(x)\n", " \n", " print(f'{config:>5}: {params:>12,} params | '\n", " f'{params*4/1e6:>6.1f} MB fp32 | '\n", " f'{params*2/1e6:>6.1f} MB fp16 | '\n", " f'latent: {mu.shape}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 SSM-Conv Hybrid Backbone\n", "\n", "The denoising backbone replaces quadratic attention with:\n", "- **Bidirectional SSM** with zigzag scanning (O(N) complexity)\n", "- **Local DWConv** for spatial feature enhancement\n", "- **One globally-shared MQA attention block** (from DiMSUM)\n", "- **AdaLN-Group conditioning** (46% fewer params than full adaLN)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Test each backbone configuration\n", "for config_name in ['tiny', 'small', 'base']:\n", " lc = 16 if config_name == 'tiny' else 32\n", " backbone = MicroForgeBackbone(latent_channels=lc, config=config_name)\n", " params = sum(p.numel() for p in backbone.parameters())\n", " \n", " z = torch.randn(1, lc, 8, 8)\n", " t = torch.rand(1)\n", " text_emb = torch.randn(1, 10, 768)\n", " text_pooled = torch.randn(1, 768)\n", " \n", " start = time.time()\n", " v = backbone(z, t, text_emb, text_pooled)\n", " elapsed = time.time() - start\n", " \n", " print(f'{config_name:>5}: {params:>12,} params | '\n", " f'{params*4/1e6:>6.1f} MB fp32 | '\n", " f'{params*2/1e6:>6.1f} MB fp16 | '\n", " f'latency: {elapsed*1000:.0f}ms')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.3 Recurrent Latent Planner (Novel Component)\n", "\n", "The RLP is our key innovation — a \"reasoning core\" that maintains persistent plan tokens:\n", "\n", "```\n", "plan_0 = init(text)\n", "for each denoising step:\n", " plan = READ(plan, image_tokens) # absorb image info\n", " plan = REASON(plan) # self-attention over plan\n", " output = PROJECT(plan) # inject into backbone\n", " z_{t-1} = backbone(z_t, output) # guided denoising\n", "```\n", "\n", "Only 32 plan tokens Ɨ D dims = negligible memory overhead." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "planner = RecurrentLatentPlanner(num_plan_tokens=32, dim=384, text_dim=768, latent_channels=32)\n", "params = sum(p.numel() for p in planner.parameters())\n", "print(f'Planner: {params:,} params = {params*4/1e6:.1f} MB fp32')\n", "print(f'Plan state size: {planner.get_plan_size_bytes()} bytes = {planner.get_plan_size_bytes()/1024:.1f} KB')\n", "\n", "# Test planner with self-conditioning (simulating multi-step)\n", "text_pooled = torch.randn(1, 768)\n", "plan = planner.initialize_plan(text_pooled, batch_size=1)\n", "print(f'\\nInitial plan: {plan.shape}')\n", "\n", "# Simulate 3 denoising steps with plan carry-forward\n", "for step in range(3):\n", " z = torch.randn(1, 32, 8, 8)\n", " img_tokens = z.reshape(1, 32, -1).permute(0, 2, 1)\n", " t_emb = torch.randn(1, 384)\n", " \n", " plan, output = planner(img_tokens, plan, t_emb)\n", " \n", " # Self-condition for next step\n", " plan = planner.initialize_plan(text_pooled, 1, prev_plan=plan)\n", " print(f'Step {step}: plan_norm={plan.norm():.2f}, output_norm={output.norm():.2f}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Full Pipeline Assembly" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Assemble full pipeline with tiny config (for fast testing)\n", "vae = MicroForgeVAE(config='tiny')\n", "backbone = MicroForgeBackbone(latent_channels=16, config='tiny')\n", "planner = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)\n", "text_encoder = SimpleTextEncoder(vocab_size=8192, embed_dim=768, num_layers=2)\n", "\n", "pipeline = MicroForgePipeline(vae, backbone, text_encoder, planner, device='cpu')\n", "\n", "# Parameter count\n", "params = pipeline.count_parameters()\n", "print('=== MicroForge Parameter Budget ===')\n", "for name, count in params.items():\n", " print(f' {name:>15}: {count:>12,} ({count*4/1e6:.1f} MB fp32, {count*2/1e6:.1f} MB fp16)')\n", "\n", "# Memory estimate\n", "print('\\n=== Memory Estimates ===')\n", "for res in [128, 256, 512]:\n", " mem = pipeline.get_memory_estimate(res, res)\n", " print(f' {res}x{res}: ~{mem[\"estimated_inference_mb\"]:.0f} MB inference')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. End-to-End Inference Test" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Generate a test image (random weights = noise, but validates full pipeline)\n", "tokens = torch.randint(0, 8192, (1, 10))\n", "\n", "start = time.time()\n", "with torch.no_grad():\n", " images = pipeline.text2img(\n", " tokens, \n", " height=128, width=128,\n", " num_steps=4, # Few steps for speed\n", " cfg_scale=1.0, # No CFG for untrained model\n", " seed=42\n", " )\n", "elapsed = time.time() - start\n", "\n", "print(f'Generated {images.shape} in {elapsed:.2f}s')\n", "print(f'Range: [{images.min():.2f}, {images.max():.2f}]')\n", "\n", "# Visualize\n", "img = images[0].permute(1, 2, 0).cpu().numpy()\n", "img = (img - img.min()) / (img.max() - img.min() + 1e-8)\n", "\n", "plt.figure(figsize=(4, 4))\n", "plt.imshow(img)\n", "plt.title('MicroForge Output (untrained, random weights)')\n", "plt.axis('off')\n", "plt.tight_layout()\n", "plt.savefig('test_generation.png', dpi=100)\n", "plt.show()\n", "print('Saved to test_generation.png')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Training Pipeline Demo\n", "\n", "### 5.1 Stage 1: VAE Training\n", "\n", "Train the VAE on synthetic data to verify the training loop.\n", "In production, use ImageNet or similar with perceptual + adversarial losses." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Stage 1: VAE Training\n", "vae_train = MicroForgeVAE(config='tiny').train()\n", "vae_opt = torch.optim.AdamW(vae_train.parameters(), lr=1e-4, weight_decay=0.01)\n", "loss_fn = MicroForgeLoss(lambda_kl=1e-6)\n", "\n", "vae_losses = []\n", "print('=== Stage 1: VAE Training ===')\n", "for step in range(50):\n", " # Synthetic data: random colored patches\n", " images = torch.randn(4, 3, 128, 128) * 0.5\n", " \n", " x_recon, mu, logvar = vae_train(images)\n", " losses = loss_fn.vae_loss(x_recon, images, mu, logvar)\n", " \n", " vae_opt.zero_grad()\n", " losses['total'].backward()\n", " torch.nn.utils.clip_grad_norm_(vae_train.parameters(), 2.0)\n", " vae_opt.step()\n", " \n", " vae_losses.append(losses['recon'].item())\n", " if step % 10 == 0:\n", " print(f' Step {step:3d}: recon={losses[\"recon\"].item():.4f}, kl={losses[\"kl\"].item():.2f}')\n", "\n", "plt.figure(figsize=(8, 3))\n", "plt.plot(vae_losses)\n", "plt.xlabel('Step')\n", "plt.ylabel('Reconstruction Loss')\n", "plt.title('Stage 1: VAE Training')\n", "plt.tight_layout()\n", "plt.savefig('vae_training.png', dpi=100)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5.2 Stage 2: Backbone Flow Matching Training\n", "\n", "Train the SSM backbone with rectified flow matching.\n", "VAE is frozen; backbone learns to predict velocity v(z_t, t)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Stage 2: Backbone Training with Flow Matching\n", "vae_train.eval()\n", "backbone_train = MicroForgeBackbone(latent_channels=16, config='tiny')\n", "planner_train = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)\n", "\n", "trainer = MicroForgeTrainer(\n", " vae_train, backbone_train, planner_train,\n", " lr=1e-4, weight_decay=0.01, use_ema=True\n", ")\n", "\n", "flow_losses = []\n", "print('=== Stage 2: Backbone Flow Matching Training ===')\n", "for step in range(100):\n", " images = torch.randn(4, 3, 128, 128) * 0.5\n", " text_emb = torch.randn(4, 10, 768)\n", " text_pooled = torch.randn(4, 768)\n", " \n", " losses = trainer.train_step(images, text_emb, text_pooled)\n", " flow_losses.append(losses['flow'])\n", " \n", " if step % 20 == 0:\n", " print(f' Step {step:3d}: flow_loss={losses[\"flow\"]:.4f}')\n", "\n", "plt.figure(figsize=(8, 3))\n", "plt.plot(flow_losses)\n", "plt.xlabel('Step')\n", "plt.ylabel('Flow Matching Loss')\n", "plt.title('Stage 2: Backbone Training')\n", "plt.tight_layout()\n", "plt.savefig('backbone_training.png', dpi=100)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Staged Training Curriculum (Production)\n", "\n", "The full training curriculum for a production model:\n", "\n", "```\n", "STAGE 1 — VAE (freeze after):\n", " Data: ImageNet + SAM (mixed res)\n", " Loss: L1 recon + 1e-6*KL + perceptual (LPIPS) + adversarial (PatchGAN)\n", " Steps: 100K, batch=256, lr=1e-4\n", " Hardware: 4Ɨ A100 (or 1Ɨ T4 with grad accumulation)\n", "\n", "STAGE 2 — Backbone Low-Res (128-256px):\n", " Data: Teacher-generated synthetic data (FLUX/SD3.5 outputs)\n", " Loss: Flow matching ||v_pred - v_target||²\n", " Steps: 500K, batch=128, lr=1e-4\n", " Freeze: VAE encoder+decoder\n", " Train: Backbone + Planner\n", "\n", "STAGE 3 — Backbone High-Res (256-512px):\n", " Data: Same + high-res subset\n", " Loss: Flow matching + resolution-adaptive noise schedule\n", " Steps: 200K, batch=64, lr=5e-5\n", " Init: From Stage 2 weights\n", "\n", "STAGE 4 — Knowledge Distillation:\n", " Teacher: FLUX.1-dev or SD3.5-Large\n", " Loss: Flow matching + t-scaled distillation loss\n", " Steps: 100K, batch=64, lr=2e-5\n", "\n", "STAGE 5 — Editing (spatial concat):\n", " Data: InstructPix2Pix pairs + FLUX Kontext edits\n", " Loss: Flow matching on [target | source] concat\n", " Steps: 50K, batch=32, lr=1e-5\n", " Trick: Progressive: T2I → Edit → Joint (DreamLite recipe)\n", "\n", "STAGE 6 — Step Distillation (4-step):\n", " Method: Consistency distillation + LADD\n", " Steps: 50K, batch=128, lr=1e-5\n", " Target: 1-4 step generation\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Demonstrate staged freeze/thaw training\n", "print('=== Staged Training Configuration ===')\n", "print()\n", "\n", "# Stage 1: Only VAE trainable\n", "vae_s = MicroForgeVAE(config='tiny')\n", "backbone_s = MicroForgeBackbone(latent_channels=16, config='tiny')\n", "planner_s = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)\n", "\n", "def count_trainable(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", "def freeze(model):\n", " for p in model.parameters():\n", " p.requires_grad_(False)\n", "\n", "def unfreeze(model):\n", " for p in model.parameters():\n", " p.requires_grad_(True)\n", "\n", "# Stage 1: VAE only\n", "freeze(backbone_s)\n", "freeze(planner_s)\n", "unfreeze(vae_s)\n", "print(f'Stage 1 (VAE): {count_trainable(vae_s):,} trainable params')\n", "\n", "# Stage 2: Backbone + Planner only\n", "freeze(vae_s)\n", "unfreeze(backbone_s)\n", "unfreeze(planner_s)\n", "print(f'Stage 2 (Backbone+Planner): {count_trainable(backbone_s) + count_trainable(planner_s):,} trainable params')\n", "\n", "# Stage 5: Editing - all unfrozen but low LR\n", "unfreeze(vae_s)\n", "unfreeze(backbone_s)\n", "unfreeze(planner_s)\n", "total = count_trainable(vae_s) + count_trainable(backbone_s) + count_trainable(planner_s)\n", "print(f'Stage 5 (Joint): {total:,} trainable params')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Memory Profiling for Mobile Deployment\n", "\n", "Target: < 3-4 GB RAM for inference on consumer devices." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print('=== MicroForge Memory Budget ===')\n", "print()\n", "\n", "configs = {\n", " 'Mobile (tiny)': ('tiny', 16, 16, 256),\n", " 'Prototype (small)': ('small', 32, 32, 384),\n", " 'Full (base)': ('base', 32, 32, 512),\n", "}\n", "\n", "for name, (cfg, lc, plan_tokens, plan_dim) in configs.items():\n", " vae = MicroForgeVAE(config=cfg)\n", " bb = MicroForgeBackbone(latent_channels=lc, config=cfg)\n", " pl = RecurrentLatentPlanner(num_plan_tokens=plan_tokens, dim=plan_dim, text_dim=768, latent_channels=lc)\n", " \n", " total_params = sum(p.numel() for p in vae.parameters()) + \\\n", " sum(p.numel() for p in bb.parameters()) + \\\n", " sum(p.numel() for p in pl.parameters())\n", " \n", " fp32_mb = total_params * 4 / 1e6\n", " fp16_mb = total_params * 2 / 1e6\n", " int8_mb = total_params / 1e6\n", " \n", " print(f'{name}:')\n", " print(f' Total params: {total_params:,}')\n", " print(f' FP32: {fp32_mb:.0f} MB | FP16: {fp16_mb:.0f} MB | INT8: {int8_mb:.0f} MB')\n", " \n", " # Activation memory estimate (rough)\n", " # For 512px: latent = 16x16xC, backbone processes 256 tokens\n", " latent_tokens = 16 * 16 # at 512px\n", " act_mb = latent_tokens * plan_dim * 4 / 1e6 * 20 # ~20 intermediate tensors\n", " print(f' Activation memory @512px: ~{act_mb:.0f} MB')\n", " print(f' Total inference @512px (FP16): ~{fp16_mb + act_mb:.0f} MB')\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. Editing Readiness Demo\n", "\n", "The architecture supports editing via spatial concatenation:\n", "- **Generation**: `z_input = [z_noise | zeros]` (width-concat)\n", "- **Editing**: `z_input = [z_noise | z_source]` (width-concat)\n", "- **Inpainting**: `z_input = [z_noise | z_masked_source]`\n", "- **Super-res**: `z_input = [z_noise | z_lowres_upsampled]`\n", "\n", "No extra parameters needed — same backbone handles all tasks.\n", "Task is indicated by prepending task tokens to the text prompt." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Demonstrate spatial concatenation for different tasks\n", "B, C, H, W = 1, 16, 8, 8 # Latent dimensions for 256px\n", "\n", "z_noise = torch.randn(B, C, H, W)\n", "z_source = torch.randn(B, C, H, W)\n", "z_zeros = torch.zeros(B, C, H, W)\n", "\n", "# Generation mode\n", "z_gen = torch.cat([z_noise, z_zeros], dim=-1) # [B, C, H, 2W]\n", "print(f'Generation input: {z_gen.shape} (target + blank context)')\n", "\n", "# Editing mode\n", "z_edit = torch.cat([z_noise, z_source], dim=-1)\n", "print(f'Editing input: {z_edit.shape} (target + source context)')\n", "\n", "# Inpainting mode\n", "mask = torch.ones(B, 1, H, W)\n", "mask[:, :, 2:6, 2:6] = 0 # Unmask center region\n", "z_masked = z_source * mask # Zero out inpaint region\n", "z_inpaint = torch.cat([z_noise, z_masked], dim=-1)\n", "print(f'Inpaint input: {z_inpaint.shape} (target + masked source)')\n", "\n", "# The backbone processes all of these identically\n", "bb = MicroForgeBackbone(latent_channels=C, config='tiny')\n", "t = torch.rand(B)\n", "text_emb = torch.randn(B, 5, 768)\n", "text_pooled = torch.randn(B, 768)\n", "\n", "v_gen = bb(z_gen, t, text_emb, text_pooled)\n", "print(f'\\nBackbone output: {v_gen.shape}')\n", "print(f'Target velocity (left half): {v_gen[..., :W].shape}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 9. Mathematical Formulation Summary\n", "\n", "### Forward Process (Rectified Flow)\n", "$$z_t = (1-t) \\cdot z_0 + t \\cdot \\epsilon, \\quad \\epsilon \\sim \\mathcal{N}(0, I)$$\n", "\n", "### Training Objective\n", "$$\\mathcal{L}_{\\text{flow}} = \\mathbb{E}_{t, z_0, \\epsilon} \\left[ w(t) \\|v_\\theta(z_t, t, c) - (\\epsilon - z_0)\\|^2 \\right]$$\n", "\n", "where $w(t) = \\frac{1}{1 + |2t - 1|}$ (t-scaling, peaks at $t=0.5$)\n", "\n", "### Sampling (Euler ODE)\n", "$$z_{t-\\Delta t} = z_t + \\Delta t \\cdot v_\\theta(z_t, t, c)$$\n", "\n", "### Planner Update\n", "$$p^{(l+1)} = \\text{SelfAttn}(\\text{CrossAttn}(p^{(l)}, \\text{Proj}(z_t)))$$\n", "\n", "### Self-Conditioning\n", "$$p_t = \\sigma(w) \\cdot p_{t+1} + (1 - \\sigma(w)) \\cdot p_{\\text{init}}(c_{\\text{text}})$$\n", "\n", "### VAE Loss\n", "$$\\mathcal{L}_{\\text{VAE}} = \\|x - \\hat{x}\\|_1 + \\lambda_{\\text{KL}} \\cdot D_{\\text{KL}}(q(z|x) \\| \\mathcal{N}(0, I))$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10. Ablation Plan\n", "\n", "To validate each component's contribution:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ablations = [\n", " ('Full MicroForge', True, True, True),\n", " ('No Planner', True, False, True),\n", " ('No SSM (attention only)', False, True, False), # Replace SSM with self-attn\n", " ('No Shared Attention', True, True, True), # Remove shared attn block\n", " ('No DWConv in SSM', True, True, True), # Remove local_conv from SSM\n", "]\n", "\n", "print('=== Ablation Plan ===')\n", "print(f'{\"Configuration\":>30} | {\"SSM\":>5} | {\"Planner\":>8} | {\"SharedAttn\":>10}')\n", "print('-' * 65)\n", "for name, ssm, planner, shared in ablations:\n", " print(f'{name:>30} | {\"āœ“\" if ssm else \"āœ—\":>5} | {\"āœ“\" if planner else \"āœ—\":>8} | {\"āœ“\" if shared else \"āœ—\":>10}')\n", "\n", "print()\n", "print('Metrics to track per ablation:')\n", "print(' - FID (quality) on COCO-30K')\n", "print(' - CLIP-Score (prompt adherence)')\n", "print(' - ImageReward (aesthetics)')\n", "print(' - Inference latency (ms)')\n", "print(' - Peak memory (MB)')\n", "print(' - Training convergence speed (steps to target FID)')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 11. Dataset Pipeline for Staged Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Dataset recommendations per training stage\n", "print('=== Recommended Datasets ===')\n", "print()\n", "\n", "stages = {\n", " 'Stage 1 - VAE': {\n", " 'datasets': [\n", " 'ImageNet-1K (class-cond, 1.28M images)',\n", " 'SAM-1M (diverse scenes, SA-1B subset)',\n", " 'FFHQ (70K faces for quality tuning)',\n", " ],\n", " 'hub_ids': ['ILSVRC/imagenet-1k', 'facebook/sam', 'NoCrypt/ffhq-512'],\n", " },\n", " 'Stage 2 - Low-Res T2I': {\n", " 'datasets': [\n", " 'JourneyDB-4M (high aesthetic quality)',\n", " 'LAION-Aesthetics-6.5+ (filtered subset)',\n", " 'Teacher-generated synthetic data (FLUX/SD3.5 outputs)',\n", " ],\n", " 'hub_ids': ['JourneyDB/JourneyDB', 'laion/laion2B-en-aesthetic'],\n", " },\n", " 'Stage 3 - High-Res T2I': {\n", " 'datasets': [\n", " 'Same as Stage 2, filtered for >512px',\n", " 'Unsplash-25K (very high quality photos)',\n", " ],\n", " 'hub_ids': [],\n", " },\n", " 'Stage 4 - Knowledge Distillation': {\n", " 'datasets': [\n", " 'Self-generated: 1M prompts → FLUX.1-dev outputs',\n", " 'DiffusionDB-2M (real user prompts)',\n", " ],\n", " 'hub_ids': ['poloclub/diffusiondb'],\n", " },\n", " 'Stage 5 - Editing': {\n", " 'datasets': [\n", " 'InstructPix2Pix (454K editing pairs)',\n", " 'MagicBrush (10K high-quality edits)',\n", " 'GRIT-Entity (subject-driven, 200K)',\n", " 'Custom: FLUX.1-Kontext-generated edit pairs',\n", " ],\n", " 'hub_ids': ['timbrooks/instructpix2pix-clip-filtered', 'osunlp/MagicBrush'],\n", " },\n", "}\n", "\n", "for stage, info in stages.items():\n", " print(f'\\n{stage}:')\n", " for ds in info['datasets']:\n", " print(f' • {ds}')\n", " if info['hub_ids']:\n", " print(f' HF Hub: {info[\"hub_ids\"]}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 12. Comparison with Existing Architectures" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "comparison = [\n", " ('SD-v1.5', '860M', '~3.4 GB', 'O(N²)', 'UNet', 'No', '20-50'),\n", " ('SDXL', '2.6B', '~6.5 GB', 'O(N²)', 'UNet', 'No', '20-50'),\n", " ('FLUX.1-dev', '12B', '~24 GB', 'O(N²)', 'MM-DiT', 'No', '20-50'),\n", " ('SD3.5-Medium', '2.5B', '~6 GB', 'O(N²)', 'MM-DiT', 'No', '28'),\n", " ('SANA-Sprint', '600M+2B', '~5.5 GB', 'O(N)', 'Linear DiT', 'No', '1-4'),\n", " ('SnapGen', '380M+2B', '~4 GB', 'O(N²)', 'Pruned UNet', 'No', '4-28'),\n", " ('DreamLite', '389M+2B', '~4 GB', 'O(N²)', 'Pruned UNet', 'Yes', '4'),\n", " ('MicroForge-tiny', '28M+text', '~0.2 GB*', 'O(N)', 'SSM-Conv', 'Yes', '4-20'),\n", " ('MicroForge-small', '114M+text', '~0.6 GB*', 'O(N)', 'SSM-Conv', 'Yes', '4-20'),\n", " ('MicroForge-base', '240M+text', '~1.2 GB*', 'O(N)', 'SSM-Conv', 'Yes', '4-20'),\n", "]\n", "\n", "print(f'{\"Model\":>18} | {\"Params\":>12} | {\"VRAM\":>10} | {\"Complexity\":>10} | {\"Backbone\":>12} | {\"Edit\":>5} | {\"Steps\":>6}')\n", "print('-' * 95)\n", "for row in comparison:\n", " print(f'{row[0]:>18} | {row[1]:>12} | {row[2]:>10} | {row[3]:>10} | {row[4]:>12} | {row[5]:>5} | {row[6]:>6}')\n", "print()\n", "print('* MicroForge VRAM excludes text encoder (shared/swappable component)')\n", "print(' With CLIP-L (428M): add ~0.9 GB. With Gemma-2-2B: add ~4 GB.')\n", "print(' For mobile: use TinyCLIP (~60M) adding only ~0.12 GB.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 13. Export and Save Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Save model checkpoint\n", "os.makedirs('checkpoints', exist_ok=True)\n", "\n", "checkpoint = {\n", " 'vae_state_dict': vae_train.state_dict(),\n", " 'backbone_state_dict': backbone_train.state_dict(),\n", " 'planner_state_dict': planner_train.state_dict(),\n", " 'config': {\n", " 'vae_config': 'tiny',\n", " 'backbone_config': 'tiny',\n", " 'latent_channels': 16,\n", " 'plan_tokens': 16,\n", " 'plan_dim': 256,\n", " 'text_dim': 768,\n", " },\n", " 'architecture_version': '0.1.0',\n", "}\n", "\n", "torch.save(checkpoint, 'checkpoints/microforge_tiny_demo.pt')\n", "size_mb = os.path.getsize('checkpoints/microforge_tiny_demo.pt') / 1e6\n", "print(f'Saved checkpoint: {size_mb:.1f} MB')\n", "print('Done!')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.12.0" } }, "nbformat": 4, "nbformat_minor": 4 }