LatentRecurrentFlow / lrf /train_v2.py
krystv's picture
Upload lrf/train_v2.py with huggingface_hub
a80be80 verified
"""
LRF v2 Training on CIFAR-10 with pre-trained TAESD VAE.
This script:
1. Loads TAESD (pre-trained, frozen) as the image encoder/decoder
2. Pre-computes all CIFAR-10 latents (fast, ~30s)
3. Trains the RecursiveLatentCore denoiser on real latents
4. Generates real images and saves them
"""
import os
import sys
import time
import json
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as T
import numpy as np
from pathlib import Path
sys.path.insert(0, '/app')
from lrf.model_v2 import LRFv2, RectifiedFlowScheduler
def load_taesd(device='cpu'):
"""Load pre-trained TAESD VAE."""
from diffusers import AutoencoderTiny
vae = AutoencoderTiny.from_pretrained('madebyollin/taesd', torch_dtype=torch.float32)
vae.eval()
vae.to(device)
for p in vae.parameters():
p.requires_grad_(False)
return vae
def precompute_latents(vae, dataset, batch_size=64, device='cpu'):
"""Pre-compute all latent representations. Much faster than encoding on-the-fly."""
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
all_latents = []
all_labels = []
total = len(loader)
print(f"Pre-computing latents for {len(dataset)} images ({total} batches)...", flush=True)
t0 = time.time()
with torch.no_grad():
for batch_idx, (images, labels) in enumerate(loader):
images = images.to(device)
latents = vae.encode(images).latents
all_latents.append(latents.cpu())
all_labels.append(labels)
if (batch_idx + 1) % 50 == 0 or batch_idx == 0:
elapsed = time.time() - t0
print(f" Batch {batch_idx+1}/{total} ({elapsed:.0f}s)", flush=True)
all_latents = torch.cat(all_latents, dim=0)
all_labels = torch.cat(all_labels, dim=0)
dt = time.time() - t0
print(f"Done in {dt:.1f}s. Latent shape: {all_latents.shape}", flush=True)
print(f"Latent stats: mean={all_latents.mean():.4f}, std={all_latents.std():.4f}, "
f"min={all_latents.min():.4f}, max={all_latents.max():.4f}", flush=True)
return all_latents, all_labels
def train_denoiser(
config=None,
num_epochs=50,
batch_size=128,
lr=2e-4,
device='cpu',
output_dir='/app/lrf_v2_output',
save_every=10,
):
"""Train the LRF denoiser on CIFAR-10 latents."""
os.makedirs(output_dir, exist_ok=True)
print("=" * 60)
print("LRF v2 - Training on CIFAR-10")
print("=" * 60)
# 1. Load TAESD
print("\n[Step 1] Loading TAESD VAE...")
vae = load_taesd(device)
print(f" TAESD loaded: {sum(p.numel() for p in vae.parameters()):,} params (frozen)")
# 2. Load CIFAR-10
print("\n[Step 2] Loading CIFAR-10...")
transform = T.Compose([
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), # -> [-1, 1]
])
# CIFAR-10 is 32x32, TAESD expects multiples of 8 -> resize to 32
# Actually TAESD works on any size, 32x32 -> 4x4 latent (f=8)
trainset = torchvision.datasets.CIFAR10(
root='/app/data', train=True, download=True, transform=transform,
)
testset = torchvision.datasets.CIFAR10(
root='/app/data', train=False, download=True, transform=transform,
)
print(f" Train: {len(trainset)}, Test: {len(testset)}")
print(f" Image size: {trainset[0][0].shape}")
# 3. Pre-compute latents (or load from cache)
print("\n[Step 3] Pre-computing latents...", flush=True)
cache_path = os.path.join(output_dir, 'latent_cache.pt')
if os.path.exists(cache_path):
print(" Loading cached latents...", flush=True)
cache = torch.load(cache_path, weights_only=True)
train_latents = cache['train_latents']
train_labels = cache['train_labels']
test_latents = cache['test_latents']
test_labels = cache['test_labels']
print(f" Loaded from cache. Train: {train_latents.shape}, Test: {test_latents.shape}", flush=True)
else:
train_latents, train_labels = precompute_latents(vae, trainset, batch_size=256, device=device)
test_latents, test_labels = precompute_latents(vae, testset, batch_size=256, device=device)
torch.save({
'train_latents': train_latents, 'train_labels': train_labels,
'test_latents': test_latents, 'test_labels': test_labels,
}, cache_path)
print(f" Cached latents to {cache_path}", flush=True)
# Verify VAE reconstruction works
print("\n[Step 3b] Verifying VAE reconstruction...")
with torch.no_grad():
sample_imgs = torch.stack([trainset[i][0] for i in range(8)]).to(device)
sample_lats = vae.encode(sample_imgs).latents
sample_recs = vae.decode(sample_lats).sample
recon_mse = F.mse_loss(sample_recs, sample_imgs).item()
print(f" VAE reconstruction MSE on real images: {recon_mse:.4f}")
# Save reconstruction grid
save_image_grid(
torch.cat([sample_imgs[:4], sample_recs[:4]], dim=0),
os.path.join(output_dir, 'vae_reconstruction.png'),
nrow=4, title='Top: Original, Bottom: TAESD Reconstruction'
)
print(f" Saved VAE reconstruction grid to {output_dir}/vae_reconstruction.png")
# Normalize latents for better training
lat_mean = train_latents.mean()
lat_std = train_latents.std()
print(f"\n Latent mean: {lat_mean:.4f}, std: {lat_std:.4f}")
# Scale latents to roughly unit variance
latent_scale = lat_std.item()
# Create dataset of (latent, label)
train_ds = TensorDataset(train_latents, train_labels)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
num_workers=0, drop_last=True)
# 4. Create model
print("\n[Step 4] Creating LRF denoiser...")
config = config or LRFv2.small_config()
config['latent_ch'] = train_latents.shape[1] # Should be 4
model = LRFv2(config).to(device)
params = model.count_params()
print(f" Config: dim={config['dim']}, blocks={config['num_blocks']}, "
f"T_inner={config['T_inner']}, T_outer={config['T_outer']}")
print(f" Parameters: {params['total']:,} total, {params['core']:,} core")
print(f" Effective depth: {config['T_outer'] * config['T_inner'] * config['num_blocks']} layers "
f"from {config['num_blocks']} blocks")
# 5. Training
print(f"\n[Step 5] Training for {num_epochs} epochs...")
scheduler = RectifiedFlowScheduler()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01, betas=(0.9, 0.95))
# Cosine annealing
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=num_epochs * len(train_loader), eta_min=lr * 0.01
)
# EMA for stable sampling
ema_decay = 0.999
ema_params = {name: p.clone().detach() for name, p in model.named_parameters()}
loss_history = []
best_loss = float('inf')
for epoch in range(num_epochs):
model.train()
epoch_loss = 0.0
num_batches = 0
for latents, labels in train_loader:
latents = latents.to(device)
labels = labels.to(device)
B = latents.shape[0]
# Sample timesteps and noise
t = scheduler.sample_timesteps(B, device)
noise = torch.randn_like(latents)
# Create noisy latent
z_t = scheduler.add_noise(latents, noise, t)
# Predict velocity (with 10% CFG dropout)
v_pred = model.predict_velocity(z_t, t, labels, cfg_dropout=0.1)
# Velocity target
v_target = scheduler.get_velocity_target(latents, noise)
# MSE loss with min-SNR weighting
loss_per_sample = (v_pred - v_target).pow(2).mean(dim=[1, 2, 3])
# SNR weighting: upweight middle timesteps
w = 1.0 / (t * (1 - t) + 0.01)
w = w / w.mean()
loss = (loss_per_sample * w).mean()
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
# EMA update
with torch.no_grad():
for name, p in model.named_parameters():
ema_params[name].mul_(ema_decay).add_(p, alpha=1 - ema_decay)
epoch_loss += loss.item()
num_batches += 1
avg_loss = epoch_loss / num_batches
loss_history.append(avg_loss)
if avg_loss < best_loss:
best_loss = avg_loss
if (epoch + 1) % 5 == 0 or epoch == 0:
current_lr = optimizer.param_groups[0]['lr']
print(f" Epoch {epoch+1:3d}/{num_epochs}: loss={avg_loss:.4f}, "
f"best={best_loss:.4f}, lr={current_lr:.2e}", flush=True)
# Save and generate samples periodically
if (epoch + 1) % save_every == 0 or epoch == num_epochs - 1:
# Swap to EMA for sampling
saved_params = {}
with torch.no_grad():
for name, p in model.named_parameters():
saved_params[name] = p.clone()
p.copy_(ema_params[name])
# Generate samples
model.eval()
samples = generate_samples(model, vae, scheduler, device,
num_samples=16, num_steps=10, cfg_scale=2.0)
save_image_grid(
samples,
os.path.join(output_dir, f'samples_epoch{epoch+1:03d}.png'),
nrow=8, title=f'Epoch {epoch+1}, Loss={avg_loss:.4f}'
)
# Restore original params
with torch.no_grad():
for name, p in model.named_parameters():
p.copy_(saved_params[name])
# Save checkpoint
torch.save({
'model_state': model.state_dict(),
'ema_params': ema_params,
'config': config,
'epoch': epoch + 1,
'loss': avg_loss,
'latent_scale': latent_scale,
'loss_history': loss_history,
}, os.path.join(output_dir, 'checkpoint.pt'))
# Final generation with EMA
with torch.no_grad():
for name, p in model.named_parameters():
p.copy_(ema_params[name])
model.eval()
# Generate class-conditional samples
print("\n[Step 6] Generating final samples...")
cifar_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
all_samples = []
for cls_idx in range(10):
samples = generate_samples(model, vae, scheduler, device,
num_samples=4, num_steps=50, cfg_scale=3.0,
class_label=cls_idx)
all_samples.append(samples)
all_samples = torch.cat(all_samples, dim=0)
save_image_grid(
all_samples,
os.path.join(output_dir, 'final_class_conditional.png'),
nrow=4, title='Class-conditional generation (rows: airplane, auto, bird, cat, deer, dog, frog, horse, ship, truck)'
)
# Save loss plot
save_loss_plot(loss_history, os.path.join(output_dir, 'loss.png'))
# Save config
with open(os.path.join(output_dir, 'config.json'), 'w') as f:
json.dump(config, f, indent=2)
print(f"\n{'='*60}")
print(f"Training complete! Best loss: {best_loss:.4f}")
print(f"Output directory: {output_dir}")
print(f"{'='*60}")
return model, vae, loss_history
def generate_samples(model, vae, scheduler, device, num_samples=8,
num_steps=20, cfg_scale=2.0, class_label=None):
"""Generate images from the model."""
model.eval()
# Latent shape for CIFAR-10: [B, 4, 4, 4] (32x32 image, f=8)
shape = (num_samples, 4, 4, 4)
if class_label is not None:
labels = torch.full((num_samples,), class_label, dtype=torch.long, device=device)
else:
labels = torch.randint(0, 10, (num_samples,), device=device)
z = scheduler.sample(model, shape, labels, num_steps=num_steps,
cfg_scale=cfg_scale, device=device)
# Decode through TAESD
with torch.no_grad():
images = vae.decode(z.to(device)).sample
return images.clamp(-1, 1).cpu()
def save_image_grid(images, path, nrow=8, title=''):
"""Save image grid using torchvision."""
# Convert from [-1,1] to [0,1]
images = (images + 1) / 2
images = images.clamp(0, 1)
grid = torchvision.utils.make_grid(images, nrow=nrow, padding=2, normalize=False)
# Save using PIL
from PIL import Image
grid_np = grid.permute(1, 2, 0).numpy()
grid_np = (grid_np * 255).astype(np.uint8)
img = Image.fromarray(grid_np)
img.save(path)
def save_loss_plot(losses, path):
"""Save loss curve."""
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
plt.savefig(path, dpi=100, bbox_inches='tight')
plt.close()
except ImportError:
print("matplotlib not available, skipping loss plot")
if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
train_denoiser(
config=LRFv2.fast_config(),
num_epochs=30,
batch_size=64,
lr=3e-4,
device=device,
output_dir='/app/lrf_v2_output',
save_every=5,
)