"""Minimal 1D ViT used for both ECG and PPG encoders. Shapes ------ forward(x_tokens): [B, N, d] -> [B, N, d] Patch tokenisation is handled separately (see ecg_encoder.py / ppg_encoder.py) so this module is purely the transformer trunk. """ from __future__ import annotations import torch from torch import nn class MHA(nn.Module): def __init__(self, d: int, heads: int, attn_drop: float = 0.0, proj_drop: float = 0.0): super().__init__() assert d % heads == 0 self.h = heads self.dh = d // heads self.qkv = nn.Linear(d, 3 * d, bias=True) self.proj = nn.Linear(d, d, bias=True) self.ad = nn.Dropout(attn_drop) self.pd = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor) -> torch.Tensor: b, n, d = x.shape qkv = self.qkv(x).view(b, n, 3, self.h, self.dh).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # [b, h, n, dh] out = nn.functional.scaled_dot_product_attention( q, k, v, dropout_p=self.ad.p if self.training else 0.0 ) out = out.transpose(1, 2).reshape(b, n, d) return self.pd(self.proj(out)) class Block(nn.Module): def __init__(self, d: int, heads: int, mlp_ratio: float = 4.0, drop: float = 0.0): super().__init__() self.n1 = nn.LayerNorm(d) self.attn = MHA(d, heads, attn_drop=drop, proj_drop=drop) self.n2 = nn.LayerNorm(d) hidden = int(d * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(d, hidden), nn.GELU(), nn.Dropout(drop), nn.Linear(hidden, d), nn.Dropout(drop), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.n1(x)) x = x + self.mlp(self.n2(x)) return x class ViT1D(nn.Module): """Token-in, token-out transformer trunk with final LayerNorm.""" def __init__( self, depth: int = 12, d_model: int = 256, heads: int = 8, mlp_ratio: float = 4.0, drop: float = 0.0, ): super().__init__() self.blocks = nn.ModuleList( [Block(d_model, heads, mlp_ratio, drop) for _ in range(depth)] ) self.norm = nn.LayerNorm(d_model) def forward(self, tokens: torch.Tensor) -> torch.Tensor: x = tokens for blk in self.blocks: x = blk(x) return self.norm(x) class CrossAttnBlock(nn.Module): """Self-attention → cross-attention(kv=context) → MLP.""" def __init__(self, d: int, heads: int, mlp_ratio: float = 4.0, drop: float = 0.0): super().__init__() self.n1 = nn.LayerNorm(d) self.self_attn = MHA(d, heads, attn_drop=drop, proj_drop=drop) self.n2q = nn.LayerNorm(d) self.n2k = nn.LayerNorm(d) self.h = heads self.dh = d // heads self.q = nn.Linear(d, d) self.kv = nn.Linear(d, 2 * d) self.op = nn.Linear(d, d) self.n3 = nn.LayerNorm(d) hidden = int(d * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(d, hidden), nn.GELU(), nn.Dropout(drop), nn.Linear(hidden, d), nn.Dropout(drop), ) def forward(self, x: torch.Tensor, ctx: torch.Tensor) -> torch.Tensor: x = x + self.self_attn(self.n1(x)) q = self.q(self.n2q(x)) kv = self.kv(self.n2k(ctx)) b, n, d = q.shape m = ctx.shape[1] q = q.view(b, n, self.h, self.dh).transpose(1, 2) k, v = kv.view(b, m, 2, self.h, self.dh).permute(2, 0, 3, 1, 4) o = nn.functional.scaled_dot_product_attention(q, k, v) o = o.transpose(1, 2).reshape(b, n, d) x = x + self.op(o) x = x + self.mlp(self.n3(x)) return x class CrossAttentionPredictor(nn.Module): """Query = positional tokens at target positions; KV = ECG context (+ optional Δt token).""" def __init__( self, depth: int = 4, d_model: int = 256, heads: int = 8, mlp_ratio: float = 4.0, drop: float = 0.0, ): super().__init__() self.blocks = nn.ModuleList( [CrossAttnBlock(d_model, heads, mlp_ratio, drop) for _ in range(depth)] ) self.norm = nn.LayerNorm(d_model) def forward(self, queries: torch.Tensor, ctx: torch.Tensor) -> torch.Tensor: x = queries for blk in self.blocks: x = blk(x, ctx) return self.norm(x)