|
|
""" |
|
|
ECG-Transformer: Foundation Model for 12-Lead ECG Analysis |
|
|
QueenBee Medical AI Stack |
|
|
|
|
|
Architecture: |
|
|
- 1D Patch Embedding (100ms patches at 500Hz = 50 samples) |
|
|
- 12-lead processing with lead embeddings |
|
|
- Transformer encoder with rotary position embeddings |
|
|
- Multi-label classification heads (5 superclasses + 71 SCP codes) |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
|
|
|
class RotaryPositionEmbedding(nn.Module): |
|
|
"""Rotary Position Embedding for better position awareness""" |
|
|
|
|
|
def __init__(self, dim: int, max_seq_len: int = 2048): |
|
|
super().__init__() |
|
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
self.register_buffer('inv_freq', inv_freq) |
|
|
self.max_seq_len = max_seq_len |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
seq_len = x.shape[1] |
|
|
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) |
|
|
freqs = torch.einsum('i,j->ij', t, self.inv_freq) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
|
|
|
cos = emb.cos()[None, None, :, :] |
|
|
sin = emb.sin()[None, None, :, :] |
|
|
return cos, sin |
|
|
|
|
|
|
|
|
def rotate_half(x: torch.Tensor) -> torch.Tensor: |
|
|
x1, x2 = x.chunk(2, dim=-1) |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, |
|
|
cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
return q_embed, k_embed |
|
|
|
|
|
|
|
|
class ECGPatchEmbedding(nn.Module): |
|
|
""" |
|
|
Convert 12-lead ECG into patches. |
|
|
Each patch = 50 samples (100ms at 500Hz) - clinically meaningful window |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_leads: int = 12, |
|
|
patch_size: int = 50, |
|
|
embed_dim: int = 256, |
|
|
signal_length: int = 5000, |
|
|
): |
|
|
super().__init__() |
|
|
self.num_leads = num_leads |
|
|
self.patch_size = patch_size |
|
|
self.embed_dim = embed_dim |
|
|
self.num_patches = signal_length // patch_size |
|
|
|
|
|
|
|
|
self.proj = nn.Conv1d( |
|
|
num_leads, |
|
|
embed_dim, |
|
|
kernel_size=patch_size, |
|
|
stride=patch_size |
|
|
) |
|
|
|
|
|
|
|
|
self.lead_embedding = nn.Embedding(num_leads, embed_dim) |
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02) |
|
|
|
|
|
|
|
|
self.norm = nn.LayerNorm(embed_dim) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
x: (batch, num_leads, signal_length) - 12-lead ECG |
|
|
returns: (batch, num_patches + 1, embed_dim) |
|
|
""" |
|
|
B = x.shape[0] |
|
|
|
|
|
|
|
|
x = self.proj(x) |
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
|
|
|
lead_ids = torch.arange(self.num_leads, device=x.device) |
|
|
lead_emb = self.lead_embedding(lead_ids).mean(dim=0, keepdim=True) |
|
|
x = x + lead_emb.unsqueeze(0) |
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(B, -1, -1) |
|
|
x = torch.cat([cls_tokens, x], dim=1) |
|
|
|
|
|
return self.norm(x) |
|
|
|
|
|
|
|
|
class MultiHeadSelfAttention(nn.Module): |
|
|
"""Multi-head self-attention with rotary embeddings""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embed_dim: int = 256, |
|
|
num_heads: int = 8, |
|
|
dropout: float = 0.1, |
|
|
): |
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = embed_dim // num_heads |
|
|
assert self.head_dim * num_heads == embed_dim |
|
|
|
|
|
self.qkv = nn.Linear(embed_dim, 3 * embed_dim) |
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.rotary = RotaryPositionEmbedding(self.head_dim) |
|
|
|
|
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
B, N, C = x.shape |
|
|
|
|
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) |
|
|
qkv = qkv.permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
|
|
|
cos, sin = self.rotary(x) |
|
|
q, k = apply_rotary_pos_emb(q, k, cos, sin) |
|
|
|
|
|
|
|
|
scale = self.head_dim ** -0.5 |
|
|
attn = (q @ k.transpose(-2, -1)) * scale |
|
|
|
|
|
if mask is not None: |
|
|
attn = attn.masked_fill(mask == 0, float('-inf')) |
|
|
|
|
|
attn = F.softmax(attn, dim=-1) |
|
|
attn = self.dropout(attn) |
|
|
|
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
|
|
x = self.out_proj(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
"""Transformer block with pre-norm architecture""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embed_dim: int = 256, |
|
|
num_heads: int = 8, |
|
|
mlp_ratio: float = 4.0, |
|
|
dropout: float = 0.1, |
|
|
): |
|
|
super().__init__() |
|
|
self.norm1 = nn.LayerNorm(embed_dim) |
|
|
self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout) |
|
|
self.norm2 = nn.LayerNorm(embed_dim) |
|
|
|
|
|
mlp_dim = int(embed_dim * mlp_ratio) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(embed_dim, mlp_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(mlp_dim, embed_dim), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = x + self.attn(self.norm1(x)) |
|
|
x = x + self.mlp(self.norm2(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class ECGTransformer(nn.Module): |
|
|
""" |
|
|
ECG-Transformer Foundation Model |
|
|
|
|
|
Trained on PTB-XL dataset (21,799 12-lead ECGs) |
|
|
Multi-label classification for 5 superclasses + 71 SCP diagnostic codes |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_leads: int = 12, |
|
|
signal_length: int = 5000, |
|
|
patch_size: int = 50, |
|
|
embed_dim: int = 256, |
|
|
depth: int = 8, |
|
|
num_heads: int = 8, |
|
|
mlp_ratio: float = 4.0, |
|
|
dropout: float = 0.1, |
|
|
num_superclasses: int = 5, |
|
|
num_scp_codes: int = 71, |
|
|
): |
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
|
|
|
|
|
|
self.patch_embed = ECGPatchEmbedding( |
|
|
num_leads=num_leads, |
|
|
patch_size=patch_size, |
|
|
embed_dim=embed_dim, |
|
|
signal_length=signal_length, |
|
|
) |
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout) |
|
|
for _ in range(depth) |
|
|
]) |
|
|
|
|
|
|
|
|
self.norm = nn.LayerNorm(embed_dim) |
|
|
|
|
|
|
|
|
self.superclass_head = nn.Linear(embed_dim, num_superclasses) |
|
|
self.scp_head = nn.Linear(embed_dim, num_scp_codes) |
|
|
|
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.ones_(m.weight) |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, nn.Conv1d): |
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out') |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
|
|
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Extract features from ECG signal""" |
|
|
x = self.patch_embed(x) |
|
|
|
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
|
|
|
x = self.norm(x) |
|
|
return x[:, 0] |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
x: (batch, 12, 5000) - 12-lead ECG signal |
|
|
returns: (superclass_logits, scp_logits) |
|
|
""" |
|
|
features = self.forward_features(x) |
|
|
|
|
|
superclass_logits = self.superclass_head(features) |
|
|
scp_logits = self.scp_head(features) |
|
|
|
|
|
return superclass_logits, scp_logits |
|
|
|
|
|
def get_embeddings(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Get embeddings for downstream tasks""" |
|
|
return self.forward_features(x) |
|
|
|
|
|
|
|
|
|
|
|
def ecg_transformer_tiny(**kwargs): |
|
|
"""Tiny model for fast iteration (1.2M params)""" |
|
|
return ECGTransformer( |
|
|
embed_dim=128, |
|
|
depth=4, |
|
|
num_heads=4, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
def ecg_transformer_small(**kwargs): |
|
|
"""Small model (4.7M params)""" |
|
|
return ECGTransformer( |
|
|
embed_dim=256, |
|
|
depth=6, |
|
|
num_heads=8, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
def ecg_transformer_base(**kwargs): |
|
|
"""Base model (12M params)""" |
|
|
return ECGTransformer( |
|
|
embed_dim=384, |
|
|
depth=8, |
|
|
num_heads=8, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
def ecg_transformer_large(**kwargs): |
|
|
"""Large model (45M params)""" |
|
|
return ECGTransformer( |
|
|
embed_dim=512, |
|
|
depth=12, |
|
|
num_heads=16, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
model = ecg_transformer_small() |
|
|
print(f"ECG-Transformer Small") |
|
|
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
|
|
|
|
|
|
x = torch.randn(4, 12, 5000) |
|
|
superclass, scp = model(x) |
|
|
print(f"Input shape: {x.shape}") |
|
|
print(f"Superclass output: {superclass.shape}") |
|
|
print(f"SCP codes output: {scp.shape}") |
|
|
|
|
|
|
|
|
emb = model.get_embeddings(x) |
|
|
print(f"Embedding shape: {emb.shape}") |
|
|
|