""" LRF v2 Training on CIFAR-10 with pre-trained TAESD VAE. This script: 1. Loads TAESD (pre-trained, frozen) as the image encoder/decoder 2. Pre-computes all CIFAR-10 latents (fast, ~30s) 3. Trains the RecursiveLatentCore denoiser on real latents 4. Generates real images and saves them """ import os import sys import time import json import torch import torch.nn.functional as F from torch.utils.data import DataLoader, TensorDataset import torchvision import torchvision.transforms as T import numpy as np from pathlib import Path sys.path.insert(0, '/app') from lrf.model_v2 import LRFv2, RectifiedFlowScheduler def load_taesd(device='cpu'): """Load pre-trained TAESD VAE.""" from diffusers import AutoencoderTiny vae = AutoencoderTiny.from_pretrained('madebyollin/taesd', torch_dtype=torch.float32) vae.eval() vae.to(device) for p in vae.parameters(): p.requires_grad_(False) return vae def precompute_latents(vae, dataset, batch_size=64, device='cpu'): """Pre-compute all latent representations. Much faster than encoding on-the-fly.""" loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0) all_latents = [] all_labels = [] total = len(loader) print(f"Pre-computing latents for {len(dataset)} images ({total} batches)...", flush=True) t0 = time.time() with torch.no_grad(): for batch_idx, (images, labels) in enumerate(loader): images = images.to(device) latents = vae.encode(images).latents all_latents.append(latents.cpu()) all_labels.append(labels) if (batch_idx + 1) % 50 == 0 or batch_idx == 0: elapsed = time.time() - t0 print(f" Batch {batch_idx+1}/{total} ({elapsed:.0f}s)", flush=True) all_latents = torch.cat(all_latents, dim=0) all_labels = torch.cat(all_labels, dim=0) dt = time.time() - t0 print(f"Done in {dt:.1f}s. Latent shape: {all_latents.shape}", flush=True) print(f"Latent stats: mean={all_latents.mean():.4f}, std={all_latents.std():.4f}, " f"min={all_latents.min():.4f}, max={all_latents.max():.4f}", flush=True) return all_latents, all_labels def train_denoiser( config=None, num_epochs=50, batch_size=128, lr=2e-4, device='cpu', output_dir='/app/lrf_v2_output', save_every=10, ): """Train the LRF denoiser on CIFAR-10 latents.""" os.makedirs(output_dir, exist_ok=True) print("=" * 60) print("LRF v2 - Training on CIFAR-10") print("=" * 60) # 1. Load TAESD print("\n[Step 1] Loading TAESD VAE...") vae = load_taesd(device) print(f" TAESD loaded: {sum(p.numel() for p in vae.parameters()):,} params (frozen)") # 2. Load CIFAR-10 print("\n[Step 2] Loading CIFAR-10...") transform = T.Compose([ T.ToTensor(), T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), # -> [-1, 1] ]) # CIFAR-10 is 32x32, TAESD expects multiples of 8 -> resize to 32 # Actually TAESD works on any size, 32x32 -> 4x4 latent (f=8) trainset = torchvision.datasets.CIFAR10( root='/app/data', train=True, download=True, transform=transform, ) testset = torchvision.datasets.CIFAR10( root='/app/data', train=False, download=True, transform=transform, ) print(f" Train: {len(trainset)}, Test: {len(testset)}") print(f" Image size: {trainset[0][0].shape}") # 3. Pre-compute latents (or load from cache) print("\n[Step 3] Pre-computing latents...", flush=True) cache_path = os.path.join(output_dir, 'latent_cache.pt') if os.path.exists(cache_path): print(" Loading cached latents...", flush=True) cache = torch.load(cache_path, weights_only=True) train_latents = cache['train_latents'] train_labels = cache['train_labels'] test_latents = cache['test_latents'] test_labels = cache['test_labels'] print(f" Loaded from cache. Train: {train_latents.shape}, Test: {test_latents.shape}", flush=True) else: train_latents, train_labels = precompute_latents(vae, trainset, batch_size=256, device=device) test_latents, test_labels = precompute_latents(vae, testset, batch_size=256, device=device) torch.save({ 'train_latents': train_latents, 'train_labels': train_labels, 'test_latents': test_latents, 'test_labels': test_labels, }, cache_path) print(f" Cached latents to {cache_path}", flush=True) # Verify VAE reconstruction works print("\n[Step 3b] Verifying VAE reconstruction...") with torch.no_grad(): sample_imgs = torch.stack([trainset[i][0] for i in range(8)]).to(device) sample_lats = vae.encode(sample_imgs).latents sample_recs = vae.decode(sample_lats).sample recon_mse = F.mse_loss(sample_recs, sample_imgs).item() print(f" VAE reconstruction MSE on real images: {recon_mse:.4f}") # Save reconstruction grid save_image_grid( torch.cat([sample_imgs[:4], sample_recs[:4]], dim=0), os.path.join(output_dir, 'vae_reconstruction.png'), nrow=4, title='Top: Original, Bottom: TAESD Reconstruction' ) print(f" Saved VAE reconstruction grid to {output_dir}/vae_reconstruction.png") # Normalize latents for better training lat_mean = train_latents.mean() lat_std = train_latents.std() print(f"\n Latent mean: {lat_mean:.4f}, std: {lat_std:.4f}") # Scale latents to roughly unit variance latent_scale = lat_std.item() # Create dataset of (latent, label) train_ds = TensorDataset(train_latents, train_labels) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True) # 4. Create model print("\n[Step 4] Creating LRF denoiser...") config = config or LRFv2.small_config() config['latent_ch'] = train_latents.shape[1] # Should be 4 model = LRFv2(config).to(device) params = model.count_params() print(f" Config: dim={config['dim']}, blocks={config['num_blocks']}, " f"T_inner={config['T_inner']}, T_outer={config['T_outer']}") print(f" Parameters: {params['total']:,} total, {params['core']:,} core") print(f" Effective depth: {config['T_outer'] * config['T_inner'] * config['num_blocks']} layers " f"from {config['num_blocks']} blocks") # 5. Training print(f"\n[Step 5] Training for {num_epochs} epochs...") scheduler = RectifiedFlowScheduler() optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01, betas=(0.9, 0.95)) # Cosine annealing lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=num_epochs * len(train_loader), eta_min=lr * 0.01 ) # EMA for stable sampling ema_decay = 0.999 ema_params = {name: p.clone().detach() for name, p in model.named_parameters()} loss_history = [] best_loss = float('inf') for epoch in range(num_epochs): model.train() epoch_loss = 0.0 num_batches = 0 for latents, labels in train_loader: latents = latents.to(device) labels = labels.to(device) B = latents.shape[0] # Sample timesteps and noise t = scheduler.sample_timesteps(B, device) noise = torch.randn_like(latents) # Create noisy latent z_t = scheduler.add_noise(latents, noise, t) # Predict velocity (with 10% CFG dropout) v_pred = model.predict_velocity(z_t, t, labels, cfg_dropout=0.1) # Velocity target v_target = scheduler.get_velocity_target(latents, noise) # MSE loss with min-SNR weighting loss_per_sample = (v_pred - v_target).pow(2).mean(dim=[1, 2, 3]) # SNR weighting: upweight middle timesteps w = 1.0 / (t * (1 - t) + 0.01) w = w / w.mean() loss = (loss_per_sample * w).mean() optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() # EMA update with torch.no_grad(): for name, p in model.named_parameters(): ema_params[name].mul_(ema_decay).add_(p, alpha=1 - ema_decay) epoch_loss += loss.item() num_batches += 1 avg_loss = epoch_loss / num_batches loss_history.append(avg_loss) if avg_loss < best_loss: best_loss = avg_loss if (epoch + 1) % 5 == 0 or epoch == 0: current_lr = optimizer.param_groups[0]['lr'] print(f" Epoch {epoch+1:3d}/{num_epochs}: loss={avg_loss:.4f}, " f"best={best_loss:.4f}, lr={current_lr:.2e}", flush=True) # Save and generate samples periodically if (epoch + 1) % save_every == 0 or epoch == num_epochs - 1: # Swap to EMA for sampling saved_params = {} with torch.no_grad(): for name, p in model.named_parameters(): saved_params[name] = p.clone() p.copy_(ema_params[name]) # Generate samples model.eval() samples = generate_samples(model, vae, scheduler, device, num_samples=16, num_steps=10, cfg_scale=2.0) save_image_grid( samples, os.path.join(output_dir, f'samples_epoch{epoch+1:03d}.png'), nrow=8, title=f'Epoch {epoch+1}, Loss={avg_loss:.4f}' ) # Restore original params with torch.no_grad(): for name, p in model.named_parameters(): p.copy_(saved_params[name]) # Save checkpoint torch.save({ 'model_state': model.state_dict(), 'ema_params': ema_params, 'config': config, 'epoch': epoch + 1, 'loss': avg_loss, 'latent_scale': latent_scale, 'loss_history': loss_history, }, os.path.join(output_dir, 'checkpoint.pt')) # Final generation with EMA with torch.no_grad(): for name, p in model.named_parameters(): p.copy_(ema_params[name]) model.eval() # Generate class-conditional samples print("\n[Step 6] Generating final samples...") cifar_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] all_samples = [] for cls_idx in range(10): samples = generate_samples(model, vae, scheduler, device, num_samples=4, num_steps=50, cfg_scale=3.0, class_label=cls_idx) all_samples.append(samples) all_samples = torch.cat(all_samples, dim=0) save_image_grid( all_samples, os.path.join(output_dir, 'final_class_conditional.png'), nrow=4, title='Class-conditional generation (rows: airplane, auto, bird, cat, deer, dog, frog, horse, ship, truck)' ) # Save loss plot save_loss_plot(loss_history, os.path.join(output_dir, 'loss.png')) # Save config with open(os.path.join(output_dir, 'config.json'), 'w') as f: json.dump(config, f, indent=2) print(f"\n{'='*60}") print(f"Training complete! Best loss: {best_loss:.4f}") print(f"Output directory: {output_dir}") print(f"{'='*60}") return model, vae, loss_history def generate_samples(model, vae, scheduler, device, num_samples=8, num_steps=20, cfg_scale=2.0, class_label=None): """Generate images from the model.""" model.eval() # Latent shape for CIFAR-10: [B, 4, 4, 4] (32x32 image, f=8) shape = (num_samples, 4, 4, 4) if class_label is not None: labels = torch.full((num_samples,), class_label, dtype=torch.long, device=device) else: labels = torch.randint(0, 10, (num_samples,), device=device) z = scheduler.sample(model, shape, labels, num_steps=num_steps, cfg_scale=cfg_scale, device=device) # Decode through TAESD with torch.no_grad(): images = vae.decode(z.to(device)).sample return images.clamp(-1, 1).cpu() def save_image_grid(images, path, nrow=8, title=''): """Save image grid using torchvision.""" # Convert from [-1,1] to [0,1] images = (images + 1) / 2 images = images.clamp(0, 1) grid = torchvision.utils.make_grid(images, nrow=nrow, padding=2, normalize=False) # Save using PIL from PIL import Image grid_np = grid.permute(1, 2, 0).numpy() grid_np = (grid_np * 255).astype(np.uint8) img = Image.fromarray(grid_np) img.save(path) def save_loss_plot(losses, path): """Save loss curve.""" try: import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt plt.figure(figsize=(10, 4)) plt.plot(losses) plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Training Loss') plt.grid(True, alpha=0.3) plt.savefig(path, dpi=100, bbox_inches='tight') plt.close() except ImportError: print("matplotlib not available, skipping loss plot") if __name__ == '__main__': device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Device: {device}") train_denoiser( config=LRFv2.fast_config(), num_epochs=30, batch_size=64, lr=3e-4, device=device, output_dir='/app/lrf_v2_output', save_every=5, )