liquid-diffusion / test_model.py
krystv's picture
Upload test_model.py
421b295 verified
"""
LiquidDiffusion — Complete Test Suite
Tests model construction, forward/backward, training stability, and sampling.
Run: python test_model.py
"""
import sys
import math
import torch
import torch.nn.functional as F
# Add parent directory to path
sys.path.insert(0, '.')
from liquid_diffusion.model import (
LiquidDiffusionUNet, liquid_diffusion_tiny,
liquid_diffusion_small, liquid_diffusion_base
)
print("=" * 70)
print("LiquidDiffusion: Novel Attention-Free Image Generation")
print("Based on Liquid Neural Networks (CfC) + Rectified Flow")
print("=" * 70)
all_passed = True
# Test 1: Model construction
print("\n--- Test 1: Model Construction & Parameter Count ---")
for name, factory in [("tiny", liquid_diffusion_tiny), ("small", liquid_diffusion_small), ("base", liquid_diffusion_base)]:
m = factory()
total, trainable = m.count_params()
print(f" {name:8s}: {total:>12,} params ({total/1e6:.1f}M)")
del m
# Test 2: Forward pass
print("\n--- Test 2: Forward Pass (multiple resolutions) ---")
model = liquid_diffusion_tiny()
for res in [32, 64, 128]:
x = torch.randn(2, 3, res, res)
t = torch.rand(2)
out = model(x, t)
ok = out.shape == x.shape
print(f" {res}x{res}: {'OK' if ok else 'FAIL'} shape={out.shape}")
if not ok: all_passed = False
# Test 3: Backward pass
print("\n--- Test 3: Backward Pass (gradient flow) ---")
model = liquid_diffusion_tiny()
x = torch.randn(2, 3, 64, 64)
t = torch.rand(2)
out = model(x, t)
loss = out.mean()
loss.backward()
num_params_with_grad = sum(1 for p in model.parameters() if p.grad is not None)
nan_grads = sum(1 for p in model.parameters() if p.grad is not None and torch.isnan(p.grad).any())
print(f" Params with gradients: {num_params_with_grad}")
print(f" NaN gradients: {nan_grads}")
if nan_grads > 0: all_passed = False
# Test 4: Training stability (20 steps)
print("\n--- Test 4: Training Stability (20 steps, random data) ---")
model = liquid_diffusion_tiny()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
losses = []
for step in range(20):
model.train()
x0 = torch.randn(4, 3, 64, 64)
x1 = torch.randn_like(x0)
t_val = torch.rand(4)
x_t = (1 - t_val[:, None, None, None]) * x0 + t_val[:, None, None, None] * x1
v_target = x1 - x0
v_pred = model(x_t, t_val)
loss = F.mse_loss(v_pred, v_target)
optimizer.zero_grad()
loss.backward()
gn = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
losses.append(loss.item())
if step % 5 == 0:
print(f" Step {step:3d}: loss={loss.item():.4f}, grad_norm={gn.item():.4f}")
stable = all(not math.isnan(l) and not math.isinf(l) for l in losses)
not_exploding = max(losses) < 100
print(f" Stable (no NaN/Inf): {'OK' if stable else 'FAIL'}")
print(f" Not exploding: {'OK' if not_exploding else 'FAIL'} (max={max(losses):.4f})")
if not stable or not not_exploding: all_passed = False
# Test 5: Sampling
print("\n--- Test 5: Sampling (10 Euler steps) ---")
model.eval()
with torch.no_grad():
z = torch.randn(2, 3, 64, 64)
for i in range(10, 0, -1):
t_s = torch.full((2,), i / 10.0)
v = model(z, t_s)
z = z - v * 0.1
z = z.clamp(-1, 1)
print(f" Shape: {z.shape}, range: [{z.min():.3f}, {z.max():.3f}]")
# Test 6: Timestep sensitivity
print("\n--- Test 6: Timestep Sensitivity ---")
model.eval()
x = torch.randn(1, 3, 64, 64)
for t_val in [0.01, 0.25, 0.5, 0.75, 0.99]:
with torch.no_grad():
out = model(x, torch.tensor([t_val]))
print(f" t={t_val:.2f}: mean={out.mean():.6f}, std={out.std():.6f}")
# Test 7: Architecture properties
print("\n--- Test 7: Architecture Properties ---")
m = liquid_diffusion_tiny()
total_blocks = (sum(len(s) for s in m.encoder_blocks) + len(m.bottleneck) + sum(len(s) for s in m.decoder_blocks))
print(f" Attention layers: 0")
print(f" Sequential loops: 0")
print(f" CfC blocks: {total_blocks}")
print(f" Training objective: Rectified Flow (MSE velocity)")
# Test 8: VRAM estimates
print("\n--- Test 8: VRAM Estimates (fp16 training) ---")
for name, factory, res, bs in [
("tiny 256px bs4", liquid_diffusion_tiny, 256, 4),
("small 256px bs4", liquid_diffusion_small, 256, 4),
("base 256px bs2", liquid_diffusion_base, 256, 2),
("tiny 512px bs2", liquid_diffusion_tiny, 512, 2),
]:
m = factory()
tp = sum(p.numel() for p in m.parameters())
est = (tp * 2 + tp * 4 + tp * 8) / 1e9 + bs * 3 * res * res * 4 * len(m.channels) * max(m.channels) / 1e9 * 0.3
print(f" {name:20s}: {tp/1e6:.1f}M params, ~{est:.1f}GB VRAM")
del m
print("\n" + "=" * 70)
print(f"ALL TESTS {'PASSED' if all_passed else 'SOME FAILURES'}")
print("=" * 70)