File size: 4,724 Bytes
421b295 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """
LiquidDiffusion — Complete Test Suite
Tests model construction, forward/backward, training stability, and sampling.
Run: python test_model.py
"""
import sys
import math
import torch
import torch.nn.functional as F
# Add parent directory to path
sys.path.insert(0, '.')
from liquid_diffusion.model import (
LiquidDiffusionUNet, liquid_diffusion_tiny,
liquid_diffusion_small, liquid_diffusion_base
)
print("=" * 70)
print("LiquidDiffusion: Novel Attention-Free Image Generation")
print("Based on Liquid Neural Networks (CfC) + Rectified Flow")
print("=" * 70)
all_passed = True
# Test 1: Model construction
print("\n--- Test 1: Model Construction & Parameter Count ---")
for name, factory in [("tiny", liquid_diffusion_tiny), ("small", liquid_diffusion_small), ("base", liquid_diffusion_base)]:
m = factory()
total, trainable = m.count_params()
print(f" {name:8s}: {total:>12,} params ({total/1e6:.1f}M)")
del m
# Test 2: Forward pass
print("\n--- Test 2: Forward Pass (multiple resolutions) ---")
model = liquid_diffusion_tiny()
for res in [32, 64, 128]:
x = torch.randn(2, 3, res, res)
t = torch.rand(2)
out = model(x, t)
ok = out.shape == x.shape
print(f" {res}x{res}: {'OK' if ok else 'FAIL'} shape={out.shape}")
if not ok: all_passed = False
# Test 3: Backward pass
print("\n--- Test 3: Backward Pass (gradient flow) ---")
model = liquid_diffusion_tiny()
x = torch.randn(2, 3, 64, 64)
t = torch.rand(2)
out = model(x, t)
loss = out.mean()
loss.backward()
num_params_with_grad = sum(1 for p in model.parameters() if p.grad is not None)
nan_grads = sum(1 for p in model.parameters() if p.grad is not None and torch.isnan(p.grad).any())
print(f" Params with gradients: {num_params_with_grad}")
print(f" NaN gradients: {nan_grads}")
if nan_grads > 0: all_passed = False
# Test 4: Training stability (20 steps)
print("\n--- Test 4: Training Stability (20 steps, random data) ---")
model = liquid_diffusion_tiny()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
losses = []
for step in range(20):
model.train()
x0 = torch.randn(4, 3, 64, 64)
x1 = torch.randn_like(x0)
t_val = torch.rand(4)
x_t = (1 - t_val[:, None, None, None]) * x0 + t_val[:, None, None, None] * x1
v_target = x1 - x0
v_pred = model(x_t, t_val)
loss = F.mse_loss(v_pred, v_target)
optimizer.zero_grad()
loss.backward()
gn = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
losses.append(loss.item())
if step % 5 == 0:
print(f" Step {step:3d}: loss={loss.item():.4f}, grad_norm={gn.item():.4f}")
stable = all(not math.isnan(l) and not math.isinf(l) for l in losses)
not_exploding = max(losses) < 100
print(f" Stable (no NaN/Inf): {'OK' if stable else 'FAIL'}")
print(f" Not exploding: {'OK' if not_exploding else 'FAIL'} (max={max(losses):.4f})")
if not stable or not not_exploding: all_passed = False
# Test 5: Sampling
print("\n--- Test 5: Sampling (10 Euler steps) ---")
model.eval()
with torch.no_grad():
z = torch.randn(2, 3, 64, 64)
for i in range(10, 0, -1):
t_s = torch.full((2,), i / 10.0)
v = model(z, t_s)
z = z - v * 0.1
z = z.clamp(-1, 1)
print(f" Shape: {z.shape}, range: [{z.min():.3f}, {z.max():.3f}]")
# Test 6: Timestep sensitivity
print("\n--- Test 6: Timestep Sensitivity ---")
model.eval()
x = torch.randn(1, 3, 64, 64)
for t_val in [0.01, 0.25, 0.5, 0.75, 0.99]:
with torch.no_grad():
out = model(x, torch.tensor([t_val]))
print(f" t={t_val:.2f}: mean={out.mean():.6f}, std={out.std():.6f}")
# Test 7: Architecture properties
print("\n--- Test 7: Architecture Properties ---")
m = liquid_diffusion_tiny()
total_blocks = (sum(len(s) for s in m.encoder_blocks) + len(m.bottleneck) + sum(len(s) for s in m.decoder_blocks))
print(f" Attention layers: 0")
print(f" Sequential loops: 0")
print(f" CfC blocks: {total_blocks}")
print(f" Training objective: Rectified Flow (MSE velocity)")
# Test 8: VRAM estimates
print("\n--- Test 8: VRAM Estimates (fp16 training) ---")
for name, factory, res, bs in [
("tiny 256px bs4", liquid_diffusion_tiny, 256, 4),
("small 256px bs4", liquid_diffusion_small, 256, 4),
("base 256px bs2", liquid_diffusion_base, 256, 2),
("tiny 512px bs2", liquid_diffusion_tiny, 512, 2),
]:
m = factory()
tp = sum(p.numel() for p in m.parameters())
est = (tp * 2 + tp * 4 + tp * 8) / 1e9 + bs * 3 * res * res * 4 * len(m.channels) * max(m.channels) / 1e9 * 0.3
print(f" {name:20s}: {tp/1e6:.1f}M params, ~{est:.1f}GB VRAM")
del m
print("\n" + "=" * 70)
print(f"ALL TESTS {'PASSED' if all_passed else 'SOME FAILURES'}")
print("=" * 70)
|