ecg-lead-generator / model.py
rishsoraganvi's picture
Add model weights, config, and architecture
637ec8e verified
"""
ECG Lead Generator — Model Architecture
CLIP-Conditioned 1D U-Net: 7 known leads → 5 predicted leads (V2-V6)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class FiLM(nn.Module):
"""Feature-wise Linear Modulation for CLIP conditioning (scale + shift)."""
def __init__(self, cond_d: int, ch: int):
super().__init__()
self.scale = nn.Linear(cond_d, ch)
self.shift = nn.Linear(cond_d, ch)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
return x * (1 + self.scale(c).unsqueeze(-1)) + self.shift(c).unsqueeze(-1)
class ResBlk(nn.Module):
"""Residual conv block with GroupNorm + GELU + optional FiLM conditioning."""
def __init__(self, ci: int, co: int, cd: int = None, drop: float = 0.1):
super().__init__()
g = lambda ch: min(8, ch)
self.body = nn.Sequential(
nn.GroupNorm(g(ci), ci), nn.GELU(),
nn.Conv1d(ci, co, 3, padding=1), nn.Dropout(drop),
nn.GroupNorm(g(co), co), nn.GELU(),
nn.Conv1d(co, co, 3, padding=1),
)
self.skip = nn.Conv1d(ci, co, 1) if ci != co else nn.Identity()
self.film = FiLM(cd, co) if cd else None
def forward(self, x, c=None):
h = self.body(x)
if self.film and c is not None:
h = self.film(h, c)
return h + self.skip(x)
class Down(nn.Module):
def __init__(self, ch):
super().__init__()
self.p = nn.Conv1d(ch, ch, 4, 2, 1)
def forward(self, x):
return self.p(x)
class Up(nn.Module):
def __init__(self, ci, co):
super().__init__()
self.u = nn.ConvTranspose1d(ci, co, 4, 2, 1)
def forward(self, x, skip):
x = self.u(x)
d = skip.shape[-1] - x.shape[-1]
if d > 0:
x = F.pad(x, [0, d])
return torch.cat([x[:, :, :skip.shape[-1]], skip], dim=1)
class LeadGenerator(nn.Module):
"""
CLIP-conditioned 1D U-Net.
Input : [B, 7, L] — 7 known ECG leads (I, II, III, aVR, aVL, aVF, V1)
Cond : [B, D] — CLIP visual embedding (FiLM-injected at every scale)
Output: [B, 5, L] — predicted leads V2, V3, V4, V5, V6
"""
def __init__(self, ni=7, no=5, ch=64, cd=1024, drop=0.1):
super().__init__()
self.cproj = nn.Sequential(
nn.Linear(cd, ch * 4), nn.GELU(), nn.Linear(ch * 4, ch * 4)
)
C = ch * 4
self.e1, self.d1 = ResBlk(ni, ch, C, drop), Down(ch)
self.e2, self.d2 = ResBlk(ch, ch*2, C, drop), Down(ch*2)
self.e3, self.d3 = ResBlk(ch*2, ch*4, C, drop), Down(ch*4)
self.e4, self.d4 = ResBlk(ch*4, ch*8, C, drop), Down(ch*8)
self.m1 = ResBlk(ch*8, ch*8, C, drop)
self.m2 = ResBlk(ch*8, ch*8, C, drop)
self.u4, self.r4 = Up(ch*8, ch*8), ResBlk(ch*16, ch*8, C, drop)
self.u3, self.r3 = Up(ch*8, ch*4), ResBlk(ch*8, ch*4, C, drop)
self.u2, self.r2 = Up(ch*4, ch*2), ResBlk(ch*4, ch*2, C, drop)
self.u1, self.r1 = Up(ch*2, ch), ResBlk(ch*2, ch, C, drop)
self.out = nn.Sequential(
nn.GroupNorm(min(8, ch), ch), nn.GELU(), nn.Conv1d(ch, no, 1)
)
def forward(self, x, clip_emb):
c = self.cproj(clip_emb)
s1 = self.e1(x, c); x = self.d1(s1)
s2 = self.e2(x, c); x = self.d2(s2)
s3 = self.e3(x, c); x = self.d3(s3)
s4 = self.e4(x, c); x = self.d4(s4)
x = self.m2(self.m1(x, c), c)
x = self.r4(self.u4(x, s4), c)
x = self.r3(self.u3(x, s3), c)
x = self.r2(self.u2(x, s2), c)
x = self.r1(self.u1(x, s1), c)
return self.out(x)
def load_from_hub(repo_id: str = "your-username/ecg-lead-generator") -> LeadGenerator:
"""Load LeadGenerator weights from Hugging Face Hub."""
from huggingface_hub import hf_hub_download
import json
config_path = hf_hub_download(repo_id, "config.json")
with open(config_path) as f:
cfg = json.load(f)
model = LeadGenerator(
ni=cfg["n_in"],
no=cfg["n_out"],
ch=cfg["base_ch"],
cd=cfg["clip_dim"],
)
try:
from safetensors.torch import load_file
w_path = hf_hub_download(repo_id, "model.safetensors")
state = load_file(w_path)
except Exception:
w_path = hf_hub_download(repo_id, "lead_generator_weights.pt")
ckpt = torch.load(w_path, map_location="cpu")
state = ckpt["model_state"]
model.load_state_dict(state)
model.eval()
return model