v0.3: Notebook — force re-clone, log step 1 immediately, show progress instantly
Browse files
LiquidFlow_Training.ipynb
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
{"cells":[{"cell_type":"markdown","metadata":{},"source":["# 🌊 LiquidFlow — Liquid-SSM Flow Matching Image Generator\n","\n","A **novel architecture** combining:\n","- **Liquid Time-Constant Networks** (CfC closed-form) — adaptive ODE dynamics, bounded by construction\n","- **Selective State Space Models** (Mamba-style) — linear-time long-range context, parallelizable\n","- **Zigzag Scanning** — 2D spatial awareness for image patches\n","- **Physics-Informed Regularization** — smoothness + total variation constraints\n","- **Rectified Flow Matching** — ODE-based generation (no noise schedule tuning)\n","\n","### 💻 Hardware Requirements\n","| Config | GPU VRAM | Best For |\n","|--------|----------|----------|\n","| tiny-128 (bs=32) | ~4 GB | Colab free T4, Kaggle |\n","| small-128 (bs=16) | ~8 GB | Colab free T4, Kaggle |\n","| base-256 (bs=8) | ~12 GB | Colab Pro, Kaggle |\n","| 512 (bs=4) | ~14 GB | Colab Pro, A100 |"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 0. Setup & Install"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!nvidia-smi || echo 'No GPU — CPU only'\nimport torch\nprint(f'PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}')\nif torch.cuda.is_available():\n print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB')"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!pip install -q torch torchvision einops pillow matplotlib tqdm"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os\nif not os.path.exists('liquidflow'):\n !git clone https://huggingface.co/krystv/LiquidFlow liquidflow_repo\n !cp -r liquidflow_repo/liquidflow .\nelse:\n print('liquidflow/ already exists — updating...')\n !cd liquidflow_repo && git pull && cp -r liquidflow/* ../liquidflow/\n\nfrom liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512\nfrom liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\nfrom liquidflow.sampling import euler_sample, heun_sample, generate_grid, make_grid_image\nprint('✅ LiquidFlow imported successfully!')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 1. ⚙️ Configuration"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#@title 🎛️ Training Configuration { display-mode: \"form\" }\n\nDATASET = 'cifar10' #@param ['cifar10', 'flowers', 'celeba', 'folder', 'fashion_mnist', 'afhq', 'lsun_churches']\nCUSTOM_DATA_DIR = '/content/my_images' #@param {type:\"string\"}\nMODEL_SIZE = 'tiny' #@param ['tiny', 'small', 'base', '512']\nIMG_SIZE = 128 #@param [32, 64, 128, 256, 512] {type:\"integer\"}\nEPOCHS = 100 #@param {type:\"integer\"}\nBATCH_SIZE = 32 #@param [4, 8, 16, 32, 64, 128] {type:\"integer\"}\nLEARNING_RATE = 3e-4 #@param {type:\"number\"}\nGRAD_ACCUM = 1 #@param [1, 2, 4, 8] {type:\"integer\"}\nUSE_AMP = True #@param {type:\"boolean\"}\nLAMBDA_SMOOTH = 0.01 #@param {type:\"number\"}\nLAMBDA_TV = 0.001 #@param {type:\"number\"}\nSAMPLE_EVERY = 5 #@param {type:\"integer\"}\nSAMPLE_STEPS = 50 #@param [10, 25, 50, 100] {type:\"integer\"}\nLOG_EVERY = 50 #@param {type:\"integer\"}\nSAVE_EVERY = 10 #@param {type:\"integer\"}\nOUTPUT_DIR = './outputs'\nDATA_DIR = './data'\n\nimport torch\nif torch.cuda.is_available():\n vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9\n print(f'GPU VRAM: {vram_gb:.1f} GB')\n recommended = {(32,'tiny'):128,(64,'tiny'):64,(128,'tiny'):32,(32,'small'):64,(64,'small'):32,(128,'small'):16,(256,'base'):8,(512,'512'):4}\n key = (IMG_SIZE, MODEL_SIZE)\n if key in recommended and vram_gb < 16:\n rec_bs = recommended[key]\n if BATCH_SIZE > rec_bs:\n print(f'⚠️ Reducing batch size {BATCH_SIZE} → {rec_bs} for {vram_gb:.0f}GB VRAM')\n BATCH_SIZE = rec_bs\nelse:\n print('⚠️ No GPU detected'); USE_AMP = False\n\nprint(f'\\n📋 Config: {MODEL_SIZE}-{IMG_SIZE}, {DATASET}, bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={EPOCHS}')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 2. 📦 Load Dataset"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import torchvision\nimport torchvision.transforms as transforms\nfrom torch.utils.data import DataLoader, Dataset, ConcatDataset\nfrom pathlib import Path\nfrom PIL import Image\nimport os, matplotlib.pyplot as plt, numpy as np\n\ndef get_transform(img_size):\n return transforms.Compose([transforms.Resize(img_size+img_size//8), transforms.CenterCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5]*3,[0.5]*3)])\n\nclass ImageFolderFlat(Dataset):\n def __init__(self, root, transform):\n self.transform = transform\n self.files = []\n for ext in ['*.png','*.jpg','*.jpeg','*.webp','*.bmp']: self.files.extend(Path(root).rglob(ext))\n self.files = sorted(self.files); print(f'Found {len(self.files)} images in {root}')\n def __len__(self): return len(self.files)\n def __getitem__(self, idx): return self.transform(Image.open(self.files[idx]).convert('RGB'))\n\nclass GrayscaleToRGB:\n def __call__(self, x): return x.repeat(3,1,1) if x.shape[0]==1 else x\n\ntfm = get_transform(IMG_SIZE)\n\nif DATASET == 'cifar10':\n dataset = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=tfm)\nelif DATASET == 'flowers':\n dataset = ConcatDataset([torchvision.datasets.Flowers102(root=DATA_DIR, split=s, download=True, transform=tfm) for s in ['train','val','test']])\nelif DATASET == 'celeba':\n dataset = torchvision.datasets.CelebA(root=DATA_DIR, split='train', download=True, transform=tfm)\nelif DATASET == 'fashion_mnist':\n fm_tfm = transforms.Compose([transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize([0.5],[0.5]), GrayscaleToRGB()])\n dataset = torchvision.datasets.FashionMNIST(root=DATA_DIR, train=True, download=True, transform=fm_tfm)\nelif DATASET == 'folder':\n dataset = ImageFolderFlat(CUSTOM_DATA_DIR, tfm)\nelse:\n raise ValueError(f'Unknown dataset: {DATASET}')\n\nprint(f'✅ {DATASET}: {len(dataset)} images')\n\nfig, axes = plt.subplots(1, 8, figsize=(16, 2))\nfor i, ax in enumerate(axes):\n sample = dataset[i]; sample = sample[0] if isinstance(sample,(list,tuple)) else sample\n ax.imshow((sample*0.5+0.5).permute(1,2,0).clamp(0,1).numpy()); ax.axis('off')\nplt.suptitle(f'{DATASET} samples ({IMG_SIZE}×{IMG_SIZE})'); plt.tight_layout(); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 3. 🏗️ Build Model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import torch\nfrom liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = {'tiny':liquidflow_tiny,'small':liquidflow_small,'base':liquidflow_base,'512':liquidflow_512}[MODEL_SIZE](img_size=IMG_SIZE).to(device)\n\nnum_params = model.count_params()\nprint(f'🏗️ LiquidFlow-{MODEL_SIZE}: {num_params:,} ({num_params/1e6:.1f}M) params')\nprint(f' {IMG_SIZE}×{IMG_SIZE}, patch={model.patch_size}, patches={model.num_patches}, dim={model.d_model}, depth={model.depth}')\n\nwith torch.no_grad():\n v = model(torch.randn(1,3,IMG_SIZE,IMG_SIZE,device=device), torch.tensor([0.5],device=device))\n print(f' ✅ Forward pass OK')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 4. 🚀 Train"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import math, time, json\nimport torch.nn as nn\nfrom liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\nfrom liquidflow.sampling import euler_sample, make_grid_image\nimport matplotlib.pyplot as plt\n\nos.makedirs(f'{OUTPUT_DIR}/samples', exist_ok=True)\nos.makedirs(f'{OUTPUT_DIR}/checkpoints', exist_ok=True)\n\ndataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)\noptimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9,0.999), weight_decay=0.01)\n\ntotal_steps = EPOCHS * len(dataloader) // GRAD_ACCUM\nwarmup_steps = min(500, total_steps // 10)\ndef cosine_lr(step):\n if step < warmup_steps: return step / max(1, warmup_steps)\n p = (step - warmup_steps) / max(1, total_steps - warmup_steps)\n return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * p))\n\nscheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, cosine_lr)\ncriterion = PhysicsInformedFlowLoss(lambda_smooth=LAMBDA_SMOOTH, lambda_tv=LAMBDA_TV).to(device)\nema = EMAModel(model, decay=0.9999)\n\n# Use modern AMP API (no deprecation warnings)\namp_device = 'cuda' if device.type == 'cuda' else 'cpu'\nscaler = torch.amp.GradScaler(amp_device, enabled=USE_AMP)\n\nall_losses = []\nglobal_step = 0\nprint(f'🚀 Training {EPOCHS} epochs, {total_steps} steps')\nprint(f' Effective batch: {BATCH_SIZE} × {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}')\nprint(f' LR: {LEARNING_RATE} → warmup {warmup_steps} steps → cosine decay\\n')\n\nt_start = time.time()\nfor epoch in range(EPOCHS):\n model.train()\n epoch_loss = epoch_flow = 0.0; n_batches = 0\n\n for batch_idx, batch_data in enumerate(dataloader):\n x1 = (batch_data[0] if isinstance(batch_data,(list,tuple)) else batch_data).to(device)\n B = x1.shape[0]\n x0 = torch.randn_like(x1)\n t = torch.rand(B, device=device)\n x_t = t.view(B,1,1,1) * x1 + (1 - t.view(B,1,1,1)) * x0\n\n with torch.amp.autocast(amp_device, enabled=USE_AMP):\n v_pred = model(x_t, t)\n loss, ld = criterion(v_pred, x0, x1, t, step=global_step)\n loss = loss / GRAD_ACCUM\n\n scaler.scale(loss).backward()\n\n if (batch_idx + 1) % GRAD_ACCUM == 0:\n scaler.unscale_(optimizer)\n gn = nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n scaler.step(optimizer); scaler.update(); optimizer.zero_grad()\n scheduler.step(); ema.update(model); global_step += 1\n\n epoch_loss += ld['total'].item(); epoch_flow += ld['flow'].item(); n_batches += 1\n\n if global_step % LOG_EVERY == 0:\n avg = epoch_loss/n_batches; avg_f = epoch_flow/n_batches\n lr_now = scheduler.get_last_lr()[0]; it_s = global_step/(time.time()-t_start)\n all_losses.append({'step':global_step,'loss':avg,'flow':avg_f,'lr':lr_now,'epoch':epoch})\n print(f' E{epoch+1} step {global_step}/{total_steps} | loss={avg:.4f} flow={avg_f:.4f} lr={lr_now:.2e} gn={gn:.2f} [{it_s:.1f} it/s]')\n\n avg_epoch = epoch_loss / max(1, n_batches)\n print(f'\\n📊 Epoch {epoch+1}/{EPOCHS} — avg loss: {avg_epoch:.4f}\\n')\n\n if (epoch+1) % SAMPLE_EVERY == 0 or epoch == 0:\n model.eval(); ema.apply_shadow(model)\n with torch.no_grad():\n imgs = euler_sample(model, (min(16,BATCH_SIZE),3,IMG_SIZE,IMG_SIZE), num_steps=SAMPLE_STEPS, device=device)\n imgs = imgs.clamp(-1,1)*0.5+0.5\n grid = make_grid_image(imgs, nrow=4)\n grid.save(f'{OUTPUT_DIR}/samples/epoch_{epoch+1:04d}.png')\n fig, ax = plt.subplots(1,1,figsize=(8,8)); ax.imshow(grid)\n ax.set_title(f'Epoch {epoch+1} — {MODEL_SIZE}-{IMG_SIZE} on {DATASET}'); ax.axis('off'); plt.tight_layout(); plt.show()\n ema.restore(model); model.train()\n\n if (epoch+1) % SAVE_EVERY == 0:\n ckpt = {'model':model.state_dict(),'optimizer':optimizer.state_dict(),'scheduler':scheduler.state_dict(),'ema':ema.state_dict(),'epoch':epoch,'global_step':global_step}\n torch.save(ckpt, f'{OUTPUT_DIR}/checkpoints/epoch_{epoch+1:04d}.pt')\n torch.save(ckpt, f'{OUTPUT_DIR}/checkpoints/latest.pt')\n print(f'💾 Checkpoint saved: epoch {epoch+1}')\n\nema.apply_shadow(model)\ntorch.save({'model':model.state_dict(),'config':{'model_size':MODEL_SIZE,'img_size':IMG_SIZE,'dataset':DATASET,'num_params':num_params,'epochs':EPOCHS}}, f'{OUTPUT_DIR}/liquidflow_final.pt')\nema.restore(model)\nprint(f'\\n✅ Training complete! {(time.time()-t_start)/60:.1f} min total')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 5. 📈 Training Curves"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import matplotlib.pyplot as plt\nif all_losses:\n steps=[d['step'] for d in all_losses]; losses=[d['loss'] for d in all_losses]; flows=[d['flow'] for d in all_losses]; lrs=[d['lr'] for d in all_losses]\n fig,(ax1,ax2) = plt.subplots(1,2,figsize=(14,5))\n ax1.plot(steps,losses,label='Total',alpha=0.8); ax1.plot(steps,flows,label='Flow',alpha=0.8)\n ax1.set_xlabel('Step'); ax1.set_ylabel('Loss'); ax1.set_title('Training Loss'); ax1.legend(); ax1.grid(True,alpha=0.3)\n ax2.plot(steps,lrs,color='orange'); ax2.set_xlabel('Step'); ax2.set_ylabel('LR'); ax2.set_title('LR Schedule'); ax2.grid(True,alpha=0.3)\n plt.tight_layout(); plt.savefig(f'{OUTPUT_DIR}/training_curves.png',dpi=150); plt.show()\nelse: print('No training logs yet.')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 6. 🎨 Generate Images"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#@title 🎨 Generation Settings { display-mode: \"form\" }\nNUM_IMAGES = 16 #@param {type:\"integer\"}\nGEN_STEPS = 50 #@param [10, 25, 50, 100, 200] {type:\"integer\"}\nSAMPLER = 'euler' #@param ['euler', 'heun']\nSEED = 42 #@param {type:\"integer\"}\n\nimport torch\nfrom liquidflow.sampling import euler_sample, heun_sample, make_grid_image\nimport matplotlib.pyplot as plt\n\nckpt_path = f'{OUTPUT_DIR}/liquidflow_final.pt'\nif os.path.exists(ckpt_path):\n model.load_state_dict(torch.load(ckpt_path, map_location=device, weights_only=False)['model'])\n print(f'Loaded: {ckpt_path}')\nmodel.eval(); torch.manual_seed(SEED)\nshape = (NUM_IMAGES, 3, IMG_SIZE, IMG_SIZE)\nwith torch.no_grad():\n images = (euler_sample if SAMPLER=='euler' else heun_sample)(model, shape, num_steps=GEN_STEPS, device=device)\nimages = images.clamp(-1,1)*0.5+0.5\ngrid = make_grid_image(images, nrow=int(NUM_IMAGES**0.5))\ngrid.save(f'{OUTPUT_DIR}/generated_final.png')\nplt.figure(figsize=(10,10)); plt.imshow(grid)\nplt.title(f'LiquidFlow-{MODEL_SIZE} | {DATASET} {IMG_SIZE}×{IMG_SIZE} | {GEN_STEPS} steps ({SAMPLER})'); plt.axis('off'); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 7. 📱 Export for Mobile"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["model.eval()\nexample_x = torch.randn(1,3,IMG_SIZE,IMG_SIZE,device=device)\nexample_t = torch.tensor([0.5],device=device)\ntry:\n traced = torch.jit.trace(model, (example_x, example_t))\n ts_path = f'{OUTPUT_DIR}/liquidflow_mobile.pt'; traced.save(ts_path)\n print(f'✅ TorchScript: {ts_path} ({os.path.getsize(ts_path)/1e6:.1f} MB)')\nexcept Exception as e: print(f'⚠️ TorchScript failed: {e}')\ntry:\n onnx_path = f'{OUTPUT_DIR}/liquidflow.onnx'\n torch.onnx.export(model.cpu(), (example_x.cpu(), example_t.cpu()), onnx_path, opset_version=14,\n input_names=['image','timestep'], output_names=['velocity'],\n dynamic_axes={'image':{0:'batch'},'timestep':{0:'batch'},'velocity':{0:'batch'}})\n print(f'✅ ONNX: {onnx_path} ({os.path.getsize(onnx_path)/1e6:.1f} MB)'); model.to(device)\nexcept Exception as e: print(f'⚠️ ONNX failed: {e}'); model.to(device)"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 8. 🧪 Recommended Experiments\n","\n","| Goal | Dataset | Model | Size | Epochs | Time (T4) |\n","|------|---------|-------|------|--------|----------|\n","| Sanity check | CIFAR-10 | tiny | 32 | 20 | ~5 min |\n","| Baseline | CIFAR-10 | tiny | 128 | 100 | ~2 hrs |\n","| Quality | Flowers-102 | small | 128 | 200 | ~4 hrs |\n","| Faces | CelebA | small | 128 | 50 | ~6 hrs |\n","| High-res | CelebA | 512 | 512 | 100 | ~12 hrs |"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12"}},"nbformat":4,"nbformat_minor":4}
|
|
|
|
| 1 |
+
{"cells":[{"cell_type":"markdown","metadata":{},"source":["# 🌊 LiquidFlow v0.3 — Liquid-SSM Flow Matching Image Generator\n","\n","**v0.3**: Parallel SSM scan via `torch.associative_scan` (O(log L) not O(L)) — **no sequential Python loops**\n","\n","| Config | GPU VRAM | Tokens (L) | Best For |\n","|--------|----------|-----------|----------|\n","| tiny-128 (bs=32) | ~4 GB | 256 | Colab free T4 |\n","| small-128 (bs=16) | ~6 GB | 256 | Colab free T4 |\n","| base-256 (bs=8) | ~12 GB | 1024 | Colab Pro |\n","| 512 (bs=4) | ~14 GB | 1024 | A100 |"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 0. Setup"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!nvidia-smi || echo 'No GPU'\nimport torch\nprint(f'PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}')\nif torch.cuda.is_available(): print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB')"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!pip install -q torch torchvision einops pillow matplotlib tqdm"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os\n# ALWAYS re-clone to get latest version\n!rm -rf liquidflow liquidflow_repo\n!git clone https://huggingface.co/krystv/LiquidFlow liquidflow_repo\n!cp -r liquidflow_repo/liquidflow .\n\nfrom liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512, HAS_NATIVE_SCAN\nfrom liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\nfrom liquidflow.sampling import euler_sample, heun_sample, make_grid_image\nprint(f'✅ LiquidFlow v0.3 imported! Parallel scan (native): {HAS_NATIVE_SCAN}')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 1. ⚙️ Configuration"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#@title 🎛️ Training Configuration { display-mode: \"form\" }\nDATASET = 'cifar10' #@param ['cifar10', 'flowers', 'celeba', 'folder', 'fashion_mnist']\nCUSTOM_DATA_DIR = '/content/my_images' #@param {type:\"string\"}\nMODEL_SIZE = 'tiny' #@param ['tiny', 'small', 'base', '512']\nIMG_SIZE = 128 #@param [32, 64, 128, 256, 512] {type:\"integer\"}\nEPOCHS = 100 #@param {type:\"integer\"}\nBATCH_SIZE = 32 #@param [4, 8, 16, 32, 64, 128] {type:\"integer\"}\nLEARNING_RATE = 3e-4 #@param {type:\"number\"}\nGRAD_ACCUM = 1 #@param [1, 2, 4, 8] {type:\"integer\"}\nUSE_AMP = True #@param {type:\"boolean\"}\nLAMBDA_SMOOTH = 0.01 #@param {type:\"number\"}\nLAMBDA_TV = 0.001 #@param {type:\"number\"}\nSAMPLE_EVERY = 5 #@param {type:\"integer\"}\nSAMPLE_STEPS = 50 #@param [10, 25, 50, 100] {type:\"integer\"}\nLOG_EVERY = 25 #@param {type:\"integer\"}\nSAVE_EVERY = 10 #@param {type:\"integer\"}\nOUTPUT_DIR = './outputs'; DATA_DIR = './data'\n\nimport torch\nif torch.cuda.is_available():\n vram = torch.cuda.get_device_properties(0).total_mem/1e9\n rec = {(32,'tiny'):128,(64,'tiny'):64,(128,'tiny'):32,(32,'small'):64,(64,'small'):32,(128,'small'):16,(256,'base'):8,(512,'512'):4}\n if (IMG_SIZE,MODEL_SIZE) in rec and vram<16 and BATCH_SIZE>rec[(IMG_SIZE,MODEL_SIZE)]:\n BATCH_SIZE=rec[(IMG_SIZE,MODEL_SIZE)]; print(f'⚠️ Auto batch size: {BATCH_SIZE}')\nelse: USE_AMP=False\nprint(f'📋 {MODEL_SIZE}-{IMG_SIZE}, {DATASET}, bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={EPOCHS}')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 2. 📦 Dataset"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import torchvision, torchvision.transforms as T\nfrom torch.utils.data import DataLoader, Dataset, ConcatDataset\nfrom pathlib import Path; from PIL import Image\nimport os, matplotlib.pyplot as plt, numpy as np\n\ndef tfm(s): return T.Compose([T.Resize(s+s//8),T.CenterCrop(s),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize([.5]*3,[.5]*3)])\nclass FolderDS(Dataset):\n def __init__(s,r,t): s.t=t; s.f=sorted(sum([list(Path(r).rglob(e)) for e in['*.png','*.jpg','*.jpeg','*.webp']],[]))\n def __len__(s): return len(s.f)\n def __getitem__(s,i): return s.t(Image.open(s.f[i]).convert('RGB'))\nclass G2R:\n def __call__(s,x): return x.repeat(3,1,1) if x.shape[0]==1 else x\n\ntf = tfm(IMG_SIZE)\nif DATASET=='cifar10': dataset=torchvision.datasets.CIFAR10(DATA_DIR,True,download=True,transform=tf)\nelif DATASET=='flowers': dataset=ConcatDataset([torchvision.datasets.Flowers102(DATA_DIR,s,download=True,transform=tf) for s in['train','val','test']])\nelif DATASET=='celeba': dataset=torchvision.datasets.CelebA(DATA_DIR,'train',download=True,transform=tf)\nelif DATASET=='fashion_mnist': dataset=torchvision.datasets.FashionMNIST(DATA_DIR,True,download=True,transform=T.Compose([T.Resize(IMG_SIZE),T.ToTensor(),T.Normalize([.5],[.5]),G2R()]))\nelif DATASET=='folder': dataset=FolderDS(CUSTOM_DATA_DIR,tf)\nelse: raise ValueError(DATASET)\nprint(f'✅ {DATASET}: {len(dataset)} images')\nfig,ax=plt.subplots(1,8,figsize=(16,2))\nfor i,a in enumerate(ax):\n s=dataset[i]; s=s[0] if isinstance(s,(list,tuple)) else s\n a.imshow((s*.5+.5).permute(1,2,0).clamp(0,1).numpy()); a.axis('off')\nplt.suptitle(f'{DATASET} ({IMG_SIZE}×{IMG_SIZE})'); plt.tight_layout(); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 3. 🏗️ Model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = {'tiny':liquidflow_tiny,'small':liquidflow_small,'base':liquidflow_base,'512':liquidflow_512}[MODEL_SIZE](img_size=IMG_SIZE).to(device)\nnum_params = model.count_params()\nprint(f'🏗️ LiquidFlow-{MODEL_SIZE}: {num_params:,} ({num_params/1e6:.1f}M)')\nprint(f' {IMG_SIZE}×{IMG_SIZE}, patch={model.patch_size}, L={model.num_patches} tokens')\nwith torch.no_grad(): model(torch.randn(1,3,IMG_SIZE,IMG_SIZE,device=device), torch.tensor([.5],device=device))\nprint(' ✅ Forward pass OK')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 4. 🚀 Train"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import math, time, torch.nn as nn\nfrom liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\nfrom liquidflow.sampling import euler_sample, make_grid_image\nimport matplotlib.pyplot as plt\n\nos.makedirs(f'{OUTPUT_DIR}/samples',exist_ok=True); os.makedirs(f'{OUTPUT_DIR}/checkpoints',exist_ok=True)\ndl = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=2,pin_memory=True,drop_last=True)\nopt = torch.optim.AdamW(model.parameters(),lr=LEARNING_RATE,betas=(.9,.999),weight_decay=.01)\ntot = EPOCHS*len(dl)//GRAD_ACCUM; wu = min(500,tot//10)\ndef clr(s):\n if s<wu: return s/max(1,wu)\n return .1+.9*.5*(1+math.cos(math.pi*(s-wu)/max(1,tot-wu)))\nsch = torch.optim.lr_scheduler.LambdaLR(opt,clr)\ncrit = PhysicsInformedFlowLoss(lambda_smooth=LAMBDA_SMOOTH,lambda_tv=LAMBDA_TV).to(device)\nema = EMAModel(model,decay=.9999)\namp_dev = 'cuda' if device.type=='cuda' else 'cpu'\nscaler = torch.amp.GradScaler(amp_dev,enabled=USE_AMP)\nall_losses=[]; gs=0\nprint(f'🚀 {EPOCHS} epochs, {tot} steps, batch {BATCH_SIZE}×{GRAD_ACCUM}={BATCH_SIZE*GRAD_ACCUM}')\nprint(f' LR {LEARNING_RATE}, warmup {wu} steps\\n')\nt0=time.time()\nfor ep in range(EPOCHS):\n model.train(); el=ef=0.; nb=0\n for bi,bd in enumerate(dl):\n x1=(bd[0] if isinstance(bd,(list,tuple)) else bd).to(device)\n B=x1.shape[0]; x0=torch.randn_like(x1); t=torch.rand(B,device=device)\n xt=t.view(B,1,1,1)*x1+(1-t.view(B,1,1,1))*x0\n with torch.amp.autocast(amp_dev,enabled=USE_AMP):\n vp=model(xt,t); loss,ld=crit(vp,x0,x1,t,step=gs); loss=loss/GRAD_ACCUM\n scaler.scale(loss).backward()\n if (bi+1)%GRAD_ACCUM==0:\n scaler.unscale_(opt); gn=nn.utils.clip_grad_norm_(model.parameters(),1.0)\n scaler.step(opt); scaler.update(); opt.zero_grad(); sch.step(); ema.update(model); gs+=1\n el+=ld['total'].item(); ef+=ld['flow'].item(); nb+=1\n # LOG IMMEDIATELY on step 1, then every LOG_EVERY\n if gs==1 or gs%LOG_EVERY==0:\n a=el/nb; af=ef/nb; lr=sch.get_last_lr()[0]; its=gs/(time.time()-t0)\n all_losses.append({'step':gs,'loss':a,'flow':af,'lr':lr,'epoch':ep})\n print(f' E{ep+1} step {gs}/{tot} | loss={a:.4f} flow={af:.4f} lr={lr:.2e} gn={gn:.2f} [{its:.1f} it/s]')\n print(f'\\n📊 Epoch {ep+1}/{EPOCHS} — loss: {el/max(1,nb):.4f}\\n')\n if (ep+1)%SAMPLE_EVERY==0 or ep==0:\n model.eval(); ema.apply_shadow(model)\n with torch.no_grad():\n im=euler_sample(model,(min(16,BATCH_SIZE),3,IMG_SIZE,IMG_SIZE),num_steps=SAMPLE_STEPS,device=device)\n im=im.clamp(-1,1)*.5+.5; g=make_grid_image(im,nrow=4)\n g.save(f'{OUTPUT_DIR}/samples/ep{ep+1:04d}.png')\n fig,ax=plt.subplots(1,1,figsize=(8,8)); ax.imshow(g)\n ax.set_title(f'Epoch {ep+1}'); ax.axis('off'); plt.show()\n ema.restore(model); model.train()\n if (ep+1)%SAVE_EVERY==0:\n torch.save({'model':model.state_dict(),'opt':opt.state_dict(),'sch':sch.state_dict(),'ema':ema.state_dict(),'ep':ep,'gs':gs},f'{OUTPUT_DIR}/checkpoints/latest.pt')\n print(f'💾 Saved epoch {ep+1}')\nema.apply_shadow(model)\ntorch.save({'model':model.state_dict(),'config':{'model':MODEL_SIZE,'img':IMG_SIZE,'data':DATASET,'params':num_params}},f'{OUTPUT_DIR}/liquidflow_final.pt')\nema.restore(model); print(f'\\n✅ Done! {(time.time()-t0)/60:.1f} min')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 5. 📈 Curves"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["if all_losses:\n s=[d['step'] for d in all_losses]; l=[d['loss'] for d in all_losses]; f=[d['flow'] for d in all_losses]\n fig,(a1,a2)=plt.subplots(1,2,figsize=(14,5))\n a1.plot(s,l,label='Total'); a1.plot(s,f,label='Flow'); a1.legend(); a1.grid(True,alpha=.3); a1.set_title('Loss')\n a2.plot(s,[d['lr'] for d in all_losses],color='orange'); a2.grid(True,alpha=.3); a2.set_title('LR')\n plt.tight_layout(); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 6. 🎨 Generate"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["model.eval(); torch.manual_seed(42)\nwith torch.no_grad():\n im=euler_sample(model,(16,3,IMG_SIZE,IMG_SIZE),num_steps=50,device=device)\nim=im.clamp(-1,1)*.5+.5; g=make_grid_image(im,nrow=4); g.save(f'{OUTPUT_DIR}/final.png')\nplt.figure(figsize=(10,10)); plt.imshow(g); plt.axis('off'); plt.title(f'LiquidFlow-{MODEL_SIZE} {IMG_SIZE}px'); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 7. 📱 Export"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["model.eval()\ntry:\n tr=torch.jit.trace(model,(torch.randn(1,3,IMG_SIZE,IMG_SIZE,device=device),torch.tensor([.5],device=device)))\n tr.save(f'{OUTPUT_DIR}/mobile.pt'); print(f'✅ TorchScript: {os.path.getsize(f\"{OUTPUT_DIR}/mobile.pt\")/1e6:.1f}MB')\nexcept Exception as e: print(f'⚠️ {e}')"]},{"cell_type":"markdown","metadata":{},"source":["---\n","## 8. 🧪 Experiments\n","| Goal | Dataset | Model | Size | Epochs | Time (T4) |\n","|------|---------|-------|------|--------|----------|\n","| Quick test | CIFAR-10 | tiny | 32 | 20 | ~3 min |\n","| Baseline | CIFAR-10 | tiny | 128 | 100 | ~1.5 hrs |\n","| Quality | Flowers | small | 128 | 200 | ~3 hrs |\n","| Faces | CelebA | small | 128 | 50 | ~4 hrs |\n","| Hi-res | CelebA | 512 | 512 | 100 | ~10 hrs |"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4"},"kernelspec":{"display_name":"Python 3","name":"python3"}},"nbformat":4,"nbformat_minor":4}
|