| |
| """ |
| ============================================================================ |
| LatentRecurrentFlow (LRF) — End-to-End Test Script |
| ============================================================================ |
| |
| Tests the full pipeline on CPU: |
| 1. Model creation and parameter counting |
| 2. VAE forward pass |
| 3. Flow matching forward pass |
| 4. Recursive latent core forward pass |
| 5. Full training loop (few steps) |
| 6. Sample generation |
| 7. Checkpoint save/load |
| |
| Run: python test_lrf.py |
| """ |
|
|
| import sys |
| import os |
| import time |
| import torch |
| import traceback |
|
|
| |
| sys.path.insert(0, '/app') |
|
|
| def test_model_creation(): |
| """Test model creation with different configs.""" |
| print("\n[TEST 1] Model Creation") |
| print("-" * 40) |
| |
| from lrf.model import LatentRecurrentFlow |
| |
| |
| model = LatentRecurrentFlow(LatentRecurrentFlow.tiny_config()) |
| counts = model.count_parameters() |
| print("Tiny config parameters:") |
| for name, count in counts.items(): |
| print(f" {name}: {count:,}") |
| assert counts['total'] > 0, "Model has no parameters!" |
| |
| |
| model_default = LatentRecurrentFlow(LatentRecurrentFlow.default_config()) |
| counts_default = model_default.count_parameters() |
| print("\nDefault config parameters:") |
| for name, count in counts_default.items(): |
| print(f" {name}: {count:,}") |
| assert counts_default['total'] > counts['total'], "Default should be larger than tiny" |
| |
| print("✓ Model creation passed") |
| return True |
|
|
|
|
| def test_vae(): |
| """Test VAE forward and backward.""" |
| print("\n[TEST 2] VAE Forward/Backward") |
| print("-" * 40) |
| |
| from lrf.model import CompactVAE |
| |
| vae = CompactVAE(in_channels=3, latent_channels=16, encoder_base_ch=32, decoder_base_ch=64) |
| |
| |
| enc_params = sum(p.numel() for p in vae.encoder.parameters()) |
| dec_params = sum(p.numel() for p in vae.decoder.parameters()) |
| print(f"Encoder params: {enc_params:,}") |
| print(f"Decoder params: {dec_params:,}") |
| |
| |
| x = torch.randn(2, 3, 64, 64) |
| recon, mean, logvar = vae(x) |
| print(f"Input shape: {x.shape}") |
| print(f"Latent shape: {mean.shape}") |
| print(f"Recon shape: {recon.shape}") |
| |
| assert recon.shape == x.shape, f"Reconstruction shape mismatch: {recon.shape} != {x.shape}" |
| assert mean.shape[1] == 16, f"Latent channels mismatch: {mean.shape[1]}" |
| |
| |
| loss = F.l1_loss(recon, x) - 0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp()) * 1e-6 |
| loss.backward() |
| |
| |
| grad_ok = all(p.grad is not None for p in vae.parameters() if p.requires_grad) |
| print(f"Gradients computed: {grad_ok}") |
| |
| print("✓ VAE test passed") |
| return True |
|
|
|
|
| def test_gla(): |
| """Test Gated Linear Attention.""" |
| print("\n[TEST 3] Gated Linear Attention") |
| print("-" * 40) |
| |
| from lrf.model import GatedLinearAttention |
| |
| gla = GatedLinearAttention(dim=64, num_heads=4, head_dim=16) |
| |
| B, H, W, D = 2, 8, 8, 64 |
| x = torch.randn(B, H * W, D) |
| |
| t0 = time.time() |
| out = gla(x, h=H, w=W) |
| dt = time.time() - t0 |
| |
| print(f"Input: {x.shape}") |
| print(f"Output: {out.shape}") |
| print(f"Time: {dt*1000:.1f}ms") |
| |
| assert out.shape == x.shape, f"Shape mismatch: {out.shape}" |
| |
| |
| B, H, W, D = 1, 32, 32, 64 |
| x_large = torch.randn(B, H * W, D) |
| t0 = time.time() |
| out_large = gla(x_large, h=H, w=W) |
| dt_large = time.time() - t0 |
| print(f"\nLarger input (32x32={H*W} tokens):") |
| print(f" Time: {dt_large*1000:.1f}ms") |
| |
| print("✓ GLA test passed") |
| return True |
|
|
|
|
| def test_recursive_core(): |
| """Test the Recursive Latent Core.""" |
| print("\n[TEST 4] Recursive Latent Core") |
| print("-" * 40) |
| |
| from lrf.model import RecursiveLatentCore |
| |
| core = RecursiveLatentCore( |
| dim=32, |
| cond_dim=64, |
| num_blocks=2, |
| num_heads=2, |
| head_dim=16, |
| T_inner=2, |
| T_outer=1, |
| use_ift_training=False, |
| ) |
| |
| params = sum(p.numel() for p in core.parameters()) |
| print(f"Core params: {params:,}") |
| |
| B, C, H, W = 2, 32, 4, 4 |
| z_t = torch.randn(B, C, H, W) |
| t = torch.rand(B) |
| text_emb = torch.randn(B, 10, 64) |
| text_global = torch.randn(B, 64) |
| |
| |
| t0 = time.time() |
| v = core(z_t, t, text_emb, text_global) |
| dt = time.time() - t0 |
| |
| print(f"Input shape: {z_t.shape}") |
| print(f"Output shape: {v.shape}") |
| print(f"Time: {dt*1000:.1f}ms") |
| |
| assert v.shape == z_t.shape, f"Shape mismatch: {v.shape}" |
| |
| |
| loss = v.pow(2).mean() |
| loss.backward() |
| |
| grad_ok = sum(1 for p in core.parameters() if p.grad is not None and p.requires_grad) |
| total_params = sum(1 for p in core.parameters() if p.requires_grad) |
| print(f"Params with grad: {grad_ok}/{total_params}") |
| |
| print("✓ Recursive core test passed") |
| return True |
|
|
|
|
| def test_ift_training(): |
| """Test IFT (Implicit Function Theorem) training mode.""" |
| print("\n[TEST 5] IFT Training Mode") |
| print("-" * 40) |
| |
| from lrf.model import RecursiveLatentCore |
| |
| |
| core_ift = RecursiveLatentCore( |
| dim=32, cond_dim=64, num_blocks=2, num_heads=2, head_dim=16, |
| T_inner=3, T_outer=2, use_ift_training=True, |
| ) |
| core_ift.train() |
| |
| z_t = torch.randn(2, 32, 4, 4, requires_grad=True) |
| t = torch.rand(2) |
| |
| v = core_ift(z_t, t) |
| loss = v.pow(2).mean() |
| loss.backward() |
| |
| print(f"IFT mode: loss={loss.item():.4f}") |
| print(f" T_outer={core_ift.T_outer}, T_inner={core_ift.T_inner}") |
| print(f" Effective depth: {core_ift.T_outer * core_ift.T_inner * core_ift.num_blocks} layers") |
| print(f" Actual blocks: {core_ift.num_blocks}") |
| |
| print("✓ IFT training test passed") |
| return True |
|
|
|
|
| def test_flow_matching(): |
| """Test flow matching scheduler.""" |
| print("\n[TEST 6] Flow Matching Scheduler") |
| print("-" * 40) |
| |
| from lrf.training import RectifiedFlowScheduler |
| |
| scheduler = RectifiedFlowScheduler(shift=1.0) |
| |
| z_0 = torch.randn(2, 16, 4, 4) |
| noise = torch.randn_like(z_0) |
| t = torch.tensor([0.0, 0.5]) |
| |
| z_t = scheduler.add_noise(z_0, noise, t) |
| v_target = scheduler.get_velocity_target(z_0, noise) |
| |
| print(f"z_0 shape: {z_0.shape}") |
| print(f"z_t shape: {z_t.shape}") |
| print(f"v_target shape: {v_target.shape}") |
| |
| |
| t_zero = torch.tensor([0.0, 0.0]) |
| z_t_zero = scheduler.add_noise(z_0, noise, t_zero) |
| diff = (z_t_zero - z_0).abs().max().item() |
| print(f"At t=0, |z_t - z_0| max = {diff:.6f}") |
| assert diff < 1e-5, f"At t=0, z_t should equal z_0, got diff={diff}" |
| |
| |
| t_one = torch.tensor([1.0, 1.0]) |
| z_t_one = scheduler.add_noise(z_0, noise, t_one) |
| diff_one = (z_t_one - noise).abs().max().item() |
| print(f"At t=1, |z_t - noise| max = {diff_one:.6f}") |
| assert diff_one < 1e-5, f"At t=1, z_t should equal noise, got diff={diff_one}" |
| |
| print("✓ Flow matching test passed") |
| return True |
|
|
|
|
| def test_full_training(): |
| """Test full training pipeline.""" |
| print("\n[TEST 7] Full Training Pipeline") |
| print("-" * 40) |
| |
| from lrf.model import LatentRecurrentFlow |
| from lrf.training import LRFTrainer, SyntheticImageTextDataset |
| from torch.utils.data import DataLoader |
| |
| config = LatentRecurrentFlow.tiny_config() |
| model = LatentRecurrentFlow(config) |
| |
| trainer = LRFTrainer(model, torch.device('cpu'), '/app/test_checkpoints') |
| |
| dataset = SyntheticImageTextDataset(num_samples=16, image_size=64, max_text_length=32) |
| dataloader = DataLoader(dataset, batch_size=4, shuffle=True) |
| |
| |
| print(" Training VAE...") |
| vae_opt = torch.optim.AdamW(model.vae.parameters(), lr=1e-3) |
| for i, batch in enumerate(dataloader): |
| if i >= 3: |
| break |
| losses = trainer.train_vae_step(batch['image'], vae_opt) |
| print(f" VAE step {i}: loss={losses['total']:.4f}") |
| |
| |
| print(" Training flow matching...") |
| for p in model.vae.parameters(): |
| p.requires_grad = False |
| |
| flow_params = list(model.core.parameters()) + list(model.text_encoder.parameters()) |
| flow_opt = torch.optim.AdamW(flow_params, lr=1e-3) |
| |
| for i, batch in enumerate(dataloader): |
| if i >= 3: |
| break |
| losses = trainer.train_flow_step( |
| batch['image'], batch['token_ids'], batch['attention_mask'], |
| flow_opt, |
| ) |
| print(f" Flow step {i}: loss={losses['flow_loss']:.4f}") |
| |
| |
| print(" Generating samples...") |
| sample_tokens = torch.randint(1, 31999, (2, 32)) |
| sample_mask = torch.ones(2, 32) |
| |
| images = trainer.generate( |
| sample_tokens, sample_mask, |
| num_steps=5, cfg_scale=1.0, |
| latent_h=4, latent_w=4, |
| ) |
| print(f" Generated: {images.shape}, range=[{images.min():.3f}, {images.max():.3f}]") |
| |
| |
| print(" Saving checkpoint...") |
| trainer.save_checkpoint('/app/test_checkpoints/test.pt', 'test', 0) |
| trainer.load_checkpoint('/app/test_checkpoints/test.pt') |
| |
| print("✓ Full training pipeline test passed") |
| return True |
|
|
|
|
| def test_memory_estimate(): |
| """Estimate memory usage for different configs.""" |
| print("\n[TEST 8] Memory Estimation") |
| print("-" * 40) |
| |
| from lrf.model import LatentRecurrentFlow |
| |
| configs = { |
| 'tiny': LatentRecurrentFlow.tiny_config(), |
| 'default': LatentRecurrentFlow.default_config(), |
| } |
| |
| for name, config in configs.items(): |
| model = LatentRecurrentFlow(config) |
| counts = model.count_parameters() |
| |
| |
| param_bytes = counts['total'] * 4 |
| param_mb = param_bytes / (1024 * 1024) |
| |
| |
| param_int8_mb = counts['total'] * 1 / (1024 * 1024) |
| |
| print(f"\n{name} config:") |
| print(f" Total params: {counts['total']:,}") |
| print(f" FP32 size: {param_mb:.1f} MB") |
| print(f" INT8 size: {param_int8_mb:.1f} MB") |
| |
| |
| latent_h = 256 // 16 |
| latent_w = 256 // 16 |
| latent_tokens = latent_h * latent_w |
| act_bytes = 2 * latent_tokens * config['latent_channels'] * 4 |
| act_mb = act_bytes / (1024 * 1024) |
| print(f" Est. activation memory (256x256): {act_mb:.1f} MB") |
| |
| del model |
| |
| print("\n✓ Memory estimation passed") |
| return True |
|
|
|
|
| |
| import torch.nn.functional as F |
|
|
| def main(): |
| """Run all tests.""" |
| print("=" * 60) |
| print("LatentRecurrentFlow (LRF) - End-to-End Tests") |
| print("=" * 60) |
| |
| tests = [ |
| ("Model Creation", test_model_creation), |
| ("VAE", test_vae), |
| ("GLA", test_gla), |
| ("Recursive Core", test_recursive_core), |
| ("IFT Training", test_ift_training), |
| ("Flow Matching", test_flow_matching), |
| ("Full Training", test_full_training), |
| ("Memory Estimate", test_memory_estimate), |
| ] |
| |
| results = [] |
| for name, test_fn in tests: |
| try: |
| passed = test_fn() |
| results.append((name, passed)) |
| except Exception as e: |
| print(f"\n✗ {name} FAILED: {e}") |
| traceback.print_exc() |
| results.append((name, False)) |
| |
| print("\n" + "=" * 60) |
| print("Test Summary") |
| print("=" * 60) |
| |
| all_passed = True |
| for name, passed in results: |
| status = "✓ PASS" if passed else "✗ FAIL" |
| print(f" {status}: {name}") |
| if not passed: |
| all_passed = False |
| |
| if all_passed: |
| print("\n✓ ALL TESTS PASSED!") |
| else: |
| print("\n✗ SOME TESTS FAILED!") |
| sys.exit(1) |
| |
| return all_passed |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|