File size: 4,724 Bytes
421b295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
LiquidDiffusion — Complete Test Suite
Tests model construction, forward/backward, training stability, and sampling.
Run: python test_model.py
"""
import sys
import math
import torch
import torch.nn.functional as F

# Add parent directory to path
sys.path.insert(0, '.')

from liquid_diffusion.model import (
    LiquidDiffusionUNet, liquid_diffusion_tiny,
    liquid_diffusion_small, liquid_diffusion_base
)

print("=" * 70)
print("LiquidDiffusion: Novel Attention-Free Image Generation")
print("Based on Liquid Neural Networks (CfC) + Rectified Flow")
print("=" * 70)

all_passed = True

# Test 1: Model construction
print("\n--- Test 1: Model Construction & Parameter Count ---")
for name, factory in [("tiny", liquid_diffusion_tiny), ("small", liquid_diffusion_small), ("base", liquid_diffusion_base)]:
    m = factory()
    total, trainable = m.count_params()
    print(f"  {name:8s}: {total:>12,} params ({total/1e6:.1f}M)")
    del m

# Test 2: Forward pass
print("\n--- Test 2: Forward Pass (multiple resolutions) ---")
model = liquid_diffusion_tiny()
for res in [32, 64, 128]:
    x = torch.randn(2, 3, res, res)
    t = torch.rand(2)
    out = model(x, t)
    ok = out.shape == x.shape
    print(f"  {res}x{res}: {'OK' if ok else 'FAIL'} shape={out.shape}")
    if not ok: all_passed = False

# Test 3: Backward pass
print("\n--- Test 3: Backward Pass (gradient flow) ---")
model = liquid_diffusion_tiny()
x = torch.randn(2, 3, 64, 64)
t = torch.rand(2)
out = model(x, t)
loss = out.mean()
loss.backward()
num_params_with_grad = sum(1 for p in model.parameters() if p.grad is not None)
nan_grads = sum(1 for p in model.parameters() if p.grad is not None and torch.isnan(p.grad).any())
print(f"  Params with gradients: {num_params_with_grad}")
print(f"  NaN gradients: {nan_grads}")
if nan_grads > 0: all_passed = False

# Test 4: Training stability (20 steps)
print("\n--- Test 4: Training Stability (20 steps, random data) ---")
model = liquid_diffusion_tiny()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
losses = []
for step in range(20):
    model.train()
    x0 = torch.randn(4, 3, 64, 64)
    x1 = torch.randn_like(x0)
    t_val = torch.rand(4)
    x_t = (1 - t_val[:, None, None, None]) * x0 + t_val[:, None, None, None] * x1
    v_target = x1 - x0
    v_pred = model(x_t, t_val)
    loss = F.mse_loss(v_pred, v_target)
    optimizer.zero_grad()
    loss.backward()
    gn = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    losses.append(loss.item())
    if step % 5 == 0:
        print(f"  Step {step:3d}: loss={loss.item():.4f}, grad_norm={gn.item():.4f}")

stable = all(not math.isnan(l) and not math.isinf(l) for l in losses)
not_exploding = max(losses) < 100
print(f"  Stable (no NaN/Inf): {'OK' if stable else 'FAIL'}")
print(f"  Not exploding: {'OK' if not_exploding else 'FAIL'} (max={max(losses):.4f})")
if not stable or not not_exploding: all_passed = False

# Test 5: Sampling
print("\n--- Test 5: Sampling (10 Euler steps) ---")
model.eval()
with torch.no_grad():
    z = torch.randn(2, 3, 64, 64)
    for i in range(10, 0, -1):
        t_s = torch.full((2,), i / 10.0)
        v = model(z, t_s)
        z = z - v * 0.1
    z = z.clamp(-1, 1)
print(f"  Shape: {z.shape}, range: [{z.min():.3f}, {z.max():.3f}]")

# Test 6: Timestep sensitivity
print("\n--- Test 6: Timestep Sensitivity ---")
model.eval()
x = torch.randn(1, 3, 64, 64)
for t_val in [0.01, 0.25, 0.5, 0.75, 0.99]:
    with torch.no_grad():
        out = model(x, torch.tensor([t_val]))
    print(f"  t={t_val:.2f}: mean={out.mean():.6f}, std={out.std():.6f}")

# Test 7: Architecture properties
print("\n--- Test 7: Architecture Properties ---")
m = liquid_diffusion_tiny()
total_blocks = (sum(len(s) for s in m.encoder_blocks) + len(m.bottleneck) + sum(len(s) for s in m.decoder_blocks))
print(f"  Attention layers: 0")
print(f"  Sequential loops: 0")
print(f"  CfC blocks: {total_blocks}")
print(f"  Training objective: Rectified Flow (MSE velocity)")

# Test 8: VRAM estimates
print("\n--- Test 8: VRAM Estimates (fp16 training) ---")
for name, factory, res, bs in [
    ("tiny 256px bs4", liquid_diffusion_tiny, 256, 4),
    ("small 256px bs4", liquid_diffusion_small, 256, 4),
    ("base 256px bs2", liquid_diffusion_base, 256, 2),
    ("tiny 512px bs2", liquid_diffusion_tiny, 512, 2),
]:
    m = factory()
    tp = sum(p.numel() for p in m.parameters())
    est = (tp * 2 + tp * 4 + tp * 8) / 1e9 + bs * 3 * res * res * 4 * len(m.channels) * max(m.channels) / 1e9 * 0.3
    print(f"  {name:20s}: {tp/1e6:.1f}M params, ~{est:.1f}GB VRAM")
    del m

print("\n" + "=" * 70)
print(f"ALL TESTS {'PASSED' if all_passed else 'SOME FAILURES'}")
print("=" * 70)