LatentRecurrentFlow / tests /test_notebook.py
krystv's picture
Upload tests/test_notebook.py with huggingface_hub
0034111 verified
#!/usr/bin/env python3
"""Run all notebook cells as a script to verify everything works."""
import sys, os
sys.path.insert(0, '/app')
os.environ['MPLBACKEND'] = 'Agg' # Non-interactive matplotlib
import torch
import numpy as np
device = torch.device('cpu')
print(f'Using device: {device}')
# === Cell 1: Architecture Overview ===
print("\n=== Cell 1: Architecture Overview ===")
from lrf.model import LatentRecurrentFlow, RecursiveLatentCore, CompactVAE, GatedLinearAttention
from lrf.training import LRFTrainer, RectifiedFlowScheduler, SyntheticImageTextDataset
from lrf.pipeline import LRFPipeline, LRFTrainingPipeline
configs = {
'Tiny (5.7M)': LatentRecurrentFlow.tiny_config(),
'Default (16.3M)': LatentRecurrentFlow.default_config(),
}
for name, config in configs.items():
model = LatentRecurrentFlow(config)
counts = model.count_parameters()
print(f'\n{name}:')
for module, count in counts.items():
print(f' {module:20s}: {count:>12,}')
del model
# === Cell 2: VAE Training ===
print("\n=== Cell 2: VAE Training ===")
config = LatentRecurrentFlow.tiny_config()
model = LatentRecurrentFlow(config).to(device)
from torch.utils.data import DataLoader
dataset = SyntheticImageTextDataset(num_samples=100, image_size=64, max_text_length=32)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)
trainer = LRFTrainer(model, device, '/app/nb_checkpoints')
vae_optimizer = torch.optim.AdamW(model.vae.parameters(), lr=1e-3, weight_decay=0.01)
for i, batch in enumerate(dataloader):
if i >= 10:
break
losses = trainer.train_vae_step(batch['image'], vae_optimizer)
if i % 5 == 0:
print(f' VAE step {i}: loss={losses["total"]:.4f}')
trainer.save_checkpoint('/app/nb_checkpoints/vae.pt', 'vae', 0)
# VAE reconstruction
model.eval()
with torch.no_grad():
sample_batch = next(iter(dataloader))
images = sample_batch['image'].to(device)
recon, _, _ = model.vae(images)
print(f' Reconstruction MSE: {((recon - images)**2).mean():.4f}')
# === Cell 3: Flow Matching Training ===
print("\n=== Cell 3: Flow Matching Training ===")
for p in model.vae.parameters():
p.requires_grad = False
flow_params = list(model.core.parameters()) + list(model.text_encoder.parameters())
flow_optimizer = torch.optim.AdamW(flow_params, lr=1e-3, weight_decay=0.01)
model.core.train()
model.text_encoder.train()
for i, batch in enumerate(dataloader):
if i >= 10:
break
losses = trainer.train_flow_step(
batch['image'], batch['token_ids'], batch['attention_mask'],
flow_optimizer, cfg_dropout=0.1
)
if i % 5 == 0:
print(f' Flow step {i}: loss={losses["flow_loss"]:.4f}')
trainer.save_checkpoint('/app/nb_checkpoints/flow.pt', 'flow', 0)
# === Cell 4: Generation ===
print("\n=== Cell 4: Generation ===")
model.eval()
pipe = LRFPipeline(model, device=device)
prompts = ['a sunset', 'a cat', 'mountains', 'abstract art']
images = pipe(prompts, num_steps=5, cfg_scale=1.0, height=64, width=64, seed=42)
print(f' Generated {images.shape[0]} images: {images.shape}')
print(f' Range: [{images.min():.3f}, {images.max():.3f}]')
# === Cell 5: Save & Load ===
print("\n=== Cell 5: Save & Load ===")
pipe.save_pretrained('/app/nb_model')
print(' Model saved to /app/nb_model/')
for f in os.listdir('/app/nb_model'):
size = os.path.getsize(f'/app/nb_model/{f}')
print(f' {f}: {size/1024:.1f} KB')
pipe_loaded = LRFPipeline.from_pretrained('/app/nb_model', device=str(device))
images_loaded = pipe_loaded('test prompt', num_steps=5, height=64, width=64, seed=42)
print(f' Reloaded model generates: {images_loaded.shape}')
# === Cell 6: Training Curriculum ===
print("\n=== Cell 6: Training Curriculum ===")
curriculum = LRFTrainingPipeline.get_curriculum()
for i, stage_name in enumerate(curriculum):
stage = LRFTrainingPipeline.get_stage_config(stage_name)
print(f' Stage {i+1}: {stage_name} - {stage["description"]}')
# === Cell 7: Core Architecture ===
print("\n=== Cell 7: Core Architecture ===")
core = RecursiveLatentCore(
dim=32, cond_dim=64, num_blocks=2, num_heads=2, head_dim=16,
T_inner=4, T_outer=2, use_ift_training=False
)
print(f' Effective depth: {core.T_outer * core.T_inner * core.num_blocks} layers')
print(f' Parameters: {sum(p.numel() for p in core.parameters()):,}')
# === Cell 8: GLA Scaling ===
print("\n=== Cell 8: GLA Scaling ===")
import time
gla = GatedLinearAttention(dim=64, num_heads=4, head_dim=16)
for s in [4, 8, 16, 32]:
x = torch.randn(1, s*s, 64)
_ = gla(x, h=s, w=s) # warmup
t0 = time.time()
for _ in range(5):
_ = gla(x, h=s, w=s)
dt = (time.time() - t0) / 5
print(f' {s}×{s} = {s*s:>5} tokens: {dt*1000:.2f}ms')
print("\n✅ ALL NOTEBOOK CELLS VERIFIED!")