| """ |
| LRF v2 Training on CIFAR-10 with pre-trained TAESD VAE. |
| |
| This script: |
| 1. Loads TAESD (pre-trained, frozen) as the image encoder/decoder |
| 2. Pre-computes all CIFAR-10 latents (fast, ~30s) |
| 3. Trains the RecursiveLatentCore denoiser on real latents |
| 4. Generates real images and saves them |
| """ |
|
|
| import os |
| import sys |
| import time |
| import json |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, TensorDataset |
| import torchvision |
| import torchvision.transforms as T |
| import numpy as np |
| from pathlib import Path |
|
|
| sys.path.insert(0, '/app') |
| from lrf.model_v2 import LRFv2, RectifiedFlowScheduler |
|
|
|
|
| def load_taesd(device='cpu'): |
| """Load pre-trained TAESD VAE.""" |
| from diffusers import AutoencoderTiny |
| vae = AutoencoderTiny.from_pretrained('madebyollin/taesd', torch_dtype=torch.float32) |
| vae.eval() |
| vae.to(device) |
| for p in vae.parameters(): |
| p.requires_grad_(False) |
| return vae |
|
|
|
|
| def precompute_latents(vae, dataset, batch_size=64, device='cpu'): |
| """Pre-compute all latent representations. Much faster than encoding on-the-fly.""" |
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0) |
| |
| all_latents = [] |
| all_labels = [] |
| |
| total = len(loader) |
| print(f"Pre-computing latents for {len(dataset)} images ({total} batches)...", flush=True) |
| t0 = time.time() |
| |
| with torch.no_grad(): |
| for batch_idx, (images, labels) in enumerate(loader): |
| images = images.to(device) |
| latents = vae.encode(images).latents |
| all_latents.append(latents.cpu()) |
| all_labels.append(labels) |
| if (batch_idx + 1) % 50 == 0 or batch_idx == 0: |
| elapsed = time.time() - t0 |
| print(f" Batch {batch_idx+1}/{total} ({elapsed:.0f}s)", flush=True) |
| |
| all_latents = torch.cat(all_latents, dim=0) |
| all_labels = torch.cat(all_labels, dim=0) |
| |
| dt = time.time() - t0 |
| print(f"Done in {dt:.1f}s. Latent shape: {all_latents.shape}", flush=True) |
| print(f"Latent stats: mean={all_latents.mean():.4f}, std={all_latents.std():.4f}, " |
| f"min={all_latents.min():.4f}, max={all_latents.max():.4f}", flush=True) |
| |
| return all_latents, all_labels |
|
|
|
|
| def train_denoiser( |
| config=None, |
| num_epochs=50, |
| batch_size=128, |
| lr=2e-4, |
| device='cpu', |
| output_dir='/app/lrf_v2_output', |
| save_every=10, |
| ): |
| """Train the LRF denoiser on CIFAR-10 latents.""" |
| os.makedirs(output_dir, exist_ok=True) |
| |
| print("=" * 60) |
| print("LRF v2 - Training on CIFAR-10") |
| print("=" * 60) |
| |
| |
| print("\n[Step 1] Loading TAESD VAE...") |
| vae = load_taesd(device) |
| print(f" TAESD loaded: {sum(p.numel() for p in vae.parameters()):,} params (frozen)") |
| |
| |
| print("\n[Step 2] Loading CIFAR-10...") |
| transform = T.Compose([ |
| T.ToTensor(), |
| T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ]) |
| |
| |
| |
| trainset = torchvision.datasets.CIFAR10( |
| root='/app/data', train=True, download=True, transform=transform, |
| ) |
| testset = torchvision.datasets.CIFAR10( |
| root='/app/data', train=False, download=True, transform=transform, |
| ) |
| print(f" Train: {len(trainset)}, Test: {len(testset)}") |
| print(f" Image size: {trainset[0][0].shape}") |
| |
| |
| print("\n[Step 3] Pre-computing latents...", flush=True) |
| cache_path = os.path.join(output_dir, 'latent_cache.pt') |
| if os.path.exists(cache_path): |
| print(" Loading cached latents...", flush=True) |
| cache = torch.load(cache_path, weights_only=True) |
| train_latents = cache['train_latents'] |
| train_labels = cache['train_labels'] |
| test_latents = cache['test_latents'] |
| test_labels = cache['test_labels'] |
| print(f" Loaded from cache. Train: {train_latents.shape}, Test: {test_latents.shape}", flush=True) |
| else: |
| train_latents, train_labels = precompute_latents(vae, trainset, batch_size=256, device=device) |
| test_latents, test_labels = precompute_latents(vae, testset, batch_size=256, device=device) |
| torch.save({ |
| 'train_latents': train_latents, 'train_labels': train_labels, |
| 'test_latents': test_latents, 'test_labels': test_labels, |
| }, cache_path) |
| print(f" Cached latents to {cache_path}", flush=True) |
| |
| |
| print("\n[Step 3b] Verifying VAE reconstruction...") |
| with torch.no_grad(): |
| sample_imgs = torch.stack([trainset[i][0] for i in range(8)]).to(device) |
| sample_lats = vae.encode(sample_imgs).latents |
| sample_recs = vae.decode(sample_lats).sample |
| recon_mse = F.mse_loss(sample_recs, sample_imgs).item() |
| print(f" VAE reconstruction MSE on real images: {recon_mse:.4f}") |
| |
| |
| save_image_grid( |
| torch.cat([sample_imgs[:4], sample_recs[:4]], dim=0), |
| os.path.join(output_dir, 'vae_reconstruction.png'), |
| nrow=4, title='Top: Original, Bottom: TAESD Reconstruction' |
| ) |
| print(f" Saved VAE reconstruction grid to {output_dir}/vae_reconstruction.png") |
| |
| |
| lat_mean = train_latents.mean() |
| lat_std = train_latents.std() |
| print(f"\n Latent mean: {lat_mean:.4f}, std: {lat_std:.4f}") |
| |
| latent_scale = lat_std.item() |
| |
| |
| train_ds = TensorDataset(train_latents, train_labels) |
| train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, |
| num_workers=0, drop_last=True) |
| |
| |
| print("\n[Step 4] Creating LRF denoiser...") |
| config = config or LRFv2.small_config() |
| config['latent_ch'] = train_latents.shape[1] |
| model = LRFv2(config).to(device) |
| params = model.count_params() |
| print(f" Config: dim={config['dim']}, blocks={config['num_blocks']}, " |
| f"T_inner={config['T_inner']}, T_outer={config['T_outer']}") |
| print(f" Parameters: {params['total']:,} total, {params['core']:,} core") |
| print(f" Effective depth: {config['T_outer'] * config['T_inner'] * config['num_blocks']} layers " |
| f"from {config['num_blocks']} blocks") |
| |
| |
| print(f"\n[Step 5] Training for {num_epochs} epochs...") |
| scheduler = RectifiedFlowScheduler() |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01, betas=(0.9, 0.95)) |
| |
| |
| lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=num_epochs * len(train_loader), eta_min=lr * 0.01 |
| ) |
| |
| |
| ema_decay = 0.999 |
| ema_params = {name: p.clone().detach() for name, p in model.named_parameters()} |
| |
| loss_history = [] |
| best_loss = float('inf') |
| |
| for epoch in range(num_epochs): |
| model.train() |
| epoch_loss = 0.0 |
| num_batches = 0 |
| |
| for latents, labels in train_loader: |
| latents = latents.to(device) |
| labels = labels.to(device) |
| B = latents.shape[0] |
| |
| |
| t = scheduler.sample_timesteps(B, device) |
| noise = torch.randn_like(latents) |
| |
| |
| z_t = scheduler.add_noise(latents, noise, t) |
| |
| |
| v_pred = model.predict_velocity(z_t, t, labels, cfg_dropout=0.1) |
| |
| |
| v_target = scheduler.get_velocity_target(latents, noise) |
| |
| |
| loss_per_sample = (v_pred - v_target).pow(2).mean(dim=[1, 2, 3]) |
| |
| |
| w = 1.0 / (t * (1 - t) + 0.01) |
| w = w / w.mean() |
| loss = (loss_per_sample * w).mean() |
| |
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| lr_scheduler.step() |
| |
| |
| with torch.no_grad(): |
| for name, p in model.named_parameters(): |
| ema_params[name].mul_(ema_decay).add_(p, alpha=1 - ema_decay) |
| |
| epoch_loss += loss.item() |
| num_batches += 1 |
| |
| avg_loss = epoch_loss / num_batches |
| loss_history.append(avg_loss) |
| |
| if avg_loss < best_loss: |
| best_loss = avg_loss |
| |
| if (epoch + 1) % 5 == 0 or epoch == 0: |
| current_lr = optimizer.param_groups[0]['lr'] |
| print(f" Epoch {epoch+1:3d}/{num_epochs}: loss={avg_loss:.4f}, " |
| f"best={best_loss:.4f}, lr={current_lr:.2e}", flush=True) |
| |
| |
| if (epoch + 1) % save_every == 0 or epoch == num_epochs - 1: |
| |
| saved_params = {} |
| with torch.no_grad(): |
| for name, p in model.named_parameters(): |
| saved_params[name] = p.clone() |
| p.copy_(ema_params[name]) |
| |
| |
| model.eval() |
| samples = generate_samples(model, vae, scheduler, device, |
| num_samples=16, num_steps=10, cfg_scale=2.0) |
| |
| save_image_grid( |
| samples, |
| os.path.join(output_dir, f'samples_epoch{epoch+1:03d}.png'), |
| nrow=8, title=f'Epoch {epoch+1}, Loss={avg_loss:.4f}' |
| ) |
| |
| |
| with torch.no_grad(): |
| for name, p in model.named_parameters(): |
| p.copy_(saved_params[name]) |
| |
| |
| torch.save({ |
| 'model_state': model.state_dict(), |
| 'ema_params': ema_params, |
| 'config': config, |
| 'epoch': epoch + 1, |
| 'loss': avg_loss, |
| 'latent_scale': latent_scale, |
| 'loss_history': loss_history, |
| }, os.path.join(output_dir, 'checkpoint.pt')) |
| |
| |
| with torch.no_grad(): |
| for name, p in model.named_parameters(): |
| p.copy_(ema_params[name]) |
| |
| model.eval() |
| |
| |
| print("\n[Step 6] Generating final samples...") |
| cifar_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', |
| 'dog', 'frog', 'horse', 'ship', 'truck'] |
| |
| all_samples = [] |
| for cls_idx in range(10): |
| samples = generate_samples(model, vae, scheduler, device, |
| num_samples=4, num_steps=50, cfg_scale=3.0, |
| class_label=cls_idx) |
| all_samples.append(samples) |
| |
| all_samples = torch.cat(all_samples, dim=0) |
| save_image_grid( |
| all_samples, |
| os.path.join(output_dir, 'final_class_conditional.png'), |
| nrow=4, title='Class-conditional generation (rows: airplane, auto, bird, cat, deer, dog, frog, horse, ship, truck)' |
| ) |
| |
| |
| save_loss_plot(loss_history, os.path.join(output_dir, 'loss.png')) |
| |
| |
| with open(os.path.join(output_dir, 'config.json'), 'w') as f: |
| json.dump(config, f, indent=2) |
| |
| print(f"\n{'='*60}") |
| print(f"Training complete! Best loss: {best_loss:.4f}") |
| print(f"Output directory: {output_dir}") |
| print(f"{'='*60}") |
| |
| return model, vae, loss_history |
|
|
|
|
| def generate_samples(model, vae, scheduler, device, num_samples=8, |
| num_steps=20, cfg_scale=2.0, class_label=None): |
| """Generate images from the model.""" |
| model.eval() |
| |
| |
| shape = (num_samples, 4, 4, 4) |
| |
| if class_label is not None: |
| labels = torch.full((num_samples,), class_label, dtype=torch.long, device=device) |
| else: |
| labels = torch.randint(0, 10, (num_samples,), device=device) |
| |
| z = scheduler.sample(model, shape, labels, num_steps=num_steps, |
| cfg_scale=cfg_scale, device=device) |
| |
| |
| with torch.no_grad(): |
| images = vae.decode(z.to(device)).sample |
| |
| return images.clamp(-1, 1).cpu() |
|
|
|
|
| def save_image_grid(images, path, nrow=8, title=''): |
| """Save image grid using torchvision.""" |
| |
| images = (images + 1) / 2 |
| images = images.clamp(0, 1) |
| |
| grid = torchvision.utils.make_grid(images, nrow=nrow, padding=2, normalize=False) |
| |
| |
| from PIL import Image |
| grid_np = grid.permute(1, 2, 0).numpy() |
| grid_np = (grid_np * 255).astype(np.uint8) |
| img = Image.fromarray(grid_np) |
| img.save(path) |
|
|
|
|
| def save_loss_plot(losses, path): |
| """Save loss curve.""" |
| try: |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| |
| plt.figure(figsize=(10, 4)) |
| plt.plot(losses) |
| plt.xlabel('Epoch') |
| plt.ylabel('Loss') |
| plt.title('Training Loss') |
| plt.grid(True, alpha=0.3) |
| plt.savefig(path, dpi=100, bbox_inches='tight') |
| plt.close() |
| except ImportError: |
| print("matplotlib not available, skipping loss plot") |
|
|
|
|
| if __name__ == '__main__': |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f"Device: {device}") |
| |
| train_denoiser( |
| config=LRFv2.fast_config(), |
| num_epochs=30, |
| batch_size=64, |
| lr=3e-4, |
| device=device, |
| output_dir='/app/lrf_v2_output', |
| save_every=5, |
| ) |
|
|