| """The four models under test. They share encoders, differ in loss and Delta-t. |
| |
| Model variants: |
| A: ECG-JEPA unimodal (I-JEPA self-prediction on ECG only) |
| B: cross-modal JEPA, delta_t = 0 |
| C: symmetric InfoNCE (no predictor) |
| F: PhysioJEPA v1 (cross-modal JEPA, variable delta_t) |
| """ |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
| from .dt_embed import DeltaTEmbedding |
| from .ecg_encoder import ECGPatchTokeniser |
| from .ema import EMA |
| from .masking import multi_block_mask_1d |
| from .ppg_encoder import PPGPatchTokeniser |
| from .vit import CrossAttentionPredictor, ViT1D |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| ecg_patch: int = 50 |
| ppg_patch: int = 25 |
| d_model: int = 256 |
| ecg_depth: int = 12 |
| ppg_depth: int = 6 |
| heads: int = 8 |
| pred_depth: int = 4 |
| max_tokens: int = 128 |
| |
| query_mode: str = "learned" |
| mask_ratio: float = 0.50 |
|
|
|
|
| def _pool(x: torch.Tensor) -> torch.Tensor: |
| return x.mean(dim=1) |
|
|
|
|
| def _make_query_emb(cfg: ModelConfig) -> tuple[nn.Module | None, torch.Tensor | None]: |
| """Returns either a learned nn.Parameter wrapped in a tiny Module, or a |
| fixed sinusoidal table buffer. Caller should index with positions. |
| """ |
| if cfg.query_mode == "sinusoidal": |
| import math |
| n_pos, d = cfg.max_tokens, cfg.d_model |
| pe = torch.zeros(n_pos, d) |
| pos = torch.arange(0, n_pos, dtype=torch.float32).unsqueeze(1) |
| div = torch.exp(torch.arange(0, d, 2, dtype=torch.float32) * |
| -(math.log(10_000.0) / d)) |
| pe[:, 0::2] = torch.sin(pos * div) |
| pe[:, 1::2] = torch.cos(pos * div) |
| return None, pe |
| return None, None |
|
|
|
|
| class ECGOnlyEncoder(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.tok = ECGPatchTokeniser(patch_size=cfg.ecg_patch, d_model=cfg.d_model, |
| max_patches=cfg.max_tokens) |
| self.trunk = ViT1D(depth=cfg.ecg_depth, d_model=cfg.d_model, heads=cfg.heads) |
|
|
| def forward(self, ecg: torch.Tensor) -> torch.Tensor: |
| return self.trunk(self.tok(ecg)) |
|
|
|
|
| class PPGEncoder(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.tok = PPGPatchTokeniser(patch_size=cfg.ppg_patch, d_model=cfg.d_model, |
| max_patches=cfg.max_tokens) |
| self.trunk = ViT1D(depth=cfg.ppg_depth, d_model=cfg.d_model, heads=cfg.heads) |
|
|
| def forward(self, ppg: torch.Tensor) -> torch.Tensor: |
| return self.trunk(self.tok(ppg)) |
|
|
|
|
| |
| |
| |
| class BaselineA(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.ecg = ECGOnlyEncoder(cfg) |
| self.ecg_tgt = EMA(self.ecg) |
| self.predictor = CrossAttentionPredictor( |
| depth=cfg.pred_depth, d_model=cfg.d_model, heads=cfg.heads |
| ) |
| _, sinpe = _make_query_emb(cfg) |
| if sinpe is None: |
| self.query_emb = nn.Parameter(torch.zeros(cfg.max_tokens, cfg.d_model)) |
| nn.init.trunc_normal_(self.query_emb, std=0.02) |
| else: |
| self.register_buffer("query_emb", sinpe, persistent=False) |
|
|
| def step(self, batch: dict) -> dict: |
| ecg = batch["ecg"] |
| b = ecg.shape[0] |
| n_ecg = ecg.shape[-1] // self.cfg.ecg_patch |
| ctx_idxs = [] |
| tgt_idxs = [] |
| for _ in range(b): |
| c, t = multi_block_mask_1d(n_ecg, n_targets=4, target_size_range=(4, 8), |
| mask_ratio=self.cfg.mask_ratio) |
| ctx_idxs.append(c) |
| tgt_idxs.append(t) |
| |
| |
| tok = self.ecg.tok(ecg) |
| trunk = self.ecg.trunk |
| |
| full_ctx = trunk(tok) |
| tgt_full = self.ecg_tgt.target.trunk(self.ecg_tgt.target.tok(ecg)).detach() |
| L_self = torch.tensor(0.0, device=ecg.device) |
| total = 0 |
| for i in range(b): |
| q = self.query_emb[tgt_idxs[i]].unsqueeze(0) |
| ctx_tokens = full_ctx[i : i + 1, ctx_idxs[i], :] |
| pred = self.predictor(q, ctx_tokens).squeeze(0) |
| tgt_v = tgt_full[i, tgt_idxs[i], :] |
| L_self = L_self + F.l1_loss(pred, tgt_v, reduction="mean") |
| total += 1 |
| L_self = L_self / max(total, 1) |
| return {"loss": L_self, "L_self": L_self.detach(), "L_cross": torch.tensor(0.0), |
| "z_ecg": _pool(full_ctx.detach())} |
|
|
| def targets(self): |
| return [(self.ecg, self.ecg_tgt)] |
|
|
|
|
| |
| |
| |
| class CrossModalBackbone(nn.Module): |
| """Dual online encoders + two EMA targets + cross-attention predictor + Ξt emb.""" |
|
|
| def __init__(self, cfg: ModelConfig, use_predictor: bool = True, use_delta_t: bool = True): |
| super().__init__() |
| self.cfg = cfg |
| self.use_predictor = use_predictor |
| self.use_delta_t = use_delta_t |
| self.ecg = ECGOnlyEncoder(cfg) |
| self.ppg = PPGEncoder(cfg) |
| self.ecg_tgt = EMA(self.ecg) |
| self.ppg_tgt = EMA(self.ppg) |
| if use_predictor: |
| self.predictor = CrossAttentionPredictor( |
| depth=cfg.pred_depth, d_model=cfg.d_model, heads=cfg.heads |
| ) |
| _, sinpe = _make_query_emb(cfg) |
| if sinpe is None: |
| self.query_emb = nn.Parameter(torch.zeros(cfg.max_tokens, cfg.d_model)) |
| nn.init.trunc_normal_(self.query_emb, std=0.02) |
| else: |
| self.register_buffer("query_emb", sinpe, persistent=False) |
| if use_delta_t: |
| self.dt_emb = DeltaTEmbedding(d_model=cfg.d_model) |
|
|
| def encode_ctx(self, ecg: torch.Tensor) -> torch.Tensor: |
| return self.ecg(ecg) |
|
|
| def encode_ppg_target(self, ppg: torch.Tensor) -> torch.Tensor: |
| with torch.no_grad(): |
| return self.ppg_tgt.target(ppg).detach() |
|
|
| def predict_ppg(self, z_ecg: torch.Tensor, n_ppg_tokens: int, |
| dt_seconds: torch.Tensor | None) -> torch.Tensor: |
| b = z_ecg.shape[0] |
| q = self.query_emb[:n_ppg_tokens].unsqueeze(0).expand(b, -1, -1) |
| ctx = z_ecg |
| if self.use_delta_t and dt_seconds is not None: |
| dt_tok = self.dt_emb(dt_seconds).unsqueeze(1) |
| ctx = torch.cat([ctx, dt_tok], dim=1) |
| return self.predictor(q, ctx) |
|
|
| def targets(self): |
| return [(self.ecg, self.ecg_tgt), (self.ppg, self.ppg_tgt)] |
|
|
|
|
| |
| |
| |
| class BaselineB(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.bb = CrossModalBackbone(cfg, use_predictor=True, use_delta_t=False) |
|
|
| def step(self, batch: dict) -> dict: |
| ecg, ppg = batch["ecg"], batch["ppg"] |
| z_ecg = self.bb.encode_ctx(ecg) |
| z_ppg_tgt = self.bb.encode_ppg_target(ppg) |
| n_ppg = z_ppg_tgt.shape[1] |
| z_pred = self.bb.predict_ppg(z_ecg, n_ppg, dt_seconds=None) |
| L_cross = F.l1_loss(z_pred, z_ppg_tgt) |
|
|
| |
| n_ecg = z_ecg.shape[1] |
| b = z_ecg.shape[0] |
| tok = self.bb.ecg.tok(ecg) |
| full_ctx = self.bb.ecg.trunk(tok) |
| tgt_full = self.bb.ecg_tgt.target.trunk(self.bb.ecg_tgt.target.tok(ecg)).detach() |
| L_self = torch.tensor(0.0, device=ecg.device) |
| for i in range(b): |
| c, t = multi_block_mask_1d(n_ecg, n_targets=4, target_size_range=(4, 8), mask_ratio=self.cfg.mask_ratio) |
| if len(t) == 0: |
| continue |
| q = self.bb.query_emb[t].unsqueeze(0) |
| ctx_tokens = full_ctx[i : i + 1, c, :] |
| pred = self.bb.predictor(q, ctx_tokens).squeeze(0) |
| tgt_v = tgt_full[i, t, :] |
| L_self = L_self + F.l1_loss(pred, tgt_v) |
| L_self = L_self / max(b, 1) |
|
|
| loss = L_cross + 0.3 * L_self |
| return {"loss": loss, "L_cross": L_cross.detach(), "L_self": L_self.detach(), |
| "z_ecg": _pool(z_ecg.detach()), "z_ppg": _pool(z_ppg_tgt.detach()), |
| "z_pred": _pool(z_pred.detach())} |
|
|
| def targets(self): |
| return self.bb.targets() |
|
|
|
|
| |
| |
| |
| class BaselineC(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.ecg = ECGOnlyEncoder(cfg) |
| self.ppg = PPGEncoder(cfg) |
| self.ecg_head = nn.Linear(cfg.d_model, cfg.d_model) |
| self.ppg_head = nn.Linear(cfg.d_model, cfg.d_model) |
| |
| |
| |
| self.log_tau = nn.Parameter(torch.log(torch.tensor(1.0 / 0.07))) |
|
|
| def step(self, batch: dict) -> dict: |
| ecg, ppg = batch["ecg"], batch["ppg"] |
| z_ecg = F.normalize(self.ecg_head(_pool(self.ecg(ecg))), dim=-1) |
| z_ppg = F.normalize(self.ppg_head(_pool(self.ppg(ppg))), dim=-1) |
| tau = torch.clamp(self.log_tau.exp(), 0.01, 100.0) |
| logits = tau * z_ecg @ z_ppg.t() |
| b = z_ecg.shape[0] |
| labels = torch.arange(b, device=ecg.device) |
| loss = 0.5 * (F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels)) |
| return {"loss": loss, "L_cross": loss.detach(), "L_self": torch.tensor(0.0), |
| "z_ecg": z_ecg.detach(), "z_ppg": z_ppg.detach(), |
| "z_pred": z_ppg.detach(), "tau": tau.detach()} |
|
|
| def targets(self): |
| return [] |
|
|
|
|
| |
| |
| |
| class PhysioJEPA(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.bb = CrossModalBackbone(cfg, use_predictor=True, use_delta_t=True) |
|
|
| def step(self, batch: dict) -> dict: |
| ecg, ppg = batch["ecg"], batch["ppg"] |
| dt = batch["dt_seconds"] |
| z_ecg = self.bb.encode_ctx(ecg) |
| z_ppg_tgt = self.bb.encode_ppg_target(ppg) |
| n_ppg = z_ppg_tgt.shape[1] |
| z_pred = self.bb.predict_ppg(z_ecg, n_ppg, dt_seconds=dt) |
| L_cross = F.l1_loss(z_pred, z_ppg_tgt) |
|
|
| |
| n_ecg = z_ecg.shape[1] |
| b = z_ecg.shape[0] |
| tok = self.bb.ecg.tok(ecg) |
| full_ctx = self.bb.ecg.trunk(tok) |
| tgt_full = self.bb.ecg_tgt.target.trunk(self.bb.ecg_tgt.target.tok(ecg)).detach() |
| L_self = torch.tensor(0.0, device=ecg.device) |
| for i in range(b): |
| c, t = multi_block_mask_1d(n_ecg, n_targets=4, target_size_range=(4, 8), mask_ratio=self.cfg.mask_ratio) |
| if len(t) == 0: |
| continue |
| q = self.bb.query_emb[t].unsqueeze(0) |
| ctx_tokens = full_ctx[i : i + 1, c, :] |
| pred = self.bb.predictor(q, ctx_tokens).squeeze(0) |
| tgt_v = tgt_full[i, t, :] |
| L_self = L_self + F.l1_loss(pred, tgt_v) |
| L_self = L_self / max(b, 1) |
|
|
| loss = L_cross + 0.3 * L_self |
| return {"loss": loss, "L_cross": L_cross.detach(), "L_self": L_self.detach(), |
| "z_ecg": _pool(z_ecg.detach()), "z_ppg": _pool(z_ppg_tgt.detach()), |
| "z_pred": _pool(z_pred.detach()), "dt": dt.detach()} |
|
|
| def targets(self): |
| return self.bb.targets() |
|
|
|
|
| MODEL_REGISTRY = {"A": BaselineA, "B": BaselineB, "C": BaselineC, "F": PhysioJEPA} |
|
|