| """ | |
| ECG Lead Generator — Model Architecture | |
| CLIP-Conditioned 1D U-Net: 7 known leads → 5 predicted leads (V2-V6) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class FiLM(nn.Module): | |
| """Feature-wise Linear Modulation for CLIP conditioning (scale + shift).""" | |
| def __init__(self, cond_d: int, ch: int): | |
| super().__init__() | |
| self.scale = nn.Linear(cond_d, ch) | |
| self.shift = nn.Linear(cond_d, ch) | |
| def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: | |
| return x * (1 + self.scale(c).unsqueeze(-1)) + self.shift(c).unsqueeze(-1) | |
| class ResBlk(nn.Module): | |
| """Residual conv block with GroupNorm + GELU + optional FiLM conditioning.""" | |
| def __init__(self, ci: int, co: int, cd: int = None, drop: float = 0.1): | |
| super().__init__() | |
| g = lambda ch: min(8, ch) | |
| self.body = nn.Sequential( | |
| nn.GroupNorm(g(ci), ci), nn.GELU(), | |
| nn.Conv1d(ci, co, 3, padding=1), nn.Dropout(drop), | |
| nn.GroupNorm(g(co), co), nn.GELU(), | |
| nn.Conv1d(co, co, 3, padding=1), | |
| ) | |
| self.skip = nn.Conv1d(ci, co, 1) if ci != co else nn.Identity() | |
| self.film = FiLM(cd, co) if cd else None | |
| def forward(self, x, c=None): | |
| h = self.body(x) | |
| if self.film and c is not None: | |
| h = self.film(h, c) | |
| return h + self.skip(x) | |
| class Down(nn.Module): | |
| def __init__(self, ch): | |
| super().__init__() | |
| self.p = nn.Conv1d(ch, ch, 4, 2, 1) | |
| def forward(self, x): | |
| return self.p(x) | |
| class Up(nn.Module): | |
| def __init__(self, ci, co): | |
| super().__init__() | |
| self.u = nn.ConvTranspose1d(ci, co, 4, 2, 1) | |
| def forward(self, x, skip): | |
| x = self.u(x) | |
| d = skip.shape[-1] - x.shape[-1] | |
| if d > 0: | |
| x = F.pad(x, [0, d]) | |
| return torch.cat([x[:, :, :skip.shape[-1]], skip], dim=1) | |
| class LeadGenerator(nn.Module): | |
| """ | |
| CLIP-conditioned 1D U-Net. | |
| Input : [B, 7, L] — 7 known ECG leads (I, II, III, aVR, aVL, aVF, V1) | |
| Cond : [B, D] — CLIP visual embedding (FiLM-injected at every scale) | |
| Output: [B, 5, L] — predicted leads V2, V3, V4, V5, V6 | |
| """ | |
| def __init__(self, ni=7, no=5, ch=64, cd=1024, drop=0.1): | |
| super().__init__() | |
| self.cproj = nn.Sequential( | |
| nn.Linear(cd, ch * 4), nn.GELU(), nn.Linear(ch * 4, ch * 4) | |
| ) | |
| C = ch * 4 | |
| self.e1, self.d1 = ResBlk(ni, ch, C, drop), Down(ch) | |
| self.e2, self.d2 = ResBlk(ch, ch*2, C, drop), Down(ch*2) | |
| self.e3, self.d3 = ResBlk(ch*2, ch*4, C, drop), Down(ch*4) | |
| self.e4, self.d4 = ResBlk(ch*4, ch*8, C, drop), Down(ch*8) | |
| self.m1 = ResBlk(ch*8, ch*8, C, drop) | |
| self.m2 = ResBlk(ch*8, ch*8, C, drop) | |
| self.u4, self.r4 = Up(ch*8, ch*8), ResBlk(ch*16, ch*8, C, drop) | |
| self.u3, self.r3 = Up(ch*8, ch*4), ResBlk(ch*8, ch*4, C, drop) | |
| self.u2, self.r2 = Up(ch*4, ch*2), ResBlk(ch*4, ch*2, C, drop) | |
| self.u1, self.r1 = Up(ch*2, ch), ResBlk(ch*2, ch, C, drop) | |
| self.out = nn.Sequential( | |
| nn.GroupNorm(min(8, ch), ch), nn.GELU(), nn.Conv1d(ch, no, 1) | |
| ) | |
| def forward(self, x, clip_emb): | |
| c = self.cproj(clip_emb) | |
| s1 = self.e1(x, c); x = self.d1(s1) | |
| s2 = self.e2(x, c); x = self.d2(s2) | |
| s3 = self.e3(x, c); x = self.d3(s3) | |
| s4 = self.e4(x, c); x = self.d4(s4) | |
| x = self.m2(self.m1(x, c), c) | |
| x = self.r4(self.u4(x, s4), c) | |
| x = self.r3(self.u3(x, s3), c) | |
| x = self.r2(self.u2(x, s2), c) | |
| x = self.r1(self.u1(x, s1), c) | |
| return self.out(x) | |
| def load_from_hub(repo_id: str = "your-username/ecg-lead-generator") -> LeadGenerator: | |
| """Load LeadGenerator weights from Hugging Face Hub.""" | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| config_path = hf_hub_download(repo_id, "config.json") | |
| with open(config_path) as f: | |
| cfg = json.load(f) | |
| model = LeadGenerator( | |
| ni=cfg["n_in"], | |
| no=cfg["n_out"], | |
| ch=cfg["base_ch"], | |
| cd=cfg["clip_dim"], | |
| ) | |
| try: | |
| from safetensors.torch import load_file | |
| w_path = hf_hub_download(repo_id, "model.safetensors") | |
| state = load_file(w_path) | |
| except Exception: | |
| w_path = hf_hub_download(repo_id, "lead_generator_weights.pt") | |
| ckpt = torch.load(w_path, map_location="cpu") | |
| state = ckpt["model_state"] | |
| model.load_state_dict(state) | |
| model.eval() | |
| return model | |