""" model_loader.py =============== Authoritative architecture definitions for all three models, matched exactly to the trained checkpoint shapes. splice → MutationPredictorCNN_v2 (input: 1106-dim flat vector) v4 → MutationPredictorCNN_v4 (input: seq/mut/region/splice tensors) classic → MutationPredictorClassic (input: 1103-dim flat vector, from classic repo) """ from __future__ import annotations import logging import os from typing import Optional import numpy as np import torch import torch.nn as nn from huggingface_hub import hf_hub_download logger = logging.getLogger("mutation_xai.loader") # ═══════════════════════════════════════════════════════════════════════════════ # Shared constants # ═══════════════════════════════════════════════════════════════════════════════ NUCL = {"A": 0, "T": 1, "G": 2, "C": 3, "N": 4} MUT_TYPES = { ("A","T"):0, ("A","C"):1, ("A","G"):2, ("T","A"):3, ("T","C"):4, ("T","G"):5, ("C","A"):6, ("C","T"):7, ("C","G"):8, ("G","A"):9, ("G","T"):10,("G","C"):11, } ALL_BASES = ["A", "T", "C", "G"] # ═══════════════════════════════════════════════════════════════════════════════ # ① SPLICE MODEL — MutationPredictorCNN_v2 # ═══════════════════════════════════════════════════════════════════════════════ def _get_mutation_position_from_input(x_flat: torch.Tensor) -> torch.Tensor: """Infer mutation position from input tensor (sequence difference mask).""" return x_flat[:, 990:1089].argmax(dim=1) class MutationPredictorCNN_v2(nn.Module): """Splice-aware CNN — exact architecture from mutation-predictor-splice.""" def __init__(self, fc_region_out: int = 8, splice_fc_out: int = 16): super().__init__() fc1_in = 256 + 32 + fc_region_out + splice_fc_out self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3) self.bn1 = nn.BatchNorm1d(64) self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2) self.bn2 = nn.BatchNorm1d(128) self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm1d(256) self.global_pool = nn.AdaptiveAvgPool1d(1) self.mut_fc = nn.Linear(12, 32) self.importance_head = nn.Linear(256, 1) self.region_importance_head = nn.Linear(256, 2) self.fc_region = nn.Linear(2, fc_region_out) self.splice_fc = nn.Linear(3, splice_fc_out) self.splice_importance_head = nn.Linear(256, 3) self.fc1 = nn.Linear(fc1_in, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 1) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.4) def forward(self, x: torch.Tensor, mutation_positions: Optional[torch.Tensor] = None): bs = x.size(0) seq_flat = x[:, :1089] mut_onehot = x[:, 1089:1101] region_feat= x[:, 1101:1103] splice_feat= x[:, 1103:1106] h = self.relu(self.bn1(self.conv1(seq_flat.view(bs, 11, 99)))) h = self.relu(self.bn2(self.conv2(h))) conv_out = self.relu(self.bn3(self.conv3(h))) if mutation_positions is None: mutation_positions = _get_mutation_position_from_input(x) pos_idx = mutation_positions.clamp(0, 98).long() pe = pos_idx.view(bs, 1, 1).expand(bs, 256, 1) mut_feat = conv_out.gather(2, pe).squeeze(2) imp_score = torch.sigmoid(self.importance_head(mut_feat)) pooled = self.global_pool(conv_out).squeeze(-1) r_imp = torch.sigmoid(self.region_importance_head(pooled)) s_imp = torch.sigmoid(self.splice_importance_head(pooled)) m = self.relu(self.mut_fc(mut_onehot)) r = self.relu(self.fc_region(region_feat)) s = self.relu(self.splice_fc(splice_feat)) fused = torch.cat([pooled, m, r, s], dim=1) out = self.dropout(self.relu(self.fc1(fused))) out = self.dropout(self.relu(self.fc2(out))) logit = self.fc3(out) return logit, imp_score, r_imp, s_imp # ═══════════════════════════════════════════════════════════════════════════════ # ② V4 MODEL — MutationPredictorCNN_v4 # ═══════════════════════════════════════════════════════════════════════════════ class MutationPredictorCNN_v4(nn.Module): """V4 model — takes separate (seq, mut, region, splice) tensor inputs.""" def __init__(self): super().__init__() self.conv1 = nn.Conv1d(11, 64, 7, padding=3) self.conv2 = nn.Conv1d(64, 128, 5, padding=2) self.conv3 = nn.Conv1d(128, 256, 3, padding=1) self.pool = nn.AdaptiveAvgPool1d(1) self.mut_fc = nn.Linear(12, 32) self.region_fc= nn.Linear(2, 8) self.splice_fc= nn.Linear(3, 16) self.fc1 = nn.Linear(312, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 1) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.3) def forward(self, seq: torch.Tensor, mut: torch.Tensor, region: torch.Tensor, splice: torch.Tensor): x = self.relu(self.conv1(seq)) x = self.relu(self.conv2(x)) x = self.relu(self.conv3(x)) x = self.pool(x).squeeze(-1) m = self.relu(self.mut_fc(mut)) r = self.relu(self.region_fc(region)) s = self.relu(self.splice_fc(splice)) x = torch.cat([x, m, r, s], dim=1) x = self.dropout(self.relu(self.fc1(x))) x = self.relu(self.fc2(x)) return self.fc3(x) # ═══════════════════════════════════════════════════════════════════════════════ # ③ CLASSIC MODEL — MutationPredictorClassic # Mirrors the architecture in the explainable-space repo's model.py # Input: 1103-dim flat vector (99 ref enc + 99 mut enc + 99 diff + 12 mut_oh + 2 region + 3 splice = 1103) # Outputs: logit, importance_head_output (per-position), region_imp (2,) # ═══════════════════════════════════════════════════════════════════════════════ class MutationPredictorClassic(nn.Module): """Classic explainable model from mutation-pathogenicity-predictor.""" def __init__(self, input_dim: int = 1103): super().__init__() # Sequence portion: 99 × 11 channels self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3) self.bn1 = nn.BatchNorm1d(64) self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2) self.bn2 = nn.BatchNorm1d(128) self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm1d(256) self.pool = nn.AdaptiveAvgPool1d(1) # Importance head — from Linear(256,1) in explainable repo self.importance_head = nn.Linear(256, 1) self.region_importance_head = nn.Linear(256, 2) self.mut_fc = nn.Linear(12, 32) self.region_fc = nn.Linear(2, 8) self.splice_fc = nn.Linear(3, 16) # 256 + 32 + 8 + 16 = 312 self.fc1 = nn.Linear(312, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 1) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.3) def forward(self, x: torch.Tensor): """ x: (batch, 1103) [0:1089] = ref(99×5) + mut(99×5) + diff(99×1) flattened into 99×11 → 1089 [1089:1101] = mutation onehot (12) [1101:1103] = region flags (2) [1103:1106] = splice flags (3) — may be absent in 1103-dim variant """ bs = x.size(0) seq_flat = x[:, :1089] # If input is 1103, splice indices are [1100:1103]; handle both 1103 and 1106 shapes if x.size(1) >= 1106: mut_onehot = x[:, 1089:1101] region_feat = x[:, 1101:1103] splice_feat = x[:, 1103:1106] else: mut_onehot = x[:, 1089:1101] region_feat = x[:, 1101:1103] splice_feat = torch.zeros(bs, 3, device=x.device) h = self.relu(self.bn1(self.conv1(seq_flat.view(bs, 11, 99)))) h = self.relu(self.bn2(self.conv2(h))) conv_out = self.relu(self.bn3(self.conv3(h))) pooled = self.pool(conv_out).squeeze(-1) imp = torch.sigmoid(self.importance_head(pooled)) r_imp = torch.sigmoid(self.region_importance_head(pooled)) m = self.relu(self.mut_fc(mut_onehot)) r = self.relu(self.region_fc(region_feat)) s = self.relu(self.splice_fc(splice_feat)) fused = torch.cat([pooled, m, r, s], dim=1) out = self.dropout(self.relu(self.fc1(fused))) out = self.relu(self.fc2(out)) logit = self.fc3(out) return logit, imp, r_imp # ═══════════════════════════════════════════════════════════════════════════════ # ENCODERS # ═══════════════════════════════════════════════════════════════════════════════ def _encode_seq_11ch(seq: str, n: int = 99) -> torch.Tensor: """Encode sequence as (n, 5) one-hot. Channel layout: A/T/G/C/N.""" seq = (seq.upper() + "N" * n)[:n] enc = torch.zeros(n, 5) for i, c in enumerate(seq): enc[i, NUCL.get(c, 4)] = 1.0 return enc def encode_for_v2(ref_seq: str, mut_seq: str, exon_flag: int = 0, intron_flag: int = 0, donor_flag: int = 0, acceptor_flag: int = 0, region_flag: int = 0) -> torch.Tensor: """ Build the 1106-dim input vector used by both splice and classic models. Layout: [0:495] ref (99 × 5) [495:990] mut (99 × 5) [990:1089] diff (99 × 1) [1089:1101] mutation onehot (12) [1101:1103] region flags [exon, intron] [1103:1106] splice flags [donor, acceptor, region] """ n = 99 re = _encode_seq_11ch(ref_seq, n) # (99,5) me = _encode_seq_11ch(mut_seq, n) # (99,5) dm = torch.zeros(n, 1) rb = mb = None for i in range(min(len(ref_seq), len(mut_seq), n)): if ref_seq[i].upper() != mut_seq[i].upper(): dm[i, 0] = 1.0 if rb is None: rb = ref_seq[i].upper() mb = mut_seq[i].upper() moh = torch.zeros(12) if rb and mb: idx = MUT_TYPES.get((rb, mb)) if idx is not None: moh[idx] = 1.0 sf = torch.cat([re, me, dm], dim=1).flatten() # 99 × 11 = 1089 rt = torch.tensor([float(exon_flag), float(intron_flag)]) st = torch.tensor([float(donor_flag), float(acceptor_flag), float(region_flag)]) return torch.cat([sf, moh, rt, st]) def encode_for_v4(ref_seq: str, mut_seq: str, exon_flag: int = 0, intron_flag: int = 0, donor_flag: int = 0, acceptor_flag: int = 0, region_flag: int = 0): """ Returns separate tensors (seq, mut_oh, region, splice) for MutationPredictorCNN_v4. seq: (1, 11, 99) — stacked ref/mut/diff channels """ flat = encode_for_v2(ref_seq, mut_seq, exon_flag, intron_flag, donor_flag, acceptor_flag, region_flag) seq_flat = flat[:1089].view(11, 99).unsqueeze(0) # (1,11,99) mut_oh = flat[1089:1101].unsqueeze(0) # (1,12) region = flat[1101:1103].unsqueeze(0) # (1,2) splice = flat[1103:1106].unsqueeze(0) # (1,3) return seq_flat, mut_oh, region, splice def find_mutation_pos(ref_seq: str, mut_seq: str) -> int: """Return 0-indexed position of first differing character, or -1.""" for i in range(min(len(ref_seq), len(mut_seq), 99)): if ref_seq[i].upper() != mut_seq[i].upper(): return i return -1 # ═══════════════════════════════════════════════════════════════════════════════ # MODEL REGISTRY — loads all three models once at startup # ═══════════════════════════════════════════════════════════════════════════════ SPLICE_REPO = "nileshhanotia/mutation-predictor-splice" V4_REPO = "nileshhanotia/mutation-predictor-v4" CLASSIC_REPO = "nileshhanotia/mutation-pathogenicity-predictor" SPLICE_FILENAME = "mutation_predictor_splice.pt" V4_FILENAME = "mutation_predictor_splice_v4.pt" CLASSIC_FILENAME = "mutation_predictor.pt" # common name; fallback tried def _load_ckpt(repo: str, filename: str, token: Optional[str] = None) -> dict: """Download checkpoint, return state dict or full ckpt dict.""" path = hf_hub_download(repo_id=repo, filename=filename, token=token) ckpt = torch.load(path, map_location="cpu", weights_only=False) return ckpt def _try_filenames(repo: str, candidates: list[str], token: Optional[str] = None) -> dict: for fn in candidates: try: return _load_ckpt(repo, fn, token) except Exception: continue raise FileNotFoundError( f"None of {candidates} found in repo {repo}") class ModelRegistry: """Lazy singleton that loads each model exactly once.""" def __init__(self, hf_token: Optional[str] = None): self._token = hf_token self._splice = None self._v4 = None self._classic = None self._splice_val_acc = 0.0 self._v4_val_acc = 0.0 self._classic_val_acc = 0.0 # ── individual loaders ──────────────────────────────────────────────────── def _load_splice(self): logger.info("Loading splice model …") ckpt = _try_filenames(SPLICE_REPO, [SPLICE_FILENAME, "model.pt", "pytorch_model.pt"], self._token) sd = ckpt.get("model_state_dict", ckpt) fc_region_out = sd["fc_region.weight"].shape[0] splice_fc_out = sd["splice_fc.weight"].shape[0] m = MutationPredictorCNN_v2(fc_region_out=fc_region_out, splice_fc_out=splice_fc_out) m.load_state_dict(sd) m.eval() self._splice_val_acc = float(ckpt.get("val_accuracy", 0)) logger.info(f"Splice model ready (val_acc={self._splice_val_acc:.4f})") return m def _load_v4(self): logger.info("Loading v4 model …") ckpt = _try_filenames(V4_REPO, [V4_FILENAME, "mutation_predictor_splice_v4.pt", "model.pt", "pytorch_model.pt"], self._token) sd = ckpt.get("model_state_dict", ckpt) m = MutationPredictorCNN_v4() # Strict=False so we survive minor shape drift between checkpoints missing, unexpected = m.load_state_dict(sd, strict=False) if missing: logger.warning(f"V4 missing keys: {missing[:6]}") m.eval() self._v4_val_acc = float(ckpt.get("val_accuracy", 0)) logger.info(f"V4 model ready (val_acc={self._v4_val_acc:.4f})") return m def _load_classic(self): logger.info("Loading classic model …") ckpt = _try_filenames(CLASSIC_REPO, [CLASSIC_FILENAME, "model.pt", "mutation_predictor_classic.pt", "pytorch_model.pt"], self._token) sd = ckpt.get("model_state_dict", ckpt) # Detect input_dim from first conv weight (channels × kernel = 11 × kernel) m = MutationPredictorClassic() missing, unexpected = m.load_state_dict(sd, strict=False) if missing: logger.warning(f"Classic missing keys: {missing[:6]}") m.eval() self._classic_val_acc = float(ckpt.get("val_accuracy", 0)) logger.info(f"Classic model ready (val_acc={self._classic_val_acc:.4f})") return m # ── properties ──────────────────────────────────────────────────────────── @property def splice(self) -> MutationPredictorCNN_v2: if self._splice is None: self._splice = self._load_splice() return self._splice @property def v4(self) -> MutationPredictorCNN_v4: if self._v4 is None: self._v4 = self._load_v4() return self._v4 @property def classic(self) -> MutationPredictorClassic: if self._classic is None: self._classic = self._load_classic() return self._classic @property def val_accs(self) -> dict: return { "splice": self._splice_val_acc, "v4": self._v4_val_acc, "classic": self._classic_val_acc, }