AlexWortega commited on
Commit
3cfa9a4
·
verified ·
1 Parent(s): 5f30413

Upload test_pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_pipeline.py +321 -0
test_pipeline.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ End-to-end test: data loading → model forward → backward.
4
+ Verifies that the full pipeline works before committing to long training.
5
+
6
+ Usage:
7
+ python test_pipeline.py
8
+ python test_pipeline.py --dataset active_matter --no-streaming --local_path /data/well
9
+ """
10
+ import argparse
11
+ import sys
12
+ import time
13
+ import traceback
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+
19
+ def fmt_mem():
20
+ if torch.cuda.is_available():
21
+ alloc = torch.cuda.memory_allocated() / 1e9
22
+ res = torch.cuda.memory_reserved() / 1e9
23
+ total = torch.cuda.get_device_properties(0).total_memory / 1e9
24
+ return f"alloc={alloc:.2f}GB, reserved={res:.2f}GB, total={total:.1f}GB"
25
+ return "CPU only"
26
+
27
+
28
+ def test_data_loading(args):
29
+ """Test 1: Load data and print shapes."""
30
+ print("\n" + "=" * 60)
31
+ print("TEST 1: Data Loading")
32
+ print("=" * 60)
33
+
34
+ from data_pipeline import create_dataloader, prepare_batch, get_channel_info, get_data_info
35
+
36
+ t0 = time.time()
37
+ loader, dataset = create_dataloader(
38
+ dataset_name=args.dataset,
39
+ split="train",
40
+ batch_size=args.batch_size,
41
+ streaming=args.streaming,
42
+ local_path=args.local_path,
43
+ )
44
+ print(f" Dataset created in {time.time() - t0:.1f}s")
45
+ print(f" Dataset length: {len(dataset)}")
46
+
47
+ # Probe shapes
48
+ info = get_data_info(dataset)
49
+ print(f" Sample fields:")
50
+ for k, v in info.items():
51
+ print(f" {k}: {v}")
52
+
53
+ ch = get_channel_info(dataset)
54
+ print(f" Channel info: {ch}")
55
+
56
+ # Load one batch
57
+ t0 = time.time()
58
+ batch = next(iter(loader))
59
+ print(f" First batch loaded in {time.time() - t0:.1f}s")
60
+ print(f" Batch keys: {list(batch.keys())}")
61
+ for k, v in batch.items():
62
+ if isinstance(v, torch.Tensor):
63
+ print(f" {k}: {v.shape} ({v.dtype})")
64
+
65
+ # Prepare for model
66
+ device = "cuda" if torch.cuda.is_available() else "cpu"
67
+ x_in, x_out = prepare_batch(batch, device)
68
+ print(f" Model input: {x_in.shape} ({x_in.dtype})")
69
+ print(f" Model target: {x_out.shape} ({x_out.dtype})")
70
+ print(f" GPU memory: {fmt_mem()}")
71
+
72
+ return ch, x_in, x_out
73
+
74
+
75
+ def test_diffusion(ch, x_in, x_out):
76
+ """Test 2: Diffusion model forward + backward."""
77
+ print("\n" + "=" * 60)
78
+ print("TEST 2: Diffusion Model")
79
+ print("=" * 60)
80
+
81
+ from unet import UNet
82
+ from diffusion import GaussianDiffusion
83
+
84
+ c_in = ch["input_channels"]
85
+ c_out = ch["output_channels"]
86
+
87
+ unet = UNet(
88
+ in_channels=c_out + c_in,
89
+ out_channels=c_out,
90
+ base_ch=64,
91
+ ch_mults=(1, 2, 4, 8),
92
+ n_res=2,
93
+ attn_levels=(3,),
94
+ )
95
+ model = GaussianDiffusion(unet, timesteps=1000)
96
+ device = x_in.device
97
+ model = model.to(device)
98
+
99
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
100
+ print(f" Parameters: {n_params:,}")
101
+ print(f" GPU memory after model: {fmt_mem()}")
102
+
103
+ # Forward
104
+ t0 = time.time()
105
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
106
+ loss = model.training_loss(x_out, x_in)
107
+ print(f" Forward pass: loss={loss.item():.4f} ({time.time() - t0:.3f}s)")
108
+ print(f" GPU memory after forward: {fmt_mem()}")
109
+
110
+ # Backward
111
+ t0 = time.time()
112
+ loss.backward()
113
+ print(f" Backward pass: ({time.time() - t0:.3f}s)")
114
+ print(f" GPU memory after backward: {fmt_mem()}")
115
+
116
+ # Quick sampling test (just 5 steps for speed)
117
+ model.eval()
118
+ model.T = 5 # temporarily reduce for testing
119
+ model.betas = model.betas[:5]
120
+ model.alphas = model.alphas[:5]
121
+ model.alpha_bar = model.alpha_bar[:5]
122
+ model.sqrt_alpha_bar = model.sqrt_alpha_bar[:5]
123
+ model.sqrt_one_minus_alpha_bar = model.sqrt_one_minus_alpha_bar[:5]
124
+ model.sqrt_recip_alpha = model.sqrt_recip_alpha[:5]
125
+ model.posterior_variance = model.posterior_variance[:5]
126
+
127
+ t0 = time.time()
128
+ with torch.no_grad():
129
+ sample = model.sample(x_in[:2], shape=(2, c_out, x_in.shape[2], x_in.shape[3]))
130
+ print(f" Sampling (5 steps, B=2): shape={sample.shape} ({time.time() - t0:.3f}s)")
131
+
132
+ del model
133
+ torch.cuda.empty_cache()
134
+ print(f" DIFFUSION OK")
135
+
136
+
137
+ def test_jepa(ch, x_in, x_out):
138
+ """Test 3: JEPA forward + backward."""
139
+ print("\n" + "=" * 60)
140
+ print("TEST 3: JEPA Model")
141
+ print("=" * 60)
142
+
143
+ from jepa import JEPA
144
+
145
+ c_in = ch["input_channels"]
146
+ device = x_in.device
147
+
148
+ model = JEPA(
149
+ in_channels=c_in,
150
+ latent_channels=128,
151
+ base_ch=32,
152
+ pred_hidden=256,
153
+ ).to(device)
154
+
155
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
156
+ total_params = sum(p.numel() for p in model.parameters())
157
+ print(f" Trainable parameters: {n_params:,}")
158
+ print(f" Total parameters (incl EMA target): {total_params:,}")
159
+ print(f" GPU memory after model: {fmt_mem()}")
160
+
161
+ # Forward
162
+ t0 = time.time()
163
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
164
+ loss, metrics = model.compute_loss(x_in, x_out)
165
+ print(f" Forward: loss={loss.item():.4f}, metrics={metrics} ({time.time() - t0:.3f}s)")
166
+ print(f" GPU memory after forward: {fmt_mem()}")
167
+
168
+ # Backward
169
+ t0 = time.time()
170
+ loss.backward()
171
+ print(f" Backward: ({time.time() - t0:.3f}s)")
172
+ print(f" GPU memory after backward: {fmt_mem()}")
173
+
174
+ # EMA update
175
+ model.update_target()
176
+ print(f" EMA update: OK")
177
+
178
+ # Check latent shapes
179
+ model.eval()
180
+ with torch.no_grad():
181
+ z_pred, z_target = model(x_in[:2], x_out[:2])
182
+ print(f" Latent shapes: pred={z_pred.shape}, target={z_target.shape}")
183
+
184
+ del model
185
+ torch.cuda.empty_cache()
186
+ print(f" JEPA OK")
187
+
188
+
189
+ def test_training_step(ch, loader):
190
+ """Test 4: Full training step with optimizer and grad scaling."""
191
+ print("\n" + "=" * 60)
192
+ print("TEST 4: Full Training Step")
193
+ print("=" * 60)
194
+
195
+ from data_pipeline import prepare_batch
196
+ from unet import UNet
197
+ from diffusion import GaussianDiffusion
198
+
199
+ c_in = ch["input_channels"]
200
+ c_out = ch["output_channels"]
201
+ device = "cuda" if torch.cuda.is_available() else "cpu"
202
+
203
+ unet = UNet(in_channels=c_out + c_in, out_channels=c_out, base_ch=64)
204
+ model = GaussianDiffusion(unet, timesteps=1000).to(device)
205
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
206
+ scaler = torch.amp.GradScaler("cuda")
207
+
208
+ model.train()
209
+ losses = []
210
+
211
+ for i, batch in enumerate(loader):
212
+ if i >= 3:
213
+ break
214
+
215
+ x_in, x_out = prepare_batch(batch, device)
216
+ optimizer.zero_grad(set_to_none=True)
217
+
218
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
219
+ loss = model.training_loss(x_out, x_in)
220
+
221
+ scaler.scale(loss).backward()
222
+ scaler.unscale_(optimizer)
223
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
224
+ scaler.step(optimizer)
225
+ scaler.update()
226
+
227
+ losses.append(loss.item())
228
+ print(f" Step {i}: loss={loss.item():.4f}, mem={fmt_mem()}")
229
+
230
+ print(f" 3 training steps completed. Losses: {[f'{l:.4f}' for l in losses]}")
231
+ del model, optimizer, scaler
232
+ torch.cuda.empty_cache()
233
+ print(f" TRAINING STEP OK")
234
+
235
+
236
+ def main():
237
+ parser = argparse.ArgumentParser()
238
+ parser.add_argument("--dataset", default="turbulent_radiative_layer_2D")
239
+ parser.add_argument("--streaming", action="store_true", default=True)
240
+ parser.add_argument("--no-streaming", dest="streaming", action="store_false")
241
+ parser.add_argument("--local_path", default=None)
242
+ parser.add_argument("--batch_size", type=int, default=4)
243
+ args = parser.parse_args()
244
+
245
+ print("=" * 60)
246
+ print("THE WELL - Pipeline End-to-End Test")
247
+ print("=" * 60)
248
+ print(f"Dataset: {args.dataset}")
249
+ print(f"Streaming: {args.streaming}")
250
+ print(f"Batch: {args.batch_size}")
251
+ print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
252
+ if torch.cuda.is_available():
253
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
254
+ print(f"Memory: {fmt_mem()}")
255
+
256
+ results = {}
257
+
258
+ # Test 1: Data
259
+ try:
260
+ ch, x_in, x_out = test_data_loading(args)
261
+ results["data"] = "PASS"
262
+ except Exception as e:
263
+ print(f" FAIL: {e}")
264
+ traceback.print_exc()
265
+ results["data"] = f"FAIL: {e}"
266
+ sys.exit(1)
267
+
268
+ # Test 2: Diffusion
269
+ try:
270
+ test_diffusion(ch, x_in, x_out)
271
+ results["diffusion"] = "PASS"
272
+ except Exception as e:
273
+ print(f" FAIL: {e}")
274
+ traceback.print_exc()
275
+ results["diffusion"] = f"FAIL: {e}"
276
+
277
+ # Test 3: JEPA
278
+ try:
279
+ test_jepa(ch, x_in, x_out)
280
+ results["jepa"] = "PASS"
281
+ except Exception as e:
282
+ print(f" FAIL: {e}")
283
+ traceback.print_exc()
284
+ results["jepa"] = f"FAIL: {e}"
285
+
286
+ # Test 4: Training step
287
+ try:
288
+ loader, _ = __import__("data_pipeline").create_dataloader(
289
+ dataset_name=args.dataset,
290
+ split="train",
291
+ batch_size=args.batch_size,
292
+ streaming=args.streaming,
293
+ local_path=args.local_path,
294
+ )
295
+ test_training_step(ch, loader)
296
+ results["training_step"] = "PASS"
297
+ except Exception as e:
298
+ print(f" FAIL: {e}")
299
+ traceback.print_exc()
300
+ results["training_step"] = f"FAIL: {e}"
301
+
302
+ # Summary
303
+ print("\n" + "=" * 60)
304
+ print("SUMMARY")
305
+ print("=" * 60)
306
+ all_pass = True
307
+ for name, status in results.items():
308
+ icon = "PASS" if status == "PASS" else "FAIL"
309
+ print(f" [{icon}] {name}")
310
+ if status != "PASS":
311
+ all_pass = False
312
+
313
+ if all_pass:
314
+ print("\nAll tests passed! Pipeline is ready for training.")
315
+ else:
316
+ print("\nSome tests failed. Check output above.")
317
+ sys.exit(1)
318
+
319
+
320
+ if __name__ == "__main__":
321
+ main()