"""The four models under test. They share encoders, differ in loss and Delta-t. Model variants: A: ECG-JEPA unimodal (I-JEPA self-prediction on ECG only) B: cross-modal JEPA, delta_t = 0 C: symmetric InfoNCE (no predictor) F: PhysioJEPA v1 (cross-modal JEPA, variable delta_t) """ from __future__ import annotations from dataclasses import dataclass, field import torch import torch.nn.functional as F from torch import nn from .dt_embed import DeltaTEmbedding from .ecg_encoder import ECGPatchTokeniser from .ema import EMA from .masking import multi_block_mask_1d from .ppg_encoder import PPGPatchTokeniser from .vit import CrossAttentionPredictor, ViT1D @dataclass class ModelConfig: ecg_patch: int = 50 ppg_patch: int = 25 d_model: int = 256 ecg_depth: int = 12 ppg_depth: int = 6 heads: int = 8 pred_depth: int = 4 max_tokens: int = 128 # ablation knobs query_mode: str = "learned" # "learned" | "sinusoidal" mask_ratio: float = 0.50 def _pool(x: torch.Tensor) -> torch.Tensor: return x.mean(dim=1) def _make_query_emb(cfg: ModelConfig) -> tuple[nn.Module | None, torch.Tensor | None]: """Returns either a learned nn.Parameter wrapped in a tiny Module, or a fixed sinusoidal table buffer. Caller should index with positions. """ if cfg.query_mode == "sinusoidal": import math n_pos, d = cfg.max_tokens, cfg.d_model pe = torch.zeros(n_pos, d) pos = torch.arange(0, n_pos, dtype=torch.float32).unsqueeze(1) div = torch.exp(torch.arange(0, d, 2, dtype=torch.float32) * -(math.log(10_000.0) / d)) pe[:, 0::2] = torch.sin(pos * div) pe[:, 1::2] = torch.cos(pos * div) return None, pe # caller stores as buffer return None, None # caller creates learned Parameter class ECGOnlyEncoder(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() self.tok = ECGPatchTokeniser(patch_size=cfg.ecg_patch, d_model=cfg.d_model, max_patches=cfg.max_tokens) self.trunk = ViT1D(depth=cfg.ecg_depth, d_model=cfg.d_model, heads=cfg.heads) def forward(self, ecg: torch.Tensor) -> torch.Tensor: return self.trunk(self.tok(ecg)) # [B, N_e, d] class PPGEncoder(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() self.tok = PPGPatchTokeniser(patch_size=cfg.ppg_patch, d_model=cfg.d_model, max_patches=cfg.max_tokens) self.trunk = ViT1D(depth=cfg.ppg_depth, d_model=cfg.d_model, heads=cfg.heads) def forward(self, ppg: torch.Tensor) -> torch.Tensor: return self.trunk(self.tok(ppg)) # --------------------------------------------------------------------------- # Baseline A — ECG-JEPA unimodal (I-JEPA style self-prediction) # --------------------------------------------------------------------------- class BaselineA(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() self.cfg = cfg self.ecg = ECGOnlyEncoder(cfg) self.ecg_tgt = EMA(self.ecg) self.predictor = CrossAttentionPredictor( depth=cfg.pred_depth, d_model=cfg.d_model, heads=cfg.heads ) _, sinpe = _make_query_emb(cfg) if sinpe is None: self.query_emb = nn.Parameter(torch.zeros(cfg.max_tokens, cfg.d_model)) nn.init.trunc_normal_(self.query_emb, std=0.02) else: self.register_buffer("query_emb", sinpe, persistent=False) def step(self, batch: dict) -> dict: ecg = batch["ecg"] # [B, 1, T] b = ecg.shape[0] n_ecg = ecg.shape[-1] // self.cfg.ecg_patch ctx_idxs = [] tgt_idxs = [] for _ in range(b): c, t = multi_block_mask_1d(n_ecg, n_targets=4, target_size_range=(4, 8), mask_ratio=self.cfg.mask_ratio) ctx_idxs.append(c) tgt_idxs.append(t) # All sequences same B but variable ctx/tgt lengths — process per-sample # then pack. For efficiency use a padded approach. tok = self.ecg.tok(ecg) # [B, N, d] trunk = self.ecg.trunk # context forward: apply trunk on full sequence then gather ctx/tgt tokens full_ctx = trunk(tok) # [B, N, d] tgt_full = self.ecg_tgt.target.trunk(self.ecg_tgt.target.tok(ecg)).detach() L_self = torch.tensor(0.0, device=ecg.device) total = 0 for i in range(b): q = self.query_emb[tgt_idxs[i]].unsqueeze(0) # [1, n_t, d] ctx_tokens = full_ctx[i : i + 1, ctx_idxs[i], :] pred = self.predictor(q, ctx_tokens).squeeze(0) tgt_v = tgt_full[i, tgt_idxs[i], :] L_self = L_self + F.l1_loss(pred, tgt_v, reduction="mean") total += 1 L_self = L_self / max(total, 1) return {"loss": L_self, "L_self": L_self.detach(), "L_cross": torch.tensor(0.0), "z_ecg": _pool(full_ctx.detach())} def targets(self): return [(self.ecg, self.ecg_tgt)] # --------------------------------------------------------------------------- # Shared cross-modal backbone for Baselines B, C, and E3 PhysioJEPA # --------------------------------------------------------------------------- class CrossModalBackbone(nn.Module): """Dual online encoders + two EMA targets + cross-attention predictor + Δt emb.""" def __init__(self, cfg: ModelConfig, use_predictor: bool = True, use_delta_t: bool = True): super().__init__() self.cfg = cfg self.use_predictor = use_predictor self.use_delta_t = use_delta_t self.ecg = ECGOnlyEncoder(cfg) self.ppg = PPGEncoder(cfg) self.ecg_tgt = EMA(self.ecg) self.ppg_tgt = EMA(self.ppg) if use_predictor: self.predictor = CrossAttentionPredictor( depth=cfg.pred_depth, d_model=cfg.d_model, heads=cfg.heads ) _, sinpe = _make_query_emb(cfg) if sinpe is None: self.query_emb = nn.Parameter(torch.zeros(cfg.max_tokens, cfg.d_model)) nn.init.trunc_normal_(self.query_emb, std=0.02) else: self.register_buffer("query_emb", sinpe, persistent=False) if use_delta_t: self.dt_emb = DeltaTEmbedding(d_model=cfg.d_model) def encode_ctx(self, ecg: torch.Tensor) -> torch.Tensor: return self.ecg(ecg) def encode_ppg_target(self, ppg: torch.Tensor) -> torch.Tensor: with torch.no_grad(): return self.ppg_tgt.target(ppg).detach() def predict_ppg(self, z_ecg: torch.Tensor, n_ppg_tokens: int, dt_seconds: torch.Tensor | None) -> torch.Tensor: b = z_ecg.shape[0] q = self.query_emb[:n_ppg_tokens].unsqueeze(0).expand(b, -1, -1) ctx = z_ecg if self.use_delta_t and dt_seconds is not None: dt_tok = self.dt_emb(dt_seconds).unsqueeze(1) # [B, 1, d] ctx = torch.cat([ctx, dt_tok], dim=1) return self.predictor(q, ctx) def targets(self): return [(self.ecg, self.ecg_tgt), (self.ppg, self.ppg_tgt)] # --------------------------------------------------------------------------- # Baseline B — symmetric cross-modal JEPA, Δt = 0 # --------------------------------------------------------------------------- class BaselineB(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() self.cfg = cfg self.bb = CrossModalBackbone(cfg, use_predictor=True, use_delta_t=False) def step(self, batch: dict) -> dict: ecg, ppg = batch["ecg"], batch["ppg"] z_ecg = self.bb.encode_ctx(ecg) # [B, N_e, d] z_ppg_tgt = self.bb.encode_ppg_target(ppg) # [B, N_p, d] n_ppg = z_ppg_tgt.shape[1] z_pred = self.bb.predict_ppg(z_ecg, n_ppg, dt_seconds=None) L_cross = F.l1_loss(z_pred, z_ppg_tgt) # auxiliary self-prediction on ECG (I-JEPA style) — same code path as BaselineA n_ecg = z_ecg.shape[1] b = z_ecg.shape[0] tok = self.bb.ecg.tok(ecg) full_ctx = self.bb.ecg.trunk(tok) tgt_full = self.bb.ecg_tgt.target.trunk(self.bb.ecg_tgt.target.tok(ecg)).detach() L_self = torch.tensor(0.0, device=ecg.device) for i in range(b): c, t = multi_block_mask_1d(n_ecg, n_targets=4, target_size_range=(4, 8), mask_ratio=self.cfg.mask_ratio) if len(t) == 0: continue q = self.bb.query_emb[t].unsqueeze(0) ctx_tokens = full_ctx[i : i + 1, c, :] pred = self.bb.predictor(q, ctx_tokens).squeeze(0) tgt_v = tgt_full[i, t, :] L_self = L_self + F.l1_loss(pred, tgt_v) L_self = L_self / max(b, 1) loss = L_cross + 0.3 * L_self return {"loss": loss, "L_cross": L_cross.detach(), "L_self": L_self.detach(), "z_ecg": _pool(z_ecg.detach()), "z_ppg": _pool(z_ppg_tgt.detach()), "z_pred": _pool(z_pred.detach())} def targets(self): return self.bb.targets() # --------------------------------------------------------------------------- # Baseline C — symmetric InfoNCE (no predictor) # --------------------------------------------------------------------------- class BaselineC(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() self.cfg = cfg self.ecg = ECGOnlyEncoder(cfg) self.ppg = PPGEncoder(cfg) self.ecg_head = nn.Linear(cfg.d_model, cfg.d_model) self.ppg_head = nn.Linear(cfg.d_model, cfg.d_model) # Standard CLIP-style temperature init: physical τ ≈ 0.07 → multiplier ≈ 14.3. # The earlier init log_tau=0 made multiplier=1, leaving logits ∈ [-1, 1] which # gives loss ≈ ln(B) = uninformative ceiling. self.log_tau = nn.Parameter(torch.log(torch.tensor(1.0 / 0.07))) def step(self, batch: dict) -> dict: ecg, ppg = batch["ecg"], batch["ppg"] z_ecg = F.normalize(self.ecg_head(_pool(self.ecg(ecg))), dim=-1) z_ppg = F.normalize(self.ppg_head(_pool(self.ppg(ppg))), dim=-1) tau = torch.clamp(self.log_tau.exp(), 0.01, 100.0) logits = tau * z_ecg @ z_ppg.t() b = z_ecg.shape[0] labels = torch.arange(b, device=ecg.device) loss = 0.5 * (F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels)) return {"loss": loss, "L_cross": loss.detach(), "L_self": torch.tensor(0.0), "z_ecg": z_ecg.detach(), "z_ppg": z_ppg.detach(), "z_pred": z_ppg.detach(), "tau": tau.detach()} def targets(self): return [] # no EMA — pure contrastive # --------------------------------------------------------------------------- # E3 — PhysioJEPA v1 (variable Δt cross-modal JEPA) # --------------------------------------------------------------------------- class PhysioJEPA(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() self.cfg = cfg self.bb = CrossModalBackbone(cfg, use_predictor=True, use_delta_t=True) def step(self, batch: dict) -> dict: ecg, ppg = batch["ecg"], batch["ppg"] dt = batch["dt_seconds"] # [B] z_ecg = self.bb.encode_ctx(ecg) z_ppg_tgt = self.bb.encode_ppg_target(ppg) n_ppg = z_ppg_tgt.shape[1] z_pred = self.bb.predict_ppg(z_ecg, n_ppg, dt_seconds=dt) L_cross = F.l1_loss(z_pred, z_ppg_tgt) # auxiliary ECG self-prediction n_ecg = z_ecg.shape[1] b = z_ecg.shape[0] tok = self.bb.ecg.tok(ecg) full_ctx = self.bb.ecg.trunk(tok) tgt_full = self.bb.ecg_tgt.target.trunk(self.bb.ecg_tgt.target.tok(ecg)).detach() L_self = torch.tensor(0.0, device=ecg.device) for i in range(b): c, t = multi_block_mask_1d(n_ecg, n_targets=4, target_size_range=(4, 8), mask_ratio=self.cfg.mask_ratio) if len(t) == 0: continue q = self.bb.query_emb[t].unsqueeze(0) ctx_tokens = full_ctx[i : i + 1, c, :] pred = self.bb.predictor(q, ctx_tokens).squeeze(0) tgt_v = tgt_full[i, t, :] L_self = L_self + F.l1_loss(pred, tgt_v) L_self = L_self / max(b, 1) loss = L_cross + 0.3 * L_self return {"loss": loss, "L_cross": L_cross.detach(), "L_self": L_self.detach(), "z_ecg": _pool(z_ecg.detach()), "z_ppg": _pool(z_ppg_tgt.detach()), "z_pred": _pool(z_pred.detach()), "dt": dt.detach()} def targets(self): return self.bb.targets() MODEL_REGISTRY = {"A": BaselineA, "B": BaselineB, "C": BaselineC, "F": PhysioJEPA}