""" PMA-VAE Training Script ======================== Progressive resolution training with: - KL warmup (prevents posterior collapse) - Discriminator cold start - Mixed precision (fp16/bf16) - Gradient checkpointing option - Colab-friendly (T4 15GB VRAM) - Checkpoint saving/resuming """ import os import math import time import json import argparse import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from torchvision import transforms from torch.amp import GradScaler, autocast from PIL import Image import random from model import PMAVAE, pmavae_tiny, pmavae_small, pmavae_base from losses import PMAVAELoss # ============================================================================== # Dataset # ============================================================================== class ImageFolderDataset(Dataset): """Simple image folder dataset. Works with any folder of images.""" def __init__(self, root, resolution=256, random_crop=True): self.root = root self.resolution = resolution self.random_crop = random_crop self.files = [] exts = {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff'} for dirpath, _, filenames in os.walk(root): for f in filenames: if os.path.splitext(f)[1].lower() in exts: self.files.append(os.path.join(dirpath, f)) self.files.sort() print(f"Found {len(self.files)} images in {root}") if random_crop: self.transform = transforms.Compose([ transforms.Resize(int(resolution * 1.15), interpolation=transforms.InterpolationMode.LANCZOS, antialias=True), transforms.RandomCrop(resolution), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) else: self.transform = transforms.Compose([ transforms.Resize((resolution, resolution), interpolation=transforms.InterpolationMode.LANCZOS, antialias=True), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) def __len__(self): return len(self.files) def __getitem__(self, idx): img = Image.open(self.files[idx]).convert('RGB') return self.transform(img) class HFDatasetWrapper(Dataset): """Wrapper for HuggingFace datasets with image column.""" def __init__(self, hf_dataset, image_column='image', resolution=256): self.dataset = hf_dataset self.image_column = image_column self.transform = transforms.Compose([ transforms.Resize(int(resolution * 1.15), interpolation=transforms.InterpolationMode.LANCZOS, antialias=True), transforms.RandomCrop(resolution), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) def __len__(self): return len(self.dataset) def __getitem__(self, idx): img = self.dataset[idx][self.image_column] if not isinstance(img, Image.Image): img = Image.fromarray(img) img = img.convert('RGB') return self.transform(img) # ============================================================================== # KL Warmup Schedule # ============================================================================== class KLWarmup: """ Linear KL warmup to prevent posterior collapse. KL weight goes from 0 → target over warmup_steps. """ def __init__(self, target_weight, warmup_steps=10000): self.target_weight = target_weight self.warmup_steps = warmup_steps def get_weight(self, step): if step >= self.warmup_steps: return self.target_weight return self.target_weight * (step / self.warmup_steps) # ============================================================================== # Training Loop # ============================================================================== class PMAVAETrainer: """ Full training pipeline for PMA-VAE. Features: - Progressive resolution training - KL warmup - Discriminator cold start - Mixed precision - Checkpoint save/resume - Logging """ def __init__(self, config): self.config = config self.device = torch.device(config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')) self.global_step = 0 self.current_epoch = 0 # Build model model_fn = { 'tiny': pmavae_tiny, 'small': pmavae_small, 'base': pmavae_base, }[config.get('model_size', 'small')] self.model = model_fn( use_parallel_scan=config.get('use_parallel_scan', True) ).to(self.device) params = self.model.count_parameters() print(f"Model: {config.get('model_size', 'small')}") print(f" Encoder: {params['encoder_M']:.2f}M params") print(f" Decoder: {params['decoder_M']:.2f}M params") print(f" Total: {params['total_M']:.2f}M params") # Build loss self.criterion = PMAVAELoss( disc_start=config.get('disc_start', 10000), kl_weight=config.get('kl_weight', 1e-6), perceptual_weight=config.get('perceptual_weight', 0.5), disc_weight=config.get('disc_weight', 0.5), edge_weight=config.get('edge_weight', 0.1), free_bits=config.get('free_bits', 0.25), ).to(self.device) # Optimizers lr = config.get('lr', 4.5e-6) self.opt_vae = torch.optim.AdamW( self.model.parameters(), lr=lr * config.get('batch_size', 4), # scale with batch size betas=(0.5, 0.9), weight_decay=config.get('weight_decay', 0.01), ) self.opt_disc = torch.optim.AdamW( self.criterion.discriminator.parameters(), lr=lr * config.get('batch_size', 4), betas=(0.5, 0.9), weight_decay=config.get('weight_decay', 0.01), ) # Mixed precision self.use_amp = config.get('use_amp', True) self.scaler_vae = GradScaler('cpu' if self.device.type == 'cpu' else 'cuda', enabled=self.use_amp) self.scaler_disc = GradScaler('cpu' if self.device.type == 'cpu' else 'cuda', enabled=self.use_amp) # KL warmup self.kl_warmup = KLWarmup( target_weight=config.get('kl_weight', 1e-6), warmup_steps=config.get('kl_warmup_steps', 5000), ) # Gradient checkpointing if config.get('gradient_checkpointing', False): self._enable_gradient_checkpointing() # Logging self.log_every = config.get('log_every', 50) self.save_every = config.get('save_every', 5000) self.output_dir = config.get('output_dir', './checkpoints') os.makedirs(self.output_dir, exist_ok=True) self.train_log = [] def _enable_gradient_checkpointing(self): """Enable gradient checkpointing for encoder (saves ~30% VRAM).""" from torch.utils.checkpoint import checkpoint # Wrap encoder stages for stage in [self.model.encoder.stage1]: for module in stage: module._original_forward = module.forward module.forward = lambda x, m=module: checkpoint(m._original_forward, x, use_reentrant=False) def train_step(self, batch): """Single training step with both VAE and discriminator updates.""" batch = batch.to(self.device) # Update KL weight current_kl_weight = self.kl_warmup.get_weight(self.global_step) self.criterion.kl_weight = current_kl_weight # ==================== VAE Update ==================== self.opt_vae.zero_grad() with autocast(device_type=self.device.type, enabled=self.use_amp): recon, posteriors = self.model(batch) loss_vae, log_vae = self.criterion( batch, recon, posteriors, optimizer_idx=0, global_step=self.global_step, last_layer=self.model.get_last_decoder_layer() ) self.scaler_vae.scale(loss_vae).backward() # Gradient clipping self.scaler_vae.unscale_(self.opt_vae) grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.scaler_vae.step(self.opt_vae) self.scaler_vae.update() # ==================== Discriminator Update ==================== self.opt_disc.zero_grad() with autocast(device_type=self.device.type, enabled=self.use_amp): # Recompute recon without grad for disc with torch.no_grad(): recon_detached, _ = self.model(batch) loss_disc, log_disc = self.criterion( batch, recon_detached, posteriors, optimizer_idx=1, global_step=self.global_step, ) if self.global_step >= self.criterion.disc_start: self.scaler_disc.scale(loss_disc).backward() self.scaler_disc.unscale_(self.opt_disc) torch.nn.utils.clip_grad_norm_(self.criterion.discriminator.parameters(), 1.0) self.scaler_disc.step(self.opt_disc) self.scaler_disc.update() self.global_step += 1 # Merge logs log = {**log_vae, **log_disc} log['grad_norm'] = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm log['kl_weight'] = current_kl_weight log['step'] = self.global_step return log def train_epoch(self, dataloader): """Train for one epoch.""" self.model.train() epoch_logs = [] for batch_idx, batch in enumerate(dataloader): log = self.train_step(batch) epoch_logs.append(log) if self.global_step % self.log_every == 0: avg_log = self._average_logs(epoch_logs[-self.log_every:]) print(f"Step {self.global_step:6d} | " f"L1: {avg_log.get('l1_loss', 0):.4f} | " f"Perc: {avg_log.get('perceptual_loss', 0):.4f} | " f"KL: {avg_log.get('kl_total', 0):.2f} | " f"D: {avg_log.get('d_loss', 0):.4f} | " f"G: {avg_log.get('g_loss', 0):.4f} | " f"GN: {avg_log.get('grad_norm', 0):.2f}") if self.global_step % self.save_every == 0: self.save_checkpoint() self.current_epoch += 1 return epoch_logs def train(self, dataloader, num_epochs=100): """Full training loop.""" print(f"\nStarting training for {num_epochs} epochs") print(f" Steps per epoch: {len(dataloader)}") print(f" Device: {self.device}") print(f" AMP: {self.use_amp}") print(f" Disc starts at step: {self.criterion.disc_start}") print(f" KL warmup steps: {self.kl_warmup.warmup_steps}") print() all_logs = [] start_time = time.time() for epoch in range(num_epochs): epoch_start = time.time() epoch_logs = self.train_epoch(dataloader) all_logs.extend(epoch_logs) epoch_time = time.time() - epoch_start avg = self._average_logs(epoch_logs) print(f"\n{'='*60}") print(f"Epoch {epoch+1}/{num_epochs} completed in {epoch_time:.1f}s") print(f" Avg L1: {avg.get('l1_loss', 0):.4f}") print(f" Avg Perceptual: {avg.get('perceptual_loss', 0):.4f}") print(f" Avg KL: {avg.get('kl_total', 0):.2f}") print(f" Total time: {(time.time()-start_time)/60:.1f} min") print(f"{'='*60}\n") self.save_checkpoint(f'epoch_{epoch+1}') self.save_checkpoint('final') # Save training log with open(os.path.join(self.output_dir, 'train_log.json'), 'w') as f: json.dump(all_logs, f) total_time = time.time() - start_time print(f"\nTraining complete! Total time: {total_time/60:.1f} min") return all_logs def save_checkpoint(self, tag='latest'): """Save model and optimizer states.""" path = os.path.join(self.output_dir, f'checkpoint_{tag}.pt') torch.save({ 'model_state': self.model.state_dict(), 'disc_state': self.criterion.discriminator.state_dict(), 'opt_vae_state': self.opt_vae.state_dict(), 'opt_disc_state': self.opt_disc.state_dict(), 'global_step': self.global_step, 'epoch': self.current_epoch, 'config': self.config, }, path) print(f" Saved checkpoint: {path}") def load_checkpoint(self, path): """Load checkpoint.""" ckpt = torch.load(path, map_location=self.device, weights_only=False) self.model.load_state_dict(ckpt['model_state']) self.criterion.discriminator.load_state_dict(ckpt['disc_state']) self.opt_vae.load_state_dict(ckpt['opt_vae_state']) self.opt_disc.load_state_dict(ckpt['opt_disc_state']) self.global_step = ckpt['global_step'] self.current_epoch = ckpt['epoch'] print(f"Loaded checkpoint from {path} (step {self.global_step})") def _average_logs(self, logs): """Average a list of log dicts.""" if not logs: return {} avg = {} for key in logs[0]: if key == 'step': continue vals = [l[key] for l in logs if key in l] if vals: avg[key] = sum(vals) / len(vals) return avg @torch.no_grad() def validate(self, dataloader, max_batches=50): """Run validation.""" self.model.eval() logs = [] for i, batch in enumerate(dataloader): if i >= max_batches: break batch = batch.to(self.device) recon, posteriors = self.model(batch) # Compute metrics l1 = F.l1_loss(recon, batch).item() # PSNR mse = F.mse_loss(recon, batch).item() psnr = -10 * math.log10(mse + 1e-8) logs.append({'l1': l1, 'psnr': psnr}) avg = {k: sum(l[k] for l in logs) / len(logs) for k in logs[0]} print(f"Validation: L1={avg['l1']:.4f}, PSNR={avg['psnr']:.2f}dB") self.model.train() return avg # ============================================================================== # Synthetic data for testing # ============================================================================== class SyntheticDataset(Dataset): """Synthetic dataset for testing the training loop.""" def __init__(self, num_samples=1000, resolution=128): self.num_samples = num_samples self.resolution = resolution def __len__(self): return self.num_samples def __getitem__(self, idx): # Random noise smoothed to look like natural image patterns r = self.resolution img = torch.randn(3, r, r) # Smooth with avg pool (same padding to keep resolution) k = min(8, r // 4) if k >= 2: img = F.interpolate( F.avg_pool2d(img.unsqueeze(0), k, stride=1, padding=0), size=(r, r), mode='bilinear', align_corners=False ).squeeze(0) # Normalize to [-1, 1] img = img / (img.abs().max() + 1e-6) return img # ============================================================================== # Main # ============================================================================== def create_default_config(): return { 'model_size': 'tiny', # tiny/small/base 'resolution': 128, 'batch_size': 4, 'num_epochs': 5, 'lr': 4.5e-6, 'weight_decay': 0.01, 'kl_weight': 1e-6, 'kl_warmup_steps': 2000, 'perceptual_weight': 0.5, 'disc_weight': 0.5, 'edge_weight': 0.1, 'free_bits': 0.25, 'disc_start': 5000, 'use_amp': True, 'use_parallel_scan': False, # sequential for CPU testing 'gradient_checkpointing': False, 'log_every': 10, 'save_every': 1000, 'output_dir': './checkpoints', 'device': 'cuda' if torch.cuda.is_available() else 'cpu', } if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--test', action='store_true', help='Quick test run') parser.add_argument('--model_size', default='tiny', choices=['tiny', 'small', 'base']) parser.add_argument('--resolution', type=int, default=128) parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--epochs', type=int, default=5) parser.add_argument('--data_dir', default=None) args = parser.parse_args() config = create_default_config() config['model_size'] = args.model_size config['resolution'] = args.resolution config['batch_size'] = args.batch_size config['num_epochs'] = args.epochs if args.test: config['resolution'] = 128 # must be divisible by 16 for PixelUnshuffle config['batch_size'] = 2 config['num_epochs'] = 1 config['log_every'] = 5 config['disc_start'] = 5 config['kl_warmup_steps'] = 10 config['use_amp'] = False config['use_parallel_scan'] = False config['perceptual_weight'] = 0.0 # skip VGG in quick test for speed config['edge_weight'] = 0.0 # Create dataset if args.data_dir and os.path.isdir(args.data_dir): dataset = ImageFolderDataset(args.data_dir, resolution=config['resolution']) else: print("Using synthetic dataset for testing") dataset = SyntheticDataset(num_samples=40, resolution=config['resolution']) dataloader = DataLoader( dataset, batch_size=config['batch_size'], shuffle=True, num_workers=0, pin_memory=True if config['device'] == 'cuda' else False, drop_last=True, ) # Create trainer and train trainer = PMAVAETrainer(config) trainer.train(dataloader, num_epochs=config['num_epochs'])