guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""Minimal 1D ViT used for both ECG and PPG encoders.
Shapes
------
forward(x_tokens): [B, N, d] -> [B, N, d]
Patch tokenisation is handled separately (see ecg_encoder.py / ppg_encoder.py)
so this module is purely the transformer trunk.
"""
from __future__ import annotations
import torch
from torch import nn
class MHA(nn.Module):
def __init__(self, d: int, heads: int, attn_drop: float = 0.0, proj_drop: float = 0.0):
super().__init__()
assert d % heads == 0
self.h = heads
self.dh = d // heads
self.qkv = nn.Linear(d, 3 * d, bias=True)
self.proj = nn.Linear(d, d, bias=True)
self.ad = nn.Dropout(attn_drop)
self.pd = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, n, d = x.shape
qkv = self.qkv(x).view(b, n, 3, self.h, self.dh).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # [b, h, n, dh]
out = nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=self.ad.p if self.training else 0.0
)
out = out.transpose(1, 2).reshape(b, n, d)
return self.pd(self.proj(out))
class Block(nn.Module):
def __init__(self, d: int, heads: int, mlp_ratio: float = 4.0, drop: float = 0.0):
super().__init__()
self.n1 = nn.LayerNorm(d)
self.attn = MHA(d, heads, attn_drop=drop, proj_drop=drop)
self.n2 = nn.LayerNorm(d)
hidden = int(d * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(d, hidden), nn.GELU(), nn.Dropout(drop),
nn.Linear(hidden, d), nn.Dropout(drop),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.n1(x))
x = x + self.mlp(self.n2(x))
return x
class ViT1D(nn.Module):
"""Token-in, token-out transformer trunk with final LayerNorm."""
def __init__(
self,
depth: int = 12,
d_model: int = 256,
heads: int = 8,
mlp_ratio: float = 4.0,
drop: float = 0.0,
):
super().__init__()
self.blocks = nn.ModuleList(
[Block(d_model, heads, mlp_ratio, drop) for _ in range(depth)]
)
self.norm = nn.LayerNorm(d_model)
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
x = tokens
for blk in self.blocks:
x = blk(x)
return self.norm(x)
class CrossAttnBlock(nn.Module):
"""Self-attention → cross-attention(kv=context) → MLP."""
def __init__(self, d: int, heads: int, mlp_ratio: float = 4.0, drop: float = 0.0):
super().__init__()
self.n1 = nn.LayerNorm(d)
self.self_attn = MHA(d, heads, attn_drop=drop, proj_drop=drop)
self.n2q = nn.LayerNorm(d)
self.n2k = nn.LayerNorm(d)
self.h = heads
self.dh = d // heads
self.q = nn.Linear(d, d)
self.kv = nn.Linear(d, 2 * d)
self.op = nn.Linear(d, d)
self.n3 = nn.LayerNorm(d)
hidden = int(d * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(d, hidden), nn.GELU(), nn.Dropout(drop),
nn.Linear(hidden, d), nn.Dropout(drop),
)
def forward(self, x: torch.Tensor, ctx: torch.Tensor) -> torch.Tensor:
x = x + self.self_attn(self.n1(x))
q = self.q(self.n2q(x))
kv = self.kv(self.n2k(ctx))
b, n, d = q.shape
m = ctx.shape[1]
q = q.view(b, n, self.h, self.dh).transpose(1, 2)
k, v = kv.view(b, m, 2, self.h, self.dh).permute(2, 0, 3, 1, 4)
o = nn.functional.scaled_dot_product_attention(q, k, v)
o = o.transpose(1, 2).reshape(b, n, d)
x = x + self.op(o)
x = x + self.mlp(self.n3(x))
return x
class CrossAttentionPredictor(nn.Module):
"""Query = positional tokens at target positions; KV = ECG context (+ optional Δt token)."""
def __init__(
self,
depth: int = 4,
d_model: int = 256,
heads: int = 8,
mlp_ratio: float = 4.0,
drop: float = 0.0,
):
super().__init__()
self.blocks = nn.ModuleList(
[CrossAttnBlock(d_model, heads, mlp_ratio, drop) for _ in range(depth)]
)
self.norm = nn.LayerNorm(d_model)
def forward(self, queries: torch.Tensor, ctx: torch.Tensor) -> torch.Tensor:
x = queries
for blk in self.blocks:
x = blk(x, ctx)
return self.norm(x)