Fix deprecated torch.cuda.amp API in training cell — use torch.amp instead
Browse files- LiquidFlow_Training.ipynb +1 -747
LiquidFlow_Training.ipynb
CHANGED
|
@@ -1,747 +1 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "markdown",
|
| 5 |
-
"metadata": {},
|
| 6 |
-
"source": [
|
| 7 |
-
"# 🌊 LiquidFlow — Liquid-SSM Flow Matching Image Generator\n",
|
| 8 |
-
"\n",
|
| 9 |
-
"A **novel architecture** combining:\n",
|
| 10 |
-
"- **Liquid Time-Constant Networks** (CfC closed-form) — adaptive ODE dynamics, bounded by construction\n",
|
| 11 |
-
"- **Selective State Space Models** (Mamba-style) — linear-time long-range context, parallelizable\n",
|
| 12 |
-
"- **Zigzag Scanning** — 2D spatial awareness for image patches\n",
|
| 13 |
-
"- **Physics-Informed Regularization** — smoothness + total variation constraints\n",
|
| 14 |
-
"- **Rectified Flow Matching** — ODE-based generation (no noise schedule tuning)\n",
|
| 15 |
-
"\n",
|
| 16 |
-
"### 📋 What this notebook does\n",
|
| 17 |
-
"1. **Install & clone** the LiquidFlow codebase\n",
|
| 18 |
-
"2. **Choose a dataset** (CIFAR-10, Flowers-102, CelebA, or custom folder)\n",
|
| 19 |
-
"3. **Choose a model size** (tiny ~6M, small ~14M, base ~38M)\n",
|
| 20 |
-
"4. **Train** with one click — all Colab/Kaggle optimized\n",
|
| 21 |
-
"5. **Generate images** and visualize progress\n",
|
| 22 |
-
"6. **Export** trained model for mobile deployment\n",
|
| 23 |
-
"\n",
|
| 24 |
-
"### 💻 Hardware Requirements\n",
|
| 25 |
-
"| Config | GPU VRAM | Best For |\n",
|
| 26 |
-
"|--------|----------|----------|\n",
|
| 27 |
-
"| tiny-128 (bs=32) | ~4 GB | Colab free T4, Kaggle |\n",
|
| 28 |
-
"| small-128 (bs=16) | ~8 GB | Colab free T4, Kaggle |\n",
|
| 29 |
-
"| base-256 (bs=8) | ~12 GB | Colab Pro, Kaggle |\n",
|
| 30 |
-
"| 512 (bs=4) | ~14 GB | Colab Pro, A100 |"
|
| 31 |
-
]
|
| 32 |
-
},
|
| 33 |
-
{
|
| 34 |
-
"cell_type": "markdown",
|
| 35 |
-
"metadata": {},
|
| 36 |
-
"source": [
|
| 37 |
-
"---\n",
|
| 38 |
-
"## 0. Setup & Install"
|
| 39 |
-
]
|
| 40 |
-
},
|
| 41 |
-
{
|
| 42 |
-
"cell_type": "code",
|
| 43 |
-
"execution_count": null,
|
| 44 |
-
"metadata": {},
|
| 45 |
-
"outputs": [],
|
| 46 |
-
"source": [
|
| 47 |
-
"# Check GPU\n",
|
| 48 |
-
"!nvidia-smi || echo 'No GPU — CPU training only (very slow)'\n",
|
| 49 |
-
"import torch\n",
|
| 50 |
-
"print(f'PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}')\n",
|
| 51 |
-
"if torch.cuda.is_available():\n",
|
| 52 |
-
" print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')"
|
| 53 |
-
]
|
| 54 |
-
},
|
| 55 |
-
{
|
| 56 |
-
"cell_type": "code",
|
| 57 |
-
"execution_count": null,
|
| 58 |
-
"metadata": {},
|
| 59 |
-
"outputs": [],
|
| 60 |
-
"source": [
|
| 61 |
-
"# Install dependencies\n",
|
| 62 |
-
"!pip install -q torch torchvision einops pillow matplotlib tqdm"
|
| 63 |
-
]
|
| 64 |
-
},
|
| 65 |
-
{
|
| 66 |
-
"cell_type": "code",
|
| 67 |
-
"execution_count": null,
|
| 68 |
-
"metadata": {},
|
| 69 |
-
"outputs": [],
|
| 70 |
-
"source": [
|
| 71 |
-
"# Clone the repo (or just copy the files if already have them)\n",
|
| 72 |
-
"import os\n",
|
| 73 |
-
"if not os.path.exists('liquidflow'):\n",
|
| 74 |
-
" !git clone https://huggingface.co/krystv/LiquidFlow liquidflow_repo\n",
|
| 75 |
-
" !cp -r liquidflow_repo/liquidflow .\n",
|
| 76 |
-
"else:\n",
|
| 77 |
-
" print('liquidflow/ already exists')\n",
|
| 78 |
-
"\n",
|
| 79 |
-
"# Verify\n",
|
| 80 |
-
"from liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512\n",
|
| 81 |
-
"from liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\n",
|
| 82 |
-
"from liquidflow.sampling import euler_sample, heun_sample, generate_grid, make_grid_image\n",
|
| 83 |
-
"print('✅ LiquidFlow imported successfully!')"
|
| 84 |
-
]
|
| 85 |
-
},
|
| 86 |
-
{
|
| 87 |
-
"cell_type": "markdown",
|
| 88 |
-
"metadata": {},
|
| 89 |
-
"source": [
|
| 90 |
-
"---\n",
|
| 91 |
-
"## 1. ⚙️ Configuration — EDIT THIS CELL\n",
|
| 92 |
-
"\n",
|
| 93 |
-
"Choose your dataset, model size, and training hyperparameters."
|
| 94 |
-
]
|
| 95 |
-
},
|
| 96 |
-
{
|
| 97 |
-
"cell_type": "code",
|
| 98 |
-
"execution_count": null,
|
| 99 |
-
"metadata": {},
|
| 100 |
-
"outputs": [],
|
| 101 |
-
"source": [
|
| 102 |
-
"#@title 🎛️ Training Configuration { display-mode: \"form\" }\n",
|
| 103 |
-
"\n",
|
| 104 |
-
"# ============== DATASET ==============\n",
|
| 105 |
-
"#@markdown ### Dataset\n",
|
| 106 |
-
"DATASET = 'cifar10' #@param ['cifar10', 'flowers', 'celeba', 'folder', 'fashion_mnist', 'afhq', 'lsun_churches']\n",
|
| 107 |
-
"CUSTOM_DATA_DIR = '/content/my_images' #@param {type:\"string\"}\n",
|
| 108 |
-
"#@markdown > For 'folder': put images in CUSTOM_DATA_DIR. Supports .png/.jpg/.webp\n",
|
| 109 |
-
"\n",
|
| 110 |
-
"# ============== MODEL ==============\n",
|
| 111 |
-
"#@markdown ### Model\n",
|
| 112 |
-
"MODEL_SIZE = 'tiny' #@param ['tiny', 'small', 'base', '512']\n",
|
| 113 |
-
"IMG_SIZE = 128 #@param [32, 64, 128, 256, 512] {type:\"integer\"}\n",
|
| 114 |
-
"\n",
|
| 115 |
-
"# ============== TRAINING ==============\n",
|
| 116 |
-
"#@markdown ### Training\n",
|
| 117 |
-
"EPOCHS = 100 #@param {type:\"integer\"}\n",
|
| 118 |
-
"BATCH_SIZE = 32 #@param [4, 8, 16, 32, 64, 128] {type:\"integer\"}\n",
|
| 119 |
-
"LEARNING_RATE = 3e-4 #@param {type:\"number\"}\n",
|
| 120 |
-
"GRAD_ACCUM = 1 #@param [1, 2, 4, 8] {type:\"integer\"}\n",
|
| 121 |
-
"USE_AMP = True #@param {type:\"boolean\"}\n",
|
| 122 |
-
"\n",
|
| 123 |
-
"# ============== PHYSICS LOSS ==============\n",
|
| 124 |
-
"#@markdown ### Physics-Informed Regularization\n",
|
| 125 |
-
"LAMBDA_SMOOTH = 0.01 #@param {type:\"number\"}\n",
|
| 126 |
-
"LAMBDA_TV = 0.001 #@param {type:\"number\"}\n",
|
| 127 |
-
"\n",
|
| 128 |
-
"# ============== SAMPLING ==============\n",
|
| 129 |
-
"#@markdown ### Sampling & Logging\n",
|
| 130 |
-
"SAMPLE_EVERY = 5 #@param {type:\"integer\"}\n",
|
| 131 |
-
"SAMPLE_STEPS = 50 #@param [10, 25, 50, 100] {type:\"integer\"}\n",
|
| 132 |
-
"LOG_EVERY = 50 #@param {type:\"integer\"}\n",
|
| 133 |
-
"SAVE_EVERY = 10 #@param {type:\"integer\"}\n",
|
| 134 |
-
"\n",
|
| 135 |
-
"# ============== PATHS ==============\n",
|
| 136 |
-
"OUTPUT_DIR = './outputs'\n",
|
| 137 |
-
"DATA_DIR = './data'\n",
|
| 138 |
-
"\n",
|
| 139 |
-
"# ============== AUTO-CONFIG ==============\n",
|
| 140 |
-
"# Smart batch size based on GPU memory\n",
|
| 141 |
-
"import torch\n",
|
| 142 |
-
"if torch.cuda.is_available():\n",
|
| 143 |
-
" vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9\n",
|
| 144 |
-
" print(f'GPU VRAM: {vram_gb:.1f} GB')\n",
|
| 145 |
-
" \n",
|
| 146 |
-
" # Auto-adjust batch size if needed\n",
|
| 147 |
-
" recommended = {\n",
|
| 148 |
-
" (32, 'tiny'): 128, (64, 'tiny'): 64, (128, 'tiny'): 32,\n",
|
| 149 |
-
" (32, 'small'): 64, (64, 'small'): 32, (128, 'small'): 16,\n",
|
| 150 |
-
" (256, 'base'): 8, (512, '512'): 4,\n",
|
| 151 |
-
" }\n",
|
| 152 |
-
" key = (IMG_SIZE, MODEL_SIZE)\n",
|
| 153 |
-
" if key in recommended and vram_gb < 16:\n",
|
| 154 |
-
" rec_bs = recommended[key]\n",
|
| 155 |
-
" if BATCH_SIZE > rec_bs:\n",
|
| 156 |
-
" print(f'⚠️ Reducing batch size {BATCH_SIZE} → {rec_bs} for {vram_gb:.0f}GB VRAM')\n",
|
| 157 |
-
" BATCH_SIZE = rec_bs\n",
|
| 158 |
-
"else:\n",
|
| 159 |
-
" print('⚠️ No GPU detected — training will be very slow!')\n",
|
| 160 |
-
" USE_AMP = False\n",
|
| 161 |
-
"\n",
|
| 162 |
-
"print(f'\\n📋 Config: {MODEL_SIZE}-{IMG_SIZE}, {DATASET}, bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={EPOCHS}')\n",
|
| 163 |
-
"print(f' Physics: λ_smooth={LAMBDA_SMOOTH}, λ_tv={LAMBDA_TV}')\n",
|
| 164 |
-
"print(f' AMP: {USE_AMP}, GradAccum: {GRAD_ACCUM}')"
|
| 165 |
-
]
|
| 166 |
-
},
|
| 167 |
-
{
|
| 168 |
-
"cell_type": "markdown",
|
| 169 |
-
"metadata": {},
|
| 170 |
-
"source": [
|
| 171 |
-
"---\n",
|
| 172 |
-
"## 2. 📦 Load Dataset"
|
| 173 |
-
]
|
| 174 |
-
},
|
| 175 |
-
{
|
| 176 |
-
"cell_type": "code",
|
| 177 |
-
"execution_count": null,
|
| 178 |
-
"metadata": {},
|
| 179 |
-
"outputs": [],
|
| 180 |
-
"source": [
|
| 181 |
-
"import torchvision\n",
|
| 182 |
-
"import torchvision.transforms as transforms\n",
|
| 183 |
-
"from torch.utils.data import DataLoader, Dataset, ConcatDataset\n",
|
| 184 |
-
"from pathlib import Path\n",
|
| 185 |
-
"from PIL import Image\n",
|
| 186 |
-
"import os\n",
|
| 187 |
-
"\n",
|
| 188 |
-
"# Standard transform\n",
|
| 189 |
-
"def get_transform(img_size):\n",
|
| 190 |
-
" return transforms.Compose([\n",
|
| 191 |
-
" transforms.Resize(img_size + img_size // 8),\n",
|
| 192 |
-
" transforms.CenterCrop(img_size),\n",
|
| 193 |
-
" transforms.RandomHorizontalFlip(),\n",
|
| 194 |
-
" transforms.ToTensor(),\n",
|
| 195 |
-
" transforms.Normalize([0.5]*3, [0.5]*3),\n",
|
| 196 |
-
" ])\n",
|
| 197 |
-
"\n",
|
| 198 |
-
"class ImageFolderFlat(Dataset):\n",
|
| 199 |
-
" \"\"\"Load all images from a folder (recursively).\"\"\"\n",
|
| 200 |
-
" def __init__(self, root, transform):\n",
|
| 201 |
-
" self.transform = transform\n",
|
| 202 |
-
" self.files = []\n",
|
| 203 |
-
" for ext in ['*.png', '*.jpg', '*.jpeg', '*.webp', '*.bmp']:\n",
|
| 204 |
-
" self.files.extend(Path(root).rglob(ext))\n",
|
| 205 |
-
" self.files = sorted(self.files)\n",
|
| 206 |
-
" print(f'Found {len(self.files)} images in {root}')\n",
|
| 207 |
-
" def __len__(self): return len(self.files)\n",
|
| 208 |
-
" def __getitem__(self, idx):\n",
|
| 209 |
-
" return self.transform(Image.open(self.files[idx]).convert('RGB'))\n",
|
| 210 |
-
"\n",
|
| 211 |
-
"class GrayscaleToRGB:\n",
|
| 212 |
-
" \"\"\"Convert 1-channel grayscale to 3-channel RGB.\"\"\"\n",
|
| 213 |
-
" def __call__(self, x):\n",
|
| 214 |
-
" if x.shape[0] == 1:\n",
|
| 215 |
-
" x = x.repeat(3, 1, 1)\n",
|
| 216 |
-
" return x\n",
|
| 217 |
-
"\n",
|
| 218 |
-
"tfm = get_transform(IMG_SIZE)\n",
|
| 219 |
-
"\n",
|
| 220 |
-
"if DATASET == 'cifar10':\n",
|
| 221 |
-
" dataset = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=tfm)\n",
|
| 222 |
-
" print(f'✅ CIFAR-10: {len(dataset)} images')\n",
|
| 223 |
-
"\n",
|
| 224 |
-
"elif DATASET == 'flowers':\n",
|
| 225 |
-
" ds_train = torchvision.datasets.Flowers102(root=DATA_DIR, split='train', download=True, transform=tfm)\n",
|
| 226 |
-
" ds_val = torchvision.datasets.Flowers102(root=DATA_DIR, split='val', download=True, transform=tfm)\n",
|
| 227 |
-
" ds_test = torchvision.datasets.Flowers102(root=DATA_DIR, split='test', download=True, transform=tfm)\n",
|
| 228 |
-
" dataset = ConcatDataset([ds_train, ds_val, ds_test]) # Use all splits for generation\n",
|
| 229 |
-
" print(f'✅ Flowers-102: {len(dataset)} images (all splits)')\n",
|
| 230 |
-
"\n",
|
| 231 |
-
"elif DATASET == 'celeba':\n",
|
| 232 |
-
" dataset = torchvision.datasets.CelebA(root=DATA_DIR, split='train', download=True, transform=tfm)\n",
|
| 233 |
-
" print(f'✅ CelebA: {len(dataset)} images')\n",
|
| 234 |
-
"\n",
|
| 235 |
-
"elif DATASET == 'fashion_mnist':\n",
|
| 236 |
-
" fm_tfm = transforms.Compose([\n",
|
| 237 |
-
" transforms.Resize(IMG_SIZE),\n",
|
| 238 |
-
" transforms.ToTensor(),\n",
|
| 239 |
-
" transforms.Normalize([0.5], [0.5]),\n",
|
| 240 |
-
" GrayscaleToRGB(),\n",
|
| 241 |
-
" ])\n",
|
| 242 |
-
" dataset = torchvision.datasets.FashionMNIST(root=DATA_DIR, train=True, download=True, transform=fm_tfm)\n",
|
| 243 |
-
" print(f'✅ Fashion-MNIST: {len(dataset)} images (converted to RGB)')\n",
|
| 244 |
-
"\n",
|
| 245 |
-
"elif DATASET == 'afhq':\n",
|
| 246 |
-
" # Download AFHQ from Kaggle or manual\n",
|
| 247 |
-
" afhq_dir = os.path.join(DATA_DIR, 'afhq', 'train')\n",
|
| 248 |
-
" if not os.path.exists(afhq_dir):\n",
|
| 249 |
-
" print('⬇️ Downloading AFHQ...')\n",
|
| 250 |
-
" !pip install -q gdown\n",
|
| 251 |
-
" !gdown 1Gof5BaELXlmSJIlvKMYCe9ONYPebkNsf -O {DATA_DIR}/afhq.zip\n",
|
| 252 |
-
" !unzip -q {DATA_DIR}/afhq.zip -d {DATA_DIR}/afhq\n",
|
| 253 |
-
" dataset = ImageFolderFlat(afhq_dir, tfm)\n",
|
| 254 |
-
" print(f'✅ AFHQ: {len(dataset)} images')\n",
|
| 255 |
-
"\n",
|
| 256 |
-
"elif DATASET == 'lsun_churches':\n",
|
| 257 |
-
" # LSUN requires manual download — point to extracted folder\n",
|
| 258 |
-
" lsun_dir = os.path.join(DATA_DIR, 'lsun_churches')\n",
|
| 259 |
-
" if not os.path.exists(lsun_dir):\n",
|
| 260 |
-
" print('❌ LSUN churches not found. Please download and extract to', lsun_dir)\n",
|
| 261 |
-
" print(' See: https://github.com/fyu/lsun')\n",
|
| 262 |
-
" raise FileNotFoundError(lsun_dir)\n",
|
| 263 |
-
" dataset = ImageFolderFlat(lsun_dir, tfm)\n",
|
| 264 |
-
" print(f'✅ LSUN Churches: {len(dataset)} images')\n",
|
| 265 |
-
"\n",
|
| 266 |
-
"elif DATASET == 'folder':\n",
|
| 267 |
-
" dataset = ImageFolderFlat(CUSTOM_DATA_DIR, tfm)\n",
|
| 268 |
-
" print(f'✅ Custom folder: {len(dataset)} images from {CUSTOM_DATA_DIR}')\n",
|
| 269 |
-
"\n",
|
| 270 |
-
"else:\n",
|
| 271 |
-
" raise ValueError(f'Unknown dataset: {DATASET}')\n",
|
| 272 |
-
"\n",
|
| 273 |
-
"# Show a few samples\n",
|
| 274 |
-
"import matplotlib.pyplot as plt\n",
|
| 275 |
-
"import numpy as np\n",
|
| 276 |
-
"\n",
|
| 277 |
-
"fig, axes = plt.subplots(1, 8, figsize=(16, 2))\n",
|
| 278 |
-
"for i, ax in enumerate(axes):\n",
|
| 279 |
-
" sample = dataset[i]\n",
|
| 280 |
-
" if isinstance(sample, (list, tuple)):\n",
|
| 281 |
-
" sample = sample[0]\n",
|
| 282 |
-
" img = sample * 0.5 + 0.5 # denormalize\n",
|
| 283 |
-
" ax.imshow(img.permute(1, 2, 0).clamp(0, 1).numpy())\n",
|
| 284 |
-
" ax.axis('off')\n",
|
| 285 |
-
"plt.suptitle(f'{DATASET} samples ({IMG_SIZE}×{IMG_SIZE})', fontsize=14)\n",
|
| 286 |
-
"plt.tight_layout()\n",
|
| 287 |
-
"plt.show()"
|
| 288 |
-
]
|
| 289 |
-
},
|
| 290 |
-
{
|
| 291 |
-
"cell_type": "markdown",
|
| 292 |
-
"metadata": {},
|
| 293 |
-
"source": [
|
| 294 |
-
"---\n",
|
| 295 |
-
"## 3. 🏗️ Build Model"
|
| 296 |
-
]
|
| 297 |
-
},
|
| 298 |
-
{
|
| 299 |
-
"cell_type": "code",
|
| 300 |
-
"execution_count": null,
|
| 301 |
-
"metadata": {},
|
| 302 |
-
"outputs": [],
|
| 303 |
-
"source": [
|
| 304 |
-
"import torch\n",
|
| 305 |
-
"from liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512\n",
|
| 306 |
-
"\n",
|
| 307 |
-
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 308 |
-
"\n",
|
| 309 |
-
"model_factories = {\n",
|
| 310 |
-
" 'tiny': liquidflow_tiny,\n",
|
| 311 |
-
" 'small': liquidflow_small,\n",
|
| 312 |
-
" 'base': liquidflow_base,\n",
|
| 313 |
-
" '512': liquidflow_512,\n",
|
| 314 |
-
"}\n",
|
| 315 |
-
"\n",
|
| 316 |
-
"model = model_factories[MODEL_SIZE](img_size=IMG_SIZE).to(device)\n",
|
| 317 |
-
"\n",
|
| 318 |
-
"num_params = model.count_params()\n",
|
| 319 |
-
"print(f'🏗️ LiquidFlow-{MODEL_SIZE}')\n",
|
| 320 |
-
"print(f' Parameters: {num_params:,} ({num_params/1e6:.1f}M)')\n",
|
| 321 |
-
"print(f' Image size: {IMG_SIZE}×{IMG_SIZE}')\n",
|
| 322 |
-
"print(f' Patch size: {model.patch_size}')\n",
|
| 323 |
-
"print(f' Num patches: {model.num_patches}')\n",
|
| 324 |
-
"print(f' Model dim: {model.d_model}')\n",
|
| 325 |
-
"print(f' Depth: {model.depth}')\n",
|
| 326 |
-
"print(f' Device: {device}')\n",
|
| 327 |
-
"\n",
|
| 328 |
-
"# Quick forward pass test\n",
|
| 329 |
-
"with torch.no_grad():\n",
|
| 330 |
-
" test_x = torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=device)\n",
|
| 331 |
-
" test_t = torch.tensor([0.5], device=device)\n",
|
| 332 |
-
" test_v = model(test_x, test_t)\n",
|
| 333 |
-
" assert test_v.shape == test_x.shape\n",
|
| 334 |
-
" print(f' ✅ Forward pass OK: {test_x.shape} → {test_v.shape}')"
|
| 335 |
-
]
|
| 336 |
-
},
|
| 337 |
-
{
|
| 338 |
-
"cell_type": "markdown",
|
| 339 |
-
"metadata": {},
|
| 340 |
-
"source": [
|
| 341 |
-
"---\n",
|
| 342 |
-
"## 4. 🚀 Train"
|
| 343 |
-
]
|
| 344 |
-
},
|
| 345 |
-
{
|
| 346 |
-
"cell_type": "code",
|
| 347 |
-
"execution_count": null,
|
| 348 |
-
"metadata": {},
|
| 349 |
-
"outputs": [],
|
| 350 |
-
"source": [
|
| 351 |
-
"import math\n",
|
| 352 |
-
"import time\n",
|
| 353 |
-
"import json\n",
|
| 354 |
-
"import torch.nn as nn\n",
|
| 355 |
-
"from torch.cuda.amp import autocast, GradScaler\n",
|
| 356 |
-
"from liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\n",
|
| 357 |
-
"from liquidflow.sampling import euler_sample, make_grid_image\n",
|
| 358 |
-
"from IPython.display import display, clear_output\n",
|
| 359 |
-
"import matplotlib.pyplot as plt\n",
|
| 360 |
-
"\n",
|
| 361 |
-
"# Prepare\n",
|
| 362 |
-
"os.makedirs(f'{OUTPUT_DIR}/samples', exist_ok=True)\n",
|
| 363 |
-
"os.makedirs(f'{OUTPUT_DIR}/checkpoints', exist_ok=True)\n",
|
| 364 |
-
"\n",
|
| 365 |
-
"dataloader = DataLoader(\n",
|
| 366 |
-
" dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
|
| 367 |
-
" num_workers=2, pin_memory=True, drop_last=True\n",
|
| 368 |
-
")\n",
|
| 369 |
-
"\n",
|
| 370 |
-
"optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE,\n",
|
| 371 |
-
" betas=(0.9, 0.999), weight_decay=0.01)\n",
|
| 372 |
-
"\n",
|
| 373 |
-
"total_steps = EPOCHS * len(dataloader) // GRAD_ACCUM\n",
|
| 374 |
-
"warmup_steps = min(500, total_steps // 10)\n",
|
| 375 |
-
"\n",
|
| 376 |
-
"def cosine_lr(step):\n",
|
| 377 |
-
" if step < warmup_steps:\n",
|
| 378 |
-
" return step / max(1, warmup_steps)\n",
|
| 379 |
-
" progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)\n",
|
| 380 |
-
" return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress))\n",
|
| 381 |
-
"\n",
|
| 382 |
-
"scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, cosine_lr)\n",
|
| 383 |
-
"criterion = PhysicsInformedFlowLoss(\n",
|
| 384 |
-
" lambda_smooth=LAMBDA_SMOOTH, lambda_tv=LAMBDA_TV\n",
|
| 385 |
-
").to(device)\n",
|
| 386 |
-
"ema = EMAModel(model, decay=0.9999)\n",
|
| 387 |
-
"scaler = GradScaler(enabled=USE_AMP)\n",
|
| 388 |
-
"\n",
|
| 389 |
-
"# Training log\n",
|
| 390 |
-
"all_losses = []\n",
|
| 391 |
-
"global_step = 0\n",
|
| 392 |
-
"\n",
|
| 393 |
-
"print(f'🚀 Training {EPOCHS} epochs, {total_steps} steps')\n",
|
| 394 |
-
"print(f' Effective batch: {BATCH_SIZE} × {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}')\n",
|
| 395 |
-
"print(f' LR: {LEARNING_RATE} → warmup {warmup_steps} steps → cosine decay')\n",
|
| 396 |
-
"print()\n",
|
| 397 |
-
"\n",
|
| 398 |
-
"t_start = time.time()\n",
|
| 399 |
-
"\n",
|
| 400 |
-
"for epoch in range(EPOCHS):\n",
|
| 401 |
-
" model.train()\n",
|
| 402 |
-
" epoch_loss = 0.0\n",
|
| 403 |
-
" epoch_flow = 0.0\n",
|
| 404 |
-
" n_batches = 0\n",
|
| 405 |
-
"\n",
|
| 406 |
-
" for batch_idx, batch_data in enumerate(dataloader):\n",
|
| 407 |
-
" if isinstance(batch_data, (list, tuple)):\n",
|
| 408 |
-
" x1 = batch_data[0].to(device)\n",
|
| 409 |
-
" else:\n",
|
| 410 |
-
" x1 = batch_data.to(device)\n",
|
| 411 |
-
"\n",
|
| 412 |
-
" B = x1.shape[0]\n",
|
| 413 |
-
" x0 = torch.randn_like(x1)\n",
|
| 414 |
-
" t = torch.rand(B, device=device)\n",
|
| 415 |
-
" t_e = t.view(B, 1, 1, 1)\n",
|
| 416 |
-
" x_t = t_e * x1 + (1 - t_e) * x0\n",
|
| 417 |
-
"\n",
|
| 418 |
-
" with autocast(enabled=USE_AMP):\n",
|
| 419 |
-
" v_pred = model(x_t, t)\n",
|
| 420 |
-
" loss, ld = criterion(v_pred, x0, x1, t, step=global_step)\n",
|
| 421 |
-
" loss = loss / GRAD_ACCUM\n",
|
| 422 |
-
"\n",
|
| 423 |
-
" scaler.scale(loss).backward()\n",
|
| 424 |
-
"\n",
|
| 425 |
-
" if (batch_idx + 1) % GRAD_ACCUM == 0:\n",
|
| 426 |
-
" scaler.unscale_(optimizer)\n",
|
| 427 |
-
" gn = nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
| 428 |
-
" scaler.step(optimizer)\n",
|
| 429 |
-
" scaler.update()\n",
|
| 430 |
-
" optimizer.zero_grad()\n",
|
| 431 |
-
" scheduler.step()\n",
|
| 432 |
-
" ema.update(model)\n",
|
| 433 |
-
" global_step += 1\n",
|
| 434 |
-
"\n",
|
| 435 |
-
" epoch_loss += ld['total'].item()\n",
|
| 436 |
-
" epoch_flow += ld['flow'].item()\n",
|
| 437 |
-
" n_batches += 1\n",
|
| 438 |
-
"\n",
|
| 439 |
-
" if global_step % LOG_EVERY == 0:\n",
|
| 440 |
-
" avg = epoch_loss / n_batches\n",
|
| 441 |
-
" avg_f = epoch_flow / n_batches\n",
|
| 442 |
-
" lr_now = scheduler.get_last_lr()[0]\n",
|
| 443 |
-
" elapsed = time.time() - t_start\n",
|
| 444 |
-
" it_s = global_step / elapsed\n",
|
| 445 |
-
" all_losses.append({'step': global_step, 'loss': avg, 'flow': avg_f,\n",
|
| 446 |
-
" 'lr': lr_now, 'epoch': epoch})\n",
|
| 447 |
-
" print(f' E{epoch+1} step {global_step}/{total_steps} | '\n",
|
| 448 |
-
" f'loss={avg:.4f} flow={avg_f:.4f} lr={lr_now:.2e} '\n",
|
| 449 |
-
" f'gn={gn:.2f} [{it_s:.1f} it/s]')\n",
|
| 450 |
-
"\n",
|
| 451 |
-
" # End of epoch\n",
|
| 452 |
-
" avg_epoch = epoch_loss / max(1, n_batches)\n",
|
| 453 |
-
" print(f'\\n📊 Epoch {epoch+1}/{EPOCHS} — avg loss: {avg_epoch:.4f}\\n')\n",
|
| 454 |
-
"\n",
|
| 455 |
-
" # Sample\n",
|
| 456 |
-
" if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0:\n",
|
| 457 |
-
" model.eval()\n",
|
| 458 |
-
" ema.apply_shadow(model)\n",
|
| 459 |
-
" with torch.no_grad():\n",
|
| 460 |
-
" n_samples = min(16, BATCH_SIZE)\n",
|
| 461 |
-
" imgs = euler_sample(model, (n_samples, 3, IMG_SIZE, IMG_SIZE),\n",
|
| 462 |
-
" num_steps=SAMPLE_STEPS, device=device)\n",
|
| 463 |
-
" imgs = imgs.clamp(-1, 1) * 0.5 + 0.5\n",
|
| 464 |
-
" grid = make_grid_image(imgs, nrow=4)\n",
|
| 465 |
-
" grid.save(f'{OUTPUT_DIR}/samples/epoch_{epoch+1:04d}.png')\n",
|
| 466 |
-
"\n",
|
| 467 |
-
" # Display inline\n",
|
| 468 |
-
" fig, ax = plt.subplots(1, 1, figsize=(8, 8))\n",
|
| 469 |
-
" ax.imshow(grid)\n",
|
| 470 |
-
" ax.set_title(f'Epoch {epoch+1} — {MODEL_SIZE}-{IMG_SIZE} on {DATASET}')\n",
|
| 471 |
-
" ax.axis('off')\n",
|
| 472 |
-
" plt.tight_layout()\n",
|
| 473 |
-
" plt.show()\n",
|
| 474 |
-
"\n",
|
| 475 |
-
" ema.restore(model)\n",
|
| 476 |
-
" model.train()\n",
|
| 477 |
-
"\n",
|
| 478 |
-
" # Checkpoint\n",
|
| 479 |
-
" if (epoch + 1) % SAVE_EVERY == 0:\n",
|
| 480 |
-
" ckpt = {\n",
|
| 481 |
-
" 'model': model.state_dict(),\n",
|
| 482 |
-
" 'optimizer': optimizer.state_dict(),\n",
|
| 483 |
-
" 'scheduler': scheduler.state_dict(),\n",
|
| 484 |
-
" 'ema': ema.state_dict(),\n",
|
| 485 |
-
" 'epoch': epoch,\n",
|
| 486 |
-
" 'global_step': global_step,\n",
|
| 487 |
-
" }\n",
|
| 488 |
-
" torch.save(ckpt, f'{OUTPUT_DIR}/checkpoints/epoch_{epoch+1:04d}.pt')\n",
|
| 489 |
-
" torch.save(ckpt, f'{OUTPUT_DIR}/checkpoints/latest.pt')\n",
|
| 490 |
-
" print(f'💾 Checkpoint saved: epoch {epoch+1}')\n",
|
| 491 |
-
"\n",
|
| 492 |
-
"# Save final\n",
|
| 493 |
-
"ema.apply_shadow(model)\n",
|
| 494 |
-
"torch.save({'model': model.state_dict(), 'config': {\n",
|
| 495 |
-
" 'model_size': MODEL_SIZE, 'img_size': IMG_SIZE, 'dataset': DATASET,\n",
|
| 496 |
-
" 'num_params': num_params, 'epochs': EPOCHS,\n",
|
| 497 |
-
"}}, f'{OUTPUT_DIR}/liquidflow_final.pt')\n",
|
| 498 |
-
"ema.restore(model)\n",
|
| 499 |
-
"\n",
|
| 500 |
-
"elapsed = time.time() - t_start\n",
|
| 501 |
-
"print(f'\\n✅ Training complete! {elapsed/60:.1f} min total')\n",
|
| 502 |
-
"print(f' Final model: {OUTPUT_DIR}/liquidflow_final.pt')"
|
| 503 |
-
]
|
| 504 |
-
},
|
| 505 |
-
{
|
| 506 |
-
"cell_type": "markdown",
|
| 507 |
-
"metadata": {},
|
| 508 |
-
"source": [
|
| 509 |
-
"---\n",
|
| 510 |
-
"## 5. 📈 Training Curves"
|
| 511 |
-
]
|
| 512 |
-
},
|
| 513 |
-
{
|
| 514 |
-
"cell_type": "code",
|
| 515 |
-
"execution_count": null,
|
| 516 |
-
"metadata": {},
|
| 517 |
-
"outputs": [],
|
| 518 |
-
"source": [
|
| 519 |
-
"import matplotlib.pyplot as plt\n",
|
| 520 |
-
"\n",
|
| 521 |
-
"if all_losses:\n",
|
| 522 |
-
" steps = [d['step'] for d in all_losses]\n",
|
| 523 |
-
" losses = [d['loss'] for d in all_losses]\n",
|
| 524 |
-
" flows = [d['flow'] for d in all_losses]\n",
|
| 525 |
-
" lrs = [d['lr'] for d in all_losses]\n",
|
| 526 |
-
"\n",
|
| 527 |
-
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 528 |
-
"\n",
|
| 529 |
-
" ax1.plot(steps, losses, label='Total Loss', alpha=0.8)\n",
|
| 530 |
-
" ax1.plot(steps, flows, label='Flow Loss', alpha=0.8)\n",
|
| 531 |
-
" ax1.set_xlabel('Step'); ax1.set_ylabel('Loss')\n",
|
| 532 |
-
" ax1.set_title('Training Loss'); ax1.legend(); ax1.grid(True, alpha=0.3)\n",
|
| 533 |
-
"\n",
|
| 534 |
-
" ax2.plot(steps, lrs, color='orange')\n",
|
| 535 |
-
" ax2.set_xlabel('Step'); ax2.set_ylabel('LR')\n",
|
| 536 |
-
" ax2.set_title('Learning Rate Schedule'); ax2.grid(True, alpha=0.3)\n",
|
| 537 |
-
"\n",
|
| 538 |
-
" plt.tight_layout()\n",
|
| 539 |
-
" plt.savefig(f'{OUTPUT_DIR}/training_curves.png', dpi=150)\n",
|
| 540 |
-
" plt.show()\n",
|
| 541 |
-
"else:\n",
|
| 542 |
-
" print('No training logs yet.')"
|
| 543 |
-
]
|
| 544 |
-
},
|
| 545 |
-
{
|
| 546 |
-
"cell_type": "markdown",
|
| 547 |
-
"metadata": {},
|
| 548 |
-
"source": [
|
| 549 |
-
"---\n",
|
| 550 |
-
"## 6. 🎨 Generate Images"
|
| 551 |
-
]
|
| 552 |
-
},
|
| 553 |
-
{
|
| 554 |
-
"cell_type": "code",
|
| 555 |
-
"execution_count": null,
|
| 556 |
-
"metadata": {},
|
| 557 |
-
"outputs": [],
|
| 558 |
-
"source": [
|
| 559 |
-
"#@title 🎨 Generation Settings { display-mode: \"form\" }\n",
|
| 560 |
-
"NUM_IMAGES = 16 #@param {type:\"integer\"}\n",
|
| 561 |
-
"GEN_STEPS = 50 #@param [10, 25, 50, 100, 200] {type:\"integer\"}\n",
|
| 562 |
-
"SAMPLER = 'euler' #@param ['euler', 'heun']\n",
|
| 563 |
-
"SEED = 42 #@param {type:\"integer\"}\n",
|
| 564 |
-
"\n",
|
| 565 |
-
"import torch\n",
|
| 566 |
-
"from liquidflow.sampling import euler_sample, heun_sample, make_grid_image\n",
|
| 567 |
-
"import matplotlib.pyplot as plt\n",
|
| 568 |
-
"\n",
|
| 569 |
-
"# Load best model\n",
|
| 570 |
-
"ckpt_path = f'{OUTPUT_DIR}/liquidflow_final.pt'\n",
|
| 571 |
-
"if os.path.exists(ckpt_path):\n",
|
| 572 |
-
" ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)\n",
|
| 573 |
-
" model.load_state_dict(ckpt['model'])\n",
|
| 574 |
-
" print(f'Loaded: {ckpt_path}')\n",
|
| 575 |
-
"else:\n",
|
| 576 |
-
" print(f'No checkpoint found, using current model weights')\n",
|
| 577 |
-
"\n",
|
| 578 |
-
"model.eval()\n",
|
| 579 |
-
"torch.manual_seed(SEED)\n",
|
| 580 |
-
"\n",
|
| 581 |
-
"shape = (NUM_IMAGES, 3, IMG_SIZE, IMG_SIZE)\n",
|
| 582 |
-
"\n",
|
| 583 |
-
"with torch.no_grad():\n",
|
| 584 |
-
" if SAMPLER == 'euler':\n",
|
| 585 |
-
" images = euler_sample(model, shape, num_steps=GEN_STEPS, device=device)\n",
|
| 586 |
-
" else:\n",
|
| 587 |
-
" images = heun_sample(model, shape, num_steps=GEN_STEPS, device=device)\n",
|
| 588 |
-
"\n",
|
| 589 |
-
"images = images.clamp(-1, 1) * 0.5 + 0.5\n",
|
| 590 |
-
"grid = make_grid_image(images, nrow=int(NUM_IMAGES**0.5))\n",
|
| 591 |
-
"grid.save(f'{OUTPUT_DIR}/generated_final.png')\n",
|
| 592 |
-
"\n",
|
| 593 |
-
"plt.figure(figsize=(10, 10))\n",
|
| 594 |
-
"plt.imshow(grid)\n",
|
| 595 |
-
"plt.title(f'LiquidFlow-{MODEL_SIZE} | {DATASET} {IMG_SIZE}×{IMG_SIZE} | {GEN_STEPS} steps ({SAMPLER})')\n",
|
| 596 |
-
"plt.axis('off')\n",
|
| 597 |
-
"plt.show()"
|
| 598 |
-
]
|
| 599 |
-
},
|
| 600 |
-
{
|
| 601 |
-
"cell_type": "markdown",
|
| 602 |
-
"metadata": {},
|
| 603 |
-
"source": [
|
| 604 |
-
"---\n",
|
| 605 |
-
"## 7. 📱 Export for Mobile (ONNX + TorchScript)"
|
| 606 |
-
]
|
| 607 |
-
},
|
| 608 |
-
{
|
| 609 |
-
"cell_type": "code",
|
| 610 |
-
"execution_count": null,
|
| 611 |
-
"metadata": {},
|
| 612 |
-
"outputs": [],
|
| 613 |
-
"source": [
|
| 614 |
-
"# Export to TorchScript for mobile deployment\n",
|
| 615 |
-
"model.eval()\n",
|
| 616 |
-
"\n",
|
| 617 |
-
"# TorchScript (for PyTorch Mobile / ExecuTorch)\n",
|
| 618 |
-
"example_x = torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=device)\n",
|
| 619 |
-
"example_t = torch.tensor([0.5], device=device)\n",
|
| 620 |
-
"\n",
|
| 621 |
-
"try:\n",
|
| 622 |
-
" traced = torch.jit.trace(model, (example_x, example_t))\n",
|
| 623 |
-
" ts_path = f'{OUTPUT_DIR}/liquidflow_mobile.pt'\n",
|
| 624 |
-
" traced.save(ts_path)\n",
|
| 625 |
-
" ts_size_mb = os.path.getsize(ts_path) / 1e6\n",
|
| 626 |
-
" print(f'✅ TorchScript saved: {ts_path} ({ts_size_mb:.1f} MB)')\n",
|
| 627 |
-
"except Exception as e:\n",
|
| 628 |
-
" print(f'⚠️ TorchScript export failed: {e}')\n",
|
| 629 |
-
"\n",
|
| 630 |
-
"# ONNX\n",
|
| 631 |
-
"try:\n",
|
| 632 |
-
" onnx_path = f'{OUTPUT_DIR}/liquidflow.onnx'\n",
|
| 633 |
-
" torch.onnx.export(\n",
|
| 634 |
-
" model.cpu(), (example_x.cpu(), example_t.cpu()),\n",
|
| 635 |
-
" onnx_path, opset_version=14,\n",
|
| 636 |
-
" input_names=['image', 'timestep'],\n",
|
| 637 |
-
" output_names=['velocity'],\n",
|
| 638 |
-
" dynamic_axes={'image': {0: 'batch'}, 'timestep': {0: 'batch'}, 'velocity': {0: 'batch'}}\n",
|
| 639 |
-
" )\n",
|
| 640 |
-
" onnx_size_mb = os.path.getsize(onnx_path) / 1e6\n",
|
| 641 |
-
" print(f'✅ ONNX saved: {onnx_path} ({onnx_size_mb:.1f} MB)')\n",
|
| 642 |
-
" model.to(device)\n",
|
| 643 |
-
"except Exception as e:\n",
|
| 644 |
-
" print(f'⚠️ ONNX export failed: {e}')\n",
|
| 645 |
-
" model.to(device)"
|
| 646 |
-
]
|
| 647 |
-
},
|
| 648 |
-
{
|
| 649 |
-
"cell_type": "markdown",
|
| 650 |
-
"metadata": {},
|
| 651 |
-
"source": [
|
| 652 |
-
"---\n",
|
| 653 |
-
"## 8. 🔬 Architecture Deep Dive\n",
|
| 654 |
-
"\n",
|
| 655 |
-
"### How LiquidFlow works\n",
|
| 656 |
-
"\n",
|
| 657 |
-
"```\n",
|
| 658 |
-
"Noise x₀ ~ N(0,I) ──→ LiquidFlow v_θ(xₜ, t) ──→ Image x₁\n",
|
| 659 |
-
" │\n",
|
| 660 |
-
" ┌──────┴──────┐\n",
|
| 661 |
-
" │ Patchify │ (img → non-overlapping patches)\n",
|
| 662 |
-
" │ + PosEmb │ (2D learnable positions)\n",
|
| 663 |
-
" │ + DepthConv│ (local structure)\n",
|
| 664 |
-
" └──────┬──────┘\n",
|
| 665 |
-
" │\n",
|
| 666 |
-
" ┌────────────┼────────────┐\n",
|
| 667 |
-
" │ L × LiquidSSM Block │\n",
|
| 668 |
-
" │ ┌──────────────────┐ │\n",
|
| 669 |
-
" │ │ AdaLN (t-cond) │ │\n",
|
| 670 |
-
" │ │ Zigzag Scan │ │ ← rotates scan pattern per layer\n",
|
| 671 |
-
" │ │ SelectiveSSM │ │ ← Mamba-style, input-dependent\n",
|
| 672 |
-
" │ │ + LiquidCfC │ │ ← CfC gating, bounded dynamics\n",
|
| 673 |
-
" │ │ + FFN │ │\n",
|
| 674 |
-
" │ │ + Skip Connect │ │ ← U-Net style long skips\n",
|
| 675 |
-
" │ └──────────────────┘ │\n",
|
| 676 |
-
" └────────────┼────────────┘\n",
|
| 677 |
-
" │\n",
|
| 678 |
-
" ┌──────┴──────┐\n",
|
| 679 |
-
" │ DepthConv │\n",
|
| 680 |
-
" │ Unpatchify │ (patches → img)\n",
|
| 681 |
-
" └──────┬──────┘\n",
|
| 682 |
-
" │\n",
|
| 683 |
-
" velocity v_θ\n",
|
| 684 |
-
"```\n",
|
| 685 |
-
"\n",
|
| 686 |
-
"### Key Innovations\n",
|
| 687 |
-
"\n",
|
| 688 |
-
"1. **Liquid CfC Cell**: Instead of solving the ODE `dx/dt = f(x,t)` numerically, we use the\n",
|
| 689 |
-
" closed-form solution `x(t+Δt) = σ(-f_τ) ⊙ x(t) + (1 - σ(-f_τ)) ⊙ f_x`.\n",
|
| 690 |
-
" The sigmoid gating **guarantees bounded dynamics** — no training explosion possible.\n",
|
| 691 |
-
"\n",
|
| 692 |
-
"2. **SSM + Liquid dual path**: The SSM branch captures long-range spatial dependencies\n",
|
| 693 |
-
" via selective scanning; the Liquid branch adds continuous-time adaptive dynamics.\n",
|
| 694 |
-
" A learnable mixing coefficient balances them.\n",
|
| 695 |
-
"\n",
|
| 696 |
-
"3. **Physics-informed loss**: Smoothness (Laplacian) and Total Variation regularizers\n",
|
| 697 |
-
" act as soft PDE constraints on generated images, improving training stability\n",
|
| 698 |
-
" and reducing artifacts without domain-specific physics knowledge.\n",
|
| 699 |
-
"\n",
|
| 700 |
-
"4. **Flow Matching = Liquid ODE**: Rectified flow trains `v_θ` to follow straight paths\n",
|
| 701 |
-
" from noise to data. This is structurally identical to the LTC ODE, making Liquid\n",
|
| 702 |
-
" networks a natural fit as the velocity field parameterization."
|
| 703 |
-
]
|
| 704 |
-
},
|
| 705 |
-
{
|
| 706 |
-
"cell_type": "markdown",
|
| 707 |
-
"metadata": {},
|
| 708 |
-
"source": [
|
| 709 |
-
"---\n",
|
| 710 |
-
"## 9. 🧪 Recommended Experiments\n",
|
| 711 |
-
"\n",
|
| 712 |
-
"| Experiment | Dataset | Model | IMG_SIZE | Epochs | Notes |\n",
|
| 713 |
-
"|------------|---------|-------|----------|--------|-------|\n",
|
| 714 |
-
"| Quick sanity check | CIFAR-10 | tiny | 32 | 20 | ~5 min on T4 |\n",
|
| 715 |
-
"| Baseline 128×128 | CIFAR-10 | tiny | 128 | 100 | ~2 hrs on T4 |\n",
|
| 716 |
-
"| Quality 128×128 | Flowers-102 | small | 128 | 200 | ~4 hrs on T4 |\n",
|
| 717 |
-
"| Faces 128×128 | CelebA | small | 128 | 50 | ~6 hrs on T4 |\n",
|
| 718 |
-
"| High-res 512×512 | CelebA | 512 | 512 | 100 | needs ≥16GB |\n",
|
| 719 |
-
"| Production | Your data | small | 128 | 300+ | best quality |\n",
|
| 720 |
-
"\n",
|
| 721 |
-
"### Tips for best results:\n",
|
| 722 |
-
"- Start with `tiny` + low epochs to verify everything works\n",
|
| 723 |
-
"- Use `small` for 128×128 production quality\n",
|
| 724 |
-
"- Increase `SAMPLE_STEPS` to 100+ for final generation\n",
|
| 725 |
-
"- `heun` sampler gives better quality at half the steps vs `euler`\n",
|
| 726 |
-
"- Physics loss warmup is automatic — don't increase λ too much"
|
| 727 |
-
]
|
| 728 |
-
}
|
| 729 |
-
],
|
| 730 |
-
"metadata": {
|
| 731 |
-
"accelerator": "GPU",
|
| 732 |
-
"colab": {
|
| 733 |
-
"gpuType": "T4",
|
| 734 |
-
"provenance": []
|
| 735 |
-
},
|
| 736 |
-
"kernelspec": {
|
| 737 |
-
"display_name": "Python 3",
|
| 738 |
-
"name": "python3"
|
| 739 |
-
},
|
| 740 |
-
"language_info": {
|
| 741 |
-
"name": "python",
|
| 742 |
-
"version": "3.10.12"
|
| 743 |
-
}
|
| 744 |
-
},
|
| 745 |
-
"nbformat": 4,
|
| 746 |
-
"nbformat_minor": 4
|
| 747 |
-
}
|
|
|
|
| 1 |
+
{"cells":[{"cell_type":"markdown","metadata":{},"source":["# 🌊 LiquidFlow — Liquid-SSM Flow Matching Image Generator\n","\n","A **novel architecture** combining:\n","- **Liquid Time-Constant Networks** (CfC closed-form) — adaptive ODE dynamics, bounded by construction\n","- **Selective State Space Models** (Mamba-style) — linear-time long-range context, parallelizable\n","- **Zigzag Scanning** — 2D spatial awareness for image patches\n","- **Physics-Informed Regularization** — smoothness + total variation constraints\n","- **Rectified Flow Matching** — ODE-based generation (no noise schedule tuning)\n","\n","### 💻 Hardware Requirements\n","| Config | GPU VRAM | Best For |\n","|--------|----------|----------|\n","| tiny-128 (bs=32) | ~4 GB | Colab free T4, Kaggle |\n","| small-128 (bs=16) | ~8 GB | Colab free T4, Kaggle |\n","| base-256 (bs=8) | ~12 GB | Colab Pro, Kaggle |\n","| 512 (bs=4) | ~14 GB | Colab Pro, A100 |"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 0. Setup & Install"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!nvidia-smi || echo 'No GPU — CPU only'\nimport torch\nprint(f'PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}')\nif torch.cuda.is_available():\n print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB')"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!pip install -q torch torchvision einops pillow matplotlib tqdm"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os\nif not os.path.exists('liquidflow'):\n !git clone https://huggingface.co/krystv/LiquidFlow liquidflow_repo\n !cp -r liquidflow_repo/liquidflow .\nelse:\n print('liquidflow/ already exists — updating...')\n !cd liquidflow_repo && git pull && cp -r liquidflow/* ../liquidflow/\n\nfrom liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512\nfrom liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\nfrom liquidflow.sampling import euler_sample, heun_sample, generate_grid, make_grid_image\nprint('✅ LiquidFlow imported successfully!')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 1. ⚙️ Configuration"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#@title 🎛️ Training Configuration { display-mode: \"form\" }\n\nDATASET = 'cifar10' #@param ['cifar10', 'flowers', 'celeba', 'folder', 'fashion_mnist', 'afhq', 'lsun_churches']\nCUSTOM_DATA_DIR = '/content/my_images' #@param {type:\"string\"}\nMODEL_SIZE = 'tiny' #@param ['tiny', 'small', 'base', '512']\nIMG_SIZE = 128 #@param [32, 64, 128, 256, 512] {type:\"integer\"}\nEPOCHS = 100 #@param {type:\"integer\"}\nBATCH_SIZE = 32 #@param [4, 8, 16, 32, 64, 128] {type:\"integer\"}\nLEARNING_RATE = 3e-4 #@param {type:\"number\"}\nGRAD_ACCUM = 1 #@param [1, 2, 4, 8] {type:\"integer\"}\nUSE_AMP = True #@param {type:\"boolean\"}\nLAMBDA_SMOOTH = 0.01 #@param {type:\"number\"}\nLAMBDA_TV = 0.001 #@param {type:\"number\"}\nSAMPLE_EVERY = 5 #@param {type:\"integer\"}\nSAMPLE_STEPS = 50 #@param [10, 25, 50, 100] {type:\"integer\"}\nLOG_EVERY = 50 #@param {type:\"integer\"}\nSAVE_EVERY = 10 #@param {type:\"integer\"}\nOUTPUT_DIR = './outputs'\nDATA_DIR = './data'\n\nimport torch\nif torch.cuda.is_available():\n vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9\n print(f'GPU VRAM: {vram_gb:.1f} GB')\n recommended = {(32,'tiny'):128,(64,'tiny'):64,(128,'tiny'):32,(32,'small'):64,(64,'small'):32,(128,'small'):16,(256,'base'):8,(512,'512'):4}\n key = (IMG_SIZE, MODEL_SIZE)\n if key in recommended and vram_gb < 16:\n rec_bs = recommended[key]\n if BATCH_SIZE > rec_bs:\n print(f'⚠️ Reducing batch size {BATCH_SIZE} → {rec_bs} for {vram_gb:.0f}GB VRAM')\n BATCH_SIZE = rec_bs\nelse:\n print('⚠️ No GPU detected'); USE_AMP = False\n\nprint(f'\\n📋 Config: {MODEL_SIZE}-{IMG_SIZE}, {DATASET}, bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={EPOCHS}')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 2. 📦 Load Dataset"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import torchvision\nimport torchvision.transforms as transforms\nfrom torch.utils.data import DataLoader, Dataset, ConcatDataset\nfrom pathlib import Path\nfrom PIL import Image\nimport os, matplotlib.pyplot as plt, numpy as np\n\ndef get_transform(img_size):\n return transforms.Compose([transforms.Resize(img_size+img_size//8), transforms.CenterCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5]*3,[0.5]*3)])\n\nclass ImageFolderFlat(Dataset):\n def __init__(self, root, transform):\n self.transform = transform\n self.files = []\n for ext in ['*.png','*.jpg','*.jpeg','*.webp','*.bmp']: self.files.extend(Path(root).rglob(ext))\n self.files = sorted(self.files); print(f'Found {len(self.files)} images in {root}')\n def __len__(self): return len(self.files)\n def __getitem__(self, idx): return self.transform(Image.open(self.files[idx]).convert('RGB'))\n\nclass GrayscaleToRGB:\n def __call__(self, x): return x.repeat(3,1,1) if x.shape[0]==1 else x\n\ntfm = get_transform(IMG_SIZE)\n\nif DATASET == 'cifar10':\n dataset = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=tfm)\nelif DATASET == 'flowers':\n dataset = ConcatDataset([torchvision.datasets.Flowers102(root=DATA_DIR, split=s, download=True, transform=tfm) for s in ['train','val','test']])\nelif DATASET == 'celeba':\n dataset = torchvision.datasets.CelebA(root=DATA_DIR, split='train', download=True, transform=tfm)\nelif DATASET == 'fashion_mnist':\n fm_tfm = transforms.Compose([transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize([0.5],[0.5]), GrayscaleToRGB()])\n dataset = torchvision.datasets.FashionMNIST(root=DATA_DIR, train=True, download=True, transform=fm_tfm)\nelif DATASET == 'folder':\n dataset = ImageFolderFlat(CUSTOM_DATA_DIR, tfm)\nelse:\n raise ValueError(f'Unknown dataset: {DATASET}')\n\nprint(f'✅ {DATASET}: {len(dataset)} images')\n\nfig, axes = plt.subplots(1, 8, figsize=(16, 2))\nfor i, ax in enumerate(axes):\n sample = dataset[i]; sample = sample[0] if isinstance(sample,(list,tuple)) else sample\n ax.imshow((sample*0.5+0.5).permute(1,2,0).clamp(0,1).numpy()); ax.axis('off')\nplt.suptitle(f'{DATASET} samples ({IMG_SIZE}×{IMG_SIZE})'); plt.tight_layout(); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 3. 🏗️ Build Model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import torch\nfrom liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = {'tiny':liquidflow_tiny,'small':liquidflow_small,'base':liquidflow_base,'512':liquidflow_512}[MODEL_SIZE](img_size=IMG_SIZE).to(device)\n\nnum_params = model.count_params()\nprint(f'🏗️ LiquidFlow-{MODEL_SIZE}: {num_params:,} ({num_params/1e6:.1f}M) params')\nprint(f' {IMG_SIZE}×{IMG_SIZE}, patch={model.patch_size}, patches={model.num_patches}, dim={model.d_model}, depth={model.depth}')\n\nwith torch.no_grad():\n v = model(torch.randn(1,3,IMG_SIZE,IMG_SIZE,device=device), torch.tensor([0.5],device=device))\n print(f' ✅ Forward pass OK')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 4. 🚀 Train"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import math, time, json\nimport torch.nn as nn\nfrom liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\nfrom liquidflow.sampling import euler_sample, make_grid_image\nimport matplotlib.pyplot as plt\n\nos.makedirs(f'{OUTPUT_DIR}/samples', exist_ok=True)\nos.makedirs(f'{OUTPUT_DIR}/checkpoints', exist_ok=True)\n\ndataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)\noptimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9,0.999), weight_decay=0.01)\n\ntotal_steps = EPOCHS * len(dataloader) // GRAD_ACCUM\nwarmup_steps = min(500, total_steps // 10)\ndef cosine_lr(step):\n if step < warmup_steps: return step / max(1, warmup_steps)\n p = (step - warmup_steps) / max(1, total_steps - warmup_steps)\n return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * p))\n\nscheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, cosine_lr)\ncriterion = PhysicsInformedFlowLoss(lambda_smooth=LAMBDA_SMOOTH, lambda_tv=LAMBDA_TV).to(device)\nema = EMAModel(model, decay=0.9999)\n\n# Use modern AMP API (no deprecation warnings)\namp_device = 'cuda' if device.type == 'cuda' else 'cpu'\nscaler = torch.amp.GradScaler(amp_device, enabled=USE_AMP)\n\nall_losses = []\nglobal_step = 0\nprint(f'🚀 Training {EPOCHS} epochs, {total_steps} steps')\nprint(f' Effective batch: {BATCH_SIZE} × {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}')\nprint(f' LR: {LEARNING_RATE} → warmup {warmup_steps} steps → cosine decay\\n')\n\nt_start = time.time()\nfor epoch in range(EPOCHS):\n model.train()\n epoch_loss = epoch_flow = 0.0; n_batches = 0\n\n for batch_idx, batch_data in enumerate(dataloader):\n x1 = (batch_data[0] if isinstance(batch_data,(list,tuple)) else batch_data).to(device)\n B = x1.shape[0]\n x0 = torch.randn_like(x1)\n t = torch.rand(B, device=device)\n x_t = t.view(B,1,1,1) * x1 + (1 - t.view(B,1,1,1)) * x0\n\n with torch.amp.autocast(amp_device, enabled=USE_AMP):\n v_pred = model(x_t, t)\n loss, ld = criterion(v_pred, x0, x1, t, step=global_step)\n loss = loss / GRAD_ACCUM\n\n scaler.scale(loss).backward()\n\n if (batch_idx + 1) % GRAD_ACCUM == 0:\n scaler.unscale_(optimizer)\n gn = nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n scaler.step(optimizer); scaler.update(); optimizer.zero_grad()\n scheduler.step(); ema.update(model); global_step += 1\n\n epoch_loss += ld['total'].item(); epoch_flow += ld['flow'].item(); n_batches += 1\n\n if global_step % LOG_EVERY == 0:\n avg = epoch_loss/n_batches; avg_f = epoch_flow/n_batches\n lr_now = scheduler.get_last_lr()[0]; it_s = global_step/(time.time()-t_start)\n all_losses.append({'step':global_step,'loss':avg,'flow':avg_f,'lr':lr_now,'epoch':epoch})\n print(f' E{epoch+1} step {global_step}/{total_steps} | loss={avg:.4f} flow={avg_f:.4f} lr={lr_now:.2e} gn={gn:.2f} [{it_s:.1f} it/s]')\n\n avg_epoch = epoch_loss / max(1, n_batches)\n print(f'\\n📊 Epoch {epoch+1}/{EPOCHS} — avg loss: {avg_epoch:.4f}\\n')\n\n if (epoch+1) % SAMPLE_EVERY == 0 or epoch == 0:\n model.eval(); ema.apply_shadow(model)\n with torch.no_grad():\n imgs = euler_sample(model, (min(16,BATCH_SIZE),3,IMG_SIZE,IMG_SIZE), num_steps=SAMPLE_STEPS, device=device)\n imgs = imgs.clamp(-1,1)*0.5+0.5\n grid = make_grid_image(imgs, nrow=4)\n grid.save(f'{OUTPUT_DIR}/samples/epoch_{epoch+1:04d}.png')\n fig, ax = plt.subplots(1,1,figsize=(8,8)); ax.imshow(grid)\n ax.set_title(f'Epoch {epoch+1} — {MODEL_SIZE}-{IMG_SIZE} on {DATASET}'); ax.axis('off'); plt.tight_layout(); plt.show()\n ema.restore(model); model.train()\n\n if (epoch+1) % SAVE_EVERY == 0:\n ckpt = {'model':model.state_dict(),'optimizer':optimizer.state_dict(),'scheduler':scheduler.state_dict(),'ema':ema.state_dict(),'epoch':epoch,'global_step':global_step}\n torch.save(ckpt, f'{OUTPUT_DIR}/checkpoints/epoch_{epoch+1:04d}.pt')\n torch.save(ckpt, f'{OUTPUT_DIR}/checkpoints/latest.pt')\n print(f'💾 Checkpoint saved: epoch {epoch+1}')\n\nema.apply_shadow(model)\ntorch.save({'model':model.state_dict(),'config':{'model_size':MODEL_SIZE,'img_size':IMG_SIZE,'dataset':DATASET,'num_params':num_params,'epochs':EPOCHS}}, f'{OUTPUT_DIR}/liquidflow_final.pt')\nema.restore(model)\nprint(f'\\n✅ Training complete! {(time.time()-t_start)/60:.1f} min total')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 5. 📈 Training Curves"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import matplotlib.pyplot as plt\nif all_losses:\n steps=[d['step'] for d in all_losses]; losses=[d['loss'] for d in all_losses]; flows=[d['flow'] for d in all_losses]; lrs=[d['lr'] for d in all_losses]\n fig,(ax1,ax2) = plt.subplots(1,2,figsize=(14,5))\n ax1.plot(steps,losses,label='Total',alpha=0.8); ax1.plot(steps,flows,label='Flow',alpha=0.8)\n ax1.set_xlabel('Step'); ax1.set_ylabel('Loss'); ax1.set_title('Training Loss'); ax1.legend(); ax1.grid(True,alpha=0.3)\n ax2.plot(steps,lrs,color='orange'); ax2.set_xlabel('Step'); ax2.set_ylabel('LR'); ax2.set_title('LR Schedule'); ax2.grid(True,alpha=0.3)\n plt.tight_layout(); plt.savefig(f'{OUTPUT_DIR}/training_curves.png',dpi=150); plt.show()\nelse: print('No training logs yet.')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 6. 🎨 Generate Images"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#@title 🎨 Generation Settings { display-mode: \"form\" }\nNUM_IMAGES = 16 #@param {type:\"integer\"}\nGEN_STEPS = 50 #@param [10, 25, 50, 100, 200] {type:\"integer\"}\nSAMPLER = 'euler' #@param ['euler', 'heun']\nSEED = 42 #@param {type:\"integer\"}\n\nimport torch\nfrom liquidflow.sampling import euler_sample, heun_sample, make_grid_image\nimport matplotlib.pyplot as plt\n\nckpt_path = f'{OUTPUT_DIR}/liquidflow_final.pt'\nif os.path.exists(ckpt_path):\n model.load_state_dict(torch.load(ckpt_path, map_location=device, weights_only=False)['model'])\n print(f'Loaded: {ckpt_path}')\nmodel.eval(); torch.manual_seed(SEED)\nshape = (NUM_IMAGES, 3, IMG_SIZE, IMG_SIZE)\nwith torch.no_grad():\n images = (euler_sample if SAMPLER=='euler' else heun_sample)(model, shape, num_steps=GEN_STEPS, device=device)\nimages = images.clamp(-1,1)*0.5+0.5\ngrid = make_grid_image(images, nrow=int(NUM_IMAGES**0.5))\ngrid.save(f'{OUTPUT_DIR}/generated_final.png')\nplt.figure(figsize=(10,10)); plt.imshow(grid)\nplt.title(f'LiquidFlow-{MODEL_SIZE} | {DATASET} {IMG_SIZE}×{IMG_SIZE} | {GEN_STEPS} steps ({SAMPLER})'); plt.axis('off'); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 7. 📱 Export for Mobile"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["model.eval()\nexample_x = torch.randn(1,3,IMG_SIZE,IMG_SIZE,device=device)\nexample_t = torch.tensor([0.5],device=device)\ntry:\n traced = torch.jit.trace(model, (example_x, example_t))\n ts_path = f'{OUTPUT_DIR}/liquidflow_mobile.pt'; traced.save(ts_path)\n print(f'✅ TorchScript: {ts_path} ({os.path.getsize(ts_path)/1e6:.1f} MB)')\nexcept Exception as e: print(f'⚠️ TorchScript failed: {e}')\ntry:\n onnx_path = f'{OUTPUT_DIR}/liquidflow.onnx'\n torch.onnx.export(model.cpu(), (example_x.cpu(), example_t.cpu()), onnx_path, opset_version=14,\n input_names=['image','timestep'], output_names=['velocity'],\n dynamic_axes={'image':{0:'batch'},'timestep':{0:'batch'},'velocity':{0:'batch'}})\n print(f'✅ ONNX: {onnx_path} ({os.path.getsize(onnx_path)/1e6:.1f} MB)'); model.to(device)\nexcept Exception as e: print(f'⚠️ ONNX failed: {e}'); model.to(device)"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 8. 🧪 Recommended Experiments\n","\n","| Goal | Dataset | Model | Size | Epochs | Time (T4) |\n","|------|---------|-------|------|--------|----------|\n","| Sanity check | CIFAR-10 | tiny | 32 | 20 | ~5 min |\n","| Baseline | CIFAR-10 | tiny | 128 | 100 | ~2 hrs |\n","| Quality | Flowers-102 | small | 128 | 200 | ~4 hrs |\n","| Faces | CelebA | small | 128 | 50 | ~6 hrs |\n","| High-res | CelebA | 512 | 512 | 100 | ~12 hrs |"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12"}},"nbformat":4,"nbformat_minor":4}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|