| """ |
| LiquidFlow Training Script |
| |
| Designed for: |
| - Google Colab free tier (T4 16GB VRAM) |
| - Kaggle free tier (P100 16GB / T4x2) |
| - Any GPU with ≥8GB VRAM (128x128) |
| - Any GPU with ≥16GB VRAM (512x512) |
| |
| Key training features: |
| - Mixed precision (fp16/bf16) for memory efficiency |
| - Gradient accumulation for large effective batch sizes |
| - EMA for stable generation quality |
| - Physics-informed loss with warmup |
| - Cosine learning rate schedule with warmup |
| - Checkpoint saving/resuming |
| - Wandb/Trackio logging support |
| """ |
|
|
| import os |
| import sys |
| import math |
| import time |
| import json |
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from torch.cuda.amp import autocast, GradScaler |
| import torchvision |
| import torchvision.transforms as transforms |
| from PIL import Image |
| import numpy as np |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from model import ( |
| LiquidFlowNet, liquidflow_tiny, liquidflow_small, |
| liquidflow_base, liquidflow_512 |
| ) |
| from losses import PhysicsInformedFlowLoss, EMAModel |
| from sampling import euler_sample, heun_sample, make_grid_image |
|
|
|
|
| |
| |
| |
|
|
| class ImageFolderDataset(Dataset): |
| """Simple image dataset from folder.""" |
| |
| def __init__(self, root, img_size=128, transform=None): |
| self.root = Path(root) |
| self.img_size = img_size |
| |
| |
| self.files = [] |
| for ext in ['*.png', '*.jpg', '*.jpeg', '*.webp', '*.bmp']: |
| self.files.extend(self.root.rglob(ext)) |
| self.files = sorted(self.files) |
| |
| if transform is None: |
| self.transform = transforms.Compose([ |
| transforms.Resize(img_size), |
| transforms.CenterCrop(img_size), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ]) |
| else: |
| self.transform = transform |
| |
| def __len__(self): |
| return len(self.files) |
| |
| def __getitem__(self, idx): |
| img = Image.open(self.files[idx]).convert('RGB') |
| return self.transform(img) |
|
|
|
|
| def get_cifar10_dataset(img_size=32, data_dir='./data'): |
| """CIFAR-10 for quick experiments.""" |
| transform = transforms.Compose([ |
| transforms.Resize(img_size) if img_size != 32 else transforms.Lambda(lambda x: x), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ]) |
| dataset = torchvision.datasets.CIFAR10( |
| root=data_dir, train=True, download=True, transform=transform |
| ) |
| return dataset |
|
|
|
|
| def get_celeba_dataset(img_size=128, data_dir='./data'): |
| """CelebA for face generation.""" |
| transform = transforms.Compose([ |
| transforms.Resize(img_size), |
| transforms.CenterCrop(img_size), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ]) |
| dataset = torchvision.datasets.CelebA( |
| root=data_dir, split='train', download=True, transform=transform |
| ) |
| return dataset |
|
|
|
|
| def get_flowers_dataset(img_size=128, data_dir='./data'): |
| """Oxford Flowers 102 - small but beautiful dataset.""" |
| transform = transforms.Compose([ |
| transforms.Resize(img_size + img_size // 8), |
| transforms.CenterCrop(img_size), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ]) |
| dataset = torchvision.datasets.Flowers102( |
| root=data_dir, split='train', download=True, transform=transform |
| ) |
| return dataset |
|
|
|
|
| |
| |
| |
|
|
| def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps, min_lr_ratio=0.1): |
| """Cosine annealing with linear warmup.""" |
| def lr_lambda(step): |
| if step < warmup_steps: |
| return step / max(1, warmup_steps) |
| progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) |
| return min_lr_ratio + (1 - min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress)) |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
| |
| |
| |
|
|
| def train(args): |
| """Main training function.""" |
| |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| use_amp = device.type == 'cuda' and args.use_amp |
| print(f"Device: {device}, AMP: {use_amp}") |
| |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| os.makedirs(os.path.join(args.output_dir, 'samples'), exist_ok=True) |
| os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True) |
| |
| |
| model_factories = { |
| 'tiny': liquidflow_tiny, |
| 'small': liquidflow_small, |
| 'base': liquidflow_base, |
| '512': liquidflow_512, |
| } |
| |
| if args.model_size in model_factories: |
| model = model_factories[args.model_size](img_size=args.img_size) |
| else: |
| model = liquidflow_small(img_size=args.img_size) |
| |
| model = model.to(device) |
| num_params = model.count_params() |
| print(f"Model: LiquidFlow-{args.model_size}, Params: {num_params/1e6:.2f}M") |
| print(f"Image size: {args.img_size}x{args.img_size}") |
| |
| |
| if args.dataset == 'cifar10': |
| dataset = get_cifar10_dataset(args.img_size, args.data_dir) |
| elif args.dataset == 'flowers': |
| dataset = get_flowers_dataset(args.img_size, args.data_dir) |
| elif args.dataset == 'celeba': |
| dataset = get_celeba_dataset(args.img_size, args.data_dir) |
| elif args.dataset == 'folder': |
| dataset = ImageFolderDataset(args.data_dir, args.img_size) |
| else: |
| raise ValueError(f"Unknown dataset: {args.dataset}") |
| |
| print(f"Dataset: {args.dataset}, Size: {len(dataset)}") |
| |
| dataloader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| drop_last=True, |
| ) |
| |
| |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=args.lr, |
| betas=(0.9, 0.999), |
| weight_decay=args.weight_decay, |
| eps=1e-8, |
| ) |
| |
| |
| total_steps = args.epochs * len(dataloader) // args.grad_accum |
| warmup_steps = min(args.warmup_steps, total_steps // 10) |
| scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps) |
| |
| |
| criterion = PhysicsInformedFlowLoss( |
| lambda_smooth=args.lambda_smooth, |
| lambda_tv=args.lambda_tv, |
| use_adaptive_weights=True, |
| ).to(device) |
| |
| |
| ema = EMAModel(model, decay=args.ema_decay) |
| |
| |
| scaler = GradScaler(enabled=use_amp) |
| |
| |
| start_epoch = 0 |
| global_step = 0 |
| |
| if args.resume and os.path.exists(args.resume): |
| print(f"Resuming from {args.resume}") |
| ckpt = torch.load(args.resume, map_location=device) |
| model.load_state_dict(ckpt['model']) |
| optimizer.load_state_dict(ckpt['optimizer']) |
| scheduler.load_state_dict(ckpt['scheduler']) |
| ema.load_state_dict(ckpt['ema']) |
| start_epoch = ckpt['epoch'] + 1 |
| global_step = ckpt['global_step'] |
| print(f"Resumed at epoch {start_epoch}, step {global_step}") |
| |
| |
| config = { |
| 'model_size': args.model_size, |
| 'img_size': args.img_size, |
| 'dataset': args.dataset, |
| 'batch_size': args.batch_size, |
| 'lr': args.lr, |
| 'epochs': args.epochs, |
| 'num_params': num_params, |
| 'lambda_smooth': args.lambda_smooth, |
| 'lambda_tv': args.lambda_tv, |
| } |
| |
| with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: |
| json.dump(config, f, indent=2) |
| |
| print(f"\n{'='*60}") |
| print(f"Training for {args.epochs} epochs, {total_steps} steps") |
| print(f"Batch size: {args.batch_size} x {args.grad_accum} = {args.batch_size * args.grad_accum}") |
| print(f"Learning rate: {args.lr}") |
| print(f"{'='*60}\n") |
| |
| |
| best_loss = float('inf') |
| log_losses = [] |
| |
| for epoch in range(start_epoch, args.epochs): |
| model.train() |
| epoch_loss = 0.0 |
| epoch_flow_loss = 0.0 |
| epoch_physics_loss = 0.0 |
| num_batches = 0 |
| |
| for batch_idx, batch_data in enumerate(dataloader): |
| |
| if isinstance(batch_data, (list, tuple)): |
| x1 = batch_data[0].to(device) |
| else: |
| x1 = batch_data.to(device) |
| |
| B = x1.shape[0] |
| |
| |
| x0 = torch.randn_like(x1) |
| t = torch.rand(B, device=device) |
| |
| |
| t_expand = t.view(B, 1, 1, 1) |
| x_t = t_expand * x1 + (1.0 - t_expand) * x0 |
| |
| |
| with autocast(enabled=use_amp): |
| v_pred = model(x_t, t) |
| loss, loss_dict = criterion( |
| v_pred, x0, x1, t, |
| step=global_step, |
| ) |
| loss = loss / args.grad_accum |
| |
| |
| scaler.scale(loss).backward() |
| |
| |
| if (batch_idx + 1) % args.grad_accum == 0: |
| |
| scaler.unscale_(optimizer) |
| grad_norm = nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) |
| |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad() |
| scheduler.step() |
| ema.update(model) |
| global_step += 1 |
| |
| |
| epoch_loss += loss_dict['total'].item() |
| epoch_flow_loss += loss_dict['flow'].item() |
| epoch_physics_loss += (loss_dict['smooth'].item() + loss_dict['tv'].item()) |
| num_batches += 1 |
| |
| if global_step % args.log_every == 0: |
| avg_loss = epoch_loss / max(1, num_batches) |
| avg_flow = epoch_flow_loss / max(1, num_batches) |
| avg_phys = epoch_physics_loss / max(1, num_batches) |
| lr_current = scheduler.get_last_lr()[0] |
| |
| print( |
| f"[Epoch {epoch+1}/{args.epochs}] " |
| f"Step {global_step}/{total_steps} | " |
| f"Loss: {avg_loss:.4f} | " |
| f"Flow: {avg_flow:.4f} | " |
| f"Physics: {avg_phys:.6f} | " |
| f"LR: {lr_current:.2e} | " |
| f"GradNorm: {grad_norm:.2f}" |
| ) |
| |
| log_losses.append({ |
| 'step': global_step, |
| 'epoch': epoch, |
| 'loss': avg_loss, |
| 'flow_loss': avg_flow, |
| 'physics_loss': avg_phys, |
| 'lr': lr_current, |
| 'grad_norm': grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm, |
| }) |
| |
| |
| avg_epoch_loss = epoch_loss / max(1, num_batches) |
| print(f"\n[Epoch {epoch+1}] Average Loss: {avg_epoch_loss:.4f}\n") |
| |
| |
| if (epoch + 1) % args.sample_every == 0 or epoch == 0: |
| print("Generating samples...") |
| model.eval() |
| ema.apply_shadow(model) |
| |
| with torch.no_grad(): |
| shape = (min(16, args.batch_size), 3, args.img_size, args.img_size) |
| samples = euler_sample(model, shape, num_steps=args.sample_steps, device=device) |
| samples = samples.clamp(-1, 1) * 0.5 + 0.5 |
| |
| grid = make_grid_image(samples, nrow=4) |
| grid.save(os.path.join(args.output_dir, 'samples', f'epoch_{epoch+1:04d}.png')) |
| print(f" Saved samples to samples/epoch_{epoch+1:04d}.png") |
| |
| ema.restore(model) |
| model.train() |
| |
| |
| if (epoch + 1) % args.save_every == 0 or avg_epoch_loss < best_loss: |
| best_loss = min(best_loss, avg_epoch_loss) |
| ckpt = { |
| 'model': model.state_dict(), |
| 'optimizer': optimizer.state_dict(), |
| 'scheduler': scheduler.state_dict(), |
| 'ema': ema.state_dict(), |
| 'epoch': epoch, |
| 'global_step': global_step, |
| 'loss': avg_epoch_loss, |
| 'config': config, |
| } |
| ckpt_path = os.path.join(args.output_dir, 'checkpoints', f'epoch_{epoch+1:04d}.pt') |
| torch.save(ckpt, ckpt_path) |
| print(f" Saved checkpoint: {ckpt_path}") |
| |
| |
| torch.save(ckpt, os.path.join(args.output_dir, 'checkpoints', 'latest.pt')) |
| if avg_epoch_loss <= best_loss: |
| torch.save(ckpt, os.path.join(args.output_dir, 'checkpoints', 'best.pt')) |
| |
| |
| ema.apply_shadow(model) |
| final_state = { |
| 'model': model.state_dict(), |
| 'config': config, |
| } |
| torch.save(final_state, os.path.join(args.output_dir, 'liquidflow_final.pt')) |
| ema.restore(model) |
| |
| |
| with open(os.path.join(args.output_dir, 'training_log.json'), 'w') as f: |
| json.dump(log_losses, f, indent=2) |
| |
| print(f"\n{'='*60}") |
| print(f"Training complete! Final model saved to {args.output_dir}/liquidflow_final.pt") |
| print(f"{'='*60}") |
| |
| return model |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='LiquidFlow Training') |
| |
| |
| parser.add_argument('--model_size', type=str, default='small', |
| choices=['tiny', 'small', 'base', '512']) |
| parser.add_argument('--img_size', type=int, default=128) |
| |
| |
| parser.add_argument('--dataset', type=str, default='cifar10', |
| choices=['cifar10', 'flowers', 'celeba', 'folder']) |
| parser.add_argument('--data_dir', type=str, default='./data') |
| |
| |
| parser.add_argument('--epochs', type=int, default=100) |
| parser.add_argument('--batch_size', type=int, default=32) |
| parser.add_argument('--lr', type=float, default=3e-4) |
| parser.add_argument('--weight_decay', type=float, default=0.01) |
| parser.add_argument('--grad_accum', type=int, default=1) |
| parser.add_argument('--max_grad_norm', type=float, default=1.0) |
| parser.add_argument('--warmup_steps', type=int, default=500) |
| parser.add_argument('--ema_decay', type=float, default=0.9999) |
| |
| |
| parser.add_argument('--lambda_smooth', type=float, default=0.01) |
| parser.add_argument('--lambda_tv', type=float, default=0.001) |
| |
| |
| parser.add_argument('--use_amp', action='store_true', default=True) |
| parser.add_argument('--no_amp', action='store_true') |
| |
| |
| parser.add_argument('--output_dir', type=str, default='./outputs') |
| parser.add_argument('--log_every', type=int, default=50) |
| parser.add_argument('--sample_every', type=int, default=5) |
| parser.add_argument('--save_every', type=int, default=10) |
| parser.add_argument('--sample_steps', type=int, default=50) |
| parser.add_argument('--num_workers', type=int, default=2) |
| |
| |
| parser.add_argument('--resume', type=str, default=None) |
| |
| args = parser.parse_args() |
| |
| if args.no_amp: |
| args.use_amp = False |
| |
| train(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|