v0.5: Add streaming HF dataset support for art/anime training, new notebook with dataset options
Browse files
LiquidFlow_Training.ipynb
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
{"cells":[{"cell_type":"markdown","metadata":{},"source":["# π LiquidFlow v0.5 β Liquid-SSM Flow Matching Image Generator\n","\n","**All configs now run at the same speed** β `_auto_patch` scales patch_size so Lβ€256 tokens always.\n","\n","| Config | Image | Patch | L | Params | VRAM (bs=32) | Speed (T4) |\n","|--------|-------|-------|---|--------|-------------|------------|\n","| tiny | 128 | 8 | 256 | 4.9M | ~3GB | ~10 it/s |\n","| small | 128 | 8 | 256 | 10M | ~5GB | ~5 it/s |\n","| small | 256 | 16 | 256 | 10M | ~5GB | ~5 it/s |\n","| base | 256 | 16 | 256 | 24M | ~10GB | ~2 it/s |\n","| 512 | 512 | 32 | 256 | 24M | ~10GB | ~2 it/s |"]},{"cell_type":"markdown","metadata":{},"source":["## 0. Setup"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!nvidia-smi || echo 'No GPU'\nimport torch; print(f'PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}')\nif torch.cuda.is_available(): print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB')"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!pip install -q torch torchvision einops pillow matplotlib tqdm mambapy"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os\n!rm -rf liquidflow liquidflow_repo\n!git clone https://huggingface.co/krystv/LiquidFlow liquidflow_repo\n!cp -r liquidflow_repo/liquidflow .\nfrom liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512, HAS_PSCAN\nfrom liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\nfrom liquidflow.sampling import euler_sample, heun_sample, make_grid_image\nprint(f'β
LiquidFlow v0.5 | pscan: {HAS_PSCAN}')"]},{"cell_type":"markdown","metadata":{},"source":["## 1. Config"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#@title ποΈ Configuration { display-mode: \"form\" }\nDATASET='cifar10' #@param ['cifar10','flowers','celeba','folder','fashion_mnist']\nCUSTOM_DATA_DIR='/content/my_images' #@param {type:\"string\"}\nMODEL_SIZE='tiny' #@param ['tiny','small','base','512']\nIMG_SIZE=128 #@param [32,64,128,256,512] {type:\"integer\"}\nEPOCHS=20 #@param {type:\"integer\"}\nBATCH_SIZE=32 #@param [4,8,16,32,64,128] {type:\"integer\"}\nLEARNING_RATE=3e-4 #@param {type:\"number\"}\nGRAD_ACCUM=1 #@param [1,2,4,8] {type:\"integer\"}\nUSE_AMP=True #@param {type:\"boolean\"}\nLAMBDA_SMOOTH=0.01 #@param {type:\"number\"}\nLAMBDA_TV=0.001 #@param {type:\"number\"}\nSAMPLE_EVERY=5 #@param {type:\"integer\"}\nSAMPLE_STEPS=50 #@param [10,25,50,100] {type:\"integer\"}\nLOG_EVERY=100 #@param {type:\"integer\"}\nSAVE_EVERY=5 #@param {type:\"integer\"}\nOUTPUT_DIR='./outputs'; DATA_DIR='./data'\nimport torch\nif torch.cuda.is_available():\n vram=torch.cuda.get_device_properties(0).total_memory/1e9\n print(f'GPU VRAM: {vram:.1f} GB')\n rec={(128,'tiny'):64,(128,'small'):32,(256,'small'):32,(256,'base'):16,(512,'base'):8,(512,'512'):8}\n if (IMG_SIZE,MODEL_SIZE) in rec and BATCH_SIZE>rec[(IMG_SIZE,MODEL_SIZE)]:\n BATCH_SIZE=rec[(IMG_SIZE,MODEL_SIZE)]; print(f'β οΈ bsβ{BATCH_SIZE}')\nelse: USE_AMP=False\nprint(f'π {MODEL_SIZE}-{IMG_SIZE}, {DATASET}, bs={BATCH_SIZE}, epochs={EPOCHS}')"]},{"cell_type":"markdown","metadata":{},"source":["## 2. Dataset"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import torchvision, torchvision.transforms as T\nfrom torch.utils.data import DataLoader, Dataset, ConcatDataset\nfrom pathlib import Path; from PIL import Image\nimport os, matplotlib.pyplot as plt, numpy as np\ndef tfm(s): return T.Compose([T.Resize(s+s//8),T.CenterCrop(s),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize([.5]*3,[.5]*3)])\nclass FolderDS(Dataset):\n def __init__(s,r,t): s.t=t; s.f=sorted(sum([list(Path(r).rglob(e)) for e in['*.png','*.jpg','*.jpeg','*.webp']],[]))\n def __len__(s): return len(s.f)\n def __getitem__(s,i): return s.t(Image.open(s.f[i]).convert('RGB'))\nclass G2R:\n def __call__(s,x): return x.repeat(3,1,1) if x.shape[0]==1 else x\ntf=tfm(IMG_SIZE)\nif DATASET=='cifar10': dataset=torchvision.datasets.CIFAR10(DATA_DIR,True,download=True,transform=tf)\nelif DATASET=='flowers': dataset=ConcatDataset([torchvision.datasets.Flowers102(DATA_DIR,s,download=True,transform=tf) for s in['train','val','test']])\nelif DATASET=='celeba': dataset=torchvision.datasets.CelebA(DATA_DIR,'train',download=True,transform=tf)\nelif DATASET=='fashion_mnist': dataset=torchvision.datasets.FashionMNIST(DATA_DIR,True,download=True,transform=T.Compose([T.Resize(IMG_SIZE),T.ToTensor(),T.Normalize([.5],[.5]),G2R()]))\nelif DATASET=='folder': dataset=FolderDS(CUSTOM_DATA_DIR,tf)\nelse: raise ValueError(DATASET)\nprint(f'β
{DATASET}: {len(dataset)} images')\nfig,ax=plt.subplots(1,8,figsize=(16,2))\nfor i,a in enumerate(ax):\n s=dataset[i]; s=s[0] if isinstance(s,(list,tuple)) else s\n a.imshow((s*.5+.5).permute(1,2,0).clamp(0,1).numpy()); a.axis('off')\nplt.suptitle(f'{DATASET} ({IMG_SIZE}Γ{IMG_SIZE})'); plt.tight_layout(); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## 3. Model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel={'tiny':liquidflow_tiny,'small':liquidflow_small,'base':liquidflow_base,'512':liquidflow_512}[MODEL_SIZE](img_size=IMG_SIZE).to(device)\nnum_params=model.count_params()\nprint(f'ποΈ LiquidFlow-{MODEL_SIZE}: {num_params:,} ({num_params/1e6:.1f}M)')\nprint(f' {IMG_SIZE}px, patch={model.patch_size}, L={model.num_patches} tokens, d={model.d_model}, depth={model.depth}')\nwith torch.no_grad(): model(torch.randn(1,3,IMG_SIZE,IMG_SIZE,device=device),torch.tensor([.5],device=device))\nprint('β
Forward OK')"]},{"cell_type":"markdown","metadata":{},"source":["## 4. Train"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import math,time,torch.nn as nn\nfrom liquidflow.losses import PhysicsInformedFlowLoss,EMAModel\nfrom liquidflow.sampling import euler_sample,make_grid_image\nimport matplotlib.pyplot as plt\nos.makedirs(f'{OUTPUT_DIR}/samples',exist_ok=True); os.makedirs(f'{OUTPUT_DIR}/checkpoints',exist_ok=True)\ndl=DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=2,pin_memory=True,drop_last=True)\nopt=torch.optim.AdamW(model.parameters(),lr=LEARNING_RATE,betas=(.9,.999),weight_decay=.01)\ntot=EPOCHS*len(dl)//GRAD_ACCUM; wu=min(500,tot//10)\ndef clr(s):\n if s<wu: return s/max(1,wu)\n return .1+.9*.5*(1+math.cos(math.pi*(s-wu)/max(1,tot-wu)))\nsch=torch.optim.lr_scheduler.LambdaLR(opt,clr)\ncrit=PhysicsInformedFlowLoss(lambda_smooth=LAMBDA_SMOOTH,lambda_tv=LAMBDA_TV).to(device)\nema=EMAModel(model,decay=.9999)\namp_dev='cuda' if device.type=='cuda' else 'cpu'\nscaler=torch.amp.GradScaler(amp_dev,enabled=USE_AMP)\nall_losses=[]; gs=0; spe=len(dl)//GRAD_ACCUM\nprint(f'π {EPOCHS} epochs Γ {spe} steps = {tot} total | batch {BATCH_SIZE}Γ{GRAD_ACCUM}={BATCH_SIZE*GRAD_ACCUM}\\n')\nt0=time.time()\nfor ep in range(EPOCHS):\n model.train(); el=ef=0.; nb=0\n for bi,bd in enumerate(dl):\n x1=(bd[0] if isinstance(bd,(list,tuple)) else bd).to(device)\n B=x1.shape[0]; x0=torch.randn_like(x1); t=torch.rand(B,device=device)\n xt=t.view(B,1,1,1)*x1+(1-t.view(B,1,1,1))*x0\n with torch.amp.autocast(amp_dev,enabled=USE_AMP):\n vp=model(xt,t); loss,ld=crit(vp,x0,x1,t,step=gs); loss=loss/GRAD_ACCUM\n scaler.scale(loss).backward()\n if (bi+1)%GRAD_ACCUM==0:\n scaler.unscale_(opt); gn=nn.utils.clip_grad_norm_(model.parameters(),1.0)\n scaler.step(opt); scaler.update(); opt.zero_grad(); sch.step(); ema.update(model); gs+=1\n el+=ld['total'].item(); ef+=ld['flow'].item(); nb+=1\n if gs==1 or gs%LOG_EVERY==0:\n a=el/nb; af=ef/nb; lr=sch.get_last_lr()[0]\n elapsed=time.time()-t0; its=gs/elapsed\n eta=((tot-gs)/max(its,.01)); eta_str=f'{eta/3600:.1f}h' if eta>3600 else f'{eta/60:.0f}min'\n all_losses.append({'step':gs,'loss':a,'flow':af,'lr':lr,'epoch':ep})\n print(f' E{ep+1}/{EPOCHS} [{gs}/{tot}] loss={a:.4f} flow={af:.4f} lr={lr:.2e} gn={gn:.2f} | {its:.1f}it/s ETA={eta_str}')\n print(f' ββ Epoch {ep+1} done ββ loss={el/max(1,nb):.4f} ({(time.time()-t0)/60:.0f}min elapsed)\\n')\n if (ep+1)%SAMPLE_EVERY==0 or ep==0:\n model.eval(); ema.apply_shadow(model)\n with torch.no_grad():\n im=euler_sample(model,(min(16,BATCH_SIZE),3,IMG_SIZE,IMG_SIZE),num_steps=SAMPLE_STEPS,device=device)\n im=im.clamp(-1,1)*.5+.5; g=make_grid_image(im,nrow=4); g.save(f'{OUTPUT_DIR}/samples/ep{ep+1:04d}.png')\n fig,ax=plt.subplots(1,1,figsize=(8,8)); ax.imshow(g); ax.set_title(f'Epoch {ep+1} loss={el/max(1,nb):.4f}'); ax.axis('off'); plt.show()\n ema.restore(model); model.train()\n if (ep+1)%SAVE_EVERY==0:\n torch.save({'model':model.state_dict(),'opt':opt.state_dict(),'sch':sch.state_dict(),'ema':ema.state_dict(),'ep':ep,'gs':gs},f'{OUTPUT_DIR}/checkpoints/latest.pt')\n print(f' πΎ Saved')\nema.apply_shadow(model)\ntorch.save({'model':model.state_dict(),'config':{'model':MODEL_SIZE,'img':IMG_SIZE,'data':DATASET,'params':num_params}},f'{OUTPUT_DIR}/liquidflow_final.pt')\nema.restore(model); print(f'\\nβ
Done! {(time.time()-t0)/60:.1f}min')"]},{"cell_type":"markdown","metadata":{},"source":["## 5. Curves"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["if all_losses:\n s=[d['step'] for d in all_losses]; l=[d['loss'] for d in all_losses]; f=[d['flow'] for d in all_losses]\n fig,(a1,a2)=plt.subplots(1,2,figsize=(14,5))\n a1.plot(s,l,label='Total'); a1.plot(s,f,label='Flow'); a1.legend(); a1.grid(alpha=.3); a1.set_title('Loss')\n a2.plot(s,[d['lr'] for d in all_losses],color='orange'); a2.grid(alpha=.3); a2.set_title('LR')\n plt.tight_layout(); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## 6. Generate"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["model.eval(); torch.manual_seed(42)\nwith torch.no_grad(): im=euler_sample(model,(16,3,IMG_SIZE,IMG_SIZE),num_steps=50,device=device)\nim=im.clamp(-1,1)*.5+.5; g=make_grid_image(im,nrow=4); g.save(f'{OUTPUT_DIR}/final.png')\nplt.figure(figsize=(10,10)); plt.imshow(g); plt.axis('off'); plt.title(f'{MODEL_SIZE}-{IMG_SIZE}'); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## 7. Experiments\n","| Goal | Dataset | Model | Size | Epochs | Colab |\n","|------|---------|-------|------|--------|-------|\n","| Quick test | CIFAR-10 | tiny | 32 | 10 | 5min |\n","| **Baseline** | **CIFAR-10** | **tiny** | **128** | **20** | **1.5h** |\n","| Quality | Flowers | small | 128 | 30 | 3h |\n","| Hi-res | CIFAR-10 | small | 256 | 20 | 3h |\n","| Faces | CelebA | small | 128 | 15 | 4h |\n","| Kaggle long | any | small | 256 | 50 | 8h |\n","\n","All configs have L=256 tokens. Speed is the same for 128/256/512."]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4"},"kernelspec":{"display_name":"Python 3","name":"python3"}},"nbformat":4,"nbformat_minor":4}
|
|
|
|
| 1 |
+
{"cells":[{"cell_type":"markdown","metadata":{},"source":["# π LiquidFlow v0.5 β Art & Anime Image Generator\n","\n","**Supports HuggingFace streaming datasets** β no download, no storage limits. Train on 512px art on free Kaggle GPU.\n","\n","## Recommended Art Datasets\n","| Dataset | Size | Content | Quality | Best For |\n","|---------|------|---------|---------|----------|\n","| `huggan/wikiart` | 80K imgs, 5GB | Classical paintings, 27 styles | ββββ | Fine art generation |\n","| `fantasyfish/laion-art` | 8.5K imgs, 11GB | High-aesthetic mixed art | βββββ | General art |\n","| `huggan/few-shot-art-painting` | 1K imgs, 500MB | Oil paintings | βββ | Quick experiments |\n","| `huggan/anime-faces` | 63K imgs | Anime faces (64px) | βββ | Anime faces |\n","| `reach-vb/pokemon-blip-captions` | 833 imgs, 95MB | Pokemon art | βββ | Quick test |\n","| `jainr3/diffusiondb-pixelart` | 2K imgs, 20MB | Pixel art style | βββ | Pixel art |\n","\n","All use **streaming** β data loads on-the-fly, zero disk space needed."]},{"cell_type":"markdown","metadata":{},"source":["## 0. Setup"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!nvidia-smi || echo 'No GPU'\nimport torch; print(f'PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}')\nif torch.cuda.is_available(): print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB')"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!pip install -q torch torchvision einops pillow matplotlib tqdm mambapy datasets"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os\n!rm -rf liquidflow liquidflow_repo\n!git clone https://huggingface.co/krystv/LiquidFlow liquidflow_repo\n!cp -r liquidflow_repo/liquidflow .\nfrom liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512, HAS_PSCAN\nfrom liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\nfrom liquidflow.sampling import euler_sample, heun_sample, make_grid_image\nprint(f'β
LiquidFlow v0.5 | pscan: {HAS_PSCAN}')"]},{"cell_type":"markdown","metadata":{},"source":["## 1. Config"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#@title ποΈ Configuration { display-mode: \"form\" }\n\n#@markdown ### Dataset (HF streaming β no download needed)\nDATASET='huggan/wikiart' #@param ['huggan/wikiart','fantasyfish/laion-art','huggan/few-shot-art-painting','huggan/anime-faces','reach-vb/pokemon-blip-captions','jainr3/diffusiondb-pixelart','cifar10','folder']\nCUSTOM_DATA_DIR='/content/my_images' #@param {type:\"string\"}\n#@markdown ### Model\nMODEL_SIZE='small' #@param ['tiny','small','base','512']\nIMG_SIZE=256 #@param [128,256,512] {type:\"integer\"}\n#@markdown ### Training\nEPOCHS=10 #@param {type:\"integer\"}\nSTEPS_PER_EPOCH=1000 #@param {type:\"integer\"}\nBATCH_SIZE=16 #@param [4,8,16,32] {type:\"integer\"}\nLEARNING_RATE=3e-4 #@param {type:\"number\"}\nGRAD_ACCUM=2 #@param [1,2,4,8] {type:\"integer\"}\nUSE_AMP=True #@param {type:\"boolean\"}\nLAMBDA_SMOOTH=0.01 #@param {type:\"number\"}\nLAMBDA_TV=0.001 #@param {type:\"number\"}\nSAMPLE_EVERY=2 #@param {type:\"integer\"}\nSAMPLE_STEPS=50 #@param [10,25,50] {type:\"integer\"}\nLOG_EVERY=50 #@param {type:\"integer\"}\nOUTPUT_DIR='./outputs'\nimport torch\ndevice=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nif not torch.cuda.is_available(): USE_AMP=False\nprint(f'π {MODEL_SIZE}-{IMG_SIZE}px, {DATASET}, bs={BATCH_SIZE}Γ{GRAD_ACCUM}, epochs={EPOCHS}Γ{STEPS_PER_EPOCH}steps')"]},{"cell_type":"markdown","metadata":{},"source":["## 2. Load Dataset (Streaming β no disk needed)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import torchvision.transforms as T\nfrom torch.utils.data import DataLoader, Dataset, IterableDataset\nimport matplotlib.pyplot as plt\nfrom PIL import Image\nimport numpy as np\n\ntf = T.Compose([\n T.Resize(IMG_SIZE + IMG_SIZE//8),\n T.CenterCrop(IMG_SIZE),\n T.RandomHorizontalFlip(),\n T.ToTensor(),\n T.Normalize([.5]*3, [.5]*3),\n])\n\nif DATASET == 'cifar10':\n import torchvision\n ds = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tf)\n dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)\n USE_STREAMING = False\n print(f'β
CIFAR-10: {len(ds)} images (local)')\n\nelif DATASET == 'folder':\n from pathlib import Path\n class FolderDS(Dataset):\n def __init__(s,r,t): s.t=t; s.f=sorted(sum([list(Path(r).rglob(e)) for e in['*.png','*.jpg','*.jpeg','*.webp']],[]))\n def __len__(s): return len(s.f)\n def __getitem__(s,i): return s.t(Image.open(s.f[i]).convert('RGB'))\n ds = FolderDS(CUSTOM_DATA_DIR, tf)\n dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)\n USE_STREAMING = False\n print(f'β
Folder: {len(ds)} images')\n\nelse:\n # HuggingFace streaming dataset β loads on-the-fly, zero disk space\n from datasets import load_dataset\n \n class HFStreamDataset(IterableDataset):\n \"\"\"Wraps a HF streaming dataset for PyTorch DataLoader.\"\"\"\n def __init__(self, hf_name, transform, img_col='image', split='train'):\n self.hf_ds = load_dataset(hf_name, split=split, streaming=True).shuffle(seed=42, buffer_size=1000)\n self.transform = transform\n self.img_col = img_col\n \n def __iter__(self):\n for item in self.hf_ds:\n try:\n img = item[self.img_col]\n if isinstance(img, dict) and 'bytes' in img:\n import io\n img = Image.open(io.BytesIO(img['bytes'])).convert('RGB')\n elif not isinstance(img, Image.Image):\n img = Image.open(img).convert('RGB')\n else:\n img = img.convert('RGB')\n # Skip tiny images\n if min(img.size) < IMG_SIZE // 2:\n continue\n yield self.transform(img)\n except Exception:\n continue\n \n ds = HFStreamDataset(DATASET, tf)\n dl = DataLoader(ds, batch_size=BATCH_SIZE, num_workers=2, pin_memory=True)\n USE_STREAMING = True\n print(f'β
{DATASET}: streaming (no download)')\n\n# Preview samples\nprint('Loading preview...')\nprev_batch = next(iter(dl))\nif isinstance(prev_batch, (list, tuple)): prev_batch = prev_batch[0]\nfig, axes = plt.subplots(1, min(8, prev_batch.shape[0]), figsize=(16, 2))\nfor i, ax in enumerate(axes):\n ax.imshow((prev_batch[i]*.5+.5).permute(1,2,0).clamp(0,1).numpy()); ax.axis('off')\nplt.suptitle(f'{DATASET} @ {IMG_SIZE}px'); plt.tight_layout(); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## 3. Model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["model={'tiny':liquidflow_tiny,'small':liquidflow_small,'base':liquidflow_base,'512':liquidflow_512}[MODEL_SIZE](img_size=IMG_SIZE).to(device)\nnum_params=model.count_params()\nprint(f'ποΈ LiquidFlow-{MODEL_SIZE}: {num_params:,} ({num_params/1e6:.1f}M)')\nprint(f' {IMG_SIZE}px, patch={model.patch_size}, L={model.num_patches} tokens')\nwith torch.no_grad(): model(torch.randn(1,3,IMG_SIZE,IMG_SIZE,device=device),torch.tensor([.5],device=device))\nprint('β
Forward OK')"]},{"cell_type":"markdown","metadata":{},"source":["## 4. Train"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import math, time, torch.nn as nn\nfrom liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\nfrom liquidflow.sampling import euler_sample, make_grid_image\nimport matplotlib.pyplot as plt\nos.makedirs(f'{OUTPUT_DIR}/samples', exist_ok=True)\nos.makedirs(f'{OUTPUT_DIR}/checkpoints', exist_ok=True)\n\nopt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(.9,.999), weight_decay=.01)\ntot = EPOCHS * STEPS_PER_EPOCH\nwu = min(500, tot // 10)\ndef clr(s):\n if s < wu: return s / max(1, wu)\n return .1 + .9 * .5 * (1 + math.cos(math.pi * (s - wu) / max(1, tot - wu)))\nsch = torch.optim.lr_scheduler.LambdaLR(opt, clr)\ncrit = PhysicsInformedFlowLoss(lambda_smooth=LAMBDA_SMOOTH, lambda_tv=LAMBDA_TV).to(device)\nema = EMAModel(model, decay=.9999)\namp_dev = 'cuda' if device.type == 'cuda' else 'cpu'\nscaler = torch.amp.GradScaler(amp_dev, enabled=USE_AMP)\nall_losses = []; gs = 0\nprint(f'π {EPOCHS} epochs Γ {STEPS_PER_EPOCH} steps = {tot} total')\nprint(f' batch {BATCH_SIZE}Γ{GRAD_ACCUM}={BATCH_SIZE*GRAD_ACCUM}, LR {LEARNING_RATE}\\n')\nt0 = time.time()\n\nfor ep in range(EPOCHS):\n model.train(); el = ef = 0.; nb = 0\n \n if USE_STREAMING:\n # For streaming: iterate until STEPS_PER_EPOCH reached\n data_iter = iter(dl)\n \n for step_in_epoch in range(STEPS_PER_EPOCH * GRAD_ACCUM):\n # Get batch\n try:\n if USE_STREAMING:\n batch = next(data_iter)\n else:\n if step_in_epoch == 0:\n data_iter = iter(dl)\n batch = next(data_iter)\n except StopIteration:\n data_iter = iter(dl) # restart epoch\n batch = next(data_iter)\n \n x1 = (batch[0] if isinstance(batch, (list, tuple)) else batch).to(device)\n B = x1.shape[0]\n x0 = torch.randn_like(x1)\n t = torch.rand(B, device=device)\n xt = t.view(B,1,1,1) * x1 + (1 - t.view(B,1,1,1)) * x0\n \n with torch.amp.autocast(amp_dev, enabled=USE_AMP):\n vp = model(xt, t)\n loss, ld = crit(vp, x0, x1, t, step=gs)\n loss = loss / GRAD_ACCUM\n scaler.scale(loss).backward()\n \n if (step_in_epoch + 1) % GRAD_ACCUM == 0:\n scaler.unscale_(opt)\n gn = nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n scaler.step(opt); scaler.update(); opt.zero_grad()\n sch.step(); ema.update(model); gs += 1\n el += ld['total'].item(); ef += ld['flow'].item(); nb += 1\n \n if gs == 1 or gs % LOG_EVERY == 0:\n a = el/nb; af = ef/nb; lr = sch.get_last_lr()[0]\n elapsed = time.time() - t0; its = gs / elapsed\n eta = (tot - gs) / max(its, .01)\n eta_str = f'{eta/3600:.1f}h' if eta > 3600 else f'{eta/60:.0f}min'\n all_losses.append({'step': gs, 'loss': a, 'flow': af, 'lr': lr, 'epoch': ep})\n print(f' E{ep+1}/{EPOCHS} [{gs}/{tot}] loss={a:.4f} flow={af:.4f} lr={lr:.2e} gn={gn:.2f} | {its:.1f}it/s ETA={eta_str}')\n \n avg = el / max(1, nb)\n print(f' ββ Epoch {ep+1} done ββ loss={avg:.4f} ({(time.time()-t0)/60:.0f}min)\\n')\n \n # Sample\n if (ep + 1) % SAMPLE_EVERY == 0 or ep == 0:\n model.eval(); ema.apply_shadow(model)\n with torch.no_grad():\n im = euler_sample(model, (min(16, BATCH_SIZE), 3, IMG_SIZE, IMG_SIZE), num_steps=SAMPLE_STEPS, device=device)\n im = im.clamp(-1,1) * .5 + .5\n g = make_grid_image(im, nrow=4)\n g.save(f'{OUTPUT_DIR}/samples/ep{ep+1:04d}.png')\n fig, ax = plt.subplots(1, 1, figsize=(8, 8))\n ax.imshow(g); ax.set_title(f'Epoch {ep+1} | loss={avg:.4f}'); ax.axis('off')\n plt.show()\n ema.restore(model); model.train()\n \n # Checkpoint\n torch.save({\n 'model': model.state_dict(), 'opt': opt.state_dict(),\n 'sch': sch.state_dict(), 'ema': ema.state_dict(),\n 'ep': ep, 'gs': gs, 'config': {'model': MODEL_SIZE, 'img': IMG_SIZE, 'data': DATASET}\n }, f'{OUTPUT_DIR}/checkpoints/latest.pt')\n print(f' πΎ Checkpoint saved')\n\n# Save final with EMA\nema.apply_shadow(model)\ntorch.save({'model': model.state_dict(), 'config': {'model': MODEL_SIZE, 'img': IMG_SIZE, 'data': DATASET, 'params': num_params}}, f'{OUTPUT_DIR}/liquidflow_final.pt')\nema.restore(model)\nprint(f'\\nβ
Done! {(time.time()-t0)/60:.1f}min total')"]},{"cell_type":"markdown","metadata":{},"source":["## 5. Curves"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["if all_losses:\n s=[d['step'] for d in all_losses]; l=[d['loss'] for d in all_losses]; f=[d['flow'] for d in all_losses]\n fig,(a1,a2)=plt.subplots(1,2,figsize=(14,5))\n a1.plot(s,l,label='Total'); a1.plot(s,f,label='Flow'); a1.legend(); a1.grid(alpha=.3); a1.set_title('Loss')\n a2.plot(s,[d['lr'] for d in all_losses],color='orange'); a2.grid(alpha=.3); a2.set_title('LR')\n plt.tight_layout(); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## 6. Generate"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#@title Generation { display-mode: \"form\" }\nNUM=16 #@param {type:\"integer\"}\nSTEPS=50 #@param [25,50,100] {type:\"integer\"}\nmodel.eval(); torch.manual_seed(42)\nwith torch.no_grad(): im=euler_sample(model,(NUM,3,IMG_SIZE,IMG_SIZE),num_steps=STEPS,device=device)\nim=im.clamp(-1,1)*.5+.5; g=make_grid_image(im,nrow=4); g.save(f'{OUTPUT_DIR}/final.png')\nplt.figure(figsize=(12,12)); plt.imshow(g); plt.axis('off'); plt.title(f'{MODEL_SIZE}-{IMG_SIZE} on {DATASET}'); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## 7. Recommended Configs for Kaggle (30h GPU)\n","\n","| Goal | Dataset | Model | Size | Steps/ep | Epochs | Time |\n","|------|---------|-------|------|----------|--------|------|\n","| **Art paintings** | `huggan/wikiart` | small | 256 | 1000 | 30 | ~6h |\n","| **Anime faces** | `huggan/anime-faces` | small | 128 | 1000 | 30 | ~4h |\n","| **Hi-res art** | `fantasyfish/laion-art` | base | 512 | 500 | 20 | ~8h |\n","| **Quick test** | `reach-vb/pokemon-blip-captions` | tiny | 128 | 200 | 5 | 15min |\n","| **Pixel art** | `jainr3/diffusiondb-pixelart` | small | 256 | 500 | 20 | ~4h |\n","\n","### Why streaming?\n","- **Zero storage** β images load directly from HuggingFace servers\n","- **No waiting** β training starts immediately, no download phase\n","- **Any dataset size** β even 700GB datasets work (loads only what's needed)\n","- **Auto-shuffle** β built-in buffer shuffle for training\n","\n","### Tips for art training\n","- **Specific style = better results** β wikiart with one style > random internet images\n","- **More epochs on small dataset** beats fewer epochs on huge dataset for our model size\n","- **Check samples at epoch 5** β if you see structure forming, it's working\n","- **512px needs base/512 model** β tiny/small patch too aggressively at 512"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4"},"kernelspec":{"display_name":"Python 3","name":"python3"}},"nbformat":4,"nbformat_minor":4}
|