PhysioJEPA / scripts /smoke_test.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""CPU single-batch smoke test — gate before launching any GPU training.
Verifies that all 4 models forward+backward on a tiny batch with real-shaped
tensors, no NaN, and the loss decreases over a few optimiser steps.
"""
from __future__ import annotations
import os
import numpy as np
import torch
from physiojepa.models import MODEL_REGISTRY, ModelConfig
def _fake_batch(b: int = 4, device: str = "cpu") -> dict:
ecg = torch.randn(b, 1, 2500, device=device)
ppg = torch.randn(b, 1, 1250, device=device)
dt = torch.rand(b, device=device) * 0.45 + 0.05 # 50-500 ms
return {"ecg": ecg, "ppg": ppg, "dt_seconds": dt,
"ptt_ms": torch.full((b,), float("nan"), device=device)}
def main() -> None:
torch.manual_seed(0)
np.random.seed(0)
cfg = ModelConfig()
device = torch.device("cpu")
for variant in ("A", "B", "C", "F"):
print(f"=== {variant} ===")
m = MODEL_REGISTRY[variant](cfg).to(device)
opt = torch.optim.AdamW(m.parameters(), lr=1e-3)
losses = []
for step in range(3):
batch = _fake_batch()
opt.zero_grad(set_to_none=True)
out = m.step(batch)
out["loss"].backward()
opt.step()
for online, tgt in m.targets():
tgt.update(online, tau=0.996)
val = float(out["loss"].item())
assert np.isfinite(val), f"non-finite loss in {variant}"
losses.append(val)
print(f" step={step} loss={val:.4f} "
f"L_cross={float(out.get('L_cross', torch.tensor(0.0)).item()):.4f} "
f"L_self={float(out.get('L_self', torch.tensor(0.0)).item()):.4f}")
print(f" -> losses: {[round(x, 4) for x in losses]}")
print("\nSMOKE TEST PASSED")
if __name__ == "__main__":
main()