the-well-diffusion / test_pipeline.py
AlexWortega's picture
Upload test_pipeline.py with huggingface_hub
3cfa9a4 verified
#!/usr/bin/env python3
"""
End-to-end test: data loading → model forward → backward.
Verifies that the full pipeline works before committing to long training.
Usage:
python test_pipeline.py
python test_pipeline.py --dataset active_matter --no-streaming --local_path /data/well
"""
import argparse
import sys
import time
import traceback
import torch
import torch.nn as nn
def fmt_mem():
if torch.cuda.is_available():
alloc = torch.cuda.memory_allocated() / 1e9
res = torch.cuda.memory_reserved() / 1e9
total = torch.cuda.get_device_properties(0).total_memory / 1e9
return f"alloc={alloc:.2f}GB, reserved={res:.2f}GB, total={total:.1f}GB"
return "CPU only"
def test_data_loading(args):
"""Test 1: Load data and print shapes."""
print("\n" + "=" * 60)
print("TEST 1: Data Loading")
print("=" * 60)
from data_pipeline import create_dataloader, prepare_batch, get_channel_info, get_data_info
t0 = time.time()
loader, dataset = create_dataloader(
dataset_name=args.dataset,
split="train",
batch_size=args.batch_size,
streaming=args.streaming,
local_path=args.local_path,
)
print(f" Dataset created in {time.time() - t0:.1f}s")
print(f" Dataset length: {len(dataset)}")
# Probe shapes
info = get_data_info(dataset)
print(f" Sample fields:")
for k, v in info.items():
print(f" {k}: {v}")
ch = get_channel_info(dataset)
print(f" Channel info: {ch}")
# Load one batch
t0 = time.time()
batch = next(iter(loader))
print(f" First batch loaded in {time.time() - t0:.1f}s")
print(f" Batch keys: {list(batch.keys())}")
for k, v in batch.items():
if isinstance(v, torch.Tensor):
print(f" {k}: {v.shape} ({v.dtype})")
# Prepare for model
device = "cuda" if torch.cuda.is_available() else "cpu"
x_in, x_out = prepare_batch(batch, device)
print(f" Model input: {x_in.shape} ({x_in.dtype})")
print(f" Model target: {x_out.shape} ({x_out.dtype})")
print(f" GPU memory: {fmt_mem()}")
return ch, x_in, x_out
def test_diffusion(ch, x_in, x_out):
"""Test 2: Diffusion model forward + backward."""
print("\n" + "=" * 60)
print("TEST 2: Diffusion Model")
print("=" * 60)
from unet import UNet
from diffusion import GaussianDiffusion
c_in = ch["input_channels"]
c_out = ch["output_channels"]
unet = UNet(
in_channels=c_out + c_in,
out_channels=c_out,
base_ch=64,
ch_mults=(1, 2, 4, 8),
n_res=2,
attn_levels=(3,),
)
model = GaussianDiffusion(unet, timesteps=1000)
device = x_in.device
model = model.to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" Parameters: {n_params:,}")
print(f" GPU memory after model: {fmt_mem()}")
# Forward
t0 = time.time()
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
loss = model.training_loss(x_out, x_in)
print(f" Forward pass: loss={loss.item():.4f} ({time.time() - t0:.3f}s)")
print(f" GPU memory after forward: {fmt_mem()}")
# Backward
t0 = time.time()
loss.backward()
print(f" Backward pass: ({time.time() - t0:.3f}s)")
print(f" GPU memory after backward: {fmt_mem()}")
# Quick sampling test (just 5 steps for speed)
model.eval()
model.T = 5 # temporarily reduce for testing
model.betas = model.betas[:5]
model.alphas = model.alphas[:5]
model.alpha_bar = model.alpha_bar[:5]
model.sqrt_alpha_bar = model.sqrt_alpha_bar[:5]
model.sqrt_one_minus_alpha_bar = model.sqrt_one_minus_alpha_bar[:5]
model.sqrt_recip_alpha = model.sqrt_recip_alpha[:5]
model.posterior_variance = model.posterior_variance[:5]
t0 = time.time()
with torch.no_grad():
sample = model.sample(x_in[:2], shape=(2, c_out, x_in.shape[2], x_in.shape[3]))
print(f" Sampling (5 steps, B=2): shape={sample.shape} ({time.time() - t0:.3f}s)")
del model
torch.cuda.empty_cache()
print(f" DIFFUSION OK")
def test_jepa(ch, x_in, x_out):
"""Test 3: JEPA forward + backward."""
print("\n" + "=" * 60)
print("TEST 3: JEPA Model")
print("=" * 60)
from jepa import JEPA
c_in = ch["input_channels"]
device = x_in.device
model = JEPA(
in_channels=c_in,
latent_channels=128,
base_ch=32,
pred_hidden=256,
).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f" Trainable parameters: {n_params:,}")
print(f" Total parameters (incl EMA target): {total_params:,}")
print(f" GPU memory after model: {fmt_mem()}")
# Forward
t0 = time.time()
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
loss, metrics = model.compute_loss(x_in, x_out)
print(f" Forward: loss={loss.item():.4f}, metrics={metrics} ({time.time() - t0:.3f}s)")
print(f" GPU memory after forward: {fmt_mem()}")
# Backward
t0 = time.time()
loss.backward()
print(f" Backward: ({time.time() - t0:.3f}s)")
print(f" GPU memory after backward: {fmt_mem()}")
# EMA update
model.update_target()
print(f" EMA update: OK")
# Check latent shapes
model.eval()
with torch.no_grad():
z_pred, z_target = model(x_in[:2], x_out[:2])
print(f" Latent shapes: pred={z_pred.shape}, target={z_target.shape}")
del model
torch.cuda.empty_cache()
print(f" JEPA OK")
def test_training_step(ch, loader):
"""Test 4: Full training step with optimizer and grad scaling."""
print("\n" + "=" * 60)
print("TEST 4: Full Training Step")
print("=" * 60)
from data_pipeline import prepare_batch
from unet import UNet
from diffusion import GaussianDiffusion
c_in = ch["input_channels"]
c_out = ch["output_channels"]
device = "cuda" if torch.cuda.is_available() else "cpu"
unet = UNet(in_channels=c_out + c_in, out_channels=c_out, base_ch=64)
model = GaussianDiffusion(unet, timesteps=1000).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = torch.amp.GradScaler("cuda")
model.train()
losses = []
for i, batch in enumerate(loader):
if i >= 3:
break
x_in, x_out = prepare_batch(batch, device)
optimizer.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
loss = model.training_loss(x_out, x_in)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
losses.append(loss.item())
print(f" Step {i}: loss={loss.item():.4f}, mem={fmt_mem()}")
print(f" 3 training steps completed. Losses: {[f'{l:.4f}' for l in losses]}")
del model, optimizer, scaler
torch.cuda.empty_cache()
print(f" TRAINING STEP OK")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default="turbulent_radiative_layer_2D")
parser.add_argument("--streaming", action="store_true", default=True)
parser.add_argument("--no-streaming", dest="streaming", action="store_false")
parser.add_argument("--local_path", default=None)
parser.add_argument("--batch_size", type=int, default=4)
args = parser.parse_args()
print("=" * 60)
print("THE WELL - Pipeline End-to-End Test")
print("=" * 60)
print(f"Dataset: {args.dataset}")
print(f"Streaming: {args.streaming}")
print(f"Batch: {args.batch_size}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Memory: {fmt_mem()}")
results = {}
# Test 1: Data
try:
ch, x_in, x_out = test_data_loading(args)
results["data"] = "PASS"
except Exception as e:
print(f" FAIL: {e}")
traceback.print_exc()
results["data"] = f"FAIL: {e}"
sys.exit(1)
# Test 2: Diffusion
try:
test_diffusion(ch, x_in, x_out)
results["diffusion"] = "PASS"
except Exception as e:
print(f" FAIL: {e}")
traceback.print_exc()
results["diffusion"] = f"FAIL: {e}"
# Test 3: JEPA
try:
test_jepa(ch, x_in, x_out)
results["jepa"] = "PASS"
except Exception as e:
print(f" FAIL: {e}")
traceback.print_exc()
results["jepa"] = f"FAIL: {e}"
# Test 4: Training step
try:
loader, _ = __import__("data_pipeline").create_dataloader(
dataset_name=args.dataset,
split="train",
batch_size=args.batch_size,
streaming=args.streaming,
local_path=args.local_path,
)
test_training_step(ch, loader)
results["training_step"] = "PASS"
except Exception as e:
print(f" FAIL: {e}")
traceback.print_exc()
results["training_step"] = f"FAIL: {e}"
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
all_pass = True
for name, status in results.items():
icon = "PASS" if status == "PASS" else "FAIL"
print(f" [{icon}] {name}")
if status != "PASS":
all_pass = False
if all_pass:
print("\nAll tests passed! Pipeline is ready for training.")
else:
print("\nSome tests failed. Check output above.")
sys.exit(1)
if __name__ == "__main__":
main()