#!/usr/bin/env python3 """Run all notebook cells as a script to verify everything works.""" import sys, os sys.path.insert(0, '/app') os.environ['MPLBACKEND'] = 'Agg' # Non-interactive matplotlib import torch import numpy as np device = torch.device('cpu') print(f'Using device: {device}') # === Cell 1: Architecture Overview === print("\n=== Cell 1: Architecture Overview ===") from lrf.model import LatentRecurrentFlow, RecursiveLatentCore, CompactVAE, GatedLinearAttention from lrf.training import LRFTrainer, RectifiedFlowScheduler, SyntheticImageTextDataset from lrf.pipeline import LRFPipeline, LRFTrainingPipeline configs = { 'Tiny (5.7M)': LatentRecurrentFlow.tiny_config(), 'Default (16.3M)': LatentRecurrentFlow.default_config(), } for name, config in configs.items(): model = LatentRecurrentFlow(config) counts = model.count_parameters() print(f'\n{name}:') for module, count in counts.items(): print(f' {module:20s}: {count:>12,}') del model # === Cell 2: VAE Training === print("\n=== Cell 2: VAE Training ===") config = LatentRecurrentFlow.tiny_config() model = LatentRecurrentFlow(config).to(device) from torch.utils.data import DataLoader dataset = SyntheticImageTextDataset(num_samples=100, image_size=64, max_text_length=32) dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0) trainer = LRFTrainer(model, device, '/app/nb_checkpoints') vae_optimizer = torch.optim.AdamW(model.vae.parameters(), lr=1e-3, weight_decay=0.01) for i, batch in enumerate(dataloader): if i >= 10: break losses = trainer.train_vae_step(batch['image'], vae_optimizer) if i % 5 == 0: print(f' VAE step {i}: loss={losses["total"]:.4f}') trainer.save_checkpoint('/app/nb_checkpoints/vae.pt', 'vae', 0) # VAE reconstruction model.eval() with torch.no_grad(): sample_batch = next(iter(dataloader)) images = sample_batch['image'].to(device) recon, _, _ = model.vae(images) print(f' Reconstruction MSE: {((recon - images)**2).mean():.4f}') # === Cell 3: Flow Matching Training === print("\n=== Cell 3: Flow Matching Training ===") for p in model.vae.parameters(): p.requires_grad = False flow_params = list(model.core.parameters()) + list(model.text_encoder.parameters()) flow_optimizer = torch.optim.AdamW(flow_params, lr=1e-3, weight_decay=0.01) model.core.train() model.text_encoder.train() for i, batch in enumerate(dataloader): if i >= 10: break losses = trainer.train_flow_step( batch['image'], batch['token_ids'], batch['attention_mask'], flow_optimizer, cfg_dropout=0.1 ) if i % 5 == 0: print(f' Flow step {i}: loss={losses["flow_loss"]:.4f}') trainer.save_checkpoint('/app/nb_checkpoints/flow.pt', 'flow', 0) # === Cell 4: Generation === print("\n=== Cell 4: Generation ===") model.eval() pipe = LRFPipeline(model, device=device) prompts = ['a sunset', 'a cat', 'mountains', 'abstract art'] images = pipe(prompts, num_steps=5, cfg_scale=1.0, height=64, width=64, seed=42) print(f' Generated {images.shape[0]} images: {images.shape}') print(f' Range: [{images.min():.3f}, {images.max():.3f}]') # === Cell 5: Save & Load === print("\n=== Cell 5: Save & Load ===") pipe.save_pretrained('/app/nb_model') print(' Model saved to /app/nb_model/') for f in os.listdir('/app/nb_model'): size = os.path.getsize(f'/app/nb_model/{f}') print(f' {f}: {size/1024:.1f} KB') pipe_loaded = LRFPipeline.from_pretrained('/app/nb_model', device=str(device)) images_loaded = pipe_loaded('test prompt', num_steps=5, height=64, width=64, seed=42) print(f' Reloaded model generates: {images_loaded.shape}') # === Cell 6: Training Curriculum === print("\n=== Cell 6: Training Curriculum ===") curriculum = LRFTrainingPipeline.get_curriculum() for i, stage_name in enumerate(curriculum): stage = LRFTrainingPipeline.get_stage_config(stage_name) print(f' Stage {i+1}: {stage_name} - {stage["description"]}') # === Cell 7: Core Architecture === print("\n=== Cell 7: Core Architecture ===") core = RecursiveLatentCore( dim=32, cond_dim=64, num_blocks=2, num_heads=2, head_dim=16, T_inner=4, T_outer=2, use_ift_training=False ) print(f' Effective depth: {core.T_outer * core.T_inner * core.num_blocks} layers') print(f' Parameters: {sum(p.numel() for p in core.parameters()):,}') # === Cell 8: GLA Scaling === print("\n=== Cell 8: GLA Scaling ===") import time gla = GatedLinearAttention(dim=64, num_heads=4, head_dim=16) for s in [4, 8, 16, 32]: x = torch.randn(1, s*s, 64) _ = gla(x, h=s, w=s) # warmup t0 = time.time() for _ in range(5): _ = gla(x, h=s, w=s) dt = (time.time() - t0) / 5 print(f' {s}×{s} = {s*s:>5} tokens: {dt*1000:.2f}ms') print("\n✅ ALL NOTEBOOK CELLS VERIFIED!")