|
|
|
|
|
""" |
|
|
End-to-end test: data loading → model forward → backward. |
|
|
Verifies that the full pipeline works before committing to long training. |
|
|
|
|
|
Usage: |
|
|
python test_pipeline.py |
|
|
python test_pipeline.py --dataset active_matter --no-streaming --local_path /data/well |
|
|
""" |
|
|
import argparse |
|
|
import sys |
|
|
import time |
|
|
import traceback |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
def fmt_mem(): |
|
|
if torch.cuda.is_available(): |
|
|
alloc = torch.cuda.memory_allocated() / 1e9 |
|
|
res = torch.cuda.memory_reserved() / 1e9 |
|
|
total = torch.cuda.get_device_properties(0).total_memory / 1e9 |
|
|
return f"alloc={alloc:.2f}GB, reserved={res:.2f}GB, total={total:.1f}GB" |
|
|
return "CPU only" |
|
|
|
|
|
|
|
|
def test_data_loading(args): |
|
|
"""Test 1: Load data and print shapes.""" |
|
|
print("\n" + "=" * 60) |
|
|
print("TEST 1: Data Loading") |
|
|
print("=" * 60) |
|
|
|
|
|
from data_pipeline import create_dataloader, prepare_batch, get_channel_info, get_data_info |
|
|
|
|
|
t0 = time.time() |
|
|
loader, dataset = create_dataloader( |
|
|
dataset_name=args.dataset, |
|
|
split="train", |
|
|
batch_size=args.batch_size, |
|
|
streaming=args.streaming, |
|
|
local_path=args.local_path, |
|
|
) |
|
|
print(f" Dataset created in {time.time() - t0:.1f}s") |
|
|
print(f" Dataset length: {len(dataset)}") |
|
|
|
|
|
|
|
|
info = get_data_info(dataset) |
|
|
print(f" Sample fields:") |
|
|
for k, v in info.items(): |
|
|
print(f" {k}: {v}") |
|
|
|
|
|
ch = get_channel_info(dataset) |
|
|
print(f" Channel info: {ch}") |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
batch = next(iter(loader)) |
|
|
print(f" First batch loaded in {time.time() - t0:.1f}s") |
|
|
print(f" Batch keys: {list(batch.keys())}") |
|
|
for k, v in batch.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
print(f" {k}: {v.shape} ({v.dtype})") |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
x_in, x_out = prepare_batch(batch, device) |
|
|
print(f" Model input: {x_in.shape} ({x_in.dtype})") |
|
|
print(f" Model target: {x_out.shape} ({x_out.dtype})") |
|
|
print(f" GPU memory: {fmt_mem()}") |
|
|
|
|
|
return ch, x_in, x_out |
|
|
|
|
|
|
|
|
def test_diffusion(ch, x_in, x_out): |
|
|
"""Test 2: Diffusion model forward + backward.""" |
|
|
print("\n" + "=" * 60) |
|
|
print("TEST 2: Diffusion Model") |
|
|
print("=" * 60) |
|
|
|
|
|
from unet import UNet |
|
|
from diffusion import GaussianDiffusion |
|
|
|
|
|
c_in = ch["input_channels"] |
|
|
c_out = ch["output_channels"] |
|
|
|
|
|
unet = UNet( |
|
|
in_channels=c_out + c_in, |
|
|
out_channels=c_out, |
|
|
base_ch=64, |
|
|
ch_mults=(1, 2, 4, 8), |
|
|
n_res=2, |
|
|
attn_levels=(3,), |
|
|
) |
|
|
model = GaussianDiffusion(unet, timesteps=1000) |
|
|
device = x_in.device |
|
|
model = model.to(device) |
|
|
|
|
|
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print(f" Parameters: {n_params:,}") |
|
|
print(f" GPU memory after model: {fmt_mem()}") |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
|
|
loss = model.training_loss(x_out, x_in) |
|
|
print(f" Forward pass: loss={loss.item():.4f} ({time.time() - t0:.3f}s)") |
|
|
print(f" GPU memory after forward: {fmt_mem()}") |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
loss.backward() |
|
|
print(f" Backward pass: ({time.time() - t0:.3f}s)") |
|
|
print(f" GPU memory after backward: {fmt_mem()}") |
|
|
|
|
|
|
|
|
model.eval() |
|
|
model.T = 5 |
|
|
model.betas = model.betas[:5] |
|
|
model.alphas = model.alphas[:5] |
|
|
model.alpha_bar = model.alpha_bar[:5] |
|
|
model.sqrt_alpha_bar = model.sqrt_alpha_bar[:5] |
|
|
model.sqrt_one_minus_alpha_bar = model.sqrt_one_minus_alpha_bar[:5] |
|
|
model.sqrt_recip_alpha = model.sqrt_recip_alpha[:5] |
|
|
model.posterior_variance = model.posterior_variance[:5] |
|
|
|
|
|
t0 = time.time() |
|
|
with torch.no_grad(): |
|
|
sample = model.sample(x_in[:2], shape=(2, c_out, x_in.shape[2], x_in.shape[3])) |
|
|
print(f" Sampling (5 steps, B=2): shape={sample.shape} ({time.time() - t0:.3f}s)") |
|
|
|
|
|
del model |
|
|
torch.cuda.empty_cache() |
|
|
print(f" DIFFUSION OK") |
|
|
|
|
|
|
|
|
def test_jepa(ch, x_in, x_out): |
|
|
"""Test 3: JEPA forward + backward.""" |
|
|
print("\n" + "=" * 60) |
|
|
print("TEST 3: JEPA Model") |
|
|
print("=" * 60) |
|
|
|
|
|
from jepa import JEPA |
|
|
|
|
|
c_in = ch["input_channels"] |
|
|
device = x_in.device |
|
|
|
|
|
model = JEPA( |
|
|
in_channels=c_in, |
|
|
latent_channels=128, |
|
|
base_ch=32, |
|
|
pred_hidden=256, |
|
|
).to(device) |
|
|
|
|
|
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
print(f" Trainable parameters: {n_params:,}") |
|
|
print(f" Total parameters (incl EMA target): {total_params:,}") |
|
|
print(f" GPU memory after model: {fmt_mem()}") |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
|
|
loss, metrics = model.compute_loss(x_in, x_out) |
|
|
print(f" Forward: loss={loss.item():.4f}, metrics={metrics} ({time.time() - t0:.3f}s)") |
|
|
print(f" GPU memory after forward: {fmt_mem()}") |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
loss.backward() |
|
|
print(f" Backward: ({time.time() - t0:.3f}s)") |
|
|
print(f" GPU memory after backward: {fmt_mem()}") |
|
|
|
|
|
|
|
|
model.update_target() |
|
|
print(f" EMA update: OK") |
|
|
|
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
z_pred, z_target = model(x_in[:2], x_out[:2]) |
|
|
print(f" Latent shapes: pred={z_pred.shape}, target={z_target.shape}") |
|
|
|
|
|
del model |
|
|
torch.cuda.empty_cache() |
|
|
print(f" JEPA OK") |
|
|
|
|
|
|
|
|
def test_training_step(ch, loader): |
|
|
"""Test 4: Full training step with optimizer and grad scaling.""" |
|
|
print("\n" + "=" * 60) |
|
|
print("TEST 4: Full Training Step") |
|
|
print("=" * 60) |
|
|
|
|
|
from data_pipeline import prepare_batch |
|
|
from unet import UNet |
|
|
from diffusion import GaussianDiffusion |
|
|
|
|
|
c_in = ch["input_channels"] |
|
|
c_out = ch["output_channels"] |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
unet = UNet(in_channels=c_out + c_in, out_channels=c_out, base_ch=64) |
|
|
model = GaussianDiffusion(unet, timesteps=1000).to(device) |
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) |
|
|
scaler = torch.amp.GradScaler("cuda") |
|
|
|
|
|
model.train() |
|
|
losses = [] |
|
|
|
|
|
for i, batch in enumerate(loader): |
|
|
if i >= 3: |
|
|
break |
|
|
|
|
|
x_in, x_out = prepare_batch(batch, device) |
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
|
|
loss = model.training_loss(x_out, x_in) |
|
|
|
|
|
scaler.scale(loss).backward() |
|
|
scaler.unscale_(optimizer) |
|
|
nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
|
|
|
losses.append(loss.item()) |
|
|
print(f" Step {i}: loss={loss.item():.4f}, mem={fmt_mem()}") |
|
|
|
|
|
print(f" 3 training steps completed. Losses: {[f'{l:.4f}' for l in losses]}") |
|
|
del model, optimizer, scaler |
|
|
torch.cuda.empty_cache() |
|
|
print(f" TRAINING STEP OK") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--dataset", default="turbulent_radiative_layer_2D") |
|
|
parser.add_argument("--streaming", action="store_true", default=True) |
|
|
parser.add_argument("--no-streaming", dest="streaming", action="store_false") |
|
|
parser.add_argument("--local_path", default=None) |
|
|
parser.add_argument("--batch_size", type=int, default=4) |
|
|
args = parser.parse_args() |
|
|
|
|
|
print("=" * 60) |
|
|
print("THE WELL - Pipeline End-to-End Test") |
|
|
print("=" * 60) |
|
|
print(f"Dataset: {args.dataset}") |
|
|
print(f"Streaming: {args.streaming}") |
|
|
print(f"Batch: {args.batch_size}") |
|
|
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") |
|
|
if torch.cuda.is_available(): |
|
|
print(f"GPU: {torch.cuda.get_device_name(0)}") |
|
|
print(f"Memory: {fmt_mem()}") |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
try: |
|
|
ch, x_in, x_out = test_data_loading(args) |
|
|
results["data"] = "PASS" |
|
|
except Exception as e: |
|
|
print(f" FAIL: {e}") |
|
|
traceback.print_exc() |
|
|
results["data"] = f"FAIL: {e}" |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
try: |
|
|
test_diffusion(ch, x_in, x_out) |
|
|
results["diffusion"] = "PASS" |
|
|
except Exception as e: |
|
|
print(f" FAIL: {e}") |
|
|
traceback.print_exc() |
|
|
results["diffusion"] = f"FAIL: {e}" |
|
|
|
|
|
|
|
|
try: |
|
|
test_jepa(ch, x_in, x_out) |
|
|
results["jepa"] = "PASS" |
|
|
except Exception as e: |
|
|
print(f" FAIL: {e}") |
|
|
traceback.print_exc() |
|
|
results["jepa"] = f"FAIL: {e}" |
|
|
|
|
|
|
|
|
try: |
|
|
loader, _ = __import__("data_pipeline").create_dataloader( |
|
|
dataset_name=args.dataset, |
|
|
split="train", |
|
|
batch_size=args.batch_size, |
|
|
streaming=args.streaming, |
|
|
local_path=args.local_path, |
|
|
) |
|
|
test_training_step(ch, loader) |
|
|
results["training_step"] = "PASS" |
|
|
except Exception as e: |
|
|
print(f" FAIL: {e}") |
|
|
traceback.print_exc() |
|
|
results["training_step"] = f"FAIL: {e}" |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("SUMMARY") |
|
|
print("=" * 60) |
|
|
all_pass = True |
|
|
for name, status in results.items(): |
|
|
icon = "PASS" if status == "PASS" else "FAIL" |
|
|
print(f" [{icon}] {name}") |
|
|
if status != "PASS": |
|
|
all_pass = False |
|
|
|
|
|
if all_pass: |
|
|
print("\nAll tests passed! Pipeline is ready for training.") |
|
|
else: |
|
|
print("\nSome tests failed. Check output above.") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|