| """ |
| PMA-VAE Training Script |
| ======================== |
| Progressive resolution training with: |
| - KL warmup (prevents posterior collapse) |
| - Discriminator cold start |
| - Mixed precision (fp16/bf16) |
| - Gradient checkpointing option |
| - Colab-friendly (T4 15GB VRAM) |
| - Checkpoint saving/resuming |
| """ |
|
|
| import os |
| import math |
| import time |
| import json |
| import argparse |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision import transforms |
| from torch.amp import GradScaler, autocast |
| from PIL import Image |
| import random |
|
|
| from model import PMAVAE, pmavae_tiny, pmavae_small, pmavae_base |
| from losses import PMAVAELoss |
|
|
|
|
| |
| |
| |
|
|
| class ImageFolderDataset(Dataset): |
| """Simple image folder dataset. Works with any folder of images.""" |
| def __init__(self, root, resolution=256, random_crop=True): |
| self.root = root |
| self.resolution = resolution |
| self.random_crop = random_crop |
| |
| self.files = [] |
| exts = {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff'} |
| for dirpath, _, filenames in os.walk(root): |
| for f in filenames: |
| if os.path.splitext(f)[1].lower() in exts: |
| self.files.append(os.path.join(dirpath, f)) |
| |
| self.files.sort() |
| print(f"Found {len(self.files)} images in {root}") |
| |
| if random_crop: |
| self.transform = transforms.Compose([ |
| transforms.Resize(int(resolution * 1.15), |
| interpolation=transforms.InterpolationMode.LANCZOS, |
| antialias=True), |
| transforms.RandomCrop(resolution), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ]) |
| else: |
| self.transform = transforms.Compose([ |
| transforms.Resize((resolution, resolution), |
| interpolation=transforms.InterpolationMode.LANCZOS, |
| antialias=True), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ]) |
| |
| def __len__(self): |
| return len(self.files) |
| |
| def __getitem__(self, idx): |
| img = Image.open(self.files[idx]).convert('RGB') |
| return self.transform(img) |
|
|
|
|
| class HFDatasetWrapper(Dataset): |
| """Wrapper for HuggingFace datasets with image column.""" |
| def __init__(self, hf_dataset, image_column='image', resolution=256): |
| self.dataset = hf_dataset |
| self.image_column = image_column |
| self.transform = transforms.Compose([ |
| transforms.Resize(int(resolution * 1.15), |
| interpolation=transforms.InterpolationMode.LANCZOS, |
| antialias=True), |
| transforms.RandomCrop(resolution), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ]) |
| |
| def __len__(self): |
| return len(self.dataset) |
| |
| def __getitem__(self, idx): |
| img = self.dataset[idx][self.image_column] |
| if not isinstance(img, Image.Image): |
| img = Image.fromarray(img) |
| img = img.convert('RGB') |
| return self.transform(img) |
|
|
|
|
| |
| |
| |
|
|
| class KLWarmup: |
| """ |
| Linear KL warmup to prevent posterior collapse. |
| KL weight goes from 0 → target over warmup_steps. |
| """ |
| def __init__(self, target_weight, warmup_steps=10000): |
| self.target_weight = target_weight |
| self.warmup_steps = warmup_steps |
| |
| def get_weight(self, step): |
| if step >= self.warmup_steps: |
| return self.target_weight |
| return self.target_weight * (step / self.warmup_steps) |
|
|
|
|
| |
| |
| |
|
|
| class PMAVAETrainer: |
| """ |
| Full training pipeline for PMA-VAE. |
| |
| Features: |
| - Progressive resolution training |
| - KL warmup |
| - Discriminator cold start |
| - Mixed precision |
| - Checkpoint save/resume |
| - Logging |
| """ |
| def __init__(self, config): |
| self.config = config |
| self.device = torch.device(config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')) |
| self.global_step = 0 |
| self.current_epoch = 0 |
| |
| |
| model_fn = { |
| 'tiny': pmavae_tiny, |
| 'small': pmavae_small, |
| 'base': pmavae_base, |
| }[config.get('model_size', 'small')] |
| |
| self.model = model_fn( |
| use_parallel_scan=config.get('use_parallel_scan', True) |
| ).to(self.device) |
| |
| params = self.model.count_parameters() |
| print(f"Model: {config.get('model_size', 'small')}") |
| print(f" Encoder: {params['encoder_M']:.2f}M params") |
| print(f" Decoder: {params['decoder_M']:.2f}M params") |
| print(f" Total: {params['total_M']:.2f}M params") |
| |
| |
| self.criterion = PMAVAELoss( |
| disc_start=config.get('disc_start', 10000), |
| kl_weight=config.get('kl_weight', 1e-6), |
| perceptual_weight=config.get('perceptual_weight', 0.5), |
| disc_weight=config.get('disc_weight', 0.5), |
| edge_weight=config.get('edge_weight', 0.1), |
| free_bits=config.get('free_bits', 0.25), |
| ).to(self.device) |
| |
| |
| lr = config.get('lr', 4.5e-6) |
| self.opt_vae = torch.optim.AdamW( |
| self.model.parameters(), |
| lr=lr * config.get('batch_size', 4), |
| betas=(0.5, 0.9), |
| weight_decay=config.get('weight_decay', 0.01), |
| ) |
| self.opt_disc = torch.optim.AdamW( |
| self.criterion.discriminator.parameters(), |
| lr=lr * config.get('batch_size', 4), |
| betas=(0.5, 0.9), |
| weight_decay=config.get('weight_decay', 0.01), |
| ) |
| |
| |
| self.use_amp = config.get('use_amp', True) |
| self.scaler_vae = GradScaler('cpu' if self.device.type == 'cpu' else 'cuda', enabled=self.use_amp) |
| self.scaler_disc = GradScaler('cpu' if self.device.type == 'cpu' else 'cuda', enabled=self.use_amp) |
| |
| |
| self.kl_warmup = KLWarmup( |
| target_weight=config.get('kl_weight', 1e-6), |
| warmup_steps=config.get('kl_warmup_steps', 5000), |
| ) |
| |
| |
| if config.get('gradient_checkpointing', False): |
| self._enable_gradient_checkpointing() |
| |
| |
| self.log_every = config.get('log_every', 50) |
| self.save_every = config.get('save_every', 5000) |
| self.output_dir = config.get('output_dir', './checkpoints') |
| os.makedirs(self.output_dir, exist_ok=True) |
| |
| self.train_log = [] |
| |
| def _enable_gradient_checkpointing(self): |
| """Enable gradient checkpointing for encoder (saves ~30% VRAM).""" |
| from torch.utils.checkpoint import checkpoint |
| |
| for stage in [self.model.encoder.stage1]: |
| for module in stage: |
| module._original_forward = module.forward |
| module.forward = lambda x, m=module: checkpoint(m._original_forward, x, use_reentrant=False) |
| |
| def train_step(self, batch): |
| """Single training step with both VAE and discriminator updates.""" |
| batch = batch.to(self.device) |
| |
| |
| current_kl_weight = self.kl_warmup.get_weight(self.global_step) |
| self.criterion.kl_weight = current_kl_weight |
| |
| |
| self.opt_vae.zero_grad() |
| |
| with autocast(device_type=self.device.type, enabled=self.use_amp): |
| recon, posteriors = self.model(batch) |
| loss_vae, log_vae = self.criterion( |
| batch, recon, posteriors, |
| optimizer_idx=0, |
| global_step=self.global_step, |
| last_layer=self.model.get_last_decoder_layer() |
| ) |
| |
| self.scaler_vae.scale(loss_vae).backward() |
| |
| |
| self.scaler_vae.unscale_(self.opt_vae) |
| grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| |
| self.scaler_vae.step(self.opt_vae) |
| self.scaler_vae.update() |
| |
| |
| self.opt_disc.zero_grad() |
| |
| with autocast(device_type=self.device.type, enabled=self.use_amp): |
| |
| with torch.no_grad(): |
| recon_detached, _ = self.model(batch) |
| loss_disc, log_disc = self.criterion( |
| batch, recon_detached, posteriors, |
| optimizer_idx=1, |
| global_step=self.global_step, |
| ) |
| |
| if self.global_step >= self.criterion.disc_start: |
| self.scaler_disc.scale(loss_disc).backward() |
| self.scaler_disc.unscale_(self.opt_disc) |
| torch.nn.utils.clip_grad_norm_(self.criterion.discriminator.parameters(), 1.0) |
| self.scaler_disc.step(self.opt_disc) |
| self.scaler_disc.update() |
| |
| self.global_step += 1 |
| |
| |
| log = {**log_vae, **log_disc} |
| log['grad_norm'] = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm |
| log['kl_weight'] = current_kl_weight |
| log['step'] = self.global_step |
| |
| return log |
| |
| def train_epoch(self, dataloader): |
| """Train for one epoch.""" |
| self.model.train() |
| epoch_logs = [] |
| |
| for batch_idx, batch in enumerate(dataloader): |
| log = self.train_step(batch) |
| epoch_logs.append(log) |
| |
| if self.global_step % self.log_every == 0: |
| avg_log = self._average_logs(epoch_logs[-self.log_every:]) |
| print(f"Step {self.global_step:6d} | " |
| f"L1: {avg_log.get('l1_loss', 0):.4f} | " |
| f"Perc: {avg_log.get('perceptual_loss', 0):.4f} | " |
| f"KL: {avg_log.get('kl_total', 0):.2f} | " |
| f"D: {avg_log.get('d_loss', 0):.4f} | " |
| f"G: {avg_log.get('g_loss', 0):.4f} | " |
| f"GN: {avg_log.get('grad_norm', 0):.2f}") |
| |
| if self.global_step % self.save_every == 0: |
| self.save_checkpoint() |
| |
| self.current_epoch += 1 |
| return epoch_logs |
| |
| def train(self, dataloader, num_epochs=100): |
| """Full training loop.""" |
| print(f"\nStarting training for {num_epochs} epochs") |
| print(f" Steps per epoch: {len(dataloader)}") |
| print(f" Device: {self.device}") |
| print(f" AMP: {self.use_amp}") |
| print(f" Disc starts at step: {self.criterion.disc_start}") |
| print(f" KL warmup steps: {self.kl_warmup.warmup_steps}") |
| print() |
| |
| all_logs = [] |
| start_time = time.time() |
| |
| for epoch in range(num_epochs): |
| epoch_start = time.time() |
| epoch_logs = self.train_epoch(dataloader) |
| all_logs.extend(epoch_logs) |
| |
| epoch_time = time.time() - epoch_start |
| avg = self._average_logs(epoch_logs) |
| |
| print(f"\n{'='*60}") |
| print(f"Epoch {epoch+1}/{num_epochs} completed in {epoch_time:.1f}s") |
| print(f" Avg L1: {avg.get('l1_loss', 0):.4f}") |
| print(f" Avg Perceptual: {avg.get('perceptual_loss', 0):.4f}") |
| print(f" Avg KL: {avg.get('kl_total', 0):.2f}") |
| print(f" Total time: {(time.time()-start_time)/60:.1f} min") |
| print(f"{'='*60}\n") |
| |
| self.save_checkpoint(f'epoch_{epoch+1}') |
| |
| self.save_checkpoint('final') |
| |
| |
| with open(os.path.join(self.output_dir, 'train_log.json'), 'w') as f: |
| json.dump(all_logs, f) |
| |
| total_time = time.time() - start_time |
| print(f"\nTraining complete! Total time: {total_time/60:.1f} min") |
| return all_logs |
| |
| def save_checkpoint(self, tag='latest'): |
| """Save model and optimizer states.""" |
| path = os.path.join(self.output_dir, f'checkpoint_{tag}.pt') |
| torch.save({ |
| 'model_state': self.model.state_dict(), |
| 'disc_state': self.criterion.discriminator.state_dict(), |
| 'opt_vae_state': self.opt_vae.state_dict(), |
| 'opt_disc_state': self.opt_disc.state_dict(), |
| 'global_step': self.global_step, |
| 'epoch': self.current_epoch, |
| 'config': self.config, |
| }, path) |
| print(f" Saved checkpoint: {path}") |
| |
| def load_checkpoint(self, path): |
| """Load checkpoint.""" |
| ckpt = torch.load(path, map_location=self.device, weights_only=False) |
| self.model.load_state_dict(ckpt['model_state']) |
| self.criterion.discriminator.load_state_dict(ckpt['disc_state']) |
| self.opt_vae.load_state_dict(ckpt['opt_vae_state']) |
| self.opt_disc.load_state_dict(ckpt['opt_disc_state']) |
| self.global_step = ckpt['global_step'] |
| self.current_epoch = ckpt['epoch'] |
| print(f"Loaded checkpoint from {path} (step {self.global_step})") |
| |
| def _average_logs(self, logs): |
| """Average a list of log dicts.""" |
| if not logs: |
| return {} |
| avg = {} |
| for key in logs[0]: |
| if key == 'step': |
| continue |
| vals = [l[key] for l in logs if key in l] |
| if vals: |
| avg[key] = sum(vals) / len(vals) |
| return avg |
| |
| @torch.no_grad() |
| def validate(self, dataloader, max_batches=50): |
| """Run validation.""" |
| self.model.eval() |
| logs = [] |
| |
| for i, batch in enumerate(dataloader): |
| if i >= max_batches: |
| break |
| batch = batch.to(self.device) |
| recon, posteriors = self.model(batch) |
| |
| |
| l1 = F.l1_loss(recon, batch).item() |
| |
| mse = F.mse_loss(recon, batch).item() |
| psnr = -10 * math.log10(mse + 1e-8) |
| |
| logs.append({'l1': l1, 'psnr': psnr}) |
| |
| avg = {k: sum(l[k] for l in logs) / len(logs) for k in logs[0]} |
| print(f"Validation: L1={avg['l1']:.4f}, PSNR={avg['psnr']:.2f}dB") |
| self.model.train() |
| return avg |
|
|
|
|
| |
| |
| |
|
|
| class SyntheticDataset(Dataset): |
| """Synthetic dataset for testing the training loop.""" |
| def __init__(self, num_samples=1000, resolution=128): |
| self.num_samples = num_samples |
| self.resolution = resolution |
| |
| def __len__(self): |
| return self.num_samples |
| |
| def __getitem__(self, idx): |
| |
| r = self.resolution |
| img = torch.randn(3, r, r) |
| |
| k = min(8, r // 4) |
| if k >= 2: |
| img = F.interpolate( |
| F.avg_pool2d(img.unsqueeze(0), k, stride=1, padding=0), |
| size=(r, r), mode='bilinear', align_corners=False |
| ).squeeze(0) |
| |
| img = img / (img.abs().max() + 1e-6) |
| |
| return img |
|
|
|
|
| |
| |
| |
|
|
| def create_default_config(): |
| return { |
| 'model_size': 'tiny', |
| 'resolution': 128, |
| 'batch_size': 4, |
| 'num_epochs': 5, |
| 'lr': 4.5e-6, |
| 'weight_decay': 0.01, |
| 'kl_weight': 1e-6, |
| 'kl_warmup_steps': 2000, |
| 'perceptual_weight': 0.5, |
| 'disc_weight': 0.5, |
| 'edge_weight': 0.1, |
| 'free_bits': 0.25, |
| 'disc_start': 5000, |
| 'use_amp': True, |
| 'use_parallel_scan': False, |
| 'gradient_checkpointing': False, |
| 'log_every': 10, |
| 'save_every': 1000, |
| 'output_dir': './checkpoints', |
| 'device': 'cuda' if torch.cuda.is_available() else 'cpu', |
| } |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--test', action='store_true', help='Quick test run') |
| parser.add_argument('--model_size', default='tiny', choices=['tiny', 'small', 'base']) |
| parser.add_argument('--resolution', type=int, default=128) |
| parser.add_argument('--batch_size', type=int, default=4) |
| parser.add_argument('--epochs', type=int, default=5) |
| parser.add_argument('--data_dir', default=None) |
| args = parser.parse_args() |
| |
| config = create_default_config() |
| config['model_size'] = args.model_size |
| config['resolution'] = args.resolution |
| config['batch_size'] = args.batch_size |
| config['num_epochs'] = args.epochs |
| |
| if args.test: |
| config['resolution'] = 128 |
| config['batch_size'] = 2 |
| config['num_epochs'] = 1 |
| config['log_every'] = 5 |
| config['disc_start'] = 5 |
| config['kl_warmup_steps'] = 10 |
| config['use_amp'] = False |
| config['use_parallel_scan'] = False |
| config['perceptual_weight'] = 0.0 |
| config['edge_weight'] = 0.0 |
| |
| |
| if args.data_dir and os.path.isdir(args.data_dir): |
| dataset = ImageFolderDataset(args.data_dir, resolution=config['resolution']) |
| else: |
| print("Using synthetic dataset for testing") |
| dataset = SyntheticDataset(num_samples=40, resolution=config['resolution']) |
| |
| dataloader = DataLoader( |
| dataset, |
| batch_size=config['batch_size'], |
| shuffle=True, |
| num_workers=0, |
| pin_memory=True if config['device'] == 'cuda' else False, |
| drop_last=True, |
| ) |
| |
| |
| trainer = PMAVAETrainer(config) |
| trainer.train(dataloader, num_epochs=config['num_epochs']) |
|
|