File size: 1,834 Bytes
31e2456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""CPU single-batch smoke test — gate before launching any GPU training.

Verifies that all 4 models forward+backward on a tiny batch with real-shaped
tensors, no NaN, and the loss decreases over a few optimiser steps.
"""
from __future__ import annotations

import os

import numpy as np
import torch

from physiojepa.models import MODEL_REGISTRY, ModelConfig


def _fake_batch(b: int = 4, device: str = "cpu") -> dict:
    ecg = torch.randn(b, 1, 2500, device=device)
    ppg = torch.randn(b, 1, 1250, device=device)
    dt = torch.rand(b, device=device) * 0.45 + 0.05  # 50-500 ms
    return {"ecg": ecg, "ppg": ppg, "dt_seconds": dt,
            "ptt_ms": torch.full((b,), float("nan"), device=device)}


def main() -> None:
    torch.manual_seed(0)
    np.random.seed(0)
    cfg = ModelConfig()
    device = torch.device("cpu")
    for variant in ("A", "B", "C", "F"):
        print(f"=== {variant} ===")
        m = MODEL_REGISTRY[variant](cfg).to(device)
        opt = torch.optim.AdamW(m.parameters(), lr=1e-3)
        losses = []
        for step in range(3):
            batch = _fake_batch()
            opt.zero_grad(set_to_none=True)
            out = m.step(batch)
            out["loss"].backward()
            opt.step()
            for online, tgt in m.targets():
                tgt.update(online, tau=0.996)
            val = float(out["loss"].item())
            assert np.isfinite(val), f"non-finite loss in {variant}"
            losses.append(val)
            print(f"  step={step} loss={val:.4f} "
                  f"L_cross={float(out.get('L_cross', torch.tensor(0.0)).item()):.4f} "
                  f"L_self={float(out.get('L_self', torch.tensor(0.0)).item()):.4f}")
        print(f"  -> losses: {[round(x, 4) for x in losses]}")
    print("\nSMOKE TEST PASSED")


if __name__ == "__main__":
    main()