""" LiquidFlow Training Script Designed for: - Google Colab free tier (T4 16GB VRAM) - Kaggle free tier (P100 16GB / T4x2) - Any GPU with ≥8GB VRAM (128x128) - Any GPU with ≥16GB VRAM (512x512) Key training features: - Mixed precision (fp16/bf16) for memory efficiency - Gradient accumulation for large effective batch sizes - EMA for stable generation quality - Physics-informed loss with warmup - Cosine learning rate schedule with warmup - Checkpoint saving/resuming - Wandb/Trackio logging support """ import os import sys import math import time import json import argparse from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from torch.cuda.amp import autocast, GradScaler import torchvision import torchvision.transforms as transforms from PIL import Image import numpy as np # Add parent to path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from model import ( LiquidFlowNet, liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512 ) from losses import PhysicsInformedFlowLoss, EMAModel from sampling import euler_sample, heun_sample, make_grid_image # ============================================================ # DATASET UTILITIES # ============================================================ class ImageFolderDataset(Dataset): """Simple image dataset from folder.""" def __init__(self, root, img_size=128, transform=None): self.root = Path(root) self.img_size = img_size # Find all images self.files = [] for ext in ['*.png', '*.jpg', '*.jpeg', '*.webp', '*.bmp']: self.files.extend(self.root.rglob(ext)) self.files = sorted(self.files) if transform is None: self.transform = transforms.Compose([ transforms.Resize(img_size), transforms.CenterCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) else: self.transform = transform def __len__(self): return len(self.files) def __getitem__(self, idx): img = Image.open(self.files[idx]).convert('RGB') return self.transform(img) def get_cifar10_dataset(img_size=32, data_dir='./data'): """CIFAR-10 for quick experiments.""" transform = transforms.Compose([ transforms.Resize(img_size) if img_size != 32 else transforms.Lambda(lambda x: x), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = torchvision.datasets.CIFAR10( root=data_dir, train=True, download=True, transform=transform ) return dataset def get_celeba_dataset(img_size=128, data_dir='./data'): """CelebA for face generation.""" transform = transforms.Compose([ transforms.Resize(img_size), transforms.CenterCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = torchvision.datasets.CelebA( root=data_dir, split='train', download=True, transform=transform ) return dataset def get_flowers_dataset(img_size=128, data_dir='./data'): """Oxford Flowers 102 - small but beautiful dataset.""" transform = transforms.Compose([ transforms.Resize(img_size + img_size // 8), transforms.CenterCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = torchvision.datasets.Flowers102( root=data_dir, split='train', download=True, transform=transform ) return dataset # ============================================================ # LEARNING RATE SCHEDULE # ============================================================ def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps, min_lr_ratio=0.1): """Cosine annealing with linear warmup.""" def lr_lambda(step): if step < warmup_steps: return step / max(1, warmup_steps) progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) return min_lr_ratio + (1 - min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress)) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # ============================================================ # TRAINING LOOP # ============================================================ def train(args): """Main training function.""" # Setup device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') use_amp = device.type == 'cuda' and args.use_amp print(f"Device: {device}, AMP: {use_amp}") # Create output directory os.makedirs(args.output_dir, exist_ok=True) os.makedirs(os.path.join(args.output_dir, 'samples'), exist_ok=True) os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True) # ---- Model ---- model_factories = { 'tiny': liquidflow_tiny, 'small': liquidflow_small, 'base': liquidflow_base, '512': liquidflow_512, } if args.model_size in model_factories: model = model_factories[args.model_size](img_size=args.img_size) else: model = liquidflow_small(img_size=args.img_size) model = model.to(device) num_params = model.count_params() print(f"Model: LiquidFlow-{args.model_size}, Params: {num_params/1e6:.2f}M") print(f"Image size: {args.img_size}x{args.img_size}") # ---- Dataset ---- if args.dataset == 'cifar10': dataset = get_cifar10_dataset(args.img_size, args.data_dir) elif args.dataset == 'flowers': dataset = get_flowers_dataset(args.img_size, args.data_dir) elif args.dataset == 'celeba': dataset = get_celeba_dataset(args.img_size, args.data_dir) elif args.dataset == 'folder': dataset = ImageFolderDataset(args.data_dir, args.img_size) else: raise ValueError(f"Unknown dataset: {args.dataset}") print(f"Dataset: {args.dataset}, Size: {len(dataset)}") dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True, ) # ---- Optimizer ---- optimizer = torch.optim.AdamW( model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.weight_decay, eps=1e-8, ) # ---- Schedule ---- total_steps = args.epochs * len(dataloader) // args.grad_accum warmup_steps = min(args.warmup_steps, total_steps // 10) scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps) # ---- Loss ---- criterion = PhysicsInformedFlowLoss( lambda_smooth=args.lambda_smooth, lambda_tv=args.lambda_tv, use_adaptive_weights=True, ).to(device) # ---- EMA ---- ema = EMAModel(model, decay=args.ema_decay) # ---- AMP ---- scaler = GradScaler(enabled=use_amp) # ---- Resume ---- start_epoch = 0 global_step = 0 if args.resume and os.path.exists(args.resume): print(f"Resuming from {args.resume}") ckpt = torch.load(args.resume, map_location=device) model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) scheduler.load_state_dict(ckpt['scheduler']) ema.load_state_dict(ckpt['ema']) start_epoch = ckpt['epoch'] + 1 global_step = ckpt['global_step'] print(f"Resumed at epoch {start_epoch}, step {global_step}") # ---- Training Config ---- config = { 'model_size': args.model_size, 'img_size': args.img_size, 'dataset': args.dataset, 'batch_size': args.batch_size, 'lr': args.lr, 'epochs': args.epochs, 'num_params': num_params, 'lambda_smooth': args.lambda_smooth, 'lambda_tv': args.lambda_tv, } with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: json.dump(config, f, indent=2) print(f"\n{'='*60}") print(f"Training for {args.epochs} epochs, {total_steps} steps") print(f"Batch size: {args.batch_size} x {args.grad_accum} = {args.batch_size * args.grad_accum}") print(f"Learning rate: {args.lr}") print(f"{'='*60}\n") # ---- Training ---- best_loss = float('inf') log_losses = [] for epoch in range(start_epoch, args.epochs): model.train() epoch_loss = 0.0 epoch_flow_loss = 0.0 epoch_physics_loss = 0.0 num_batches = 0 for batch_idx, batch_data in enumerate(dataloader): # Handle different dataset formats if isinstance(batch_data, (list, tuple)): x1 = batch_data[0].to(device) # images only, ignore labels else: x1 = batch_data.to(device) B = x1.shape[0] # Sample noise (x0) and timestep (t) x0 = torch.randn_like(x1) t = torch.rand(B, device=device) # Interpolate: x_t = t * x_1 + (1-t) * x_0 t_expand = t.view(B, 1, 1, 1) x_t = t_expand * x1 + (1.0 - t_expand) * x0 # Forward pass with AMP with autocast(enabled=use_amp): v_pred = model(x_t, t) loss, loss_dict = criterion( v_pred, x0, x1, t, step=global_step, ) loss = loss / args.grad_accum # Backward scaler.scale(loss).backward() # Gradient accumulation step if (batch_idx + 1) % args.grad_accum == 0: # Gradient clipping (critical for stability) scaler.unscale_(optimizer) grad_norm = nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad() scheduler.step() ema.update(model) global_step += 1 # Logging epoch_loss += loss_dict['total'].item() epoch_flow_loss += loss_dict['flow'].item() epoch_physics_loss += (loss_dict['smooth'].item() + loss_dict['tv'].item()) num_batches += 1 if global_step % args.log_every == 0: avg_loss = epoch_loss / max(1, num_batches) avg_flow = epoch_flow_loss / max(1, num_batches) avg_phys = epoch_physics_loss / max(1, num_batches) lr_current = scheduler.get_last_lr()[0] print( f"[Epoch {epoch+1}/{args.epochs}] " f"Step {global_step}/{total_steps} | " f"Loss: {avg_loss:.4f} | " f"Flow: {avg_flow:.4f} | " f"Physics: {avg_phys:.6f} | " f"LR: {lr_current:.2e} | " f"GradNorm: {grad_norm:.2f}" ) log_losses.append({ 'step': global_step, 'epoch': epoch, 'loss': avg_loss, 'flow_loss': avg_flow, 'physics_loss': avg_phys, 'lr': lr_current, 'grad_norm': grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm, }) # ---- End of Epoch ---- avg_epoch_loss = epoch_loss / max(1, num_batches) print(f"\n[Epoch {epoch+1}] Average Loss: {avg_epoch_loss:.4f}\n") # Sample images with EMA if (epoch + 1) % args.sample_every == 0 or epoch == 0: print("Generating samples...") model.eval() ema.apply_shadow(model) with torch.no_grad(): shape = (min(16, args.batch_size), 3, args.img_size, args.img_size) samples = euler_sample(model, shape, num_steps=args.sample_steps, device=device) samples = samples.clamp(-1, 1) * 0.5 + 0.5 grid = make_grid_image(samples, nrow=4) grid.save(os.path.join(args.output_dir, 'samples', f'epoch_{epoch+1:04d}.png')) print(f" Saved samples to samples/epoch_{epoch+1:04d}.png") ema.restore(model) model.train() # Save checkpoint if (epoch + 1) % args.save_every == 0 or avg_epoch_loss < best_loss: best_loss = min(best_loss, avg_epoch_loss) ckpt = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'ema': ema.state_dict(), 'epoch': epoch, 'global_step': global_step, 'loss': avg_epoch_loss, 'config': config, } ckpt_path = os.path.join(args.output_dir, 'checkpoints', f'epoch_{epoch+1:04d}.pt') torch.save(ckpt, ckpt_path) print(f" Saved checkpoint: {ckpt_path}") # Also save "latest" and "best" torch.save(ckpt, os.path.join(args.output_dir, 'checkpoints', 'latest.pt')) if avg_epoch_loss <= best_loss: torch.save(ckpt, os.path.join(args.output_dir, 'checkpoints', 'best.pt')) # Save final model (EMA weights) ema.apply_shadow(model) final_state = { 'model': model.state_dict(), 'config': config, } torch.save(final_state, os.path.join(args.output_dir, 'liquidflow_final.pt')) ema.restore(model) # Save training log with open(os.path.join(args.output_dir, 'training_log.json'), 'w') as f: json.dump(log_losses, f, indent=2) print(f"\n{'='*60}") print(f"Training complete! Final model saved to {args.output_dir}/liquidflow_final.pt") print(f"{'='*60}") return model def main(): parser = argparse.ArgumentParser(description='LiquidFlow Training') # Model parser.add_argument('--model_size', type=str, default='small', choices=['tiny', 'small', 'base', '512']) parser.add_argument('--img_size', type=int, default=128) # Dataset parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'flowers', 'celeba', 'folder']) parser.add_argument('--data_dir', type=str, default='./data') # Training parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--weight_decay', type=float, default=0.01) parser.add_argument('--grad_accum', type=int, default=1) parser.add_argument('--max_grad_norm', type=float, default=1.0) parser.add_argument('--warmup_steps', type=int, default=500) parser.add_argument('--ema_decay', type=float, default=0.9999) # Physics loss parser.add_argument('--lambda_smooth', type=float, default=0.01) parser.add_argument('--lambda_tv', type=float, default=0.001) # AMP parser.add_argument('--use_amp', action='store_true', default=True) parser.add_argument('--no_amp', action='store_true') # Logging & Saving parser.add_argument('--output_dir', type=str, default='./outputs') parser.add_argument('--log_every', type=int, default=50) parser.add_argument('--sample_every', type=int, default=5) parser.add_argument('--save_every', type=int, default=10) parser.add_argument('--sample_steps', type=int, default=50) parser.add_argument('--num_workers', type=int, default=2) # Resume parser.add_argument('--resume', type=str, default=None) args = parser.parse_args() if args.no_amp: args.use_amp = False train(args) if __name__ == '__main__': main()