krystv commited on
Commit
1d50798
·
verified ·
1 Parent(s): dab968e

Fix deprecated torch.cuda.amp API in training cell — use torch.amp instead

Browse files
Files changed (1) hide show
  1. LiquidFlow_Training.ipynb +1 -747
LiquidFlow_Training.ipynb CHANGED
@@ -1,747 +1 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# 🌊 LiquidFlow — Liquid-SSM Flow Matching Image Generator\n",
8
- "\n",
9
- "A **novel architecture** combining:\n",
10
- "- **Liquid Time-Constant Networks** (CfC closed-form) — adaptive ODE dynamics, bounded by construction\n",
11
- "- **Selective State Space Models** (Mamba-style) — linear-time long-range context, parallelizable\n",
12
- "- **Zigzag Scanning** — 2D spatial awareness for image patches\n",
13
- "- **Physics-Informed Regularization** — smoothness + total variation constraints\n",
14
- "- **Rectified Flow Matching** — ODE-based generation (no noise schedule tuning)\n",
15
- "\n",
16
- "### 📋 What this notebook does\n",
17
- "1. **Install & clone** the LiquidFlow codebase\n",
18
- "2. **Choose a dataset** (CIFAR-10, Flowers-102, CelebA, or custom folder)\n",
19
- "3. **Choose a model size** (tiny ~6M, small ~14M, base ~38M)\n",
20
- "4. **Train** with one click — all Colab/Kaggle optimized\n",
21
- "5. **Generate images** and visualize progress\n",
22
- "6. **Export** trained model for mobile deployment\n",
23
- "\n",
24
- "### 💻 Hardware Requirements\n",
25
- "| Config | GPU VRAM | Best For |\n",
26
- "|--------|----------|----------|\n",
27
- "| tiny-128 (bs=32) | ~4 GB | Colab free T4, Kaggle |\n",
28
- "| small-128 (bs=16) | ~8 GB | Colab free T4, Kaggle |\n",
29
- "| base-256 (bs=8) | ~12 GB | Colab Pro, Kaggle |\n",
30
- "| 512 (bs=4) | ~14 GB | Colab Pro, A100 |"
31
- ]
32
- },
33
- {
34
- "cell_type": "markdown",
35
- "metadata": {},
36
- "source": [
37
- "---\n",
38
- "## 0. Setup & Install"
39
- ]
40
- },
41
- {
42
- "cell_type": "code",
43
- "execution_count": null,
44
- "metadata": {},
45
- "outputs": [],
46
- "source": [
47
- "# Check GPU\n",
48
- "!nvidia-smi || echo 'No GPU — CPU training only (very slow)'\n",
49
- "import torch\n",
50
- "print(f'PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}')\n",
51
- "if torch.cuda.is_available():\n",
52
- " print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')"
53
- ]
54
- },
55
- {
56
- "cell_type": "code",
57
- "execution_count": null,
58
- "metadata": {},
59
- "outputs": [],
60
- "source": [
61
- "# Install dependencies\n",
62
- "!pip install -q torch torchvision einops pillow matplotlib tqdm"
63
- ]
64
- },
65
- {
66
- "cell_type": "code",
67
- "execution_count": null,
68
- "metadata": {},
69
- "outputs": [],
70
- "source": [
71
- "# Clone the repo (or just copy the files if already have them)\n",
72
- "import os\n",
73
- "if not os.path.exists('liquidflow'):\n",
74
- " !git clone https://huggingface.co/krystv/LiquidFlow liquidflow_repo\n",
75
- " !cp -r liquidflow_repo/liquidflow .\n",
76
- "else:\n",
77
- " print('liquidflow/ already exists')\n",
78
- "\n",
79
- "# Verify\n",
80
- "from liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512\n",
81
- "from liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\n",
82
- "from liquidflow.sampling import euler_sample, heun_sample, generate_grid, make_grid_image\n",
83
- "print('✅ LiquidFlow imported successfully!')"
84
- ]
85
- },
86
- {
87
- "cell_type": "markdown",
88
- "metadata": {},
89
- "source": [
90
- "---\n",
91
- "## 1. ⚙️ Configuration — EDIT THIS CELL\n",
92
- "\n",
93
- "Choose your dataset, model size, and training hyperparameters."
94
- ]
95
- },
96
- {
97
- "cell_type": "code",
98
- "execution_count": null,
99
- "metadata": {},
100
- "outputs": [],
101
- "source": [
102
- "#@title 🎛️ Training Configuration { display-mode: \"form\" }\n",
103
- "\n",
104
- "# ============== DATASET ==============\n",
105
- "#@markdown ### Dataset\n",
106
- "DATASET = 'cifar10' #@param ['cifar10', 'flowers', 'celeba', 'folder', 'fashion_mnist', 'afhq', 'lsun_churches']\n",
107
- "CUSTOM_DATA_DIR = '/content/my_images' #@param {type:\"string\"}\n",
108
- "#@markdown > For 'folder': put images in CUSTOM_DATA_DIR. Supports .png/.jpg/.webp\n",
109
- "\n",
110
- "# ============== MODEL ==============\n",
111
- "#@markdown ### Model\n",
112
- "MODEL_SIZE = 'tiny' #@param ['tiny', 'small', 'base', '512']\n",
113
- "IMG_SIZE = 128 #@param [32, 64, 128, 256, 512] {type:\"integer\"}\n",
114
- "\n",
115
- "# ============== TRAINING ==============\n",
116
- "#@markdown ### Training\n",
117
- "EPOCHS = 100 #@param {type:\"integer\"}\n",
118
- "BATCH_SIZE = 32 #@param [4, 8, 16, 32, 64, 128] {type:\"integer\"}\n",
119
- "LEARNING_RATE = 3e-4 #@param {type:\"number\"}\n",
120
- "GRAD_ACCUM = 1 #@param [1, 2, 4, 8] {type:\"integer\"}\n",
121
- "USE_AMP = True #@param {type:\"boolean\"}\n",
122
- "\n",
123
- "# ============== PHYSICS LOSS ==============\n",
124
- "#@markdown ### Physics-Informed Regularization\n",
125
- "LAMBDA_SMOOTH = 0.01 #@param {type:\"number\"}\n",
126
- "LAMBDA_TV = 0.001 #@param {type:\"number\"}\n",
127
- "\n",
128
- "# ============== SAMPLING ==============\n",
129
- "#@markdown ### Sampling & Logging\n",
130
- "SAMPLE_EVERY = 5 #@param {type:\"integer\"}\n",
131
- "SAMPLE_STEPS = 50 #@param [10, 25, 50, 100] {type:\"integer\"}\n",
132
- "LOG_EVERY = 50 #@param {type:\"integer\"}\n",
133
- "SAVE_EVERY = 10 #@param {type:\"integer\"}\n",
134
- "\n",
135
- "# ============== PATHS ==============\n",
136
- "OUTPUT_DIR = './outputs'\n",
137
- "DATA_DIR = './data'\n",
138
- "\n",
139
- "# ============== AUTO-CONFIG ==============\n",
140
- "# Smart batch size based on GPU memory\n",
141
- "import torch\n",
142
- "if torch.cuda.is_available():\n",
143
- " vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9\n",
144
- " print(f'GPU VRAM: {vram_gb:.1f} GB')\n",
145
- " \n",
146
- " # Auto-adjust batch size if needed\n",
147
- " recommended = {\n",
148
- " (32, 'tiny'): 128, (64, 'tiny'): 64, (128, 'tiny'): 32,\n",
149
- " (32, 'small'): 64, (64, 'small'): 32, (128, 'small'): 16,\n",
150
- " (256, 'base'): 8, (512, '512'): 4,\n",
151
- " }\n",
152
- " key = (IMG_SIZE, MODEL_SIZE)\n",
153
- " if key in recommended and vram_gb < 16:\n",
154
- " rec_bs = recommended[key]\n",
155
- " if BATCH_SIZE > rec_bs:\n",
156
- " print(f'⚠️ Reducing batch size {BATCH_SIZE} → {rec_bs} for {vram_gb:.0f}GB VRAM')\n",
157
- " BATCH_SIZE = rec_bs\n",
158
- "else:\n",
159
- " print('⚠️ No GPU detected — training will be very slow!')\n",
160
- " USE_AMP = False\n",
161
- "\n",
162
- "print(f'\\n📋 Config: {MODEL_SIZE}-{IMG_SIZE}, {DATASET}, bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={EPOCHS}')\n",
163
- "print(f' Physics: λ_smooth={LAMBDA_SMOOTH}, λ_tv={LAMBDA_TV}')\n",
164
- "print(f' AMP: {USE_AMP}, GradAccum: {GRAD_ACCUM}')"
165
- ]
166
- },
167
- {
168
- "cell_type": "markdown",
169
- "metadata": {},
170
- "source": [
171
- "---\n",
172
- "## 2. 📦 Load Dataset"
173
- ]
174
- },
175
- {
176
- "cell_type": "code",
177
- "execution_count": null,
178
- "metadata": {},
179
- "outputs": [],
180
- "source": [
181
- "import torchvision\n",
182
- "import torchvision.transforms as transforms\n",
183
- "from torch.utils.data import DataLoader, Dataset, ConcatDataset\n",
184
- "from pathlib import Path\n",
185
- "from PIL import Image\n",
186
- "import os\n",
187
- "\n",
188
- "# Standard transform\n",
189
- "def get_transform(img_size):\n",
190
- " return transforms.Compose([\n",
191
- " transforms.Resize(img_size + img_size // 8),\n",
192
- " transforms.CenterCrop(img_size),\n",
193
- " transforms.RandomHorizontalFlip(),\n",
194
- " transforms.ToTensor(),\n",
195
- " transforms.Normalize([0.5]*3, [0.5]*3),\n",
196
- " ])\n",
197
- "\n",
198
- "class ImageFolderFlat(Dataset):\n",
199
- " \"\"\"Load all images from a folder (recursively).\"\"\"\n",
200
- " def __init__(self, root, transform):\n",
201
- " self.transform = transform\n",
202
- " self.files = []\n",
203
- " for ext in ['*.png', '*.jpg', '*.jpeg', '*.webp', '*.bmp']:\n",
204
- " self.files.extend(Path(root).rglob(ext))\n",
205
- " self.files = sorted(self.files)\n",
206
- " print(f'Found {len(self.files)} images in {root}')\n",
207
- " def __len__(self): return len(self.files)\n",
208
- " def __getitem__(self, idx):\n",
209
- " return self.transform(Image.open(self.files[idx]).convert('RGB'))\n",
210
- "\n",
211
- "class GrayscaleToRGB:\n",
212
- " \"\"\"Convert 1-channel grayscale to 3-channel RGB.\"\"\"\n",
213
- " def __call__(self, x):\n",
214
- " if x.shape[0] == 1:\n",
215
- " x = x.repeat(3, 1, 1)\n",
216
- " return x\n",
217
- "\n",
218
- "tfm = get_transform(IMG_SIZE)\n",
219
- "\n",
220
- "if DATASET == 'cifar10':\n",
221
- " dataset = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=tfm)\n",
222
- " print(f'✅ CIFAR-10: {len(dataset)} images')\n",
223
- "\n",
224
- "elif DATASET == 'flowers':\n",
225
- " ds_train = torchvision.datasets.Flowers102(root=DATA_DIR, split='train', download=True, transform=tfm)\n",
226
- " ds_val = torchvision.datasets.Flowers102(root=DATA_DIR, split='val', download=True, transform=tfm)\n",
227
- " ds_test = torchvision.datasets.Flowers102(root=DATA_DIR, split='test', download=True, transform=tfm)\n",
228
- " dataset = ConcatDataset([ds_train, ds_val, ds_test]) # Use all splits for generation\n",
229
- " print(f'✅ Flowers-102: {len(dataset)} images (all splits)')\n",
230
- "\n",
231
- "elif DATASET == 'celeba':\n",
232
- " dataset = torchvision.datasets.CelebA(root=DATA_DIR, split='train', download=True, transform=tfm)\n",
233
- " print(f'✅ CelebA: {len(dataset)} images')\n",
234
- "\n",
235
- "elif DATASET == 'fashion_mnist':\n",
236
- " fm_tfm = transforms.Compose([\n",
237
- " transforms.Resize(IMG_SIZE),\n",
238
- " transforms.ToTensor(),\n",
239
- " transforms.Normalize([0.5], [0.5]),\n",
240
- " GrayscaleToRGB(),\n",
241
- " ])\n",
242
- " dataset = torchvision.datasets.FashionMNIST(root=DATA_DIR, train=True, download=True, transform=fm_tfm)\n",
243
- " print(f'✅ Fashion-MNIST: {len(dataset)} images (converted to RGB)')\n",
244
- "\n",
245
- "elif DATASET == 'afhq':\n",
246
- " # Download AFHQ from Kaggle or manual\n",
247
- " afhq_dir = os.path.join(DATA_DIR, 'afhq', 'train')\n",
248
- " if not os.path.exists(afhq_dir):\n",
249
- " print('⬇️ Downloading AFHQ...')\n",
250
- " !pip install -q gdown\n",
251
- " !gdown 1Gof5BaELXlmSJIlvKMYCe9ONYPebkNsf -O {DATA_DIR}/afhq.zip\n",
252
- " !unzip -q {DATA_DIR}/afhq.zip -d {DATA_DIR}/afhq\n",
253
- " dataset = ImageFolderFlat(afhq_dir, tfm)\n",
254
- " print(f'✅ AFHQ: {len(dataset)} images')\n",
255
- "\n",
256
- "elif DATASET == 'lsun_churches':\n",
257
- " # LSUN requires manual download — point to extracted folder\n",
258
- " lsun_dir = os.path.join(DATA_DIR, 'lsun_churches')\n",
259
- " if not os.path.exists(lsun_dir):\n",
260
- " print('❌ LSUN churches not found. Please download and extract to', lsun_dir)\n",
261
- " print(' See: https://github.com/fyu/lsun')\n",
262
- " raise FileNotFoundError(lsun_dir)\n",
263
- " dataset = ImageFolderFlat(lsun_dir, tfm)\n",
264
- " print(f'✅ LSUN Churches: {len(dataset)} images')\n",
265
- "\n",
266
- "elif DATASET == 'folder':\n",
267
- " dataset = ImageFolderFlat(CUSTOM_DATA_DIR, tfm)\n",
268
- " print(f'✅ Custom folder: {len(dataset)} images from {CUSTOM_DATA_DIR}')\n",
269
- "\n",
270
- "else:\n",
271
- " raise ValueError(f'Unknown dataset: {DATASET}')\n",
272
- "\n",
273
- "# Show a few samples\n",
274
- "import matplotlib.pyplot as plt\n",
275
- "import numpy as np\n",
276
- "\n",
277
- "fig, axes = plt.subplots(1, 8, figsize=(16, 2))\n",
278
- "for i, ax in enumerate(axes):\n",
279
- " sample = dataset[i]\n",
280
- " if isinstance(sample, (list, tuple)):\n",
281
- " sample = sample[0]\n",
282
- " img = sample * 0.5 + 0.5 # denormalize\n",
283
- " ax.imshow(img.permute(1, 2, 0).clamp(0, 1).numpy())\n",
284
- " ax.axis('off')\n",
285
- "plt.suptitle(f'{DATASET} samples ({IMG_SIZE}×{IMG_SIZE})', fontsize=14)\n",
286
- "plt.tight_layout()\n",
287
- "plt.show()"
288
- ]
289
- },
290
- {
291
- "cell_type": "markdown",
292
- "metadata": {},
293
- "source": [
294
- "---\n",
295
- "## 3. 🏗️ Build Model"
296
- ]
297
- },
298
- {
299
- "cell_type": "code",
300
- "execution_count": null,
301
- "metadata": {},
302
- "outputs": [],
303
- "source": [
304
- "import torch\n",
305
- "from liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512\n",
306
- "\n",
307
- "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
308
- "\n",
309
- "model_factories = {\n",
310
- " 'tiny': liquidflow_tiny,\n",
311
- " 'small': liquidflow_small,\n",
312
- " 'base': liquidflow_base,\n",
313
- " '512': liquidflow_512,\n",
314
- "}\n",
315
- "\n",
316
- "model = model_factories[MODEL_SIZE](img_size=IMG_SIZE).to(device)\n",
317
- "\n",
318
- "num_params = model.count_params()\n",
319
- "print(f'🏗️ LiquidFlow-{MODEL_SIZE}')\n",
320
- "print(f' Parameters: {num_params:,} ({num_params/1e6:.1f}M)')\n",
321
- "print(f' Image size: {IMG_SIZE}×{IMG_SIZE}')\n",
322
- "print(f' Patch size: {model.patch_size}')\n",
323
- "print(f' Num patches: {model.num_patches}')\n",
324
- "print(f' Model dim: {model.d_model}')\n",
325
- "print(f' Depth: {model.depth}')\n",
326
- "print(f' Device: {device}')\n",
327
- "\n",
328
- "# Quick forward pass test\n",
329
- "with torch.no_grad():\n",
330
- " test_x = torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=device)\n",
331
- " test_t = torch.tensor([0.5], device=device)\n",
332
- " test_v = model(test_x, test_t)\n",
333
- " assert test_v.shape == test_x.shape\n",
334
- " print(f' ✅ Forward pass OK: {test_x.shape} → {test_v.shape}')"
335
- ]
336
- },
337
- {
338
- "cell_type": "markdown",
339
- "metadata": {},
340
- "source": [
341
- "---\n",
342
- "## 4. 🚀 Train"
343
- ]
344
- },
345
- {
346
- "cell_type": "code",
347
- "execution_count": null,
348
- "metadata": {},
349
- "outputs": [],
350
- "source": [
351
- "import math\n",
352
- "import time\n",
353
- "import json\n",
354
- "import torch.nn as nn\n",
355
- "from torch.cuda.amp import autocast, GradScaler\n",
356
- "from liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\n",
357
- "from liquidflow.sampling import euler_sample, make_grid_image\n",
358
- "from IPython.display import display, clear_output\n",
359
- "import matplotlib.pyplot as plt\n",
360
- "\n",
361
- "# Prepare\n",
362
- "os.makedirs(f'{OUTPUT_DIR}/samples', exist_ok=True)\n",
363
- "os.makedirs(f'{OUTPUT_DIR}/checkpoints', exist_ok=True)\n",
364
- "\n",
365
- "dataloader = DataLoader(\n",
366
- " dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
367
- " num_workers=2, pin_memory=True, drop_last=True\n",
368
- ")\n",
369
- "\n",
370
- "optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE,\n",
371
- " betas=(0.9, 0.999), weight_decay=0.01)\n",
372
- "\n",
373
- "total_steps = EPOCHS * len(dataloader) // GRAD_ACCUM\n",
374
- "warmup_steps = min(500, total_steps // 10)\n",
375
- "\n",
376
- "def cosine_lr(step):\n",
377
- " if step < warmup_steps:\n",
378
- " return step / max(1, warmup_steps)\n",
379
- " progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)\n",
380
- " return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress))\n",
381
- "\n",
382
- "scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, cosine_lr)\n",
383
- "criterion = PhysicsInformedFlowLoss(\n",
384
- " lambda_smooth=LAMBDA_SMOOTH, lambda_tv=LAMBDA_TV\n",
385
- ").to(device)\n",
386
- "ema = EMAModel(model, decay=0.9999)\n",
387
- "scaler = GradScaler(enabled=USE_AMP)\n",
388
- "\n",
389
- "# Training log\n",
390
- "all_losses = []\n",
391
- "global_step = 0\n",
392
- "\n",
393
- "print(f'🚀 Training {EPOCHS} epochs, {total_steps} steps')\n",
394
- "print(f' Effective batch: {BATCH_SIZE} × {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}')\n",
395
- "print(f' LR: {LEARNING_RATE} → warmup {warmup_steps} steps → cosine decay')\n",
396
- "print()\n",
397
- "\n",
398
- "t_start = time.time()\n",
399
- "\n",
400
- "for epoch in range(EPOCHS):\n",
401
- " model.train()\n",
402
- " epoch_loss = 0.0\n",
403
- " epoch_flow = 0.0\n",
404
- " n_batches = 0\n",
405
- "\n",
406
- " for batch_idx, batch_data in enumerate(dataloader):\n",
407
- " if isinstance(batch_data, (list, tuple)):\n",
408
- " x1 = batch_data[0].to(device)\n",
409
- " else:\n",
410
- " x1 = batch_data.to(device)\n",
411
- "\n",
412
- " B = x1.shape[0]\n",
413
- " x0 = torch.randn_like(x1)\n",
414
- " t = torch.rand(B, device=device)\n",
415
- " t_e = t.view(B, 1, 1, 1)\n",
416
- " x_t = t_e * x1 + (1 - t_e) * x0\n",
417
- "\n",
418
- " with autocast(enabled=USE_AMP):\n",
419
- " v_pred = model(x_t, t)\n",
420
- " loss, ld = criterion(v_pred, x0, x1, t, step=global_step)\n",
421
- " loss = loss / GRAD_ACCUM\n",
422
- "\n",
423
- " scaler.scale(loss).backward()\n",
424
- "\n",
425
- " if (batch_idx + 1) % GRAD_ACCUM == 0:\n",
426
- " scaler.unscale_(optimizer)\n",
427
- " gn = nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
428
- " scaler.step(optimizer)\n",
429
- " scaler.update()\n",
430
- " optimizer.zero_grad()\n",
431
- " scheduler.step()\n",
432
- " ema.update(model)\n",
433
- " global_step += 1\n",
434
- "\n",
435
- " epoch_loss += ld['total'].item()\n",
436
- " epoch_flow += ld['flow'].item()\n",
437
- " n_batches += 1\n",
438
- "\n",
439
- " if global_step % LOG_EVERY == 0:\n",
440
- " avg = epoch_loss / n_batches\n",
441
- " avg_f = epoch_flow / n_batches\n",
442
- " lr_now = scheduler.get_last_lr()[0]\n",
443
- " elapsed = time.time() - t_start\n",
444
- " it_s = global_step / elapsed\n",
445
- " all_losses.append({'step': global_step, 'loss': avg, 'flow': avg_f,\n",
446
- " 'lr': lr_now, 'epoch': epoch})\n",
447
- " print(f' E{epoch+1} step {global_step}/{total_steps} | '\n",
448
- " f'loss={avg:.4f} flow={avg_f:.4f} lr={lr_now:.2e} '\n",
449
- " f'gn={gn:.2f} [{it_s:.1f} it/s]')\n",
450
- "\n",
451
- " # End of epoch\n",
452
- " avg_epoch = epoch_loss / max(1, n_batches)\n",
453
- " print(f'\\n📊 Epoch {epoch+1}/{EPOCHS} — avg loss: {avg_epoch:.4f}\\n')\n",
454
- "\n",
455
- " # Sample\n",
456
- " if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0:\n",
457
- " model.eval()\n",
458
- " ema.apply_shadow(model)\n",
459
- " with torch.no_grad():\n",
460
- " n_samples = min(16, BATCH_SIZE)\n",
461
- " imgs = euler_sample(model, (n_samples, 3, IMG_SIZE, IMG_SIZE),\n",
462
- " num_steps=SAMPLE_STEPS, device=device)\n",
463
- " imgs = imgs.clamp(-1, 1) * 0.5 + 0.5\n",
464
- " grid = make_grid_image(imgs, nrow=4)\n",
465
- " grid.save(f'{OUTPUT_DIR}/samples/epoch_{epoch+1:04d}.png')\n",
466
- "\n",
467
- " # Display inline\n",
468
- " fig, ax = plt.subplots(1, 1, figsize=(8, 8))\n",
469
- " ax.imshow(grid)\n",
470
- " ax.set_title(f'Epoch {epoch+1} — {MODEL_SIZE}-{IMG_SIZE} on {DATASET}')\n",
471
- " ax.axis('off')\n",
472
- " plt.tight_layout()\n",
473
- " plt.show()\n",
474
- "\n",
475
- " ema.restore(model)\n",
476
- " model.train()\n",
477
- "\n",
478
- " # Checkpoint\n",
479
- " if (epoch + 1) % SAVE_EVERY == 0:\n",
480
- " ckpt = {\n",
481
- " 'model': model.state_dict(),\n",
482
- " 'optimizer': optimizer.state_dict(),\n",
483
- " 'scheduler': scheduler.state_dict(),\n",
484
- " 'ema': ema.state_dict(),\n",
485
- " 'epoch': epoch,\n",
486
- " 'global_step': global_step,\n",
487
- " }\n",
488
- " torch.save(ckpt, f'{OUTPUT_DIR}/checkpoints/epoch_{epoch+1:04d}.pt')\n",
489
- " torch.save(ckpt, f'{OUTPUT_DIR}/checkpoints/latest.pt')\n",
490
- " print(f'💾 Checkpoint saved: epoch {epoch+1}')\n",
491
- "\n",
492
- "# Save final\n",
493
- "ema.apply_shadow(model)\n",
494
- "torch.save({'model': model.state_dict(), 'config': {\n",
495
- " 'model_size': MODEL_SIZE, 'img_size': IMG_SIZE, 'dataset': DATASET,\n",
496
- " 'num_params': num_params, 'epochs': EPOCHS,\n",
497
- "}}, f'{OUTPUT_DIR}/liquidflow_final.pt')\n",
498
- "ema.restore(model)\n",
499
- "\n",
500
- "elapsed = time.time() - t_start\n",
501
- "print(f'\\n✅ Training complete! {elapsed/60:.1f} min total')\n",
502
- "print(f' Final model: {OUTPUT_DIR}/liquidflow_final.pt')"
503
- ]
504
- },
505
- {
506
- "cell_type": "markdown",
507
- "metadata": {},
508
- "source": [
509
- "---\n",
510
- "## 5. 📈 Training Curves"
511
- ]
512
- },
513
- {
514
- "cell_type": "code",
515
- "execution_count": null,
516
- "metadata": {},
517
- "outputs": [],
518
- "source": [
519
- "import matplotlib.pyplot as plt\n",
520
- "\n",
521
- "if all_losses:\n",
522
- " steps = [d['step'] for d in all_losses]\n",
523
- " losses = [d['loss'] for d in all_losses]\n",
524
- " flows = [d['flow'] for d in all_losses]\n",
525
- " lrs = [d['lr'] for d in all_losses]\n",
526
- "\n",
527
- " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
528
- "\n",
529
- " ax1.plot(steps, losses, label='Total Loss', alpha=0.8)\n",
530
- " ax1.plot(steps, flows, label='Flow Loss', alpha=0.8)\n",
531
- " ax1.set_xlabel('Step'); ax1.set_ylabel('Loss')\n",
532
- " ax1.set_title('Training Loss'); ax1.legend(); ax1.grid(True, alpha=0.3)\n",
533
- "\n",
534
- " ax2.plot(steps, lrs, color='orange')\n",
535
- " ax2.set_xlabel('Step'); ax2.set_ylabel('LR')\n",
536
- " ax2.set_title('Learning Rate Schedule'); ax2.grid(True, alpha=0.3)\n",
537
- "\n",
538
- " plt.tight_layout()\n",
539
- " plt.savefig(f'{OUTPUT_DIR}/training_curves.png', dpi=150)\n",
540
- " plt.show()\n",
541
- "else:\n",
542
- " print('No training logs yet.')"
543
- ]
544
- },
545
- {
546
- "cell_type": "markdown",
547
- "metadata": {},
548
- "source": [
549
- "---\n",
550
- "## 6. 🎨 Generate Images"
551
- ]
552
- },
553
- {
554
- "cell_type": "code",
555
- "execution_count": null,
556
- "metadata": {},
557
- "outputs": [],
558
- "source": [
559
- "#@title 🎨 Generation Settings { display-mode: \"form\" }\n",
560
- "NUM_IMAGES = 16 #@param {type:\"integer\"}\n",
561
- "GEN_STEPS = 50 #@param [10, 25, 50, 100, 200] {type:\"integer\"}\n",
562
- "SAMPLER = 'euler' #@param ['euler', 'heun']\n",
563
- "SEED = 42 #@param {type:\"integer\"}\n",
564
- "\n",
565
- "import torch\n",
566
- "from liquidflow.sampling import euler_sample, heun_sample, make_grid_image\n",
567
- "import matplotlib.pyplot as plt\n",
568
- "\n",
569
- "# Load best model\n",
570
- "ckpt_path = f'{OUTPUT_DIR}/liquidflow_final.pt'\n",
571
- "if os.path.exists(ckpt_path):\n",
572
- " ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)\n",
573
- " model.load_state_dict(ckpt['model'])\n",
574
- " print(f'Loaded: {ckpt_path}')\n",
575
- "else:\n",
576
- " print(f'No checkpoint found, using current model weights')\n",
577
- "\n",
578
- "model.eval()\n",
579
- "torch.manual_seed(SEED)\n",
580
- "\n",
581
- "shape = (NUM_IMAGES, 3, IMG_SIZE, IMG_SIZE)\n",
582
- "\n",
583
- "with torch.no_grad():\n",
584
- " if SAMPLER == 'euler':\n",
585
- " images = euler_sample(model, shape, num_steps=GEN_STEPS, device=device)\n",
586
- " else:\n",
587
- " images = heun_sample(model, shape, num_steps=GEN_STEPS, device=device)\n",
588
- "\n",
589
- "images = images.clamp(-1, 1) * 0.5 + 0.5\n",
590
- "grid = make_grid_image(images, nrow=int(NUM_IMAGES**0.5))\n",
591
- "grid.save(f'{OUTPUT_DIR}/generated_final.png')\n",
592
- "\n",
593
- "plt.figure(figsize=(10, 10))\n",
594
- "plt.imshow(grid)\n",
595
- "plt.title(f'LiquidFlow-{MODEL_SIZE} | {DATASET} {IMG_SIZE}×{IMG_SIZE} | {GEN_STEPS} steps ({SAMPLER})')\n",
596
- "plt.axis('off')\n",
597
- "plt.show()"
598
- ]
599
- },
600
- {
601
- "cell_type": "markdown",
602
- "metadata": {},
603
- "source": [
604
- "---\n",
605
- "## 7. 📱 Export for Mobile (ONNX + TorchScript)"
606
- ]
607
- },
608
- {
609
- "cell_type": "code",
610
- "execution_count": null,
611
- "metadata": {},
612
- "outputs": [],
613
- "source": [
614
- "# Export to TorchScript for mobile deployment\n",
615
- "model.eval()\n",
616
- "\n",
617
- "# TorchScript (for PyTorch Mobile / ExecuTorch)\n",
618
- "example_x = torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=device)\n",
619
- "example_t = torch.tensor([0.5], device=device)\n",
620
- "\n",
621
- "try:\n",
622
- " traced = torch.jit.trace(model, (example_x, example_t))\n",
623
- " ts_path = f'{OUTPUT_DIR}/liquidflow_mobile.pt'\n",
624
- " traced.save(ts_path)\n",
625
- " ts_size_mb = os.path.getsize(ts_path) / 1e6\n",
626
- " print(f'✅ TorchScript saved: {ts_path} ({ts_size_mb:.1f} MB)')\n",
627
- "except Exception as e:\n",
628
- " print(f'⚠️ TorchScript export failed: {e}')\n",
629
- "\n",
630
- "# ONNX\n",
631
- "try:\n",
632
- " onnx_path = f'{OUTPUT_DIR}/liquidflow.onnx'\n",
633
- " torch.onnx.export(\n",
634
- " model.cpu(), (example_x.cpu(), example_t.cpu()),\n",
635
- " onnx_path, opset_version=14,\n",
636
- " input_names=['image', 'timestep'],\n",
637
- " output_names=['velocity'],\n",
638
- " dynamic_axes={'image': {0: 'batch'}, 'timestep': {0: 'batch'}, 'velocity': {0: 'batch'}}\n",
639
- " )\n",
640
- " onnx_size_mb = os.path.getsize(onnx_path) / 1e6\n",
641
- " print(f'✅ ONNX saved: {onnx_path} ({onnx_size_mb:.1f} MB)')\n",
642
- " model.to(device)\n",
643
- "except Exception as e:\n",
644
- " print(f'⚠️ ONNX export failed: {e}')\n",
645
- " model.to(device)"
646
- ]
647
- },
648
- {
649
- "cell_type": "markdown",
650
- "metadata": {},
651
- "source": [
652
- "---\n",
653
- "## 8. 🔬 Architecture Deep Dive\n",
654
- "\n",
655
- "### How LiquidFlow works\n",
656
- "\n",
657
- "```\n",
658
- "Noise x₀ ~ N(0,I) ──→ LiquidFlow v_θ(xₜ, t) ──→ Image x₁\n",
659
- " │\n",
660
- " ┌──────┴──────┐\n",
661
- " │ Patchify │ (img → non-overlapping patches)\n",
662
- " │ + PosEmb │ (2D learnable positions)\n",
663
- " │ + DepthConv│ (local structure)\n",
664
- " └──────┬──────┘\n",
665
- " │\n",
666
- " ┌────────────┼────────────┐\n",
667
- " │ L × LiquidSSM Block │\n",
668
- " │ ┌──────────────────┐ │\n",
669
- " │ │ AdaLN (t-cond) │ │\n",
670
- " │ │ Zigzag Scan │ │ ← rotates scan pattern per layer\n",
671
- " │ │ SelectiveSSM │ │ ← Mamba-style, input-dependent\n",
672
- " │ │ + LiquidCfC │ │ ← CfC gating, bounded dynamics\n",
673
- " │ │ + FFN │ │\n",
674
- " │ │ + Skip Connect │ │ ← U-Net style long skips\n",
675
- " │ └──────────────────┘ │\n",
676
- " └────────────┼────────────┘\n",
677
- " │\n",
678
- " ┌──────┴──────┐\n",
679
- " │ DepthConv │\n",
680
- " │ Unpatchify │ (patches → img)\n",
681
- " └──────┬──────┘\n",
682
- " │\n",
683
- " velocity v_θ\n",
684
- "```\n",
685
- "\n",
686
- "### Key Innovations\n",
687
- "\n",
688
- "1. **Liquid CfC Cell**: Instead of solving the ODE `dx/dt = f(x,t)` numerically, we use the\n",
689
- " closed-form solution `x(t+Δt) = σ(-f_τ) ⊙ x(t) + (1 - σ(-f_τ)) ⊙ f_x`.\n",
690
- " The sigmoid gating **guarantees bounded dynamics** — no training explosion possible.\n",
691
- "\n",
692
- "2. **SSM + Liquid dual path**: The SSM branch captures long-range spatial dependencies\n",
693
- " via selective scanning; the Liquid branch adds continuous-time adaptive dynamics.\n",
694
- " A learnable mixing coefficient balances them.\n",
695
- "\n",
696
- "3. **Physics-informed loss**: Smoothness (Laplacian) and Total Variation regularizers\n",
697
- " act as soft PDE constraints on generated images, improving training stability\n",
698
- " and reducing artifacts without domain-specific physics knowledge.\n",
699
- "\n",
700
- "4. **Flow Matching = Liquid ODE**: Rectified flow trains `v_θ` to follow straight paths\n",
701
- " from noise to data. This is structurally identical to the LTC ODE, making Liquid\n",
702
- " networks a natural fit as the velocity field parameterization."
703
- ]
704
- },
705
- {
706
- "cell_type": "markdown",
707
- "metadata": {},
708
- "source": [
709
- "---\n",
710
- "## 9. 🧪 Recommended Experiments\n",
711
- "\n",
712
- "| Experiment | Dataset | Model | IMG_SIZE | Epochs | Notes |\n",
713
- "|------------|---------|-------|----------|--------|-------|\n",
714
- "| Quick sanity check | CIFAR-10 | tiny | 32 | 20 | ~5 min on T4 |\n",
715
- "| Baseline 128×128 | CIFAR-10 | tiny | 128 | 100 | ~2 hrs on T4 |\n",
716
- "| Quality 128×128 | Flowers-102 | small | 128 | 200 | ~4 hrs on T4 |\n",
717
- "| Faces 128×128 | CelebA | small | 128 | 50 | ~6 hrs on T4 |\n",
718
- "| High-res 512×512 | CelebA | 512 | 512 | 100 | needs ≥16GB |\n",
719
- "| Production | Your data | small | 128 | 300+ | best quality |\n",
720
- "\n",
721
- "### Tips for best results:\n",
722
- "- Start with `tiny` + low epochs to verify everything works\n",
723
- "- Use `small` for 128×128 production quality\n",
724
- "- Increase `SAMPLE_STEPS` to 100+ for final generation\n",
725
- "- `heun` sampler gives better quality at half the steps vs `euler`\n",
726
- "- Physics loss warmup is automatic — don't increase λ too much"
727
- ]
728
- }
729
- ],
730
- "metadata": {
731
- "accelerator": "GPU",
732
- "colab": {
733
- "gpuType": "T4",
734
- "provenance": []
735
- },
736
- "kernelspec": {
737
- "display_name": "Python 3",
738
- "name": "python3"
739
- },
740
- "language_info": {
741
- "name": "python",
742
- "version": "3.10.12"
743
- }
744
- },
745
- "nbformat": 4,
746
- "nbformat_minor": 4
747
- }
 
