| """Minimal 1D ViT used for both ECG and PPG encoders. |
| |
| Shapes |
| ------ |
| forward(x_tokens): [B, N, d] -> [B, N, d] |
| |
| Patch tokenisation is handled separately (see ecg_encoder.py / ppg_encoder.py) |
| so this module is purely the transformer trunk. |
| """ |
| from __future__ import annotations |
|
|
| import torch |
| from torch import nn |
|
|
|
|
| class MHA(nn.Module): |
| def __init__(self, d: int, heads: int, attn_drop: float = 0.0, proj_drop: float = 0.0): |
| super().__init__() |
| assert d % heads == 0 |
| self.h = heads |
| self.dh = d // heads |
| self.qkv = nn.Linear(d, 3 * d, bias=True) |
| self.proj = nn.Linear(d, d, bias=True) |
| self.ad = nn.Dropout(attn_drop) |
| self.pd = nn.Dropout(proj_drop) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| b, n, d = x.shape |
| qkv = self.qkv(x).view(b, n, 3, self.h, self.dh).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| out = nn.functional.scaled_dot_product_attention( |
| q, k, v, dropout_p=self.ad.p if self.training else 0.0 |
| ) |
| out = out.transpose(1, 2).reshape(b, n, d) |
| return self.pd(self.proj(out)) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, d: int, heads: int, mlp_ratio: float = 4.0, drop: float = 0.0): |
| super().__init__() |
| self.n1 = nn.LayerNorm(d) |
| self.attn = MHA(d, heads, attn_drop=drop, proj_drop=drop) |
| self.n2 = nn.LayerNorm(d) |
| hidden = int(d * mlp_ratio) |
| self.mlp = nn.Sequential( |
| nn.Linear(d, hidden), nn.GELU(), nn.Dropout(drop), |
| nn.Linear(hidden, d), nn.Dropout(drop), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.attn(self.n1(x)) |
| x = x + self.mlp(self.n2(x)) |
| return x |
|
|
|
|
| class ViT1D(nn.Module): |
| """Token-in, token-out transformer trunk with final LayerNorm.""" |
|
|
| def __init__( |
| self, |
| depth: int = 12, |
| d_model: int = 256, |
| heads: int = 8, |
| mlp_ratio: float = 4.0, |
| drop: float = 0.0, |
| ): |
| super().__init__() |
| self.blocks = nn.ModuleList( |
| [Block(d_model, heads, mlp_ratio, drop) for _ in range(depth)] |
| ) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| def forward(self, tokens: torch.Tensor) -> torch.Tensor: |
| x = tokens |
| for blk in self.blocks: |
| x = blk(x) |
| return self.norm(x) |
|
|
|
|
| class CrossAttnBlock(nn.Module): |
| """Self-attention → cross-attention(kv=context) → MLP.""" |
|
|
| def __init__(self, d: int, heads: int, mlp_ratio: float = 4.0, drop: float = 0.0): |
| super().__init__() |
| self.n1 = nn.LayerNorm(d) |
| self.self_attn = MHA(d, heads, attn_drop=drop, proj_drop=drop) |
| self.n2q = nn.LayerNorm(d) |
| self.n2k = nn.LayerNorm(d) |
| self.h = heads |
| self.dh = d // heads |
| self.q = nn.Linear(d, d) |
| self.kv = nn.Linear(d, 2 * d) |
| self.op = nn.Linear(d, d) |
| self.n3 = nn.LayerNorm(d) |
| hidden = int(d * mlp_ratio) |
| self.mlp = nn.Sequential( |
| nn.Linear(d, hidden), nn.GELU(), nn.Dropout(drop), |
| nn.Linear(hidden, d), nn.Dropout(drop), |
| ) |
|
|
| def forward(self, x: torch.Tensor, ctx: torch.Tensor) -> torch.Tensor: |
| x = x + self.self_attn(self.n1(x)) |
| q = self.q(self.n2q(x)) |
| kv = self.kv(self.n2k(ctx)) |
| b, n, d = q.shape |
| m = ctx.shape[1] |
| q = q.view(b, n, self.h, self.dh).transpose(1, 2) |
| k, v = kv.view(b, m, 2, self.h, self.dh).permute(2, 0, 3, 1, 4) |
| o = nn.functional.scaled_dot_product_attention(q, k, v) |
| o = o.transpose(1, 2).reshape(b, n, d) |
| x = x + self.op(o) |
| x = x + self.mlp(self.n3(x)) |
| return x |
|
|
|
|
| class CrossAttentionPredictor(nn.Module): |
| """Query = positional tokens at target positions; KV = ECG context (+ optional Δt token).""" |
|
|
| def __init__( |
| self, |
| depth: int = 4, |
| d_model: int = 256, |
| heads: int = 8, |
| mlp_ratio: float = 4.0, |
| drop: float = 0.0, |
| ): |
| super().__init__() |
| self.blocks = nn.ModuleList( |
| [CrossAttnBlock(d_model, heads, mlp_ratio, drop) for _ in range(depth)] |
| ) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| def forward(self, queries: torch.Tensor, ctx: torch.Tensor) -> torch.Tensor: |
| x = queries |
| for blk in self.blocks: |
| x = blk(x, ctx) |
| return self.norm(x) |
|
|