#!/usr/bin/env python3 """ End-to-end test: data loading → model forward → backward. Verifies that the full pipeline works before committing to long training. Usage: python test_pipeline.py python test_pipeline.py --dataset active_matter --no-streaming --local_path /data/well """ import argparse import sys import time import traceback import torch import torch.nn as nn def fmt_mem(): if torch.cuda.is_available(): alloc = torch.cuda.memory_allocated() / 1e9 res = torch.cuda.memory_reserved() / 1e9 total = torch.cuda.get_device_properties(0).total_memory / 1e9 return f"alloc={alloc:.2f}GB, reserved={res:.2f}GB, total={total:.1f}GB" return "CPU only" def test_data_loading(args): """Test 1: Load data and print shapes.""" print("\n" + "=" * 60) print("TEST 1: Data Loading") print("=" * 60) from data_pipeline import create_dataloader, prepare_batch, get_channel_info, get_data_info t0 = time.time() loader, dataset = create_dataloader( dataset_name=args.dataset, split="train", batch_size=args.batch_size, streaming=args.streaming, local_path=args.local_path, ) print(f" Dataset created in {time.time() - t0:.1f}s") print(f" Dataset length: {len(dataset)}") # Probe shapes info = get_data_info(dataset) print(f" Sample fields:") for k, v in info.items(): print(f" {k}: {v}") ch = get_channel_info(dataset) print(f" Channel info: {ch}") # Load one batch t0 = time.time() batch = next(iter(loader)) print(f" First batch loaded in {time.time() - t0:.1f}s") print(f" Batch keys: {list(batch.keys())}") for k, v in batch.items(): if isinstance(v, torch.Tensor): print(f" {k}: {v.shape} ({v.dtype})") # Prepare for model device = "cuda" if torch.cuda.is_available() else "cpu" x_in, x_out = prepare_batch(batch, device) print(f" Model input: {x_in.shape} ({x_in.dtype})") print(f" Model target: {x_out.shape} ({x_out.dtype})") print(f" GPU memory: {fmt_mem()}") return ch, x_in, x_out def test_diffusion(ch, x_in, x_out): """Test 2: Diffusion model forward + backward.""" print("\n" + "=" * 60) print("TEST 2: Diffusion Model") print("=" * 60) from unet import UNet from diffusion import GaussianDiffusion c_in = ch["input_channels"] c_out = ch["output_channels"] unet = UNet( in_channels=c_out + c_in, out_channels=c_out, base_ch=64, ch_mults=(1, 2, 4, 8), n_res=2, attn_levels=(3,), ) model = GaussianDiffusion(unet, timesteps=1000) device = x_in.device model = model.to(device) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f" Parameters: {n_params:,}") print(f" GPU memory after model: {fmt_mem()}") # Forward t0 = time.time() with torch.amp.autocast("cuda", dtype=torch.bfloat16): loss = model.training_loss(x_out, x_in) print(f" Forward pass: loss={loss.item():.4f} ({time.time() - t0:.3f}s)") print(f" GPU memory after forward: {fmt_mem()}") # Backward t0 = time.time() loss.backward() print(f" Backward pass: ({time.time() - t0:.3f}s)") print(f" GPU memory after backward: {fmt_mem()}") # Quick sampling test (just 5 steps for speed) model.eval() model.T = 5 # temporarily reduce for testing model.betas = model.betas[:5] model.alphas = model.alphas[:5] model.alpha_bar = model.alpha_bar[:5] model.sqrt_alpha_bar = model.sqrt_alpha_bar[:5] model.sqrt_one_minus_alpha_bar = model.sqrt_one_minus_alpha_bar[:5] model.sqrt_recip_alpha = model.sqrt_recip_alpha[:5] model.posterior_variance = model.posterior_variance[:5] t0 = time.time() with torch.no_grad(): sample = model.sample(x_in[:2], shape=(2, c_out, x_in.shape[2], x_in.shape[3])) print(f" Sampling (5 steps, B=2): shape={sample.shape} ({time.time() - t0:.3f}s)") del model torch.cuda.empty_cache() print(f" DIFFUSION OK") def test_jepa(ch, x_in, x_out): """Test 3: JEPA forward + backward.""" print("\n" + "=" * 60) print("TEST 3: JEPA Model") print("=" * 60) from jepa import JEPA c_in = ch["input_channels"] device = x_in.device model = JEPA( in_channels=c_in, latent_channels=128, base_ch=32, pred_hidden=256, ).to(device) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_params = sum(p.numel() for p in model.parameters()) print(f" Trainable parameters: {n_params:,}") print(f" Total parameters (incl EMA target): {total_params:,}") print(f" GPU memory after model: {fmt_mem()}") # Forward t0 = time.time() with torch.amp.autocast("cuda", dtype=torch.bfloat16): loss, metrics = model.compute_loss(x_in, x_out) print(f" Forward: loss={loss.item():.4f}, metrics={metrics} ({time.time() - t0:.3f}s)") print(f" GPU memory after forward: {fmt_mem()}") # Backward t0 = time.time() loss.backward() print(f" Backward: ({time.time() - t0:.3f}s)") print(f" GPU memory after backward: {fmt_mem()}") # EMA update model.update_target() print(f" EMA update: OK") # Check latent shapes model.eval() with torch.no_grad(): z_pred, z_target = model(x_in[:2], x_out[:2]) print(f" Latent shapes: pred={z_pred.shape}, target={z_target.shape}") del model torch.cuda.empty_cache() print(f" JEPA OK") def test_training_step(ch, loader): """Test 4: Full training step with optimizer and grad scaling.""" print("\n" + "=" * 60) print("TEST 4: Full Training Step") print("=" * 60) from data_pipeline import prepare_batch from unet import UNet from diffusion import GaussianDiffusion c_in = ch["input_channels"] c_out = ch["output_channels"] device = "cuda" if torch.cuda.is_available() else "cpu" unet = UNet(in_channels=c_out + c_in, out_channels=c_out, base_ch=64) model = GaussianDiffusion(unet, timesteps=1000).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) scaler = torch.amp.GradScaler("cuda") model.train() losses = [] for i, batch in enumerate(loader): if i >= 3: break x_in, x_out = prepare_batch(batch, device) optimizer.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", dtype=torch.bfloat16): loss = model.training_loss(x_out, x_in) scaler.scale(loss).backward() scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() losses.append(loss.item()) print(f" Step {i}: loss={loss.item():.4f}, mem={fmt_mem()}") print(f" 3 training steps completed. Losses: {[f'{l:.4f}' for l in losses]}") del model, optimizer, scaler torch.cuda.empty_cache() print(f" TRAINING STEP OK") def main(): parser = argparse.ArgumentParser() parser.add_argument("--dataset", default="turbulent_radiative_layer_2D") parser.add_argument("--streaming", action="store_true", default=True) parser.add_argument("--no-streaming", dest="streaming", action="store_false") parser.add_argument("--local_path", default=None) parser.add_argument("--batch_size", type=int, default=4) args = parser.parse_args() print("=" * 60) print("THE WELL - Pipeline End-to-End Test") print("=" * 60) print(f"Dataset: {args.dataset}") print(f"Streaming: {args.streaming}") print(f"Batch: {args.batch_size}") print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"Memory: {fmt_mem()}") results = {} # Test 1: Data try: ch, x_in, x_out = test_data_loading(args) results["data"] = "PASS" except Exception as e: print(f" FAIL: {e}") traceback.print_exc() results["data"] = f"FAIL: {e}" sys.exit(1) # Test 2: Diffusion try: test_diffusion(ch, x_in, x_out) results["diffusion"] = "PASS" except Exception as e: print(f" FAIL: {e}") traceback.print_exc() results["diffusion"] = f"FAIL: {e}" # Test 3: JEPA try: test_jepa(ch, x_in, x_out) results["jepa"] = "PASS" except Exception as e: print(f" FAIL: {e}") traceback.print_exc() results["jepa"] = f"FAIL: {e}" # Test 4: Training step try: loader, _ = __import__("data_pipeline").create_dataloader( dataset_name=args.dataset, split="train", batch_size=args.batch_size, streaming=args.streaming, local_path=args.local_path, ) test_training_step(ch, loader) results["training_step"] = "PASS" except Exception as e: print(f" FAIL: {e}") traceback.print_exc() results["training_step"] = f"FAIL: {e}" # Summary print("\n" + "=" * 60) print("SUMMARY") print("=" * 60) all_pass = True for name, status in results.items(): icon = "PASS" if status == "PASS" else "FAIL" print(f" [{icon}] {name}") if status != "PASS": all_pass = False if all_pass: print("\nAll tests passed! Pipeline is ready for training.") else: print("\nSome tests failed. Check output above.") sys.exit(1) if __name__ == "__main__": main()