| """ |
| Comprehensive smoke test for LiquidFlow. |
| Tests: all model sizes, forward/backward, gradient health, |
| loss convergence direction, sampling, checkpoint save/load. |
| NO actual training — just confirms everything is wired correctly. |
| """ |
| import sys, os, json, tempfile |
| sys.path.insert(0, '/app') |
|
|
| import torch |
| import torch.nn as nn |
| from liquidflow.model import ( |
| liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512, |
| LiquidCfCCell, SelectiveSSM, LiquidSSMBlock, create_scan_patterns |
| ) |
| from liquidflow.losses import PhysicsInformedFlowLoss, EMAModel |
| from liquidflow.sampling import euler_sample, heun_sample, make_grid_image |
|
|
| PASS = 0 |
| FAIL = 0 |
|
|
| def check(name, condition): |
| global PASS, FAIL |
| if condition: |
| PASS += 1 |
| print(f" ✅ {name}") |
| else: |
| FAIL += 1 |
| print(f" ❌ {name}") |
|
|
| |
| print("=" * 60) |
| print("1. MODEL VARIANTS — forward pass + shapes") |
| print("=" * 60) |
|
|
| configs = [ |
| ("tiny-128", liquidflow_tiny, 128, 2), |
| ("small-128", liquidflow_small, 128, 2), |
| ("base-256", liquidflow_base, 256, 1), |
| ("512", liquidflow_512, 512, 1), |
| ] |
|
|
| for tag, factory, img_sz, bs in configs: |
| m = factory(img_size=img_sz) |
| p = m.count_params() |
| x = torch.randn(bs, 3, img_sz, img_sz) |
| t = torch.rand(bs) |
| v = m(x, t) |
| check(f"{tag}: {p/1e6:.1f}M params, output shape {v.shape}", |
| v.shape == x.shape) |
|
|
| |
| print("\n" + "=" * 60) |
| print("2. BACKWARD PASS — gradients exist for every param") |
| print("=" * 60) |
|
|
| m = liquidflow_tiny(32) |
| x1 = torch.randn(2, 3, 32, 32) |
| x0 = torch.randn(2, 3, 32, 32) |
| t = torch.rand(2) |
| t_e = t.view(2,1,1,1) |
| x_t = t_e * x1 + (1-t_e) * x0 |
| v = m(x_t, t) |
| loss_fn = PhysicsInformedFlowLoss() |
| loss, ld = loss_fn(v, x0, x1, t, step=100) |
| loss.backward() |
|
|
| no_grad_params = [] |
| for name, p in m.named_parameters(): |
| if p.requires_grad and p.grad is None: |
| no_grad_params.append(name) |
| check("All parameters receive gradients", len(no_grad_params) == 0) |
| if no_grad_params: |
| print(f" Missing grads: {no_grad_params[:5]}...") |
|
|
| |
| print("\n" + "=" * 60) |
| print("3. GRADIENT HEALTH — no NaN, no Inf, reasonable norms") |
| print("=" * 60) |
|
|
| has_nan = any(torch.isnan(p.grad).any() for p in m.parameters() if p.grad is not None) |
| has_inf = any(torch.isinf(p.grad).any() for p in m.parameters() if p.grad is not None) |
| max_grad = max(p.grad.abs().max().item() for p in m.parameters() if p.grad is not None) |
|
|
| check("No NaN gradients", not has_nan) |
| check("No Inf gradients", not has_inf) |
| check(f"Max grad norm reasonable ({max_grad:.4f} < 100)", max_grad < 100) |
|
|
| |
| print("\n" + "=" * 60) |
| print("4. LOSS CONVERGENCE DIRECTION — 3 optimizer steps") |
| print("=" * 60) |
|
|
| m2 = liquidflow_tiny(32) |
| opt = torch.optim.AdamW(m2.parameters(), lr=1e-3) |
| losses_track = [] |
| for step in range(3): |
| x1 = torch.randn(4, 3, 32, 32) |
| x0 = torch.randn(4, 3, 32, 32) |
| t = torch.rand(4); t_e = t.view(4,1,1,1) |
| x_t = t_e*x1 + (1-t_e)*x0 |
| v = m2(x_t, t) |
| loss, _ = loss_fn(v, x0, x1, t, step=step) |
| opt.zero_grad(); loss.backward(); opt.step() |
| losses_track.append(loss.item()) |
| |
| check(f"Loss finite across steps: {[f'{l:.4f}' for l in losses_track]}", |
| all(not (l != l or abs(l) > 1e6) for l in losses_track)) |
|
|
| |
| print("\n" + "=" * 60) |
| print("5. INDIVIDUAL COMPONENTS") |
| print("=" * 60) |
|
|
| |
| cell = LiquidCfCCell(64, 64) |
| out = cell(torch.randn(2, 16, 64)) |
| check(f"LiquidCfCCell: input (2,16,64) → output {out.shape}", out.shape == (2,16,64)) |
|
|
| |
| ssm = SelectiveSSM(64, d_state=8) |
| out = ssm(torch.randn(2, 16, 64)) |
| check(f"SelectiveSSM: input (2,16,64) → output {out.shape}", out.shape == (2,16,64)) |
|
|
| |
| block = LiquidSSMBlock(64, d_state=8) |
| out = block(torch.randn(2, 16, 64)) |
| check(f"LiquidSSMBlock: input (2,16,64) → output {out.shape}", out.shape == (2,16,64)) |
|
|
| |
| patterns, inv = create_scan_patterns(8, 8) |
| check(f"Scan patterns: {len(patterns)} patterns of length {len(patterns[0])}", |
| len(patterns) == 4 and len(patterns[0]) == 64) |
|
|
| |
| for i, (p, ip) in enumerate(zip(patterns, inv)): |
| dummy = torch.arange(64) |
| recovered = dummy[p][ip] |
| check(f"Scan pattern {i}: scan→unscan is identity", torch.equal(recovered, dummy)) |
|
|
| |
| print("\n" + "=" * 60) |
| print("6. SAMPLING — Euler & Heun produce valid images") |
| print("=" * 60) |
|
|
| m3 = liquidflow_tiny(32) |
| m3.eval() |
|
|
| with torch.no_grad(): |
| imgs_euler = euler_sample(m3, (4,3,32,32), num_steps=5) |
| check(f"Euler sample shape {imgs_euler.shape}, finite", |
| imgs_euler.shape == (4,3,32,32) and torch.isfinite(imgs_euler).all()) |
| |
| imgs_heun = heun_sample(m3, (4,3,32,32), num_steps=5) |
| check(f"Heun sample shape {imgs_heun.shape}, finite", |
| imgs_heun.shape == (4,3,32,32) and torch.isfinite(imgs_heun).all()) |
| |
| clamped = imgs_euler.clamp(-1,1)*0.5+0.5 |
| grid = make_grid_image(clamped, nrow=2) |
| grid.save('/app/smoke_test_grid.png') |
| check(f"Grid image saved ({grid.size})", grid.size[0] > 0) |
|
|
| |
| print("\n" + "=" * 60) |
| print("7. EMA — shadow copy matches, save/load works") |
| print("=" * 60) |
|
|
| m4 = liquidflow_tiny(32) |
| ema = EMAModel(m4, decay=0.999) |
| ema.update(m4) |
| ema.update(m4) |
| ema.apply_shadow(m4) |
| |
| ema.restore(m4) |
| check("EMA apply/restore cycle completes", True) |
|
|
| sd = ema.state_dict() |
| check("EMA state_dict has shadow and step", |
| 'shadow' in sd and 'step' in sd) |
|
|
| |
| print("\n" + "=" * 60) |
| print("8. CHECKPOINT — save & reload matches") |
| print("=" * 60) |
|
|
| m5 = liquidflow_tiny(32) |
| opt5 = torch.optim.AdamW(m5.parameters(), lr=1e-3) |
| ckpt = { |
| 'model': m5.state_dict(), |
| 'optimizer': opt5.state_dict(), |
| 'epoch': 5, |
| 'global_step': 100, |
| } |
| tmp = tempfile.mktemp(suffix='.pt') |
| torch.save(ckpt, tmp) |
|
|
| m6 = liquidflow_tiny(32) |
| loaded = torch.load(tmp, map_location='cpu', weights_only=False) |
| m6.load_state_dict(loaded['model']) |
| check("Checkpoint save/load cycle works", loaded['epoch'] == 5) |
| os.remove(tmp) |
|
|
| |
| print("\n" + "=" * 60) |
| print("9. PHYSICS LOSS COMPONENTS — each term finite & positive") |
| print("=" * 60) |
|
|
| x_fake = torch.randn(2, 3, 32, 32) |
| lf = PhysicsInformedFlowLoss(lambda_smooth=0.01, lambda_tv=0.001) |
| sm = lf.smoothness_loss(x_fake) |
| tv = lf.total_variation_loss(x_fake) |
| check(f"Smoothness loss: {sm.item():.4f} (finite, positive)", |
| torch.isfinite(sm) and sm.item() > 0) |
| check(f"TV loss: {tv.item():.4f} (finite, positive)", |
| torch.isfinite(tv) and tv.item() > 0) |
|
|
| |
| print("\n" + "=" * 60) |
| print("10. MEMORY FOOTPRINT SUMMARY") |
| print("=" * 60) |
|
|
| for tag, factory, img_sz in [("tiny-32",liquidflow_tiny,32), |
| ("tiny-128",liquidflow_tiny,128), |
| ("small-128",liquidflow_small,128), |
| ("base-256",liquidflow_base,256), |
| ("512",liquidflow_512,512)]: |
| m = factory(img_size=img_sz) |
| p = m.count_params() |
| |
| model_gb = p * 2 / 1e9 |
| opt_gb = p * 8 / 1e9 |
| tokens = (img_sz // m.patch_size) ** 2 |
| print(f" {tag:12s}: {p/1e6:6.1f}M params | " |
| f"model={model_gb*1000:.0f}MB | opt={opt_gb*1000:.0f}MB | " |
| f"tokens={tokens:5d} | patch={m.patch_size}") |
|
|
| |
| print("\n" + "=" * 60) |
| print(f"RESULTS: {PASS} passed, {FAIL} failed") |
| print("=" * 60) |
| sys.exit(0 if FAIL == 0 else 1) |
|
|