PMA-VAE / train.py
krystv's picture
Upload train.py with huggingface_hub
f56debd verified
"""
PMA-VAE Training Script
========================
Progressive resolution training with:
- KL warmup (prevents posterior collapse)
- Discriminator cold start
- Mixed precision (fp16/bf16)
- Gradient checkpointing option
- Colab-friendly (T4 15GB VRAM)
- Checkpoint saving/resuming
"""
import os
import math
import time
import json
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.amp import GradScaler, autocast
from PIL import Image
import random
from model import PMAVAE, pmavae_tiny, pmavae_small, pmavae_base
from losses import PMAVAELoss
# ==============================================================================
# Dataset
# ==============================================================================
class ImageFolderDataset(Dataset):
"""Simple image folder dataset. Works with any folder of images."""
def __init__(self, root, resolution=256, random_crop=True):
self.root = root
self.resolution = resolution
self.random_crop = random_crop
self.files = []
exts = {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff'}
for dirpath, _, filenames in os.walk(root):
for f in filenames:
if os.path.splitext(f)[1].lower() in exts:
self.files.append(os.path.join(dirpath, f))
self.files.sort()
print(f"Found {len(self.files)} images in {root}")
if random_crop:
self.transform = transforms.Compose([
transforms.Resize(int(resolution * 1.15),
interpolation=transforms.InterpolationMode.LANCZOS,
antialias=True),
transforms.RandomCrop(resolution),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
else:
self.transform = transforms.Compose([
transforms.Resize((resolution, resolution),
interpolation=transforms.InterpolationMode.LANCZOS,
antialias=True),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
img = Image.open(self.files[idx]).convert('RGB')
return self.transform(img)
class HFDatasetWrapper(Dataset):
"""Wrapper for HuggingFace datasets with image column."""
def __init__(self, hf_dataset, image_column='image', resolution=256):
self.dataset = hf_dataset
self.image_column = image_column
self.transform = transforms.Compose([
transforms.Resize(int(resolution * 1.15),
interpolation=transforms.InterpolationMode.LANCZOS,
antialias=True),
transforms.RandomCrop(resolution),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img = self.dataset[idx][self.image_column]
if not isinstance(img, Image.Image):
img = Image.fromarray(img)
img = img.convert('RGB')
return self.transform(img)
# ==============================================================================
# KL Warmup Schedule
# ==============================================================================
class KLWarmup:
"""
Linear KL warmup to prevent posterior collapse.
KL weight goes from 0 → target over warmup_steps.
"""
def __init__(self, target_weight, warmup_steps=10000):
self.target_weight = target_weight
self.warmup_steps = warmup_steps
def get_weight(self, step):
if step >= self.warmup_steps:
return self.target_weight
return self.target_weight * (step / self.warmup_steps)
# ==============================================================================
# Training Loop
# ==============================================================================
class PMAVAETrainer:
"""
Full training pipeline for PMA-VAE.
Features:
- Progressive resolution training
- KL warmup
- Discriminator cold start
- Mixed precision
- Checkpoint save/resume
- Logging
"""
def __init__(self, config):
self.config = config
self.device = torch.device(config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu'))
self.global_step = 0
self.current_epoch = 0
# Build model
model_fn = {
'tiny': pmavae_tiny,
'small': pmavae_small,
'base': pmavae_base,
}[config.get('model_size', 'small')]
self.model = model_fn(
use_parallel_scan=config.get('use_parallel_scan', True)
).to(self.device)
params = self.model.count_parameters()
print(f"Model: {config.get('model_size', 'small')}")
print(f" Encoder: {params['encoder_M']:.2f}M params")
print(f" Decoder: {params['decoder_M']:.2f}M params")
print(f" Total: {params['total_M']:.2f}M params")
# Build loss
self.criterion = PMAVAELoss(
disc_start=config.get('disc_start', 10000),
kl_weight=config.get('kl_weight', 1e-6),
perceptual_weight=config.get('perceptual_weight', 0.5),
disc_weight=config.get('disc_weight', 0.5),
edge_weight=config.get('edge_weight', 0.1),
free_bits=config.get('free_bits', 0.25),
).to(self.device)
# Optimizers
lr = config.get('lr', 4.5e-6)
self.opt_vae = torch.optim.AdamW(
self.model.parameters(),
lr=lr * config.get('batch_size', 4), # scale with batch size
betas=(0.5, 0.9),
weight_decay=config.get('weight_decay', 0.01),
)
self.opt_disc = torch.optim.AdamW(
self.criterion.discriminator.parameters(),
lr=lr * config.get('batch_size', 4),
betas=(0.5, 0.9),
weight_decay=config.get('weight_decay', 0.01),
)
# Mixed precision
self.use_amp = config.get('use_amp', True)
self.scaler_vae = GradScaler('cpu' if self.device.type == 'cpu' else 'cuda', enabled=self.use_amp)
self.scaler_disc = GradScaler('cpu' if self.device.type == 'cpu' else 'cuda', enabled=self.use_amp)
# KL warmup
self.kl_warmup = KLWarmup(
target_weight=config.get('kl_weight', 1e-6),
warmup_steps=config.get('kl_warmup_steps', 5000),
)
# Gradient checkpointing
if config.get('gradient_checkpointing', False):
self._enable_gradient_checkpointing()
# Logging
self.log_every = config.get('log_every', 50)
self.save_every = config.get('save_every', 5000)
self.output_dir = config.get('output_dir', './checkpoints')
os.makedirs(self.output_dir, exist_ok=True)
self.train_log = []
def _enable_gradient_checkpointing(self):
"""Enable gradient checkpointing for encoder (saves ~30% VRAM)."""
from torch.utils.checkpoint import checkpoint
# Wrap encoder stages
for stage in [self.model.encoder.stage1]:
for module in stage:
module._original_forward = module.forward
module.forward = lambda x, m=module: checkpoint(m._original_forward, x, use_reentrant=False)
def train_step(self, batch):
"""Single training step with both VAE and discriminator updates."""
batch = batch.to(self.device)
# Update KL weight
current_kl_weight = self.kl_warmup.get_weight(self.global_step)
self.criterion.kl_weight = current_kl_weight
# ==================== VAE Update ====================
self.opt_vae.zero_grad()
with autocast(device_type=self.device.type, enabled=self.use_amp):
recon, posteriors = self.model(batch)
loss_vae, log_vae = self.criterion(
batch, recon, posteriors,
optimizer_idx=0,
global_step=self.global_step,
last_layer=self.model.get_last_decoder_layer()
)
self.scaler_vae.scale(loss_vae).backward()
# Gradient clipping
self.scaler_vae.unscale_(self.opt_vae)
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.scaler_vae.step(self.opt_vae)
self.scaler_vae.update()
# ==================== Discriminator Update ====================
self.opt_disc.zero_grad()
with autocast(device_type=self.device.type, enabled=self.use_amp):
# Recompute recon without grad for disc
with torch.no_grad():
recon_detached, _ = self.model(batch)
loss_disc, log_disc = self.criterion(
batch, recon_detached, posteriors,
optimizer_idx=1,
global_step=self.global_step,
)
if self.global_step >= self.criterion.disc_start:
self.scaler_disc.scale(loss_disc).backward()
self.scaler_disc.unscale_(self.opt_disc)
torch.nn.utils.clip_grad_norm_(self.criterion.discriminator.parameters(), 1.0)
self.scaler_disc.step(self.opt_disc)
self.scaler_disc.update()
self.global_step += 1
# Merge logs
log = {**log_vae, **log_disc}
log['grad_norm'] = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
log['kl_weight'] = current_kl_weight
log['step'] = self.global_step
return log
def train_epoch(self, dataloader):
"""Train for one epoch."""
self.model.train()
epoch_logs = []
for batch_idx, batch in enumerate(dataloader):
log = self.train_step(batch)
epoch_logs.append(log)
if self.global_step % self.log_every == 0:
avg_log = self._average_logs(epoch_logs[-self.log_every:])
print(f"Step {self.global_step:6d} | "
f"L1: {avg_log.get('l1_loss', 0):.4f} | "
f"Perc: {avg_log.get('perceptual_loss', 0):.4f} | "
f"KL: {avg_log.get('kl_total', 0):.2f} | "
f"D: {avg_log.get('d_loss', 0):.4f} | "
f"G: {avg_log.get('g_loss', 0):.4f} | "
f"GN: {avg_log.get('grad_norm', 0):.2f}")
if self.global_step % self.save_every == 0:
self.save_checkpoint()
self.current_epoch += 1
return epoch_logs
def train(self, dataloader, num_epochs=100):
"""Full training loop."""
print(f"\nStarting training for {num_epochs} epochs")
print(f" Steps per epoch: {len(dataloader)}")
print(f" Device: {self.device}")
print(f" AMP: {self.use_amp}")
print(f" Disc starts at step: {self.criterion.disc_start}")
print(f" KL warmup steps: {self.kl_warmup.warmup_steps}")
print()
all_logs = []
start_time = time.time()
for epoch in range(num_epochs):
epoch_start = time.time()
epoch_logs = self.train_epoch(dataloader)
all_logs.extend(epoch_logs)
epoch_time = time.time() - epoch_start
avg = self._average_logs(epoch_logs)
print(f"\n{'='*60}")
print(f"Epoch {epoch+1}/{num_epochs} completed in {epoch_time:.1f}s")
print(f" Avg L1: {avg.get('l1_loss', 0):.4f}")
print(f" Avg Perceptual: {avg.get('perceptual_loss', 0):.4f}")
print(f" Avg KL: {avg.get('kl_total', 0):.2f}")
print(f" Total time: {(time.time()-start_time)/60:.1f} min")
print(f"{'='*60}\n")
self.save_checkpoint(f'epoch_{epoch+1}')
self.save_checkpoint('final')
# Save training log
with open(os.path.join(self.output_dir, 'train_log.json'), 'w') as f:
json.dump(all_logs, f)
total_time = time.time() - start_time
print(f"\nTraining complete! Total time: {total_time/60:.1f} min")
return all_logs
def save_checkpoint(self, tag='latest'):
"""Save model and optimizer states."""
path = os.path.join(self.output_dir, f'checkpoint_{tag}.pt')
torch.save({
'model_state': self.model.state_dict(),
'disc_state': self.criterion.discriminator.state_dict(),
'opt_vae_state': self.opt_vae.state_dict(),
'opt_disc_state': self.opt_disc.state_dict(),
'global_step': self.global_step,
'epoch': self.current_epoch,
'config': self.config,
}, path)
print(f" Saved checkpoint: {path}")
def load_checkpoint(self, path):
"""Load checkpoint."""
ckpt = torch.load(path, map_location=self.device, weights_only=False)
self.model.load_state_dict(ckpt['model_state'])
self.criterion.discriminator.load_state_dict(ckpt['disc_state'])
self.opt_vae.load_state_dict(ckpt['opt_vae_state'])
self.opt_disc.load_state_dict(ckpt['opt_disc_state'])
self.global_step = ckpt['global_step']
self.current_epoch = ckpt['epoch']
print(f"Loaded checkpoint from {path} (step {self.global_step})")
def _average_logs(self, logs):
"""Average a list of log dicts."""
if not logs:
return {}
avg = {}
for key in logs[0]:
if key == 'step':
continue
vals = [l[key] for l in logs if key in l]
if vals:
avg[key] = sum(vals) / len(vals)
return avg
@torch.no_grad()
def validate(self, dataloader, max_batches=50):
"""Run validation."""
self.model.eval()
logs = []
for i, batch in enumerate(dataloader):
if i >= max_batches:
break
batch = batch.to(self.device)
recon, posteriors = self.model(batch)
# Compute metrics
l1 = F.l1_loss(recon, batch).item()
# PSNR
mse = F.mse_loss(recon, batch).item()
psnr = -10 * math.log10(mse + 1e-8)
logs.append({'l1': l1, 'psnr': psnr})
avg = {k: sum(l[k] for l in logs) / len(logs) for k in logs[0]}
print(f"Validation: L1={avg['l1']:.4f}, PSNR={avg['psnr']:.2f}dB")
self.model.train()
return avg
# ==============================================================================
# Synthetic data for testing
# ==============================================================================
class SyntheticDataset(Dataset):
"""Synthetic dataset for testing the training loop."""
def __init__(self, num_samples=1000, resolution=128):
self.num_samples = num_samples
self.resolution = resolution
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Random noise smoothed to look like natural image patterns
r = self.resolution
img = torch.randn(3, r, r)
# Smooth with avg pool (same padding to keep resolution)
k = min(8, r // 4)
if k >= 2:
img = F.interpolate(
F.avg_pool2d(img.unsqueeze(0), k, stride=1, padding=0),
size=(r, r), mode='bilinear', align_corners=False
).squeeze(0)
# Normalize to [-1, 1]
img = img / (img.abs().max() + 1e-6)
return img
# ==============================================================================
# Main
# ==============================================================================
def create_default_config():
return {
'model_size': 'tiny', # tiny/small/base
'resolution': 128,
'batch_size': 4,
'num_epochs': 5,
'lr': 4.5e-6,
'weight_decay': 0.01,
'kl_weight': 1e-6,
'kl_warmup_steps': 2000,
'perceptual_weight': 0.5,
'disc_weight': 0.5,
'edge_weight': 0.1,
'free_bits': 0.25,
'disc_start': 5000,
'use_amp': True,
'use_parallel_scan': False, # sequential for CPU testing
'gradient_checkpointing': False,
'log_every': 10,
'save_every': 1000,
'output_dir': './checkpoints',
'device': 'cuda' if torch.cuda.is_available() else 'cpu',
}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--test', action='store_true', help='Quick test run')
parser.add_argument('--model_size', default='tiny', choices=['tiny', 'small', 'base'])
parser.add_argument('--resolution', type=int, default=128)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--epochs', type=int, default=5)
parser.add_argument('--data_dir', default=None)
args = parser.parse_args()
config = create_default_config()
config['model_size'] = args.model_size
config['resolution'] = args.resolution
config['batch_size'] = args.batch_size
config['num_epochs'] = args.epochs
if args.test:
config['resolution'] = 128 # must be divisible by 16 for PixelUnshuffle
config['batch_size'] = 2
config['num_epochs'] = 1
config['log_every'] = 5
config['disc_start'] = 5
config['kl_warmup_steps'] = 10
config['use_amp'] = False
config['use_parallel_scan'] = False
config['perceptual_weight'] = 0.0 # skip VGG in quick test for speed
config['edge_weight'] = 0.0
# Create dataset
if args.data_dir and os.path.isdir(args.data_dir):
dataset = ImageFolderDataset(args.data_dir, resolution=config['resolution'])
else:
print("Using synthetic dataset for testing")
dataset = SyntheticDataset(num_samples=40, resolution=config['resolution'])
dataloader = DataLoader(
dataset,
batch_size=config['batch_size'],
shuffle=True,
num_workers=0,
pin_memory=True if config['device'] == 'cuda' else False,
drop_last=True,
)
# Create trainer and train
trainer = PMAVAETrainer(config)
trainer.train(dataloader, num_epochs=config['num_epochs'])