File size: 4,803 Bytes
0034111 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | #!/usr/bin/env python3
"""Run all notebook cells as a script to verify everything works."""
import sys, os
sys.path.insert(0, '/app')
os.environ['MPLBACKEND'] = 'Agg' # Non-interactive matplotlib
import torch
import numpy as np
device = torch.device('cpu')
print(f'Using device: {device}')
# === Cell 1: Architecture Overview ===
print("\n=== Cell 1: Architecture Overview ===")
from lrf.model import LatentRecurrentFlow, RecursiveLatentCore, CompactVAE, GatedLinearAttention
from lrf.training import LRFTrainer, RectifiedFlowScheduler, SyntheticImageTextDataset
from lrf.pipeline import LRFPipeline, LRFTrainingPipeline
configs = {
'Tiny (5.7M)': LatentRecurrentFlow.tiny_config(),
'Default (16.3M)': LatentRecurrentFlow.default_config(),
}
for name, config in configs.items():
model = LatentRecurrentFlow(config)
counts = model.count_parameters()
print(f'\n{name}:')
for module, count in counts.items():
print(f' {module:20s}: {count:>12,}')
del model
# === Cell 2: VAE Training ===
print("\n=== Cell 2: VAE Training ===")
config = LatentRecurrentFlow.tiny_config()
model = LatentRecurrentFlow(config).to(device)
from torch.utils.data import DataLoader
dataset = SyntheticImageTextDataset(num_samples=100, image_size=64, max_text_length=32)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)
trainer = LRFTrainer(model, device, '/app/nb_checkpoints')
vae_optimizer = torch.optim.AdamW(model.vae.parameters(), lr=1e-3, weight_decay=0.01)
for i, batch in enumerate(dataloader):
if i >= 10:
break
losses = trainer.train_vae_step(batch['image'], vae_optimizer)
if i % 5 == 0:
print(f' VAE step {i}: loss={losses["total"]:.4f}')
trainer.save_checkpoint('/app/nb_checkpoints/vae.pt', 'vae', 0)
# VAE reconstruction
model.eval()
with torch.no_grad():
sample_batch = next(iter(dataloader))
images = sample_batch['image'].to(device)
recon, _, _ = model.vae(images)
print(f' Reconstruction MSE: {((recon - images)**2).mean():.4f}')
# === Cell 3: Flow Matching Training ===
print("\n=== Cell 3: Flow Matching Training ===")
for p in model.vae.parameters():
p.requires_grad = False
flow_params = list(model.core.parameters()) + list(model.text_encoder.parameters())
flow_optimizer = torch.optim.AdamW(flow_params, lr=1e-3, weight_decay=0.01)
model.core.train()
model.text_encoder.train()
for i, batch in enumerate(dataloader):
if i >= 10:
break
losses = trainer.train_flow_step(
batch['image'], batch['token_ids'], batch['attention_mask'],
flow_optimizer, cfg_dropout=0.1
)
if i % 5 == 0:
print(f' Flow step {i}: loss={losses["flow_loss"]:.4f}')
trainer.save_checkpoint('/app/nb_checkpoints/flow.pt', 'flow', 0)
# === Cell 4: Generation ===
print("\n=== Cell 4: Generation ===")
model.eval()
pipe = LRFPipeline(model, device=device)
prompts = ['a sunset', 'a cat', 'mountains', 'abstract art']
images = pipe(prompts, num_steps=5, cfg_scale=1.0, height=64, width=64, seed=42)
print(f' Generated {images.shape[0]} images: {images.shape}')
print(f' Range: [{images.min():.3f}, {images.max():.3f}]')
# === Cell 5: Save & Load ===
print("\n=== Cell 5: Save & Load ===")
pipe.save_pretrained('/app/nb_model')
print(' Model saved to /app/nb_model/')
for f in os.listdir('/app/nb_model'):
size = os.path.getsize(f'/app/nb_model/{f}')
print(f' {f}: {size/1024:.1f} KB')
pipe_loaded = LRFPipeline.from_pretrained('/app/nb_model', device=str(device))
images_loaded = pipe_loaded('test prompt', num_steps=5, height=64, width=64, seed=42)
print(f' Reloaded model generates: {images_loaded.shape}')
# === Cell 6: Training Curriculum ===
print("\n=== Cell 6: Training Curriculum ===")
curriculum = LRFTrainingPipeline.get_curriculum()
for i, stage_name in enumerate(curriculum):
stage = LRFTrainingPipeline.get_stage_config(stage_name)
print(f' Stage {i+1}: {stage_name} - {stage["description"]}')
# === Cell 7: Core Architecture ===
print("\n=== Cell 7: Core Architecture ===")
core = RecursiveLatentCore(
dim=32, cond_dim=64, num_blocks=2, num_heads=2, head_dim=16,
T_inner=4, T_outer=2, use_ift_training=False
)
print(f' Effective depth: {core.T_outer * core.T_inner * core.num_blocks} layers')
print(f' Parameters: {sum(p.numel() for p in core.parameters()):,}')
# === Cell 8: GLA Scaling ===
print("\n=== Cell 8: GLA Scaling ===")
import time
gla = GatedLinearAttention(dim=64, num_heads=4, head_dim=16)
for s in [4, 8, 16, 32]:
x = torch.randn(1, s*s, 64)
_ = gla(x, h=s, w=s) # warmup
t0 = time.time()
for _ in range(5):
_ = gla(x, h=s, w=s)
dt = (time.time() - t0) / 5
print(f' {s}×{s} = {s*s:>5} tokens: {dt*1000:.2f}ms')
print("\n✅ ALL NOTEBOOK CELLS VERIFIED!")
|