krystv commited on
Commit
f143208
·
verified ·
1 Parent(s): cbeb545

Upload test_verify.py

Browse files
Files changed (1) hide show
  1. test_verify.py +357 -0
test_verify.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Comprehensive verification test for LiquidFlow.
3
+ Tests: syntax, imports, forward pass, backward pass, dimension correctness,
4
+ gradient flow, training step, and performance.
5
+
6
+ Run: python test_verify.py
7
+ """
8
+ import sys
9
+ import os
10
+ import time
11
+ import traceback
12
+
13
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+ print(f"Device: {DEVICE}")
21
+ print(f"PyTorch: {torch.__version__}")
22
+ if DEVICE == 'cuda':
23
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
24
+ print("=" * 70)
25
+
26
+ errors = []
27
+ passed = 0
28
+
29
+ def test(name, fn):
30
+ global passed, errors
31
+ try:
32
+ fn()
33
+ print(f" ✓ {name}")
34
+ passed += 1
35
+ except Exception as e:
36
+ msg = f" ✗ {name}: {e}"
37
+ print(msg)
38
+ traceback.print_exc()
39
+ errors.append(msg)
40
+
41
+ # ============================================================
42
+ # TEST 1: CfC Cell
43
+ # ============================================================
44
+ print("\n=== 1. CfC Cell ===")
45
+
46
+ def test_cfc_forward():
47
+ from liquid_flow.cfc_cell import CfCCell
48
+ cell = CfCCell(dim=64).to(DEVICE)
49
+ x = torch.randn(2, 256, 64, device=DEVICE)
50
+ out = cell(x)
51
+ assert out.shape == (2, 256, 64), f"Expected (2,256,64), got {out.shape}"
52
+
53
+ def test_cfc_backward():
54
+ from liquid_flow.cfc_cell import CfCCell
55
+ cell = CfCCell(dim=64).to(DEVICE)
56
+ x = torch.randn(2, 256, 64, device=DEVICE, requires_grad=True)
57
+ out = cell(x)
58
+ loss = out.sum()
59
+ loss.backward()
60
+ assert x.grad is not None, "No gradient on input"
61
+ assert not torch.isnan(x.grad).any(), "NaN in gradients"
62
+
63
+ def test_cfc_block_2d():
64
+ from liquid_flow.cfc_cell import CfCBlock
65
+ block = CfCBlock(dim=64).to(DEVICE)
66
+ x = torch.randn(2, 64, 16, 16, device=DEVICE)
67
+ out = block(x)
68
+ assert out.shape == (2, 64, 16, 16), f"Expected (2,64,16,16), got {out.shape}"
69
+
70
+ def test_cfc_block_backward():
71
+ from liquid_flow.cfc_cell import CfCBlock
72
+ block = CfCBlock(dim=64).to(DEVICE)
73
+ x = torch.randn(2, 64, 16, 16, device=DEVICE, requires_grad=True)
74
+ out = block(x)
75
+ loss = out.sum()
76
+ loss.backward()
77
+ assert x.grad is not None
78
+
79
+ test("CfC forward [B,L,D]", test_cfc_forward)
80
+ test("CfC backward (grad flow)", test_cfc_backward)
81
+ test("CfC Block 2D [B,C,H,W]", test_cfc_block_2d)
82
+ test("CfC Block backward", test_cfc_block_backward)
83
+
84
+ # ============================================================
85
+ # TEST 2: Mamba-2 SSD
86
+ # ============================================================
87
+ print("\n=== 2. Mamba-2 SSD ===")
88
+
89
+ def test_mamba2_forward():
90
+ from liquid_flow.mamba2_ssd import Mamba2SSD
91
+ ssd = Mamba2SSD(dim=64, d_state=8, expand=2).to(DEVICE)
92
+ x = torch.randn(2, 256, 64, device=DEVICE)
93
+ out = ssd(x)
94
+ assert out.shape == (2, 256, 64), f"Expected (2,256,64), got {out.shape}"
95
+
96
+ def test_mamba2_backward():
97
+ from liquid_flow.mamba2_ssd import Mamba2SSD
98
+ ssd = Mamba2SSD(dim=64, d_state=8, expand=2).to(DEVICE)
99
+ x = torch.randn(2, 256, 64, device=DEVICE, requires_grad=True)
100
+ out = ssd(x)
101
+ loss = out.sum()
102
+ loss.backward()
103
+ assert x.grad is not None, "No gradient on input"
104
+ assert not torch.isnan(x.grad).any(), "NaN in gradients"
105
+
106
+ def test_mamba2_block_2d():
107
+ from liquid_flow.mamba2_ssd import Mamba2Block
108
+ block = Mamba2Block(dim=64, d_state=8, expand=2).to(DEVICE)
109
+ x = torch.randn(2, 64, 16, 16, device=DEVICE)
110
+ out = block(x)
111
+ assert out.shape == (2, 64, 16, 16), f"Expected (2,64,16,16), got {out.shape}"
112
+
113
+ def test_mamba2_block_backward():
114
+ from liquid_flow.mamba2_ssd import Mamba2Block
115
+ block = Mamba2Block(dim=64, d_state=8, expand=2).to(DEVICE)
116
+ x = torch.randn(2, 64, 16, 16, device=DEVICE, requires_grad=True)
117
+ out = block(x)
118
+ loss = out.sum()
119
+ loss.backward()
120
+ assert x.grad is not None
121
+
122
+ def test_mamba2_odd_length():
123
+ """Test with non-power-of-2 sequence length."""
124
+ from liquid_flow.mamba2_ssd import Mamba2SSD
125
+ ssd = Mamba2SSD(dim=64, d_state=8, expand=2, chunk_size=16).to(DEVICE)
126
+ x = torch.randn(2, 253, 64, device=DEVICE) # Odd length
127
+ out = ssd(x)
128
+ assert out.shape == (2, 253, 64), f"Expected (2,253,64), got {out.shape}"
129
+
130
+ test("Mamba2 SSD forward", test_mamba2_forward)
131
+ test("Mamba2 SSD backward (no in-place crash)", test_mamba2_backward)
132
+ test("Mamba2 Block 2D", test_mamba2_block_2d)
133
+ test("Mamba2 Block backward", test_mamba2_block_backward)
134
+ test("Mamba2 odd sequence length", test_mamba2_odd_length)
135
+
136
+ # ============================================================
137
+ # TEST 3: LiquidMamba Block
138
+ # ============================================================
139
+ print("\n=== 3. LiquidMamba Block ===")
140
+
141
+ def test_liquid_mamba_forward():
142
+ from liquid_flow.liquid_flow_block import LiquidMambaBlock
143
+ block = LiquidMambaBlock(dim=64, d_state=8, expand=2).to(DEVICE)
144
+ x = torch.randn(2, 64, 16, 16, device=DEVICE)
145
+ out = block(x)
146
+ assert out.shape == (2, 64, 16, 16), f"Expected (2,64,16,16), got {out.shape}"
147
+
148
+ def test_liquid_mamba_backward():
149
+ from liquid_flow.liquid_flow_block import LiquidMambaBlock
150
+ block = LiquidMambaBlock(dim=64, d_state=8, expand=2).to(DEVICE)
151
+ x = torch.randn(2, 64, 16, 16, device=DEVICE, requires_grad=True)
152
+ out = block(x)
153
+ loss = out.mean()
154
+ loss.backward()
155
+ assert x.grad is not None
156
+ assert not torch.isnan(x.grad).any()
157
+
158
+ test("LiquidMamba forward", test_liquid_mamba_forward)
159
+ test("LiquidMamba backward", test_liquid_mamba_backward)
160
+
161
+ # ============================================================
162
+ # TEST 4: Full Backbone
163
+ # ============================================================
164
+ print("\n=== 4. LiquidFlow Backbone ===")
165
+
166
+ def test_backbone_forward():
167
+ from liquid_flow.liquid_flow_block import LiquidFlowBackbone
168
+ model = LiquidFlowBackbone(
169
+ in_channels=4, hidden_dim=64, num_stages=2, blocks_per_stage=2, d_state=8
170
+ ).to(DEVICE)
171
+ x = torch.randn(2, 4, 16, 16, device=DEVICE) # 128px latent
172
+ t = torch.tensor([100, 500], device=DEVICE)
173
+ out = model(x, t)
174
+ assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}"
175
+
176
+ def test_backbone_backward():
177
+ from liquid_flow.liquid_flow_block import LiquidFlowBackbone
178
+ model = LiquidFlowBackbone(
179
+ in_channels=4, hidden_dim=64, num_stages=2, blocks_per_stage=2, d_state=8
180
+ ).to(DEVICE)
181
+ x = torch.randn(2, 4, 16, 16, device=DEVICE, requires_grad=True)
182
+ t = torch.tensor([100, 500], device=DEVICE)
183
+ out = model(x, t)
184
+ loss = F.mse_loss(out, torch.randn_like(out))
185
+ loss.backward()
186
+ assert x.grad is not None
187
+ # Check model params have gradients
188
+ grads_ok = sum(1 for p in model.parameters() if p.grad is not None and not torch.isnan(p.grad).any())
189
+ total_params = sum(1 for p in model.parameters() if p.requires_grad)
190
+ assert grads_ok == total_params, f"Only {grads_ok}/{total_params} params have valid gradients"
191
+
192
+ def test_backbone_512():
193
+ """Test with 512px image (latent = 64×64)."""
194
+ from liquid_flow.liquid_flow_block import LiquidFlowBackbone
195
+ model = LiquidFlowBackbone(
196
+ in_channels=4, hidden_dim=64, num_stages=2, blocks_per_stage=1, d_state=8
197
+ ).to(DEVICE)
198
+ x = torch.randn(1, 4, 64, 64, device=DEVICE) # 512px latent
199
+ t = torch.tensor([500], device=DEVICE)
200
+ out = model(x, t)
201
+ assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}"
202
+
203
+ test("Backbone forward (128px)", test_backbone_forward)
204
+ test("Backbone backward (all grads valid)", test_backbone_backward)
205
+ test("Backbone 512px (64×64 latent)", test_backbone_512)
206
+
207
+ # ============================================================
208
+ # TEST 5: Full Generator + Training Step
209
+ # ============================================================
210
+ print("\n=== 5. Generator + Training ===")
211
+
212
+ def test_generator_forward():
213
+ from liquid_flow.generator import create_liquidflow
214
+ model = create_liquidflow(variant='tiny', image_size=128).to(DEVICE)
215
+ x = torch.randn(2, 4, 16, 16, device=DEVICE)
216
+ t = torch.tensor([100, 500], device=DEVICE)
217
+ out = model(x, t)
218
+ assert out.shape == x.shape
219
+
220
+ def test_training_step():
221
+ """Full training step: forward + loss + backward + optimizer step."""
222
+ from liquid_flow.generator import create_liquidflow
223
+ model = create_liquidflow(variant='tiny', image_size=128).to(DEVICE)
224
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
225
+
226
+ x0 = torch.randn(4, 4, 16, 16, device=DEVICE)
227
+ loss_dict = model.training_step(x0, optimizer, scaler=None, use_amp=False)
228
+
229
+ assert 'total' in loss_dict
230
+ assert 'diffusion' in loss_dict
231
+ assert 'physics' in loss_dict
232
+ assert loss_dict['total'] > 0
233
+ assert not any(v != v for v in loss_dict.values()), "NaN in losses" # NaN check
234
+
235
+ def test_training_step_multiple():
236
+ """Multiple training steps to verify no accumulation/state bugs."""
237
+ from liquid_flow.generator import create_liquidflow
238
+ model = create_liquidflow(variant='tiny', image_size=128).to(DEVICE)
239
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
240
+
241
+ losses = []
242
+ for _ in range(5):
243
+ x0 = torch.randn(4, 4, 16, 16, device=DEVICE)
244
+ loss_dict = model.training_step(x0, optimizer, scaler=None, use_amp=False)
245
+ losses.append(loss_dict['total'])
246
+ assert not (loss_dict['total'] != loss_dict['total']), "NaN loss"
247
+
248
+ # Losses should not explode
249
+ assert all(l < 100 for l in losses), f"Loss explosion: {losses}"
250
+
251
+ def test_sampling():
252
+ """Test DDIM sampling produces correct output."""
253
+ from liquid_flow.generator import create_liquidflow
254
+ model = create_liquidflow(variant='tiny', image_size=128).to(DEVICE)
255
+ model.eval()
256
+
257
+ with torch.no_grad():
258
+ samples = model.sample(batch_size=2, steps=5, ddim=True, progress=False)
259
+
260
+ assert samples.shape == (2, 4, 16, 16), f"Expected (2,4,16,16), got {samples.shape}"
261
+ assert not torch.isnan(samples).any(), "NaN in samples"
262
+
263
+ test("Generator forward", test_generator_forward)
264
+ test("Full training step (fwd+bwd+optim)", test_training_step)
265
+ test("5 training steps (no explosion)", test_training_step_multiple)
266
+ test("DDIM sampling", test_sampling)
267
+
268
+ # ============================================================
269
+ # TEST 6: Physics Loss
270
+ # ============================================================
271
+ print("\n=== 6. Physics Loss ===")
272
+
273
+ def test_physics_loss():
274
+ from liquid_flow.physics_loss import PhysicsRegularizer
275
+ phys = PhysicsRegularizer().to(DEVICE)
276
+ phys.train()
277
+ x = torch.randn(2, 4, 16, 16, device=DEVICE, requires_grad=True)
278
+ total, losses = phys(x)
279
+ assert total.requires_grad, "Physics loss not differentiable"
280
+ total.backward()
281
+ assert x.grad is not None
282
+
283
+ def test_ddim_estimator():
284
+ from liquid_flow.physics_loss import DDIMEstimator
285
+ x_t = torch.randn(2, 4, 16, 16, device=DEVICE)
286
+ eps = torch.randn(2, 4, 16, 16, device=DEVICE)
287
+ alpha_bar = torch.tensor([0.9, 0.5], device=DEVICE)
288
+ x0 = DDIMEstimator.estimate_x0(x_t, eps, alpha_bar)
289
+ assert x0.shape == x_t.shape
290
+ assert not torch.isnan(x0).any()
291
+
292
+ test("Physics loss (differentiable)", test_physics_loss)
293
+ test("DDIM estimator", test_ddim_estimator)
294
+
295
+ # ============================================================
296
+ # TEST 7: Performance / Speed
297
+ # ============================================================
298
+ print("\n=== 7. Performance ===")
299
+
300
+ def test_speed():
301
+ """Measure forward+backward time for one batch."""
302
+ from liquid_flow.generator import create_liquidflow
303
+ model = create_liquidflow(variant='tiny', image_size=128).to(DEVICE)
304
+ model.train()
305
+
306
+ x = torch.randn(4, 4, 16, 16, device=DEVICE, requires_grad=True)
307
+ t = torch.randint(0, 1000, (4,), device=DEVICE)
308
+
309
+ # Warmup
310
+ out = model(x, t)
311
+ loss = out.sum()
312
+ loss.backward()
313
+
314
+ if DEVICE == 'cuda':
315
+ torch.cuda.synchronize()
316
+
317
+ # Timed run
318
+ start = time.time()
319
+ for _ in range(5):
320
+ out = model(x, t)
321
+ loss = out.sum()
322
+ loss.backward()
323
+ if DEVICE == 'cuda':
324
+ torch.cuda.synchronize()
325
+ elapsed = (time.time() - start) / 5
326
+
327
+ print(f" → Forward+backward: {elapsed*1000:.1f} ms/batch (tiny, bs=4, 16×16)")
328
+ assert elapsed < 60, f"Too slow: {elapsed:.1f}s per step"
329
+
330
+ def test_param_count():
331
+ from liquid_flow.generator import create_liquidflow
332
+ for variant in ['tiny', 'small', 'base']:
333
+ model = create_liquidflow(variant=variant, image_size=128)
334
+ n = sum(p.numel() for p in model.parameters())
335
+ print(f" → {variant}: {n:,} params ({n/1e6:.1f}M)")
336
+
337
+ test("Speed (< 60s per step)", test_speed)
338
+ test("Param counts", test_param_count)
339
+
340
+ # ============================================================
341
+ # SUMMARY
342
+ # ============================================================
343
+ print("\n" + "=" * 70)
344
+ total = passed + len(errors)
345
+ print(f"Results: {passed}/{total} tests passed")
346
+
347
+ if errors:
348
+ print(f"\n{'='*70}")
349
+ print("FAILURES:")
350
+ for e in errors:
351
+ print(f" {e}")
352
+ print(f"{'='*70}")
353
+ sys.exit(1)
354
+ else:
355
+ print("ALL TESTS PASSED ✓")
356
+ print("Model is GPU-trainable, no sequential bottlenecks, gradients flow correctly.")
357
+ print("=" * 70)