File size: 4,450 Bytes
31e2456 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | """Minimal 1D ViT used for both ECG and PPG encoders.
Shapes
------
forward(x_tokens): [B, N, d] -> [B, N, d]
Patch tokenisation is handled separately (see ecg_encoder.py / ppg_encoder.py)
so this module is purely the transformer trunk.
"""
from __future__ import annotations
import torch
from torch import nn
class MHA(nn.Module):
def __init__(self, d: int, heads: int, attn_drop: float = 0.0, proj_drop: float = 0.0):
super().__init__()
assert d % heads == 0
self.h = heads
self.dh = d // heads
self.qkv = nn.Linear(d, 3 * d, bias=True)
self.proj = nn.Linear(d, d, bias=True)
self.ad = nn.Dropout(attn_drop)
self.pd = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, n, d = x.shape
qkv = self.qkv(x).view(b, n, 3, self.h, self.dh).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # [b, h, n, dh]
out = nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=self.ad.p if self.training else 0.0
)
out = out.transpose(1, 2).reshape(b, n, d)
return self.pd(self.proj(out))
class Block(nn.Module):
def __init__(self, d: int, heads: int, mlp_ratio: float = 4.0, drop: float = 0.0):
super().__init__()
self.n1 = nn.LayerNorm(d)
self.attn = MHA(d, heads, attn_drop=drop, proj_drop=drop)
self.n2 = nn.LayerNorm(d)
hidden = int(d * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(d, hidden), nn.GELU(), nn.Dropout(drop),
nn.Linear(hidden, d), nn.Dropout(drop),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.n1(x))
x = x + self.mlp(self.n2(x))
return x
class ViT1D(nn.Module):
"""Token-in, token-out transformer trunk with final LayerNorm."""
def __init__(
self,
depth: int = 12,
d_model: int = 256,
heads: int = 8,
mlp_ratio: float = 4.0,
drop: float = 0.0,
):
super().__init__()
self.blocks = nn.ModuleList(
[Block(d_model, heads, mlp_ratio, drop) for _ in range(depth)]
)
self.norm = nn.LayerNorm(d_model)
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
x = tokens
for blk in self.blocks:
x = blk(x)
return self.norm(x)
class CrossAttnBlock(nn.Module):
"""Self-attention → cross-attention(kv=context) → MLP."""
def __init__(self, d: int, heads: int, mlp_ratio: float = 4.0, drop: float = 0.0):
super().__init__()
self.n1 = nn.LayerNorm(d)
self.self_attn = MHA(d, heads, attn_drop=drop, proj_drop=drop)
self.n2q = nn.LayerNorm(d)
self.n2k = nn.LayerNorm(d)
self.h = heads
self.dh = d // heads
self.q = nn.Linear(d, d)
self.kv = nn.Linear(d, 2 * d)
self.op = nn.Linear(d, d)
self.n3 = nn.LayerNorm(d)
hidden = int(d * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(d, hidden), nn.GELU(), nn.Dropout(drop),
nn.Linear(hidden, d), nn.Dropout(drop),
)
def forward(self, x: torch.Tensor, ctx: torch.Tensor) -> torch.Tensor:
x = x + self.self_attn(self.n1(x))
q = self.q(self.n2q(x))
kv = self.kv(self.n2k(ctx))
b, n, d = q.shape
m = ctx.shape[1]
q = q.view(b, n, self.h, self.dh).transpose(1, 2)
k, v = kv.view(b, m, 2, self.h, self.dh).permute(2, 0, 3, 1, 4)
o = nn.functional.scaled_dot_product_attention(q, k, v)
o = o.transpose(1, 2).reshape(b, n, d)
x = x + self.op(o)
x = x + self.mlp(self.n3(x))
return x
class CrossAttentionPredictor(nn.Module):
"""Query = positional tokens at target positions; KV = ECG context (+ optional Δt token)."""
def __init__(
self,
depth: int = 4,
d_model: int = 256,
heads: int = 8,
mlp_ratio: float = 4.0,
drop: float = 0.0,
):
super().__init__()
self.blocks = nn.ModuleList(
[CrossAttnBlock(d_model, heads, mlp_ratio, drop) for _ in range(depth)]
)
self.norm = nn.LayerNorm(d_model)
def forward(self, queries: torch.Tensor, ctx: torch.Tensor) -> torch.Tensor:
x = queries
for blk in self.blocks:
x = blk(x, ctx)
return self.norm(x)
|