1
+ {"cells":[{"cell_type":"markdown","metadata":{},"source":["# 🌊 LiquidFlow — Liquid-SSM Flow Matching Image Generator\n","\n","A **novel architecture** combining:\n","- **Liquid Time-Constant Networks** (CfC closed-form) — adaptive ODE dynamics, bounded by construction\n","- **Selective State Space Models** (Mamba-style) — linear-time long-range context, parallelizable\n","- **Zigzag Scanning** — 2D spatial awareness for image patches\n","- **Physics-Informed Regularization** — smoothness + total variation constraints\n","- **Rectified Flow Matching** — ODE-based generation (no noise schedule tuning)\n","\n","### 💻 Hardware Requirements\n","| Config | GPU VRAM | Best For |\n","|--------|----------|----------|\n","| tiny-128 (bs=32) | ~4 GB | Colab free T4, Kaggle |\n","| small-128 (bs=16) | ~8 GB | Colab free T4, Kaggle |\n","| base-256 (bs=8) | ~12 GB | Colab Pro, Kaggle |\n","| 512 (bs=4) | ~14 GB | Colab Pro, A100 |"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 0. Setup & Install"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!nvidia-smi || echo 'No GPU — CPU only'\nimport torch\nprint(f'PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}')\nif torch.cuda.is_available():\n print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB')"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!pip install -q torch torchvision einops pillow matplotlib tqdm"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os\nif not os.path.exists('liquidflow'):\n !git clone https://huggingface.co/krystv/LiquidFlow liquidflow_repo\n !cp -r liquidflow_repo/liquidflow .\nelse:\n print('liquidflow/ already exists — updating...')\n !cd liquidflow_repo && git pull && cp -r liquidflow/* ../liquidflow/\n\nfrom liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512\nfrom liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\nfrom liquidflow.sampling import euler_sample, heun_sample, generate_grid, make_grid_image\nprint('✅ LiquidFlow imported successfully!')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 1. ⚙️ Configuration"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#@title 🎛️ Training Configuration { display-mode: \"form\" }\n\nDATASET = 'cifar10' #@param ['cifar10', 'flowers', 'celeba', 'folder', 'fashion_mnist', 'afhq', 'lsun_churches']\nCUSTOM_DATA_DIR = '/content/my_images' #@param {type:\"string\"}\nMODEL_SIZE = 'tiny' #@param ['tiny', 'small', 'base', '512']\nIMG_SIZE = 128 #@param [32, 64, 128, 256, 512] {type:\"integer\"}\nEPOCHS = 100 #@param {type:\"integer\"}\nBATCH_SIZE = 32 #@param [4, 8, 16, 32, 64, 128] {type:\"integer\"}\nLEARNING_RATE = 3e-4 #@param {type:\"number\"}\nGRAD_ACCUM = 1 #@param [1, 2, 4, 8] {type:\"integer\"}\nUSE_AMP = True #@param {type:\"boolean\"}\nLAMBDA_SMOOTH = 0.01 #@param {type:\"number\"}\nLAMBDA_TV = 0.001 #@param {type:\"number\"}\nSAMPLE_EVERY = 5 #@param {type:\"integer\"}\nSAMPLE_STEPS = 50 #@param [10, 25, 50, 100] {type:\"integer\"}\nLOG_EVERY = 50 #@param {type:\"integer\"}\nSAVE_EVERY = 10 #@param {type:\"integer\"}\nOUTPUT_DIR = './outputs'\nDATA_DIR = './data'\n\nimport torch\nif torch.cuda.is_available():\n vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9\n print(f'GPU VRAM: {vram_gb:.1f} GB')\n recommended = {(32,'tiny'):128,(64,'tiny'):64,(128,'tiny'):32,(32,'small'):64,(64,'small'):32,(128,'small'):16,(256,'base'):8,(512,'512'):4}\n key = (IMG_SIZE, MODEL_SIZE)\n if key in recommended and vram_gb < 16:\n rec_bs = recommended[key]\n if BATCH_SIZE > rec_bs:\n print(f'⚠️ Reducing batch size {BATCH_SIZE} → {rec_bs} for {vram_gb:.0f}GB VRAM')\n BATCH_SIZE = rec_bs\nelse:\n print('⚠️ No GPU detected'); USE_AMP = False\n\nprint(f'\\n📋 Config: {MODEL_SIZE}-{IMG_SIZE}, {DATASET}, bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={EPOCHS}')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 2. 📦 Load Dataset"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import torchvision\nimport torchvision.transforms as transforms\nfrom torch.utils.data import DataLoader, Dataset, ConcatDataset\nfrom pathlib import Path\nfrom PIL import Image\nimport os, matplotlib.pyplot as plt, numpy as np\n\ndef get_transform(img_size):\n return transforms.Compose([transforms.Resize(img_size+img_size//8), transforms.CenterCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5]*3,[0.5]*3)])\n\nclass ImageFolderFlat(Dataset):\n def __init__(self, root, transform):\n self.transform = transform\n self.files = []\n for ext in ['*.png','*.jpg','*.jpeg','*.webp','*.bmp']: self.files.extend(Path(root).rglob(ext))\n self.files = sorted(self.files); print(f'Found {len(self.files)} images in {root}')\n def __len__(self): return len(self.files)\n def __getitem__(self, idx): return self.transform(Image.open(self.files[idx]).convert('RGB'))\n\nclass GrayscaleToRGB:\n def __call__(self, x): return x.repeat(3,1,1) if x.shape[0]==1 else x\n\ntfm = get_transform(IMG_SIZE)\n\nif DATASET == 'cifar10':\n dataset = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=tfm)\nelif DATASET == 'flowers':\n dataset = ConcatDataset([torchvision.datasets.Flowers102(root=DATA_DIR, split=s, download=True, transform=tfm) for s in ['train','val','test']])\nelif DATASET == 'celeba':\n dataset = torchvision.datasets.CelebA(root=DATA_DIR, split='train', download=True, transform=tfm)\nelif DATASET == 'fashion_mnist':\n fm_tfm = transforms.Compose([transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize([0.5],[0.5]), GrayscaleToRGB()])\n dataset = torchvision.datasets.FashionMNIST(root=DATA_DIR, train=True, download=True, transform=fm_tfm)\nelif DATASET == 'folder':\n dataset = ImageFolderFlat(CUSTOM_DATA_DIR, tfm)\nelse:\n raise ValueError(f'Unknown dataset: {DATASET}')\n\nprint(f'✅ {DATASET}: {len(dataset)} images')\n\nfig, axes = plt.subplots(1, 8, figsize=(16, 2))\nfor i, ax in enumerate(axes):\n sample = dataset[i]; sample = sample[0] if isinstance(sample,(list,tuple)) else sample\n ax.imshow((sample*0.5+0.5).permute(1,2,0).clamp(0,1).numpy()); ax.axis('off')\nplt.suptitle(f'{DATASET} samples ({IMG_SIZE}×{IMG_SIZE})'); plt.tight_layout(); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 3. 🏗️ Build Model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import torch\nfrom liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = {'tiny':liquidflow_tiny,'small':liquidflow_small,'base':liquidflow_base,'512':liquidflow_512}[MODEL_SIZE](img_size=IMG_SIZE).to(device)\n\nnum_params = model.count_params()\nprint(f'🏗️ LiquidFlow-{MODEL_SIZE}: {num_params:,} ({num_params/1e6:.1f}M) params')\nprint(f' {IMG_SIZE}×{IMG_SIZE}, patch={model.patch_size}, patches={model.num_patches}, dim={model.d_model}, depth={model.depth}')\n\nwith torch.no_grad():\n v = model(torch.randn(1,3,IMG_SIZE,IMG_SIZE,device=device), torch.tensor([0.5],device=device))\n print(f' ✅ Forward pass OK')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 4. 🚀 Train"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import math, time, json\nimport torch.nn as nn\nfrom liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\nfrom liquidflow.sampling import euler_sample, make_grid_image\nimport matplotlib.pyplot as plt\n\nos.makedirs(f'{OUTPUT_DIR}/samples', exist_ok=True)\nos.makedirs(f'{OUTPUT_DIR}/checkpoints', exist_ok=True)\n\ndataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)\noptimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9,0.999), weight_decay=0.01)\n\ntotal_steps = EPOCHS * len(dataloader) // GRAD_ACCUM\nwarmup_steps = min(500, total_steps // 10)\ndef cosine_lr(step):\n if step < warmup_steps: return step / max(1, warmup_steps)\n p = (step - warmup_steps) / max(1, total_steps - warmup_steps)\n return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * p))\n\nscheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, cosine_lr)\ncriterion = PhysicsInformedFlowLoss(lambda_smooth=LAMBDA_SMOOTH, lambda_tv=LAMBDA_TV).to(device)\nema = EMAModel(model, decay=0.9999)\n\n# Use modern AMP API (no deprecation warnings)\namp_device = 'cuda' if device.type == 'cuda' else 'cpu'\nscaler = torch.amp.GradScaler(amp_device, enabled=USE_AMP)\n\nall_losses = []\nglobal_step = 0\nprint(f'🚀 Training {EPOCHS} epochs, {total_steps} steps')\nprint(f' Effective batch: {BATCH_SIZE} × {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}')\nprint(f' LR: {LEARNING_RATE} → warmup {warmup_steps} steps → cosine decay\\n')\n\nt_start = time.time()\nfor epoch in range(EPOCHS):\n model.train()\n epoch_loss = epoch_flow = 0.0; n_batches = 0\n\n for batch_idx, batch_data in enumerate(dataloader):\n x1 = (batch_data[0] if isinstance(batch_data,(list,tuple)) else batch_data).to(device)\n B = x1.shape[0]\n x0 = torch.randn_like(x1)\n t = torch.rand(B, device=device)\n x_t = t.view(B,1,1,1) * x1 + (1 - t.view(B,1,1,1)) * x0\n\n with torch.amp.autocast(amp_device, enabled=USE_AMP):\n v_pred = model(x_t, t)\n loss, ld = criterion(v_pred, x0, x1, t, step=global_step)\n loss = loss / GRAD_ACCUM\n\n scaler.scale(loss).backward()\n\n if (batch_idx + 1) % GRAD_ACCUM == 0:\n scaler.unscale_(optimizer)\n gn = nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n scaler.step(optimizer); scaler.update(); optimizer.zero_grad()\n scheduler.step(); ema.update(model); global_step += 1\n\n epoch_loss += ld['total'].item(); epoch_flow += ld['flow'].item(); n_batches += 1\n\n if global_step % LOG_EVERY == 0:\n avg = epoch_loss/n_batches; avg_f = epoch_flow/n_batches\n lr_now = scheduler.get_last_lr()[0]; it_s = global_step/(time.time()-t_start)\n all_losses.append({'step':global_step,'loss':avg,'flow':avg_f,'lr':lr_now,'epoch':epoch})\n print(f' E{epoch+1} step {global_step}/{total_steps} | loss={avg:.4f} flow={avg_f:.4f} lr={lr_now:.2e} gn={gn:.2f} [{it_s:.1f} it/s]')\n\n avg_epoch = epoch_loss / max(1, n_batches)\n print(f'\\n📊 Epoch {epoch+1}/{EPOCHS} — avg loss: {avg_epoch:.4f}\\n')\n\n if (epoch+1) % SAMPLE_EVERY == 0 or epoch == 0:\n model.eval(); ema.apply_shadow(model)\n with torch.no_grad():\n imgs = euler_sample(model, (min(16,BATCH_SIZE),3,IMG_SIZE,IMG_SIZE), num_steps=SAMPLE_STEPS, device=device)\n imgs = imgs.clamp(-1,1)*0.5+0.5\n grid = make_grid_image(imgs, nrow=4)\n grid.save(f'{OUTPUT_DIR}/samples/epoch_{epoch+1:04d}.png')\n fig, ax = plt.subplots(1,1,figsize=(8,8)); ax.imshow(grid)\n ax.set_title(f'Epoch {epoch+1} — {MODEL_SIZE}-{IMG_SIZE} on {DATASET}'); ax.axis('off'); plt.tight_layout(); plt.show()\n ema.restore(model); model.train()\n\n if (epoch+1) % SAVE_EVERY == 0:\n ckpt = {'model':model.state_dict(),'optimizer':optimizer.state_dict(),'scheduler':scheduler.state_dict(),'ema':ema.state_dict(),'epoch':epoch,'global_step':global_step}\n torch.save(ckpt, f'{OUTPUT_DIR}/checkpoints/epoch_{epoch+1:04d}.pt')\n torch.save(ckpt, f'{OUTPUT_DIR}/checkpoints/latest.pt')\n print(f'💾 Checkpoint saved: epoch {epoch+1}')\n\nema.apply_shadow(model)\ntorch.save({'model':model.state_dict(),'config':{'model_size':MODEL_SIZE,'img_size':IMG_SIZE,'dataset':DATASET,'num_params':num_params,'epochs':EPOCHS}}, f'{OUTPUT_DIR}/liquidflow_final.pt')\nema.restore(model)\nprint(f'\\n✅ Training complete! {(time.time()-t_start)/60:.1f} min total')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 5. 📈 Training Curves"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import matplotlib.pyplot as plt\nif all_losses:\n steps=[d['step'] for d in all_losses]; losses=[d['loss'] for d in all_losses]; flows=[d['flow'] for d in all_losses]; lrs=[d['lr'] for d in all_losses]\n fig,(ax1,ax2) = plt.subplots(1,2,figsize=(14,5))\n ax1.plot(steps,losses,label='Total',alpha=0.8); ax1.plot(steps,flows,label='Flow',alpha=0.8)\n ax1.set_xlabel('Step'); ax1.set_ylabel('Loss'); ax1.set_title('Training Loss'); ax1.legend(); ax1.grid(True,alpha=0.3)\n ax2.plot(steps,lrs,color='orange'); ax2.set_xlabel('Step'); ax2.set_ylabel('LR'); ax2.set_title('LR Schedule'); ax2.grid(True,alpha=0.3)\n plt.tight_layout(); plt.savefig(f'{OUTPUT_DIR}/training_curves.png',dpi=150); plt.show()\nelse: print('No training logs yet.')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 6. 🎨 Generate Images"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#@title 🎨 Generation Settings { display-mode: \"form\" }\nNUM_IMAGES = 16 #@param {type:\"integer\"}\nGEN_STEPS = 50 #@param [10, 25, 50, 100, 200] {type:\"integer\"}\nSAMPLER = 'euler' #@param ['euler', 'heun']\nSEED = 42 #@param {type:\"integer\"}\n\nimport torch\nfrom liquidflow.sampling import euler_sample, heun_sample, make_grid_image\nimport matplotlib.pyplot as plt\n\nckpt_path = f'{OUTPUT_DIR}/liquidflow_final.pt'\nif os.path.exists(ckpt_path):\n model.load_state_dict(torch.load(ckpt_path, map_location=device, weights_only=False)['model'])\n print(f'Loaded: {ckpt_path}')\nmodel.eval(); torch.manual_seed(SEED)\nshape = (NUM_IMAGES, 3, IMG_SIZE, IMG_SIZE)\nwith torch.no_grad():\n images = (euler_sample if SAMPLER=='euler' else heun_sample)(model, shape, num_steps=GEN_STEPS, device=device)\nimages = images.clamp(-1,1)*0.5+0.5\ngrid = make_grid_image(images, nrow=int(NUM_IMAGES**0.5))\ngrid.save(f'{OUTPUT_DIR}/generated_final.png')\nplt.figure(figsize=(10,10)); plt.imshow(grid)\nplt.title(f'LiquidFlow-{MODEL_SIZE} | {DATASET} {IMG_SIZE}×{IMG_SIZE} | {GEN_STEPS} steps ({SAMPLER})'); plt.axis('off'); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 7. 📱 Export for Mobile"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["model.eval()\nexample_x = torch.randn(1,3,IMG_SIZE,IMG_SIZE,device=device)\nexample_t = torch.tensor([0.5],device=device)\ntry:\n traced = torch.jit.trace(model, (example_x, example_t))\n ts_path = f'{OUTPUT_DIR}/liquidflow_mobile.pt'; traced.save(ts_path)\n print(f'✅ TorchScript: {ts_path} ({os.path.getsize(ts_path)/1e6:.1f} MB)')\nexcept Exception as e: print(f'⚠️ TorchScript failed: {e}')\ntry:\n onnx_path = f'{OUTPUT_DIR}/liquidflow.onnx'\n torch.onnx.export(model.cpu(), (example_x.cpu(), example_t.cpu()), onnx_path, opset_version=14,\n input_names=['image','timestep'], output_names=['velocity'],\n dynamic_axes={'image':{0:'batch'},'timestep':{0:'batch'},'velocity':{0:'batch'}})\n print(f'✅ ONNX: {onnx_path} ({os.path.getsize(onnx_path)/1e6:.1f} MB)'); model.to(device)\nexcept Exception as e: print(f'⚠️ ONNX failed: {e}'); model.to(device)"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 8. 🧪 Recommended Experiments\n","\n","| Goal | Dataset | Model | Size | Epochs | Time (T4) |\n","|------|---------|-------|------|--------|----------|\n","| Sanity check | CIFAR-10 | tiny | 32 | 20 | ~5 min |\n","| Baseline | CIFAR-10 | tiny | 128 | 100 | ~2 hrs |\n","| Quality | Flowers-102 | small | 128 | 200 | ~4 hrs |\n","| Faces | CelebA | small | 128 | 50 | ~6 hrs |\n","| High-res | CelebA | 512 | 512 | 100 | ~12 hrs |"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12"}},"nbformat":4,"nbformat_minor":4}