{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# LatentRecurrentFlow (LRF) v3 — Train & Generate\n", "\n", "**One notebook. Run top to bottom. Produces real images.**\n", "\n", "Architecture:\n", "- **TAESD** (pre-trained, 2.4M params, frozen) as the VAE\n", "- **1.47M-param Recursive Denoising Core** (4 shared blocks × 2 recursions = 8 effective layers)\n", "- **Rectified flow** matching with SNR weighting and CFG dropout\n", "- **EMA** for stable sampling\n", "\n", "Trains on CIFAR-10 in ~60 min on CPU, ~10 min on GPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install -q torch torchvision einops diffusers safetensors huggingface_hub matplotlib pillow" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Download the self-contained training script\n", "import os\n", "if not os.path.exists('lrf_v3.py'):\n", " !wget -q https://huggingface.co/krystv/LatentRecurrentFlow/resolve/main/lrf_v3.py\n", "from lrf_v3 import *\n", "print(f'Device: {DEVICE}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Architecture" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "notebook_cell_2_architecture()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Train on CIFAR-10\n", "\n", "This will:\n", "1. Load TAESD (pre-trained VAE)\n", "2. Pre-compute CIFAR-10 latents (~4 min CPU)\n", "3. Train the denoiser (30 epochs, ~60 min CPU / ~10 min GPU)\n", "4. Generate class-conditional samples" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model, vae, losses = train(epochs=30, bs=64, lr=3e-4, out='./lrf_out')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Visualize Results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "import numpy as np\n", "\n", "# Loss curve\n", "plt.figure(figsize=(10, 3))\n", "plt.plot(losses, 'b-')\n", "plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Training Loss')\n", "plt.grid(True, alpha=0.3); plt.show()\n", "\n", "# VAE reconstruction\n", "plt.figure(figsize=(8, 4))\n", "plt.imshow(np.array(Image.open('./lrf_out/vae_check.png')))\n", "plt.title('VAE Check (top=original, bottom=TAESD reconstruction)')\n", "plt.axis('off'); plt.show()\n", "\n", "# Final generation\n", "plt.figure(figsize=(10, 25))\n", "plt.imshow(np.array(Image.open('./lrf_out/final.png')))\n", "classes = ['airplane','auto','bird','cat','deer','dog','frog','horse','ship','truck']\n", "plt.title('Class-conditional generation\\n' + ', '.join(classes))\n", "plt.axis('off'); plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Generate Custom Samples" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torchvision\n", "\n", "sched = FlowScheduler()\n", "classes = ['airplane','auto','bird','cat','deer','dog','frog','horse','ship','truck']\n", "\n", "# Generate 8 images of each class\n", "for cls_id, cls_name in enumerate(classes):\n", " imgs = gen(model, vae, sched, DEVICE, n=8, steps=50, cfg=3.0, cls_id=cls_id)\n", " grid = torchvision.utils.make_grid((imgs+1)/2, nrow=8, padding=2)\n", " plt.figure(figsize=(16, 2))\n", " plt.imshow(grid.permute(1,2,0).numpy())\n", " plt.title(f'{cls_name} (class {cls_id})')\n", " plt.axis('off'); plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Load Pre-Trained Model (skip training)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load pre-trained model from HuggingFace Hub\n", "from huggingface_hub import hf_hub_download\n", "import torch\n", "\n", "model_path = hf_hub_download('krystv/LatentRecurrentFlow', 'v3/model.pt')\n", "ckpt = torch.load(model_path, map_location='cpu', weights_only=False)\n", "\n", "model = LRF(ckpt['cfg'])\n", "model.load_state_dict(ckpt['state'])\n", "model.eval()\n", "\n", "vae = get_taesd(DEVICE)\n", "sched = FlowScheduler()\n", "\n", "# Generate\n", "imgs = gen(model, vae, sched, DEVICE, n=16, steps=50, cfg=3.0)\n", "save_grid(imgs, 'quick_gen.png', 4)\n", "\n", "from PIL import Image\n", "import matplotlib.pyplot as plt\n", "plt.imshow(np.array(Image.open('quick_gen.png')))\n", "plt.axis('off'); plt.show()" ] } ], "metadata": { "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"name": "python", "version": "3.10.0"} }, "nbformat": 4, "nbformat_minor": 4 }