krystv commited on
Commit
c0ad812
·
verified ·
1 Parent(s): 1c47b5f

v0.5: notebook — document auto patch_size, fix total_memory, all configs L≤256"

Browse files
Files changed (1) hide show
  1. LiquidFlow_Training.ipynb +1 -1
LiquidFlow_Training.ipynb CHANGED
@@ -1 +1 @@
1
- {"cells":[{"cell_type":"markdown","metadata":{},"source":["# 🌊 LiquidFlow v0.4 — Liquid-SSM Flow Matching Image Generator\n","\n","**Parallel SSM scan** via `mambapy.pscan` — O(log L), full gradient support.\n","\n","| Config | VRAM | Tokens | Recommended Epochs |\n","|--------|------|--------|-------------------|\n","| tiny-128 CIFAR-10 | ~4GB | 256 | 15-20 (~3h T4) |\n","| small-128 Flowers | ~6GB | 256 | 50-100 (~6h T4) |\n","| tiny-128 CelebA | ~4GB | 256 | 10-15 (~4h T4) |"]},{"cell_type":"markdown","metadata":{},"source":["## 0. Setup"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!nvidia-smi || echo 'No GPU'\nimport torch; print(f'PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}')\nif torch.cuda.is_available(): print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB')"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!pip install -q torch torchvision einops pillow matplotlib tqdm mambapy"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os\n!rm -rf liquidflow liquidflow_repo\n!git clone https://huggingface.co/krystv/LiquidFlow liquidflow_repo\n!cp -r liquidflow_repo/liquidflow .\nfrom liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512, HAS_PSCAN\nfrom liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\nfrom liquidflow.sampling import euler_sample, heun_sample, make_grid_image\nprint(f'✅ LiquidFlow v0.4 | Parallel scan: {HAS_PSCAN}')"]},{"cell_type":"markdown","metadata":{},"source":["## 1. Config"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#@title 🎛️ Training Configuration { display-mode: \"form\" }\nDATASET='cifar10' #@param ['cifar10','flowers','celeba','folder','fashion_mnist']\nCUSTOM_DATA_DIR='/content/my_images' #@param {type:\"string\"}\nMODEL_SIZE='tiny' #@param ['tiny','small','base','512']\nIMG_SIZE=128 #@param [32,64,128,256,512] {type:\"integer\"}\nEPOCHS=20 #@param {type:\"integer\"}\nBATCH_SIZE=32 #@param [4,8,16,32,64,128] {type:\"integer\"}\nLEARNING_RATE=3e-4 #@param {type:\"number\"}\nGRAD_ACCUM=1 #@param [1,2,4,8] {type:\"integer\"}\nUSE_AMP=True #@param {type:\"boolean\"}\nLAMBDA_SMOOTH=0.01 #@param {type:\"number\"}\nLAMBDA_TV=0.001 #@param {type:\"number\"}\nSAMPLE_EVERY=5 #@param {type:\"integer\"}\nSAMPLE_STEPS=50 #@param [10,25,50,100] {type:\"integer\"}\nLOG_EVERY=100 #@param {type:\"integer\"}\nSAVE_EVERY=5 #@param {type:\"integer\"}\nOUTPUT_DIR='./outputs'; DATA_DIR='./data'\nimport torch\nif torch.cuda.is_available():\n vram=torch.cuda.get_device_properties(0).total_memory/1e9\n print(f'GPU VRAM: {vram:.1f} GB')\n rec={(32,'tiny'):128,(64,'tiny'):64,(128,'tiny'):32,(32,'small'):64,(64,'small'):32,(128,'small'):16,(256,'base'):8,(512,'512'):4}\n if (IMG_SIZE,MODEL_SIZE) in rec and vram<16 and BATCH_SIZE>rec[(IMG_SIZE,MODEL_SIZE)]: BATCH_SIZE=rec[(IMG_SIZE,MODEL_SIZE)]; print(f'⚠️ bs→{BATCH_SIZE}')\nelse: USE_AMP=False\nprint(f'📋 {MODEL_SIZE}-{IMG_SIZE}, {DATASET}, bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={EPOCHS}')"]},{"cell_type":"markdown","metadata":{},"source":["## 2. Dataset"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import torchvision, torchvision.transforms as T\nfrom torch.utils.data import DataLoader, Dataset, ConcatDataset\nfrom pathlib import Path; from PIL import Image\nimport os, matplotlib.pyplot as plt, numpy as np\ndef tfm(s): return T.Compose([T.Resize(s+s//8),T.CenterCrop(s),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize([.5]*3,[.5]*3)])\nclass FolderDS(Dataset):\n def __init__(s,r,t): s.t=t; s.f=sorted(sum([list(Path(r).rglob(e)) for e in['*.png','*.jpg','*.jpeg','*.webp']],[]))\n def __len__(s): return len(s.f)\n def __getitem__(s,i): return s.t(Image.open(s.f[i]).convert('RGB'))\nclass G2R:\n def __call__(s,x): return x.repeat(3,1,1) if x.shape[0]==1 else x\ntf=tfm(IMG_SIZE)\nif DATASET=='cifar10': dataset=torchvision.datasets.CIFAR10(DATA_DIR,True,download=True,transform=tf)\nelif DATASET=='flowers': dataset=ConcatDataset([torchvision.datasets.Flowers102(DATA_DIR,s,download=True,transform=tf) for s in['train','val','test']])\nelif DATASET=='celeba': dataset=torchvision.datasets.CelebA(DATA_DIR,'train',download=True,transform=tf)\nelif DATASET=='fashion_mnist': dataset=torchvision.datasets.FashionMNIST(DATA_DIR,True,download=True,transform=T.Compose([T.Resize(IMG_SIZE),T.ToTensor(),T.Normalize([.5],[.5]),G2R()]))\nelif DATASET=='folder': dataset=FolderDS(CUSTOM_DATA_DIR,tf)\nelse: raise ValueError(DATASET)\nprint(f'✅ {DATASET}: {len(dataset)} images')\nfig,ax=plt.subplots(1,8,figsize=(16,2))\nfor i,a in enumerate(ax):\n s=dataset[i]; s=s[0] if isinstance(s,(list,tuple)) else s\n a.imshow((s*.5+.5).permute(1,2,0).clamp(0,1).numpy()); a.axis('off')\nplt.suptitle(f'{DATASET} ({IMG_SIZE}×{IMG_SIZE})'); plt.tight_layout(); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## 3. Model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel={'tiny':liquidflow_tiny,'small':liquidflow_small,'base':liquidflow_base,'512':liquidflow_512}[MODEL_SIZE](img_size=IMG_SIZE).to(device)\nnum_params=model.count_params()\nprint(f'🏗️ LiquidFlow-{MODEL_SIZE}: {num_params:,} ({num_params/1e6:.1f}M), L={model.num_patches} tokens, patch={model.patch_size}')\nwith torch.no_grad(): model(torch.randn(1,3,IMG_SIZE,IMG_SIZE,device=device),torch.tensor([.5],device=device))\nprint('✅ Forward OK')"]},{"cell_type":"markdown","metadata":{},"source":["## 4. Train"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import math,time,torch.nn as nn\nfrom liquidflow.losses import PhysicsInformedFlowLoss,EMAModel\nfrom liquidflow.sampling import euler_sample,make_grid_image\nimport matplotlib.pyplot as plt\nos.makedirs(f'{OUTPUT_DIR}/samples',exist_ok=True); os.makedirs(f'{OUTPUT_DIR}/checkpoints',exist_ok=True)\ndl=DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=2,pin_memory=True,drop_last=True)\nopt=torch.optim.AdamW(model.parameters(),lr=LEARNING_RATE,betas=(.9,.999),weight_decay=.01)\ntot=EPOCHS*len(dl)//GRAD_ACCUM; wu=min(500,tot//10)\ndef clr(s):\n if s<wu: return s/max(1,wu)\n return .1+.9*.5*(1+math.cos(math.pi*(s-wu)/max(1,tot-wu)))\nsch=torch.optim.lr_scheduler.LambdaLR(opt,clr)\ncrit=PhysicsInformedFlowLoss(lambda_smooth=LAMBDA_SMOOTH,lambda_tv=LAMBDA_TV).to(device)\nema=EMAModel(model,decay=.9999)\namp_dev='cuda' if device.type=='cuda' else 'cpu'\nscaler=torch.amp.GradScaler(amp_dev,enabled=USE_AMP)\nall_losses=[]; gs=0\nsteps_per_epoch=len(dl)//GRAD_ACCUM\nprint(f'🚀 {EPOCHS} epochs × {steps_per_epoch} steps/epoch = {tot} total steps')\nprint(f' batch {BATCH_SIZE}×{GRAD_ACCUM}={BATCH_SIZE*GRAD_ACCUM}, LR {LEARNING_RATE}, warmup {wu}\\n')\nt0=time.time()\nfor ep in range(EPOCHS):\n model.train(); el=ef=0.; nb=0\n for bi,bd in enumerate(dl):\n x1=(bd[0] if isinstance(bd,(list,tuple)) else bd).to(device)\n B=x1.shape[0]; x0=torch.randn_like(x1); t=torch.rand(B,device=device)\n xt=t.view(B,1,1,1)*x1+(1-t.view(B,1,1,1))*x0\n with torch.amp.autocast(amp_dev,enabled=USE_AMP):\n vp=model(xt,t); loss,ld=crit(vp,x0,x1,t,step=gs); loss=loss/GRAD_ACCUM\n scaler.scale(loss).backward()\n if (bi+1)%GRAD_ACCUM==0:\n scaler.unscale_(opt); gn=nn.utils.clip_grad_norm_(model.parameters(),1.0)\n scaler.step(opt); scaler.update(); opt.zero_grad(); sch.step(); ema.update(model); gs+=1\n el+=ld['total'].item(); ef+=ld['flow'].item(); nb+=1\n if gs==1 or gs%LOG_EVERY==0:\n a=el/nb; af=ef/nb; lr=sch.get_last_lr()[0]\n elapsed=time.time()-t0; its=gs/elapsed\n eta_s=(tot-gs)/max(its,0.01); eta_m=eta_s/60; eta_h=eta_m/60\n if eta_h>=1: eta_str=f'{eta_h:.1f}h'\n else: eta_str=f'{eta_m:.0f}min'\n all_losses.append({'step':gs,'loss':a,'flow':af,'lr':lr,'epoch':ep})\n print(f' E{ep+1}/{EPOCHS} step {gs}/{tot} | loss={a:.4f} flow={af:.4f} | lr={lr:.2e} gn={gn:.2f} | {its:.1f} it/s | ETA {eta_str}')\n avg=el/max(1,nb)\n elapsed_m=(time.time()-t0)/60\n print(f'\\n📊 Epoch {ep+1}/{EPOCHS} — loss: {avg:.4f} — elapsed: {elapsed_m:.1f}min\\n')\n if (ep+1)%SAMPLE_EVERY==0 or ep==0:\n model.eval(); ema.apply_shadow(model)\n with torch.no_grad():\n im=euler_sample(model,(min(16,BATCH_SIZE),3,IMG_SIZE,IMG_SIZE),num_steps=SAMPLE_STEPS,device=device)\n im=im.clamp(-1,1)*.5+.5; g=make_grid_image(im,nrow=4); g.save(f'{OUTPUT_DIR}/samples/ep{ep+1:04d}.png')\n fig,ax=plt.subplots(1,1,figsize=(8,8)); ax.imshow(g); ax.set_title(f'Epoch {ep+1} | loss={avg:.4f}'); ax.axis('off'); plt.show()\n ema.restore(model); model.train()\n if (ep+1)%SAVE_EVERY==0:\n torch.save({'model':model.state_dict(),'opt':opt.state_dict(),'sch':sch.state_dict(),'ema':ema.state_dict(),'ep':ep,'gs':gs},f'{OUTPUT_DIR}/checkpoints/latest.pt')\n print(f'💾 Saved epoch {ep+1}')\nema.apply_shadow(model)\ntorch.save({'model':model.state_dict(),'config':{'model':MODEL_SIZE,'img':IMG_SIZE,'data':DATASET,'params':num_params}},f'{OUTPUT_DIR}/liquidflow_final.pt')\nema.restore(model); print(f'\\n✅ Done! {(time.time()-t0)/60:.1f}min total')"]},{"cell_type":"markdown","metadata":{},"source":["## 5. Curves"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["if all_losses:\n s=[d['step'] for d in all_losses]; l=[d['loss'] for d in all_losses]; f=[d['flow'] for d in all_losses]\n fig,(a1,a2)=plt.subplots(1,2,figsize=(14,5))\n a1.plot(s,l,label='Total'); a1.plot(s,f,label='Flow'); a1.legend(); a1.grid(alpha=.3); a1.set_title('Loss')\n a2.plot(s,[d['lr'] for d in all_losses],color='orange'); a2.grid(alpha=.3); a2.set_title('LR')\n plt.tight_layout(); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## 6. Generate"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["model.eval(); torch.manual_seed(42)\nwith torch.no_grad(): im=euler_sample(model,(16,3,IMG_SIZE,IMG_SIZE),num_steps=50,device=device)\nim=im.clamp(-1,1)*.5+.5; g=make_grid_image(im,nrow=4); g.save(f'{OUTPUT_DIR}/final.png')\nplt.figure(figsize=(10,10)); plt.imshow(g); plt.axis('off'); plt.title(f'{MODEL_SIZE}-{IMG_SIZE}'); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## 7. Export"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["model.eval()\ntry:\n tr=torch.jit.trace(model,(torch.randn(1,3,IMG_SIZE,IMG_SIZE,device=device),torch.tensor([.5],device=device)))\n tr.save(f'{OUTPUT_DIR}/mobile.pt'); print(f'✅ TorchScript: {os.path.getsize(f\"{OUTPUT_DIR}/mobile.pt\")/1e6:.1f}MB')\nexcept Exception as e: print(f'⚠️ {e}')"]},{"cell_type":"markdown","metadata":{},"source":["## 8. Experiment Guide\n","| Goal | Dataset | Model | Size | Epochs | Colab Time |\n","|------|---------|-------|------|--------|------------|\n","| Quick test | CIFAR-10 | tiny | 32 | 10 | ~10 min |\n","| **Good baseline** | **CIFAR-10** | **tiny** | **128** | **20** | **~3h** |\n","| Quality | Flowers | small | 128 | 50 | ~6h |\n","| Faces | CelebA | small | 128 | 15 | ~4h |\n","\n","### Tips\n","- Loss ~0.05-0.07 at epoch 3 is **good** for flow matching on CIFAR-10\n","- Generated images will look noisy until epoch 5-10, then sharpen\n","- For **Kaggle** (30h limit): use `EPOCHS=50` with `tiny` for best results\n","- `heun` sampler gives better quality at half the steps vs `euler`"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4"},"kernelspec":{"display_name":"Python 3","name":"python3"}},"nbformat":4,"nbformat_minor":4}
 
1
+ {"cells":[{"cell_type":"markdown","metadata":{},"source":["# 🌊 LiquidFlow v0.5 — Liquid-SSM Flow Matching Image Generator\n","\n","**All configs now run at the same speed** — `_auto_patch` scales patch_size so L≤256 tokens always.\n","\n","| Config | Image | Patch | L | Params | VRAM (bs=32) | Speed (T4) |\n","|--------|-------|-------|---|--------|-------------|------------|\n","| tiny | 128 | 8 | 256 | 4.9M | ~3GB | ~10 it/s |\n","| small | 128 | 8 | 256 | 10M | ~5GB | ~5 it/s |\n","| small | 256 | 16 | 256 | 10M | ~5GB | ~5 it/s |\n","| base | 256 | 16 | 256 | 24M | ~10GB | ~2 it/s |\n","| 512 | 512 | 32 | 256 | 24M | ~10GB | ~2 it/s |"]},{"cell_type":"markdown","metadata":{},"source":["## 0. Setup"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!nvidia-smi || echo 'No GPU'\nimport torch; print(f'PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}')\nif torch.cuda.is_available(): print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB')"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!pip install -q torch torchvision einops pillow matplotlib tqdm mambapy"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os\n!rm -rf liquidflow liquidflow_repo\n!git clone https://huggingface.co/krystv/LiquidFlow liquidflow_repo\n!cp -r liquidflow_repo/liquidflow .\nfrom liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512, HAS_PSCAN\nfrom liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\nfrom liquidflow.sampling import euler_sample, heun_sample, make_grid_image\nprint(f'✅ LiquidFlow v0.5 | pscan: {HAS_PSCAN}')"]},{"cell_type":"markdown","metadata":{},"source":["## 1. Config"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#@title 🎛️ Configuration { display-mode: \"form\" }\nDATASET='cifar10' #@param ['cifar10','flowers','celeba','folder','fashion_mnist']\nCUSTOM_DATA_DIR='/content/my_images' #@param {type:\"string\"}\nMODEL_SIZE='tiny' #@param ['tiny','small','base','512']\nIMG_SIZE=128 #@param [32,64,128,256,512] {type:\"integer\"}\nEPOCHS=20 #@param {type:\"integer\"}\nBATCH_SIZE=32 #@param [4,8,16,32,64,128] {type:\"integer\"}\nLEARNING_RATE=3e-4 #@param {type:\"number\"}\nGRAD_ACCUM=1 #@param [1,2,4,8] {type:\"integer\"}\nUSE_AMP=True #@param {type:\"boolean\"}\nLAMBDA_SMOOTH=0.01 #@param {type:\"number\"}\nLAMBDA_TV=0.001 #@param {type:\"number\"}\nSAMPLE_EVERY=5 #@param {type:\"integer\"}\nSAMPLE_STEPS=50 #@param [10,25,50,100] {type:\"integer\"}\nLOG_EVERY=100 #@param {type:\"integer\"}\nSAVE_EVERY=5 #@param {type:\"integer\"}\nOUTPUT_DIR='./outputs'; DATA_DIR='./data'\nimport torch\nif torch.cuda.is_available():\n vram=torch.cuda.get_device_properties(0).total_memory/1e9\n print(f'GPU VRAM: {vram:.1f} GB')\n rec={(128,'tiny'):64,(128,'small'):32,(256,'small'):32,(256,'base'):16,(512,'base'):8,(512,'512'):8}\n if (IMG_SIZE,MODEL_SIZE) in rec and BATCH_SIZE>rec[(IMG_SIZE,MODEL_SIZE)]:\n BATCH_SIZE=rec[(IMG_SIZE,MODEL_SIZE)]; print(f'⚠️ bs→{BATCH_SIZE}')\nelse: USE_AMP=False\nprint(f'📋 {MODEL_SIZE}-{IMG_SIZE}, {DATASET}, bs={BATCH_SIZE}, epochs={EPOCHS}')"]},{"cell_type":"markdown","metadata":{},"source":["## 2. Dataset"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import torchvision, torchvision.transforms as T\nfrom torch.utils.data import DataLoader, Dataset, ConcatDataset\nfrom pathlib import Path; from PIL import Image\nimport os, matplotlib.pyplot as plt, numpy as np\ndef tfm(s): return T.Compose([T.Resize(s+s//8),T.CenterCrop(s),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize([.5]*3,[.5]*3)])\nclass FolderDS(Dataset):\n def __init__(s,r,t): s.t=t; s.f=sorted(sum([list(Path(r).rglob(e)) for e in['*.png','*.jpg','*.jpeg','*.webp']],[]))\n def __len__(s): return len(s.f)\n def __getitem__(s,i): return s.t(Image.open(s.f[i]).convert('RGB'))\nclass G2R:\n def __call__(s,x): return x.repeat(3,1,1) if x.shape[0]==1 else x\ntf=tfm(IMG_SIZE)\nif DATASET=='cifar10': dataset=torchvision.datasets.CIFAR10(DATA_DIR,True,download=True,transform=tf)\nelif DATASET=='flowers': dataset=ConcatDataset([torchvision.datasets.Flowers102(DATA_DIR,s,download=True,transform=tf) for s in['train','val','test']])\nelif DATASET=='celeba': dataset=torchvision.datasets.CelebA(DATA_DIR,'train',download=True,transform=tf)\nelif DATASET=='fashion_mnist': dataset=torchvision.datasets.FashionMNIST(DATA_DIR,True,download=True,transform=T.Compose([T.Resize(IMG_SIZE),T.ToTensor(),T.Normalize([.5],[.5]),G2R()]))\nelif DATASET=='folder': dataset=FolderDS(CUSTOM_DATA_DIR,tf)\nelse: raise ValueError(DATASET)\nprint(f'✅ {DATASET}: {len(dataset)} images')\nfig,ax=plt.subplots(1,8,figsize=(16,2))\nfor i,a in enumerate(ax):\n s=dataset[i]; s=s[0] if isinstance(s,(list,tuple)) else s\n a.imshow((s*.5+.5).permute(1,2,0).clamp(0,1).numpy()); a.axis('off')\nplt.suptitle(f'{DATASET} ({IMG_SIZE}×{IMG_SIZE})'); plt.tight_layout(); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## 3. Model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel={'tiny':liquidflow_tiny,'small':liquidflow_small,'base':liquidflow_base,'512':liquidflow_512}[MODEL_SIZE](img_size=IMG_SIZE).to(device)\nnum_params=model.count_params()\nprint(f'🏗️ LiquidFlow-{MODEL_SIZE}: {num_params:,} ({num_params/1e6:.1f}M)')\nprint(f' {IMG_SIZE}px, patch={model.patch_size}, L={model.num_patches} tokens, d={model.d_model}, depth={model.depth}')\nwith torch.no_grad(): model(torch.randn(1,3,IMG_SIZE,IMG_SIZE,device=device),torch.tensor([.5],device=device))\nprint('✅ Forward OK')"]},{"cell_type":"markdown","metadata":{},"source":["## 4. Train"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import math,time,torch.nn as nn\nfrom liquidflow.losses import PhysicsInformedFlowLoss,EMAModel\nfrom liquidflow.sampling import euler_sample,make_grid_image\nimport matplotlib.pyplot as plt\nos.makedirs(f'{OUTPUT_DIR}/samples',exist_ok=True); os.makedirs(f'{OUTPUT_DIR}/checkpoints',exist_ok=True)\ndl=DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=2,pin_memory=True,drop_last=True)\nopt=torch.optim.AdamW(model.parameters(),lr=LEARNING_RATE,betas=(.9,.999),weight_decay=.01)\ntot=EPOCHS*len(dl)//GRAD_ACCUM; wu=min(500,tot//10)\ndef clr(s):\n if s<wu: return s/max(1,wu)\n return .1+.9*.5*(1+math.cos(math.pi*(s-wu)/max(1,tot-wu)))\nsch=torch.optim.lr_scheduler.LambdaLR(opt,clr)\ncrit=PhysicsInformedFlowLoss(lambda_smooth=LAMBDA_SMOOTH,lambda_tv=LAMBDA_TV).to(device)\nema=EMAModel(model,decay=.9999)\namp_dev='cuda' if device.type=='cuda' else 'cpu'\nscaler=torch.amp.GradScaler(amp_dev,enabled=USE_AMP)\nall_losses=[]; gs=0; spe=len(dl)//GRAD_ACCUM\nprint(f'🚀 {EPOCHS} epochs × {spe} steps = {tot} total | batch {BATCH_SIZE}×{GRAD_ACCUM}={BATCH_SIZE*GRAD_ACCUM}\\n')\nt0=time.time()\nfor ep in range(EPOCHS):\n model.train(); el=ef=0.; nb=0\n for bi,bd in enumerate(dl):\n x1=(bd[0] if isinstance(bd,(list,tuple)) else bd).to(device)\n B=x1.shape[0]; x0=torch.randn_like(x1); t=torch.rand(B,device=device)\n xt=t.view(B,1,1,1)*x1+(1-t.view(B,1,1,1))*x0\n with torch.amp.autocast(amp_dev,enabled=USE_AMP):\n vp=model(xt,t); loss,ld=crit(vp,x0,x1,t,step=gs); loss=loss/GRAD_ACCUM\n scaler.scale(loss).backward()\n if (bi+1)%GRAD_ACCUM==0:\n scaler.unscale_(opt); gn=nn.utils.clip_grad_norm_(model.parameters(),1.0)\n scaler.step(opt); scaler.update(); opt.zero_grad(); sch.step(); ema.update(model); gs+=1\n el+=ld['total'].item(); ef+=ld['flow'].item(); nb+=1\n if gs==1 or gs%LOG_EVERY==0:\n a=el/nb; af=ef/nb; lr=sch.get_last_lr()[0]\n elapsed=time.time()-t0; its=gs/elapsed\n eta=((tot-gs)/max(its,.01)); eta_str=f'{eta/3600:.1f}h' if eta>3600 else f'{eta/60:.0f}min'\n all_losses.append({'step':gs,'loss':a,'flow':af,'lr':lr,'epoch':ep})\n print(f' E{ep+1}/{EPOCHS} [{gs}/{tot}] loss={a:.4f} flow={af:.4f} lr={lr:.2e} gn={gn:.2f} | {its:.1f}it/s ETA={eta_str}')\n print(f' ── Epoch {ep+1} done ── loss={el/max(1,nb):.4f} ({(time.time()-t0)/60:.0f}min elapsed)\\n')\n if (ep+1)%SAMPLE_EVERY==0 or ep==0:\n model.eval(); ema.apply_shadow(model)\n with torch.no_grad():\n im=euler_sample(model,(min(16,BATCH_SIZE),3,IMG_SIZE,IMG_SIZE),num_steps=SAMPLE_STEPS,device=device)\n im=im.clamp(-1,1)*.5+.5; g=make_grid_image(im,nrow=4); g.save(f'{OUTPUT_DIR}/samples/ep{ep+1:04d}.png')\n fig,ax=plt.subplots(1,1,figsize=(8,8)); ax.imshow(g); ax.set_title(f'Epoch {ep+1} loss={el/max(1,nb):.4f}'); ax.axis('off'); plt.show()\n ema.restore(model); model.train()\n if (ep+1)%SAVE_EVERY==0:\n torch.save({'model':model.state_dict(),'opt':opt.state_dict(),'sch':sch.state_dict(),'ema':ema.state_dict(),'ep':ep,'gs':gs},f'{OUTPUT_DIR}/checkpoints/latest.pt')\n print(f' 💾 Saved')\nema.apply_shadow(model)\ntorch.save({'model':model.state_dict(),'config':{'model':MODEL_SIZE,'img':IMG_SIZE,'data':DATASET,'params':num_params}},f'{OUTPUT_DIR}/liquidflow_final.pt')\nema.restore(model); print(f'\\n✅ Done! {(time.time()-t0)/60:.1f}min')"]},{"cell_type":"markdown","metadata":{},"source":["## 5. Curves"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["if all_losses:\n s=[d['step'] for d in all_losses]; l=[d['loss'] for d in all_losses]; f=[d['flow'] for d in all_losses]\n fig,(a1,a2)=plt.subplots(1,2,figsize=(14,5))\n a1.plot(s,l,label='Total'); a1.plot(s,f,label='Flow'); a1.legend(); a1.grid(alpha=.3); a1.set_title('Loss')\n a2.plot(s,[d['lr'] for d in all_losses],color='orange'); a2.grid(alpha=.3); a2.set_title('LR')\n plt.tight_layout(); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## 6. Generate"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["model.eval(); torch.manual_seed(42)\nwith torch.no_grad(): im=euler_sample(model,(16,3,IMG_SIZE,IMG_SIZE),num_steps=50,device=device)\nim=im.clamp(-1,1)*.5+.5; g=make_grid_image(im,nrow=4); g.save(f'{OUTPUT_DIR}/final.png')\nplt.figure(figsize=(10,10)); plt.imshow(g); plt.axis('off'); plt.title(f'{MODEL_SIZE}-{IMG_SIZE}'); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## 7. Experiments\n","| Goal | Dataset | Model | Size | Epochs | Colab |\n","|------|---------|-------|------|--------|-------|\n","| Quick test | CIFAR-10 | tiny | 32 | 10 | 5min |\n","| **Baseline** | **CIFAR-10** | **tiny** | **128** | **20** | **1.5h** |\n","| Quality | Flowers | small | 128 | 30 | 3h |\n","| Hi-res | CIFAR-10 | small | 256 | 20 | 3h |\n","| Faces | CelebA | small | 128 | 15 | 4h |\n","| Kaggle long | any | small | 256 | 50 | 8h |\n","\n","All configs have L=256 tokens. Speed is the same for 128/256/512."]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4"},"kernelspec":{"display_name":"Python 3","name":"python3"}},"nbformat":4,"nbformat_minor":4}