LiquidFlow / liquidflow /train.py
krystv's picture
Add train.py — full training script with CLI
04ada50 verified
"""
LiquidFlow Training Script
Designed for:
- Google Colab free tier (T4 16GB VRAM)
- Kaggle free tier (P100 16GB / T4x2)
- Any GPU with ≥8GB VRAM (128x128)
- Any GPU with ≥16GB VRAM (512x512)
Key training features:
- Mixed precision (fp16/bf16) for memory efficiency
- Gradient accumulation for large effective batch sizes
- EMA for stable generation quality
- Physics-informed loss with warmup
- Cosine learning rate schedule with warmup
- Checkpoint saving/resuming
- Wandb/Trackio logging support
"""
import os
import sys
import math
import time
import json
import argparse
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
# Add parent to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model import (
LiquidFlowNet, liquidflow_tiny, liquidflow_small,
liquidflow_base, liquidflow_512
)
from losses import PhysicsInformedFlowLoss, EMAModel
from sampling import euler_sample, heun_sample, make_grid_image
# ============================================================
# DATASET UTILITIES
# ============================================================
class ImageFolderDataset(Dataset):
"""Simple image dataset from folder."""
def __init__(self, root, img_size=128, transform=None):
self.root = Path(root)
self.img_size = img_size
# Find all images
self.files = []
for ext in ['*.png', '*.jpg', '*.jpeg', '*.webp', '*.bmp']:
self.files.extend(self.root.rglob(ext))
self.files = sorted(self.files)
if transform is None:
self.transform = transforms.Compose([
transforms.Resize(img_size),
transforms.CenterCrop(img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
else:
self.transform = transform
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
img = Image.open(self.files[idx]).convert('RGB')
return self.transform(img)
def get_cifar10_dataset(img_size=32, data_dir='./data'):
"""CIFAR-10 for quick experiments."""
transform = transforms.Compose([
transforms.Resize(img_size) if img_size != 32 else transforms.Lambda(lambda x: x),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
dataset = torchvision.datasets.CIFAR10(
root=data_dir, train=True, download=True, transform=transform
)
return dataset
def get_celeba_dataset(img_size=128, data_dir='./data'):
"""CelebA for face generation."""
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.CenterCrop(img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
dataset = torchvision.datasets.CelebA(
root=data_dir, split='train', download=True, transform=transform
)
return dataset
def get_flowers_dataset(img_size=128, data_dir='./data'):
"""Oxford Flowers 102 - small but beautiful dataset."""
transform = transforms.Compose([
transforms.Resize(img_size + img_size // 8),
transforms.CenterCrop(img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
dataset = torchvision.datasets.Flowers102(
root=data_dir, split='train', download=True, transform=transform
)
return dataset
# ============================================================
# LEARNING RATE SCHEDULE
# ============================================================
def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps, min_lr_ratio=0.1):
"""Cosine annealing with linear warmup."""
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
return min_lr_ratio + (1 - min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# ============================================================
# TRAINING LOOP
# ============================================================
def train(args):
"""Main training function."""
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
use_amp = device.type == 'cuda' and args.use_amp
print(f"Device: {device}, AMP: {use_amp}")
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, 'samples'), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True)
# ---- Model ----
model_factories = {
'tiny': liquidflow_tiny,
'small': liquidflow_small,
'base': liquidflow_base,
'512': liquidflow_512,
}
if args.model_size in model_factories:
model = model_factories[args.model_size](img_size=args.img_size)
else:
model = liquidflow_small(img_size=args.img_size)
model = model.to(device)
num_params = model.count_params()
print(f"Model: LiquidFlow-{args.model_size}, Params: {num_params/1e6:.2f}M")
print(f"Image size: {args.img_size}x{args.img_size}")
# ---- Dataset ----
if args.dataset == 'cifar10':
dataset = get_cifar10_dataset(args.img_size, args.data_dir)
elif args.dataset == 'flowers':
dataset = get_flowers_dataset(args.img_size, args.data_dir)
elif args.dataset == 'celeba':
dataset = get_celeba_dataset(args.img_size, args.data_dir)
elif args.dataset == 'folder':
dataset = ImageFolderDataset(args.data_dir, args.img_size)
else:
raise ValueError(f"Unknown dataset: {args.dataset}")
print(f"Dataset: {args.dataset}, Size: {len(dataset)}")
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True,
)
# ---- Optimizer ----
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.lr,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
eps=1e-8,
)
# ---- Schedule ----
total_steps = args.epochs * len(dataloader) // args.grad_accum
warmup_steps = min(args.warmup_steps, total_steps // 10)
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
# ---- Loss ----
criterion = PhysicsInformedFlowLoss(
lambda_smooth=args.lambda_smooth,
lambda_tv=args.lambda_tv,
use_adaptive_weights=True,
).to(device)
# ---- EMA ----
ema = EMAModel(model, decay=args.ema_decay)
# ---- AMP ----
scaler = GradScaler(enabled=use_amp)
# ---- Resume ----
start_epoch = 0
global_step = 0
if args.resume and os.path.exists(args.resume):
print(f"Resuming from {args.resume}")
ckpt = torch.load(args.resume, map_location=device)
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])
scheduler.load_state_dict(ckpt['scheduler'])
ema.load_state_dict(ckpt['ema'])
start_epoch = ckpt['epoch'] + 1
global_step = ckpt['global_step']
print(f"Resumed at epoch {start_epoch}, step {global_step}")
# ---- Training Config ----
config = {
'model_size': args.model_size,
'img_size': args.img_size,
'dataset': args.dataset,
'batch_size': args.batch_size,
'lr': args.lr,
'epochs': args.epochs,
'num_params': num_params,
'lambda_smooth': args.lambda_smooth,
'lambda_tv': args.lambda_tv,
}
with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
json.dump(config, f, indent=2)
print(f"\n{'='*60}")
print(f"Training for {args.epochs} epochs, {total_steps} steps")
print(f"Batch size: {args.batch_size} x {args.grad_accum} = {args.batch_size * args.grad_accum}")
print(f"Learning rate: {args.lr}")
print(f"{'='*60}\n")
# ---- Training ----
best_loss = float('inf')
log_losses = []
for epoch in range(start_epoch, args.epochs):
model.train()
epoch_loss = 0.0
epoch_flow_loss = 0.0
epoch_physics_loss = 0.0
num_batches = 0
for batch_idx, batch_data in enumerate(dataloader):
# Handle different dataset formats
if isinstance(batch_data, (list, tuple)):
x1 = batch_data[0].to(device) # images only, ignore labels
else:
x1 = batch_data.to(device)
B = x1.shape[0]
# Sample noise (x0) and timestep (t)
x0 = torch.randn_like(x1)
t = torch.rand(B, device=device)
# Interpolate: x_t = t * x_1 + (1-t) * x_0
t_expand = t.view(B, 1, 1, 1)
x_t = t_expand * x1 + (1.0 - t_expand) * x0
# Forward pass with AMP
with autocast(enabled=use_amp):
v_pred = model(x_t, t)
loss, loss_dict = criterion(
v_pred, x0, x1, t,
step=global_step,
)
loss = loss / args.grad_accum
# Backward
scaler.scale(loss).backward()
# Gradient accumulation step
if (batch_idx + 1) % args.grad_accum == 0:
# Gradient clipping (critical for stability)
scaler.unscale_(optimizer)
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
scheduler.step()
ema.update(model)
global_step += 1
# Logging
epoch_loss += loss_dict['total'].item()
epoch_flow_loss += loss_dict['flow'].item()
epoch_physics_loss += (loss_dict['smooth'].item() + loss_dict['tv'].item())
num_batches += 1
if global_step % args.log_every == 0:
avg_loss = epoch_loss / max(1, num_batches)
avg_flow = epoch_flow_loss / max(1, num_batches)
avg_phys = epoch_physics_loss / max(1, num_batches)
lr_current = scheduler.get_last_lr()[0]
print(
f"[Epoch {epoch+1}/{args.epochs}] "
f"Step {global_step}/{total_steps} | "
f"Loss: {avg_loss:.4f} | "
f"Flow: {avg_flow:.4f} | "
f"Physics: {avg_phys:.6f} | "
f"LR: {lr_current:.2e} | "
f"GradNorm: {grad_norm:.2f}"
)
log_losses.append({
'step': global_step,
'epoch': epoch,
'loss': avg_loss,
'flow_loss': avg_flow,
'physics_loss': avg_phys,
'lr': lr_current,
'grad_norm': grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm,
})
# ---- End of Epoch ----
avg_epoch_loss = epoch_loss / max(1, num_batches)
print(f"\n[Epoch {epoch+1}] Average Loss: {avg_epoch_loss:.4f}\n")
# Sample images with EMA
if (epoch + 1) % args.sample_every == 0 or epoch == 0:
print("Generating samples...")
model.eval()
ema.apply_shadow(model)
with torch.no_grad():
shape = (min(16, args.batch_size), 3, args.img_size, args.img_size)
samples = euler_sample(model, shape, num_steps=args.sample_steps, device=device)
samples = samples.clamp(-1, 1) * 0.5 + 0.5
grid = make_grid_image(samples, nrow=4)
grid.save(os.path.join(args.output_dir, 'samples', f'epoch_{epoch+1:04d}.png'))
print(f" Saved samples to samples/epoch_{epoch+1:04d}.png")
ema.restore(model)
model.train()
# Save checkpoint
if (epoch + 1) % args.save_every == 0 or avg_epoch_loss < best_loss:
best_loss = min(best_loss, avg_epoch_loss)
ckpt = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'ema': ema.state_dict(),
'epoch': epoch,
'global_step': global_step,
'loss': avg_epoch_loss,
'config': config,
}
ckpt_path = os.path.join(args.output_dir, 'checkpoints', f'epoch_{epoch+1:04d}.pt')
torch.save(ckpt, ckpt_path)
print(f" Saved checkpoint: {ckpt_path}")
# Also save "latest" and "best"
torch.save(ckpt, os.path.join(args.output_dir, 'checkpoints', 'latest.pt'))
if avg_epoch_loss <= best_loss:
torch.save(ckpt, os.path.join(args.output_dir, 'checkpoints', 'best.pt'))
# Save final model (EMA weights)
ema.apply_shadow(model)
final_state = {
'model': model.state_dict(),
'config': config,
}
torch.save(final_state, os.path.join(args.output_dir, 'liquidflow_final.pt'))
ema.restore(model)
# Save training log
with open(os.path.join(args.output_dir, 'training_log.json'), 'w') as f:
json.dump(log_losses, f, indent=2)
print(f"\n{'='*60}")
print(f"Training complete! Final model saved to {args.output_dir}/liquidflow_final.pt")
print(f"{'='*60}")
return model
def main():
parser = argparse.ArgumentParser(description='LiquidFlow Training')
# Model
parser.add_argument('--model_size', type=str, default='small',
choices=['tiny', 'small', 'base', '512'])
parser.add_argument('--img_size', type=int, default=128)
# Dataset
parser.add_argument('--dataset', type=str, default='cifar10',
choices=['cifar10', 'flowers', 'celeba', 'folder'])
parser.add_argument('--data_dir', type=str, default='./data')
# Training
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--weight_decay', type=float, default=0.01)
parser.add_argument('--grad_accum', type=int, default=1)
parser.add_argument('--max_grad_norm', type=float, default=1.0)
parser.add_argument('--warmup_steps', type=int, default=500)
parser.add_argument('--ema_decay', type=float, default=0.9999)
# Physics loss
parser.add_argument('--lambda_smooth', type=float, default=0.01)
parser.add_argument('--lambda_tv', type=float, default=0.001)
# AMP
parser.add_argument('--use_amp', action='store_true', default=True)
parser.add_argument('--no_amp', action='store_true')
# Logging & Saving
parser.add_argument('--output_dir', type=str, default='./outputs')
parser.add_argument('--log_every', type=int, default=50)
parser.add_argument('--sample_every', type=int, default=5)
parser.add_argument('--save_every', type=int, default=10)
parser.add_argument('--sample_steps', type=int, default=50)
parser.add_argument('--num_workers', type=int, default=2)
# Resume
parser.add_argument('--resume', type=str, default=None)
args = parser.parse_args()
if args.no_amp:
args.use_amp = False
train(args)
if __name__ == '__main__':
main()