{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# šŸ”¬ LatentRecurrentFlow (LRF) — A Novel Mobile-First Image Generation Architecture\n", "\n", "**A complete implementation of a novel image generation architecture designed for consumer devices.**\n", "\n", "## Key Innovations\n", "\n", "1. **Recursive Latent Refinement (RLR) Core** — HRM-inspired iterative reasoning on image latents with O(1) memory backpropagation\n", "2. **Gated Linear Diffusion (GLD) Blocks** — O(N) subquadratic spatial mixing replacing quadratic self-attention\n", "3. **Compact f=16 VAE** with SnapGen-inspired tiny decoder (1-2M params)\n", "4. **Rectified Flow** training with consistency distillation readiness\n", "5. **Editing-ready architecture** — same latent core supports text-to-image, inpainting, style editing, and more\n", "\n", "### Memory Budget\n", "| Component | FP32 | INT8 (Mobile) |\n", "|-----------|------|---------------|\n", "| VAE Decoder | 4 MB | 1 MB |\n", "| Text Encoder | 44 MB | 11 MB |\n", "| Denoising Core | 2.5 MB | 0.6 MB |\n", "| Activations (256²) | ~200 MB | ~100 MB |\n", "| **Total** | **~250 MB** | **~113 MB** |\n", "\n", "This notebook demonstrates:\n", "1. Architecture design and parameter counting\n", "2. End-to-end VAE training\n", "3. Flow matching denoiser training\n", "4. Sample generation\n", "5. Model saving and loading" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 0. Installation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install dependencies\n", "!pip install -q torch torchvision einops safetensors huggingface_hub pillow matplotlib" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Clone the LRF repo (if not already available)\n", "import os, sys\n", "\n", "# If running from the repo, just add to path\n", "if os.path.exists('lrf'):\n", " sys.path.insert(0, '.')\n", "else:\n", " # Clone from HF Hub\n", " !git clone https://huggingface.co/krystv/LatentRecurrentFlow\n", " sys.path.insert(0, 'LatentRecurrentFlow')\n", "\n", "from lrf.model import LatentRecurrentFlow, RecursiveLatentCore, CompactVAE, GatedLinearAttention\n", "from lrf.training import LRFTrainer, RectifiedFlowScheduler, SyntheticImageTextDataset\n", "from lrf.pipeline import LRFPipeline, LRFTrainingPipeline\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "# Device\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(f'Using device: {device}')\n", "if device.type == 'cuda':\n", " print(f'GPU: {torch.cuda.get_device_name()}')\n", " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Architecture Overview & Parameter Counting" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create model with different configs\n", "configs = {\n", " 'Tiny (5.7M)': LatentRecurrentFlow.tiny_config(),\n", " 'Default (16.3M)': LatentRecurrentFlow.default_config(),\n", "}\n", "\n", "for name, config in configs.items():\n", " model = LatentRecurrentFlow(config)\n", " counts = model.count_parameters()\n", " \n", " print(f'\\n=== {name} ===')\n", " print(f'Config: T_outer={config[\"T_outer\"]}, T_inner={config[\"T_inner\"]}, '\n", " f'num_blocks={config[\"num_blocks\"]}')\n", " print(f'Effective depth: {config[\"T_outer\"] * config[\"T_inner\"] * config[\"num_blocks\"]} layers '\n", " f'(from {config[\"num_blocks\"]} unique blocks)')\n", " for module, count in counts.items():\n", " mb = count * 4 / 1e6\n", " print(f' {module:20s}: {count:>12,} params ({mb:.1f} MB FP32)')\n", " del model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Stage 1: VAE Training\n", "\n", "The VAE learns to compress images into a compact latent space.\n", "- f=16 spatial compression: 256Ɨ256 → 16Ɨ16 latents\n", "- C=16 or C=32 latent channels\n", "- Tiny decoder (~280K params) inspired by SnapGen" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create model for training\n", "config = LatentRecurrentFlow.tiny_config()\n", "model = LatentRecurrentFlow(config).to(device)\n", "\n", "# Create synthetic dataset (replace with real data for actual training)\n", "dataset = SyntheticImageTextDataset(\n", " num_samples=500,\n", " image_size=64,\n", " max_text_length=32\n", ")\n", "dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)\n", "\n", "# Create trainer\n", "trainer = LRFTrainer(model, device, './lrf_checkpoints')\n", "\n", "print(f'Dataset size: {len(dataset)}')\n", "print(f'Batch size: 8')\n", "print(f'Image size: 64x64')\n", "print(f'Latent size: {64//16}x{64//16}x{config[\"latent_channels\"]}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Train VAE\n", "vae_optimizer = torch.optim.AdamW(model.vae.parameters(), lr=1e-3, weight_decay=0.01)\n", "\n", "vae_losses = []\n", "num_vae_steps = 100\n", "\n", "print('Training VAE...')\n", "step = 0\n", "for epoch in range(10): # Multiple epochs over small dataset\n", " for batch in dataloader:\n", " if step >= num_vae_steps:\n", " break\n", " losses = trainer.train_vae_step(batch['image'], vae_optimizer)\n", " vae_losses.append(losses['total'])\n", " if step % 20 == 0:\n", " print(f' Step {step}: total={losses[\"total\"]:.4f}, '\n", " f'recon={losses[\"recon\"]:.4f}, kl={losses[\"kl\"]:.4f}')\n", " step += 1\n", " if step >= num_vae_steps:\n", " break\n", "\n", "# Plot VAE loss\n", "plt.figure(figsize=(10, 4))\n", "plt.plot(vae_losses)\n", "plt.xlabel('Step')\n", "plt.ylabel('Loss')\n", "plt.title('VAE Training Loss')\n", "plt.grid(True, alpha=0.3)\n", "plt.show()\n", "\n", "# Save checkpoint\n", "trainer.save_checkpoint('./lrf_checkpoints/vae.pt', 'vae', 0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Visualize VAE reconstruction\n", "model.eval()\n", "with torch.no_grad():\n", " sample_batch = next(iter(dataloader))\n", " images = sample_batch['image'].to(device)\n", " recon, _, _ = model.vae(images)\n", "\n", "fig, axes = plt.subplots(2, 4, figsize=(16, 8))\n", "for i in range(4):\n", " # Original\n", " img = images[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5\n", " axes[0][i].imshow(np.clip(img, 0, 1))\n", " axes[0][i].set_title(f'Original {i}')\n", " axes[0][i].axis('off')\n", " \n", " # Reconstruction\n", " rec = recon[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5\n", " axes[1][i].imshow(np.clip(rec, 0, 1))\n", " axes[1][i].set_title(f'Reconstruction {i}')\n", " axes[1][i].axis('off')\n", "\n", "plt.suptitle('VAE Reconstruction Quality', fontsize=14)\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Stage 2: Flow Matching Denoiser Training\n", "\n", "The denoising core learns to predict the velocity field for rectified flow.\n", "- VAE is frozen\n", "- Core + text encoder are trained\n", "- Uses SNR-weighted flow matching loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Freeze VAE\n", "for p in model.vae.parameters():\n", " p.requires_grad = False\n", "\n", "# Train flow matching\n", "flow_params = list(model.core.parameters()) + list(model.text_encoder.parameters())\n", "flow_optimizer = torch.optim.AdamW(flow_params, lr=1e-3, weight_decay=0.01)\n", "\n", "flow_losses = []\n", "num_flow_steps = 100\n", "\n", "print('Training flow matching denoiser...')\n", "model.core.train()\n", "model.text_encoder.train()\n", "\n", "step = 0\n", "for epoch in range(10):\n", " for batch in dataloader:\n", " if step >= num_flow_steps:\n", " break\n", " losses = trainer.train_flow_step(\n", " batch['image'], batch['token_ids'], batch['attention_mask'],\n", " flow_optimizer, cfg_dropout=0.1\n", " )\n", " flow_losses.append(losses['flow_loss'])\n", " if step % 20 == 0:\n", " print(f' Step {step}: flow_loss={losses[\"flow_loss\"]:.4f}')\n", " step += 1\n", " if step >= num_flow_steps:\n", " break\n", "\n", "# Plot flow loss\n", "plt.figure(figsize=(10, 4))\n", "plt.plot(flow_losses)\n", "plt.xlabel('Step')\n", "plt.ylabel('Flow Matching Loss')\n", "plt.title('Denoiser Training Loss')\n", "plt.grid(True, alpha=0.3)\n", "plt.show()\n", "\n", "# Save checkpoint\n", "trainer.save_checkpoint('./lrf_checkpoints/flow.pt', 'flow', 0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Generation & Visualization\n", "\n", "Generate images using the trained model with Euler ODE sampling." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Generate samples\n", "model.eval()\n", "\n", "# Create prompts (using simple tokenization for prototype)\n", "prompts = [\n", " 'a beautiful sunset over the ocean with golden light',\n", " 'a cute cat sitting on a windowsill',\n", " 'a mountain landscape with snow and trees',\n", " 'a colorful abstract painting with swirls',\n", "]\n", "\n", "pipe = LRFPipeline(model, device=device)\n", "\n", "# Generate with different step counts\n", "for num_steps in [5, 10, 20]:\n", " images = pipe(\n", " prompts,\n", " num_steps=num_steps,\n", " cfg_scale=1.0, # Low cfg for untrained model\n", " height=64,\n", " width=64,\n", " seed=42,\n", " )\n", " \n", " fig, axes = plt.subplots(1, 4, figsize=(16, 4))\n", " for i in range(4):\n", " img = images[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5\n", " axes[i].imshow(np.clip(img, 0, 1))\n", " axes[i].set_title(prompts[i][:30] + '...')\n", " axes[i].axis('off')\n", " plt.suptitle(f'Generated Images ({num_steps} steps)', fontsize=14)\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Save & Load Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Save the complete model\n", "pipe.save_pretrained('./lrf_model')\n", "print('Model saved to ./lrf_model/')\n", "\n", "# List saved files\n", "for f in os.listdir('./lrf_model'):\n", " size = os.path.getsize(f'./lrf_model/{f}')\n", " print(f' {f}: {size/1024:.1f} KB')\n", "\n", "# Reload and verify\n", "pipe_loaded = LRFPipeline.from_pretrained('./lrf_model', device=str(device))\n", "images_loaded = pipe_loaded('test prompt', num_steps=5, height=64, width=64, seed=42)\n", "print(f'\\nReloaded model generates: {images_loaded.shape}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Training Curriculum for Real Data\n", "\n", "The full training curriculum for production-quality models:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Display the full training curriculum\n", "curriculum = LRFTrainingPipeline.get_curriculum()\n", "\n", "print('Full Training Curriculum')\n", "print('=' * 70)\n", "for i, stage_name in enumerate(curriculum):\n", " stage = LRFTrainingPipeline.get_stage_config(stage_name)\n", " print(f'\\nStage {i+1}: {stage_name}')\n", " print(f' Description: {stage[\"description\"]}')\n", " print(f' Freeze: {stage[\"freeze\"]}')\n", " print(f' Train: {stage[\"train\"]}')\n", " print(f' LR: {stage[\"lr\"]}')\n", " print(f' Min steps: {stage[\"min_steps\"]:,}')\n", " if 'resolution' in stage:\n", " print(f' Resolution: {stage[\"resolution\"]}Ɨ{stage[\"resolution\"]}')\n", "\n", "print('\\n' + '=' * 70)\n", "print('\\nRecommended datasets for each stage:')\n", "print(' Stage 1 (VAE): ImageNet, COCO, or any large image dataset')\n", "print(' Stage 2 (Flow 64): Synthetic captions from teacher (SDXL/SD3) + LAION-aesthetic')\n", "print(' Stage 3 (Flow 256): Filtered LAION-aesthetic (score > 6.0) + synthetic')\n", "print(' Stage 4 (Flow 512): High-quality curated dataset + JourneyDB')\n", "print(' Stage 5 (Distill): Same as Stage 4 (distill from own multi-step model)')\n", "print(' Stage 6 (Editing): InstructPix2Pix + MagicBrush + synthetic edit pairs')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Architecture Deep Dive\n", "\n", "### The Recursive Latent Refinement Loop" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Demonstrate the recursive refinement\n", "core = RecursiveLatentCore(\n", " dim=32, cond_dim=64, num_blocks=2, num_heads=2, head_dim=16,\n", " T_inner=4, T_outer=2, use_ift_training=False\n", ")\n", "\n", "print('Recursive Latent Core Architecture')\n", "print('=' * 50)\n", "print(f'Unique GLD blocks: {core.num_blocks}')\n", "print(f'T_outer (abstract updates): {core.T_outer}')\n", "print(f'T_inner (refinement steps): {core.T_inner}')\n", "print(f'Total recursions: {core.T_outer * core.T_inner}')\n", "print(f'Effective depth: {core.T_outer * core.T_inner * core.num_blocks} layers')\n", "print(f'Parameter reuse factor: {core.T_outer * core.T_inner}x')\n", "print(f'\\nParameters: {sum(p.numel() for p in core.parameters()):,}')\n", "\n", "# Show memory savings from IFT\n", "print('\\nMemory comparison:')\n", "eff_depth = core.T_outer * core.T_inner * core.num_blocks\n", "print(f' Standard backprop: O({eff_depth}) activation memory')\n", "print(f' IFT backprop: O(1) activation memory')\n", "print(f' Memory savings: {eff_depth}x')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Demonstrate GLA complexity\n", "import time\n", "\n", "gla = GatedLinearAttention(dim=64, num_heads=4, head_dim=16)\n", "\n", "print('GLA Complexity Scaling')\n", "print('=' * 50)\n", "\n", "sizes = [4, 8, 16, 32, 64]\n", "times = []\n", "\n", "for s in sizes:\n", " x = torch.randn(1, s*s, 64)\n", " \n", " # Warmup\n", " _ = gla(x, h=s, w=s)\n", " \n", " # Time\n", " t0 = time.time()\n", " for _ in range(10):\n", " _ = gla(x, h=s, w=s)\n", " dt = (time.time() - t0) / 10\n", " times.append(dt)\n", " print(f' {s}Ɨ{s} = {s*s:>5} tokens: {dt*1000:.2f}ms')\n", "\n", "# Plot scaling\n", "plt.figure(figsize=(8, 4))\n", "tokens = [s*s for s in sizes]\n", "plt.plot(tokens, [t*1000 for t in times], 'bo-', label='GLA (O(N))')\n", "# Reference quadratic line\n", "t_ref = times[0] * 1000\n", "quadratic = [t_ref * (n / tokens[0])**2 for n in tokens]\n", "plt.plot(tokens, quadratic, 'r--', label='Quadratic attention (O(N²))', alpha=0.5)\n", "plt.xlabel('Number of tokens')\n", "plt.ylabel('Time (ms)')\n", "plt.title('GLA vs Quadratic Attention Scaling')\n", "plt.legend()\n", "plt.grid(True, alpha=0.3)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. Push to HuggingFace Hub (Optional)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Uncomment to push to Hub\n", "# from huggingface_hub import HfApi\n", "# api = HfApi()\n", "# api.upload_folder(\n", "# folder_path='./lrf_model',\n", "# repo_id='your-username/LatentRecurrentFlow',\n", "# repo_type='model',\n", "# )\n", "print('To push to HF Hub, uncomment the code above and set your repo_id.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "## Summary\n", "\n", "This notebook demonstrated the LatentRecurrentFlow architecture end-to-end:\n", "\n", "1. āœ… Model creation with parameter counting\n", "2. āœ… VAE training for image compression\n", "3. āœ… Flow matching denoiser training\n", "4. āœ… Image generation with Euler ODE sampling\n", "5. āœ… Model save/load with HF-compatible format\n", "6. āœ… Training curriculum for production\n", "\n", "### Next Steps\n", "- Replace synthetic data with real image-text pairs\n", "- Scale to default config (16M params)\n", "- Train on GPU for actual quality\n", "- Add consistency distillation for 4-step generation\n", "- Add editing fine-tuning stage" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 4 }