LiquidFlow / smoke_test.py
krystv's picture
Add smoke_test.py — 25 comprehensive CPU tests
02e8800 verified
"""
Comprehensive smoke test for LiquidFlow.
Tests: all model sizes, forward/backward, gradient health,
loss convergence direction, sampling, checkpoint save/load.
NO actual training — just confirms everything is wired correctly.
"""
import sys, os, json, tempfile
sys.path.insert(0, '/app')
import torch
import torch.nn as nn
from liquidflow.model import (
liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512,
LiquidCfCCell, SelectiveSSM, LiquidSSMBlock, create_scan_patterns
)
from liquidflow.losses import PhysicsInformedFlowLoss, EMAModel
from liquidflow.sampling import euler_sample, heun_sample, make_grid_image
PASS = 0
FAIL = 0
def check(name, condition):
global PASS, FAIL
if condition:
PASS += 1
print(f" ✅ {name}")
else:
FAIL += 1
print(f" ❌ {name}")
# =========================================================
print("=" * 60)
print("1. MODEL VARIANTS — forward pass + shapes")
print("=" * 60)
configs = [
("tiny-128", liquidflow_tiny, 128, 2),
("small-128", liquidflow_small, 128, 2),
("base-256", liquidflow_base, 256, 1),
("512", liquidflow_512, 512, 1),
]
for tag, factory, img_sz, bs in configs:
m = factory(img_size=img_sz)
p = m.count_params()
x = torch.randn(bs, 3, img_sz, img_sz)
t = torch.rand(bs)
v = m(x, t)
check(f"{tag}: {p/1e6:.1f}M params, output shape {v.shape}",
v.shape == x.shape)
# =========================================================
print("\n" + "=" * 60)
print("2. BACKWARD PASS — gradients exist for every param")
print("=" * 60)
m = liquidflow_tiny(32)
x1 = torch.randn(2, 3, 32, 32)
x0 = torch.randn(2, 3, 32, 32)
t = torch.rand(2)
t_e = t.view(2,1,1,1)
x_t = t_e * x1 + (1-t_e) * x0
v = m(x_t, t)
loss_fn = PhysicsInformedFlowLoss()
loss, ld = loss_fn(v, x0, x1, t, step=100)
loss.backward()
no_grad_params = []
for name, p in m.named_parameters():
if p.requires_grad and p.grad is None:
no_grad_params.append(name)
check("All parameters receive gradients", len(no_grad_params) == 0)
if no_grad_params:
print(f" Missing grads: {no_grad_params[:5]}...")
# =========================================================
print("\n" + "=" * 60)
print("3. GRADIENT HEALTH — no NaN, no Inf, reasonable norms")
print("=" * 60)
has_nan = any(torch.isnan(p.grad).any() for p in m.parameters() if p.grad is not None)
has_inf = any(torch.isinf(p.grad).any() for p in m.parameters() if p.grad is not None)
max_grad = max(p.grad.abs().max().item() for p in m.parameters() if p.grad is not None)
check("No NaN gradients", not has_nan)
check("No Inf gradients", not has_inf)
check(f"Max grad norm reasonable ({max_grad:.4f} < 100)", max_grad < 100)
# =========================================================
print("\n" + "=" * 60)
print("4. LOSS CONVERGENCE DIRECTION — 3 optimizer steps")
print("=" * 60)
m2 = liquidflow_tiny(32)
opt = torch.optim.AdamW(m2.parameters(), lr=1e-3)
losses_track = []
for step in range(3):
x1 = torch.randn(4, 3, 32, 32)
x0 = torch.randn(4, 3, 32, 32)
t = torch.rand(4); t_e = t.view(4,1,1,1)
x_t = t_e*x1 + (1-t_e)*x0
v = m2(x_t, t)
loss, _ = loss_fn(v, x0, x1, t, step=step)
opt.zero_grad(); loss.backward(); opt.step()
losses_track.append(loss.item())
check(f"Loss finite across steps: {[f'{l:.4f}' for l in losses_track]}",
all(not (l != l or abs(l) > 1e6) for l in losses_track)) # no NaN, not huge
# =========================================================
print("\n" + "=" * 60)
print("5. INDIVIDUAL COMPONENTS")
print("=" * 60)
# LiquidCfCCell
cell = LiquidCfCCell(64, 64)
out = cell(torch.randn(2, 16, 64))
check(f"LiquidCfCCell: input (2,16,64) → output {out.shape}", out.shape == (2,16,64))
# SelectiveSSM
ssm = SelectiveSSM(64, d_state=8)
out = ssm(torch.randn(2, 16, 64))
check(f"SelectiveSSM: input (2,16,64) → output {out.shape}", out.shape == (2,16,64))
# LiquidSSMBlock
block = LiquidSSMBlock(64, d_state=8)
out = block(torch.randn(2, 16, 64))
check(f"LiquidSSMBlock: input (2,16,64) → output {out.shape}", out.shape == (2,16,64))
# Scan patterns
patterns, inv = create_scan_patterns(8, 8)
check(f"Scan patterns: {len(patterns)} patterns of length {len(patterns[0])}",
len(patterns) == 4 and len(patterns[0]) == 64)
# Verify scan ↔ unscan is identity
for i, (p, ip) in enumerate(zip(patterns, inv)):
dummy = torch.arange(64)
recovered = dummy[p][ip]
check(f"Scan pattern {i}: scan→unscan is identity", torch.equal(recovered, dummy))
# =========================================================
print("\n" + "=" * 60)
print("6. SAMPLING — Euler & Heun produce valid images")
print("=" * 60)
m3 = liquidflow_tiny(32)
m3.eval()
with torch.no_grad():
imgs_euler = euler_sample(m3, (4,3,32,32), num_steps=5)
check(f"Euler sample shape {imgs_euler.shape}, finite",
imgs_euler.shape == (4,3,32,32) and torch.isfinite(imgs_euler).all())
imgs_heun = heun_sample(m3, (4,3,32,32), num_steps=5)
check(f"Heun sample shape {imgs_heun.shape}, finite",
imgs_heun.shape == (4,3,32,32) and torch.isfinite(imgs_heun).all())
clamped = imgs_euler.clamp(-1,1)*0.5+0.5
grid = make_grid_image(clamped, nrow=2)
grid.save('/app/smoke_test_grid.png')
check(f"Grid image saved ({grid.size})", grid.size[0] > 0)
# =========================================================
print("\n" + "=" * 60)
print("7. EMA — shadow copy matches, save/load works")
print("=" * 60)
m4 = liquidflow_tiny(32)
ema = EMAModel(m4, decay=0.999)
ema.update(m4)
ema.update(m4)
ema.apply_shadow(m4)
# After apply, model params should be close to shadow
ema.restore(m4)
check("EMA apply/restore cycle completes", True)
sd = ema.state_dict()
check("EMA state_dict has shadow and step",
'shadow' in sd and 'step' in sd)
# =========================================================
print("\n" + "=" * 60)
print("8. CHECKPOINT — save & reload matches")
print("=" * 60)
m5 = liquidflow_tiny(32)
opt5 = torch.optim.AdamW(m5.parameters(), lr=1e-3)
ckpt = {
'model': m5.state_dict(),
'optimizer': opt5.state_dict(),
'epoch': 5,
'global_step': 100,
}
tmp = tempfile.mktemp(suffix='.pt')
torch.save(ckpt, tmp)
m6 = liquidflow_tiny(32)
loaded = torch.load(tmp, map_location='cpu', weights_only=False)
m6.load_state_dict(loaded['model'])
check("Checkpoint save/load cycle works", loaded['epoch'] == 5)
os.remove(tmp)
# =========================================================
print("\n" + "=" * 60)
print("9. PHYSICS LOSS COMPONENTS — each term finite & positive")
print("=" * 60)
x_fake = torch.randn(2, 3, 32, 32)
lf = PhysicsInformedFlowLoss(lambda_smooth=0.01, lambda_tv=0.001)
sm = lf.smoothness_loss(x_fake)
tv = lf.total_variation_loss(x_fake)
check(f"Smoothness loss: {sm.item():.4f} (finite, positive)",
torch.isfinite(sm) and sm.item() > 0)
check(f"TV loss: {tv.item():.4f} (finite, positive)",
torch.isfinite(tv) and tv.item() > 0)
# =========================================================
print("\n" + "=" * 60)
print("10. MEMORY FOOTPRINT SUMMARY")
print("=" * 60)
for tag, factory, img_sz in [("tiny-32",liquidflow_tiny,32),
("tiny-128",liquidflow_tiny,128),
("small-128",liquidflow_small,128),
("base-256",liquidflow_base,256),
("512",liquidflow_512,512)]:
m = factory(img_size=img_sz)
p = m.count_params()
# Model memory (fp16 training)
model_gb = p * 2 / 1e9 # fp16 params
opt_gb = p * 8 / 1e9 # optimizer states (fp32 momentum + variance)
tokens = (img_sz // m.patch_size) ** 2
print(f" {tag:12s}: {p/1e6:6.1f}M params | "
f"model={model_gb*1000:.0f}MB | opt={opt_gb*1000:.0f}MB | "
f"tokens={tokens:5d} | patch={m.patch_size}")
# =========================================================
print("\n" + "=" * 60)
print(f"RESULTS: {PASS} passed, {FAIL} failed")
print("=" * 60)
sys.exit(0 if FAIL == 0 else 1)