PhysioJEPA / src /physiojepa /models.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""The four models under test. They share encoders, differ in loss and Delta-t.
Model variants:
A: ECG-JEPA unimodal (I-JEPA self-prediction on ECG only)
B: cross-modal JEPA, delta_t = 0
C: symmetric InfoNCE (no predictor)
F: PhysioJEPA v1 (cross-modal JEPA, variable delta_t)
"""
from __future__ import annotations
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from torch import nn
from .dt_embed import DeltaTEmbedding
from .ecg_encoder import ECGPatchTokeniser
from .ema import EMA
from .masking import multi_block_mask_1d
from .ppg_encoder import PPGPatchTokeniser
from .vit import CrossAttentionPredictor, ViT1D
@dataclass
class ModelConfig:
ecg_patch: int = 50
ppg_patch: int = 25
d_model: int = 256
ecg_depth: int = 12
ppg_depth: int = 6
heads: int = 8
pred_depth: int = 4
max_tokens: int = 128
# ablation knobs
query_mode: str = "learned" # "learned" | "sinusoidal"
mask_ratio: float = 0.50
def _pool(x: torch.Tensor) -> torch.Tensor:
return x.mean(dim=1)
def _make_query_emb(cfg: ModelConfig) -> tuple[nn.Module | None, torch.Tensor | None]:
"""Returns either a learned nn.Parameter wrapped in a tiny Module, or a
fixed sinusoidal table buffer. Caller should index with positions.
"""
if cfg.query_mode == "sinusoidal":
import math
n_pos, d = cfg.max_tokens, cfg.d_model
pe = torch.zeros(n_pos, d)
pos = torch.arange(0, n_pos, dtype=torch.float32).unsqueeze(1)
div = torch.exp(torch.arange(0, d, 2, dtype=torch.float32) *
-(math.log(10_000.0) / d))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
return None, pe # caller stores as buffer
return None, None # caller creates learned Parameter
class ECGOnlyEncoder(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.tok = ECGPatchTokeniser(patch_size=cfg.ecg_patch, d_model=cfg.d_model,
max_patches=cfg.max_tokens)
self.trunk = ViT1D(depth=cfg.ecg_depth, d_model=cfg.d_model, heads=cfg.heads)
def forward(self, ecg: torch.Tensor) -> torch.Tensor:
return self.trunk(self.tok(ecg)) # [B, N_e, d]
class PPGEncoder(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.tok = PPGPatchTokeniser(patch_size=cfg.ppg_patch, d_model=cfg.d_model,
max_patches=cfg.max_tokens)
self.trunk = ViT1D(depth=cfg.ppg_depth, d_model=cfg.d_model, heads=cfg.heads)
def forward(self, ppg: torch.Tensor) -> torch.Tensor:
return self.trunk(self.tok(ppg))
# ---------------------------------------------------------------------------
# Baseline A β€” ECG-JEPA unimodal (I-JEPA style self-prediction)
# ---------------------------------------------------------------------------
class BaselineA(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
self.ecg = ECGOnlyEncoder(cfg)
self.ecg_tgt = EMA(self.ecg)
self.predictor = CrossAttentionPredictor(
depth=cfg.pred_depth, d_model=cfg.d_model, heads=cfg.heads
)
_, sinpe = _make_query_emb(cfg)
if sinpe is None:
self.query_emb = nn.Parameter(torch.zeros(cfg.max_tokens, cfg.d_model))
nn.init.trunc_normal_(self.query_emb, std=0.02)
else:
self.register_buffer("query_emb", sinpe, persistent=False)
def step(self, batch: dict) -> dict:
ecg = batch["ecg"] # [B, 1, T]
b = ecg.shape[0]
n_ecg = ecg.shape[-1] // self.cfg.ecg_patch
ctx_idxs = []
tgt_idxs = []
for _ in range(b):
c, t = multi_block_mask_1d(n_ecg, n_targets=4, target_size_range=(4, 8),
mask_ratio=self.cfg.mask_ratio)
ctx_idxs.append(c)
tgt_idxs.append(t)
# All sequences same B but variable ctx/tgt lengths β€” process per-sample
# then pack. For efficiency use a padded approach.
tok = self.ecg.tok(ecg) # [B, N, d]
trunk = self.ecg.trunk
# context forward: apply trunk on full sequence then gather ctx/tgt tokens
full_ctx = trunk(tok) # [B, N, d]
tgt_full = self.ecg_tgt.target.trunk(self.ecg_tgt.target.tok(ecg)).detach()
L_self = torch.tensor(0.0, device=ecg.device)
total = 0
for i in range(b):
q = self.query_emb[tgt_idxs[i]].unsqueeze(0) # [1, n_t, d]
ctx_tokens = full_ctx[i : i + 1, ctx_idxs[i], :]
pred = self.predictor(q, ctx_tokens).squeeze(0)
tgt_v = tgt_full[i, tgt_idxs[i], :]
L_self = L_self + F.l1_loss(pred, tgt_v, reduction="mean")
total += 1
L_self = L_self / max(total, 1)
return {"loss": L_self, "L_self": L_self.detach(), "L_cross": torch.tensor(0.0),
"z_ecg": _pool(full_ctx.detach())}
def targets(self):
return [(self.ecg, self.ecg_tgt)]
# ---------------------------------------------------------------------------
# Shared cross-modal backbone for Baselines B, C, and E3 PhysioJEPA
# ---------------------------------------------------------------------------
class CrossModalBackbone(nn.Module):
"""Dual online encoders + two EMA targets + cross-attention predictor + Ξ”t emb."""
def __init__(self, cfg: ModelConfig, use_predictor: bool = True, use_delta_t: bool = True):
super().__init__()
self.cfg = cfg
self.use_predictor = use_predictor
self.use_delta_t = use_delta_t
self.ecg = ECGOnlyEncoder(cfg)
self.ppg = PPGEncoder(cfg)
self.ecg_tgt = EMA(self.ecg)
self.ppg_tgt = EMA(self.ppg)
if use_predictor:
self.predictor = CrossAttentionPredictor(
depth=cfg.pred_depth, d_model=cfg.d_model, heads=cfg.heads
)
_, sinpe = _make_query_emb(cfg)
if sinpe is None:
self.query_emb = nn.Parameter(torch.zeros(cfg.max_tokens, cfg.d_model))
nn.init.trunc_normal_(self.query_emb, std=0.02)
else:
self.register_buffer("query_emb", sinpe, persistent=False)
if use_delta_t:
self.dt_emb = DeltaTEmbedding(d_model=cfg.d_model)
def encode_ctx(self, ecg: torch.Tensor) -> torch.Tensor:
return self.ecg(ecg)
def encode_ppg_target(self, ppg: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
return self.ppg_tgt.target(ppg).detach()
def predict_ppg(self, z_ecg: torch.Tensor, n_ppg_tokens: int,
dt_seconds: torch.Tensor | None) -> torch.Tensor:
b = z_ecg.shape[0]
q = self.query_emb[:n_ppg_tokens].unsqueeze(0).expand(b, -1, -1)
ctx = z_ecg
if self.use_delta_t and dt_seconds is not None:
dt_tok = self.dt_emb(dt_seconds).unsqueeze(1) # [B, 1, d]
ctx = torch.cat([ctx, dt_tok], dim=1)
return self.predictor(q, ctx)
def targets(self):
return [(self.ecg, self.ecg_tgt), (self.ppg, self.ppg_tgt)]
# ---------------------------------------------------------------------------
# Baseline B β€” symmetric cross-modal JEPA, Ξ”t = 0
# ---------------------------------------------------------------------------
class BaselineB(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
self.bb = CrossModalBackbone(cfg, use_predictor=True, use_delta_t=False)
def step(self, batch: dict) -> dict:
ecg, ppg = batch["ecg"], batch["ppg"]
z_ecg = self.bb.encode_ctx(ecg) # [B, N_e, d]
z_ppg_tgt = self.bb.encode_ppg_target(ppg) # [B, N_p, d]
n_ppg = z_ppg_tgt.shape[1]
z_pred = self.bb.predict_ppg(z_ecg, n_ppg, dt_seconds=None)
L_cross = F.l1_loss(z_pred, z_ppg_tgt)
# auxiliary self-prediction on ECG (I-JEPA style) β€” same code path as BaselineA
n_ecg = z_ecg.shape[1]
b = z_ecg.shape[0]
tok = self.bb.ecg.tok(ecg)
full_ctx = self.bb.ecg.trunk(tok)
tgt_full = self.bb.ecg_tgt.target.trunk(self.bb.ecg_tgt.target.tok(ecg)).detach()
L_self = torch.tensor(0.0, device=ecg.device)
for i in range(b):
c, t = multi_block_mask_1d(n_ecg, n_targets=4, target_size_range=(4, 8), mask_ratio=self.cfg.mask_ratio)
if len(t) == 0:
continue
q = self.bb.query_emb[t].unsqueeze(0)
ctx_tokens = full_ctx[i : i + 1, c, :]
pred = self.bb.predictor(q, ctx_tokens).squeeze(0)
tgt_v = tgt_full[i, t, :]
L_self = L_self + F.l1_loss(pred, tgt_v)
L_self = L_self / max(b, 1)
loss = L_cross + 0.3 * L_self
return {"loss": loss, "L_cross": L_cross.detach(), "L_self": L_self.detach(),
"z_ecg": _pool(z_ecg.detach()), "z_ppg": _pool(z_ppg_tgt.detach()),
"z_pred": _pool(z_pred.detach())}
def targets(self):
return self.bb.targets()
# ---------------------------------------------------------------------------
# Baseline C β€” symmetric InfoNCE (no predictor)
# ---------------------------------------------------------------------------
class BaselineC(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
self.ecg = ECGOnlyEncoder(cfg)
self.ppg = PPGEncoder(cfg)
self.ecg_head = nn.Linear(cfg.d_model, cfg.d_model)
self.ppg_head = nn.Linear(cfg.d_model, cfg.d_model)
# Standard CLIP-style temperature init: physical Ο„ β‰ˆ 0.07 β†’ multiplier β‰ˆ 14.3.
# The earlier init log_tau=0 made multiplier=1, leaving logits ∈ [-1, 1] which
# gives loss β‰ˆ ln(B) = uninformative ceiling.
self.log_tau = nn.Parameter(torch.log(torch.tensor(1.0 / 0.07)))
def step(self, batch: dict) -> dict:
ecg, ppg = batch["ecg"], batch["ppg"]
z_ecg = F.normalize(self.ecg_head(_pool(self.ecg(ecg))), dim=-1)
z_ppg = F.normalize(self.ppg_head(_pool(self.ppg(ppg))), dim=-1)
tau = torch.clamp(self.log_tau.exp(), 0.01, 100.0)
logits = tau * z_ecg @ z_ppg.t()
b = z_ecg.shape[0]
labels = torch.arange(b, device=ecg.device)
loss = 0.5 * (F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels))
return {"loss": loss, "L_cross": loss.detach(), "L_self": torch.tensor(0.0),
"z_ecg": z_ecg.detach(), "z_ppg": z_ppg.detach(),
"z_pred": z_ppg.detach(), "tau": tau.detach()}
def targets(self):
return [] # no EMA β€” pure contrastive
# ---------------------------------------------------------------------------
# E3 β€” PhysioJEPA v1 (variable Ξ”t cross-modal JEPA)
# ---------------------------------------------------------------------------
class PhysioJEPA(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
self.bb = CrossModalBackbone(cfg, use_predictor=True, use_delta_t=True)
def step(self, batch: dict) -> dict:
ecg, ppg = batch["ecg"], batch["ppg"]
dt = batch["dt_seconds"] # [B]
z_ecg = self.bb.encode_ctx(ecg)
z_ppg_tgt = self.bb.encode_ppg_target(ppg)
n_ppg = z_ppg_tgt.shape[1]
z_pred = self.bb.predict_ppg(z_ecg, n_ppg, dt_seconds=dt)
L_cross = F.l1_loss(z_pred, z_ppg_tgt)
# auxiliary ECG self-prediction
n_ecg = z_ecg.shape[1]
b = z_ecg.shape[0]
tok = self.bb.ecg.tok(ecg)
full_ctx = self.bb.ecg.trunk(tok)
tgt_full = self.bb.ecg_tgt.target.trunk(self.bb.ecg_tgt.target.tok(ecg)).detach()
L_self = torch.tensor(0.0, device=ecg.device)
for i in range(b):
c, t = multi_block_mask_1d(n_ecg, n_targets=4, target_size_range=(4, 8), mask_ratio=self.cfg.mask_ratio)
if len(t) == 0:
continue
q = self.bb.query_emb[t].unsqueeze(0)
ctx_tokens = full_ctx[i : i + 1, c, :]
pred = self.bb.predictor(q, ctx_tokens).squeeze(0)
tgt_v = tgt_full[i, t, :]
L_self = L_self + F.l1_loss(pred, tgt_v)
L_self = L_self / max(b, 1)
loss = L_cross + 0.3 * L_self
return {"loss": loss, "L_cross": L_cross.detach(), "L_self": L_self.detach(),
"z_ecg": _pool(z_ecg.detach()), "z_ppg": _pool(z_ppg_tgt.detach()),
"z_pred": _pool(z_pred.detach()), "dt": dt.detach()}
def targets(self):
return self.bb.targets()
MODEL_REGISTRY = {"A": BaselineA, "B": BaselineB, "C": BaselineC, "F": PhysioJEPA}