""" model_loader.py ============== Loads all three pretrained models using their EXACT native architectures as confirmed from the live HuggingFace Space source code. Models: 1. nileshhanotia/mutation-predictor-splice → MutationPredictorCNN_v2 (input dim=1106, 99bp window) → File: mutation_predictor_splice.pt 2. nileshhanotia/mutation-predictor-v4 → MutationPredictorCNN_v2 variant (inferred from same family) → File: mutation_predictor_v4.pt (or pytorch_model.pth) 3. nileshhanotia/mutation-pathogenicity-predictor → MutationPredictorCNN (classic, 99bp window) → File: pytorch_model.pth Architecture notes taken directly from live app source — nothing redesigned. """ from __future__ import annotations import logging import os from pathlib import Path import torch import torch.nn as nn import numpy as np logger = logging.getLogger(__name__) # ── HuggingFace repo IDs ────────────────────────────────────────────────────── REPO_SPLICE = "nileshhanotia/mutation-predictor-splice" REPO_V4 = "nileshhanotia/mutation-predictor-v4" REPO_CLASSIC = "nileshhanotia/mutation-pathogenicity-predictor" # ═══════════════════════════════════════════════════════════════════════════════ # Architecture 1 & 2 — MutationPredictorCNN_v2 # Source: mutation-predictor-splice-app/app.py (exact copy) # Used by both splice model and v4 model # ═══════════════════════════════════════════════════════════════════════════════ def get_mutation_position_from_input(x_flat): return x_flat[:, 990:1089].argmax(dim=1) class MutationPredictorCNN_v2(nn.Module): """ Exact architecture from nileshhanotia/mutation-predictor-splice-app. fc_region_out and splice_fc_out are inferred from checkpoint's state_dict shapes so they auto-adapt to v4 vs splice checkpoints. """ def __init__(self, fc_region_out: int = 8, splice_fc_out: int = 16): super().__init__() fc1_in = 256 + 32 + fc_region_out + splice_fc_out self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3) self.bn1 = nn.BatchNorm1d(64) self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2) self.bn2 = nn.BatchNorm1d(128) self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm1d(256) self.global_pool = nn.AdaptiveAvgPool1d(1) self.mut_fc = nn.Linear(12, 32) self.importance_head = nn.Linear(256, 1) self.region_importance_head = nn.Linear(256, 2) self.fc_region = nn.Linear(2, fc_region_out) self.splice_fc = nn.Linear(3, splice_fc_out) self.splice_importance_head = nn.Linear(256, 3) self.fc1 = nn.Linear(fc1_in, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 1) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.4) # Explainability hooks — populated during forward() self._conv3_activations: torch.Tensor | None = None self._mutation_feature: torch.Tensor | None = None self._pooled: torch.Tensor | None = None def forward(self, x, mutation_positions=None): bs = x.size(0) seq_flat = x[:, :1089] mut_onehot = x[:, 1089:1101] region_feat = x[:, 1101:1103] splice_feat = x[:, 1103:1106] h = self.relu(self.bn1(self.conv1(seq_flat.view(bs, 11, 99)))) h = self.relu(self.bn2(self.conv2(h))) conv_out = self.relu(self.bn3(self.conv3(h))) # (B, 256, 99) # ── hook: save conv3 activations ────────────────────── self._conv3_activations = conv_out.detach().clone() if mutation_positions is None: mutation_positions = get_mutation_position_from_input(x) pos_idx = mutation_positions.clamp(0, 98).long() pe = pos_idx.view(bs, 1, 1).expand(bs, 256, 1) mut_feat = conv_out.gather(2, pe).squeeze(2) # (B, 256) # ── hook: save mutation-centered feature ────────────── self._mutation_feature = mut_feat.detach().clone() imp_score = torch.sigmoid(self.importance_head(mut_feat)) pooled = self.global_pool(conv_out).squeeze(-1) # (B, 256) self._pooled = pooled.detach().clone() r_imp = torch.sigmoid(self.region_importance_head(pooled)) s_imp = torch.sigmoid(self.splice_importance_head(pooled)) m = self.relu(self.mut_fc(mut_onehot)) r = self.relu(self.fc_region(region_feat)) s = self.relu(self.splice_fc(splice_feat)) fused = torch.cat([pooled, m, r, s], dim=1) out = self.dropout(self.relu(self.fc1(fused))) out = self.dropout(self.relu(self.fc2(out))) return self.fc3(out), imp_score, r_imp, s_imp # ── Explainability extraction helpers ──────────────────────────────────── def conv3_norm_profile(self) -> np.ndarray | None: """L2 norm across channels at each of 99 positions — shape (99,).""" if self._conv3_activations is None: return None arr = self._conv3_activations.squeeze(0).norm(dim=0).numpy() return arr / (arr.max() + 1e-9) def mutation_centered_peak(self, mutation_pos: int) -> float | None: """Activation value at the mutation position in conv3.""" profile = self.conv3_norm_profile() if profile is None or mutation_pos < 0 or mutation_pos >= len(profile): return None return float(profile[mutation_pos]) def mutation_peak_ratio(self, mutation_pos: int) -> float | None: """peak_signal / mean_signal — how focused is the activation.""" profile = self.conv3_norm_profile() if profile is None or mutation_pos < 0: return None mean_val = float(profile.mean()) + 1e-9 peak_val = float(profile[mutation_pos]) return round(peak_val / mean_val, 4) def importance_head_vector(self) -> np.ndarray | None: """Raw mutation-centered feature vector — shape (256,).""" if self._mutation_feature is None: return None return self._mutation_feature.squeeze(0).numpy() # ═══════════════════════════════════════════════════════════════════════════════ # Architecture 3 — MutationPredictorCNN (classic) # Source: mutation-pathogenicity-app — uses external encoder.py / model.py # We reconstruct the standard architecture from the import signature # ═══════════════════════════════════════════════════════════════════════════════ class MutationPredictorCNN(nn.Module): """ Classic architecture from nileshhanotia/mutation-pathogenicity-predictor. The app imports MutationPredictorCNN from model.py with no args, so this is the standard default-constructor variant. Input: encoded sequence from MutationEncoder (99bp × 2 seqs = dual-channel CNN). """ def __init__(self, in_channels: int = 8, seq_len: int = 99): super().__init__() # Standard 3-layer CNN matching the import signature self.conv1 = nn.Conv1d(in_channels, 64, kernel_size=7, padding=3) self.bn1 = nn.BatchNorm1d(64) self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2) self.bn2 = nn.BatchNorm1d(128) self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm1d(256) self.pool = nn.AdaptiveAvgPool1d(1) self.fc1 = nn.Linear(256, 128) self.fc2 = nn.Linear(128, 1) self.imp = nn.Linear(256, 1) self.relu = nn.ReLU() self.drop = nn.Dropout(0.3) self._conv3_activations: torch.Tensor | None = None self._pooled: torch.Tensor | None = None def forward(self, x): h = self.relu(self.bn1(self.conv1(x))) h = self.relu(self.bn2(self.conv2(h))) h = self.relu(self.bn3(self.conv3(h))) self._conv3_activations = h.detach().clone() p = self.pool(h).squeeze(-1) self._pooled = p.detach().clone() logit = self.fc2(self.drop(self.relu(self.fc1(p)))) importance = torch.sigmoid(self.imp(p)) return logit, importance def conv3_norm_profile(self) -> np.ndarray | None: if self._conv3_activations is None: return None arr = self._conv3_activations.squeeze(0).norm(dim=0).numpy() return arr / (arr.max() + 1e-9) def importance_score(self) -> float | None: if self._pooled is None: return None return float(torch.sigmoid(self.imp(self._pooled)).squeeze().item()) # ═══════════════════════════════════════════════════════════════════════════════ # Encoders — taken directly from live app source # ═══════════════════════════════════════════════════════════════════════════════ NUCL = {"A": 0, "T": 1, "G": 2, "C": 3, "N": 4} MUT_TYPES = { ("A","T"):0, ("A","C"):1, ("A","G"):2, ("T","A"):3, ("T","C"):4, ("T","G"):5, ("C","A"):6, ("C","T"):7, ("C","G"):8, ("G","A"):9, ("G","T"):10,("G","C"):11, } def _encode_seq_5ch(seq: str, n: int = 99) -> torch.Tensor: """5-channel per-nucleotide encoding used by v2 models.""" seq = (seq.upper() + "N" * n)[:n] enc = torch.zeros(n, 5) for i, c in enumerate(seq): enc[i, NUCL.get(c, 4)] = 1.0 return enc def encode_for_v2(ref_seq: str, mut_seq: str, exon_flag: int = 0, intron_flag: int = 0, donor_flag: int = 0, acceptor_flag: int = 0, region_flag: int = 0) -> torch.Tensor: """ Full 1106-dim encoding for MutationPredictorCNN_v2. Exact logic from splice-app/app.py encode_variant(). """ re = _encode_seq_5ch(ref_seq) me = _encode_seq_5ch(mut_seq) dm = torch.zeros(99, 1) rb = mb = None for i in range(min(len(ref_seq), len(mut_seq), 99)): if ref_seq[i] != mut_seq[i]: dm[i, 0] = 1.0 if rb is None: rb = ref_seq[i].upper() mb = mut_seq[i].upper() moh = torch.zeros(12) if rb and mb: idx = MUT_TYPES.get((rb, mb)) if idx is not None: moh[idx] = 1.0 sf = torch.cat([re, me, dm], dim=1).flatten() # 99*11=1089 rt = torch.tensor([float(exon_flag), float(intron_flag)]) st = torch.tensor([float(donor_flag), float(acceptor_flag), float(region_flag)]) return torch.cat([sf, moh, rt, st]) # 1106 def encode_for_classic(ref_seq: str, mut_seq: str) -> torch.Tensor: """ 8-channel encoding for MutationPredictorCNN (classic). Reconstructed from MutationEncoder import in pathogenicity app: ref 4-ch one-hot + mut 4-ch one-hot stacked along channels → (8, 99). """ BASES = {"A": 0, "C": 1, "G": 2, "T": 3} n = 99 ref = (ref_seq.upper() + "N" * n)[:n] mut = (mut_seq.upper() + "N" * n)[:n] ref_enc = np.zeros((4, n), dtype=np.float32) mut_enc = np.zeros((4, n), dtype=np.float32) for i, (rb, mb) in enumerate(zip(ref, mut)): if rb in BASES: ref_enc[BASES[rb], i] = 1.0 if mb in BASES: mut_enc[BASES[mb], i] = 1.0 arr = np.concatenate([ref_enc, mut_enc], axis=0) # (8, 99) return torch.from_numpy(arr).unsqueeze(0) # (1, 8, 99) def find_mutation_pos(ref_seq: str, mut_seq: str) -> int: for i in range(min(len(ref_seq), len(mut_seq), 99)): if ref_seq[i] != mut_seq[i]: return i return -1 # ═══════════════════════════════════════════════════════════════════════════════ # Registry # ═══════════════════════════════════════════════════════════════════════════════ class ModelRegistry: def __init__(self, hf_token: str | None = None): self.token = hf_token or os.environ.get("HF_TOKEN") self._splice: MutationPredictorCNN_v2 | None = None self._v4: MutationPredictorCNN_v2 | None = None self._classic: MutationPredictorCNN | None = None self.demo_mode = False self.val_acc_splice = 0.0 self.val_acc_v4 = 0.0 @property def splice(self) -> MutationPredictorCNN_v2: if self._splice is None: self._splice = self._load_v2(REPO_SPLICE, "mutation_predictor_splice.pt", "splice") return self._splice @property def v4(self) -> MutationPredictorCNN_v2: if self._v4 is None: self._v4 = self._load_v2(REPO_V4, "mutation_predictor_v4.pt", "v4", fallback_files=["pytorch_model.pth", "model.pth"]) return self._v4 @property def classic(self) -> MutationPredictorCNN: if self._classic is None: self._classic = self._load_classic() return self._classic def _hf_download(self, repo_id: str, filenames: list[str]) -> str | None: try: from huggingface_hub import hf_hub_download for fname in filenames: try: return hf_hub_download(repo_id, fname, token=self.token, cache_dir="/tmp/mutation_xai") except Exception: continue except ImportError: pass return None def _load_v2(self, repo_id: str, primary: str, tag: str, fallback_files: list[str] | None = None) -> MutationPredictorCNN_v2: files = [primary] + (fallback_files or [ "pytorch_model.pth", "model.pth", "model.pt"]) path = self._hf_download(repo_id, files) model = None if path: try: ckpt = torch.load(path, map_location="cpu", weights_only=False) sd = ckpt.get("model_state_dict", ckpt) fc_region_out = sd["fc_region.weight"].shape[0] splice_fc_out = sd["splice_fc.weight"].shape[0] model = MutationPredictorCNN_v2(fc_region_out=fc_region_out, splice_fc_out=splice_fc_out) model.load_state_dict(sd, strict=True) if tag == "splice": self.val_acc_splice = ckpt.get("val_accuracy", 0.0) else: self.val_acc_v4 = ckpt.get("val_accuracy", 0.0) logger.info("Loaded %s from %s", tag, repo_id) except Exception as e: logger.warning("Failed to load %s: %s — demo mode", tag, e) model = None if model is None: self.demo_mode = True model = MutationPredictorCNN_v2() logger.warning("%s running in DEMO mode (random weights)", tag) model.eval() return model def _load_classic(self) -> MutationPredictorCNN: # ── Diagnostic: list ALL files in the repo so we know the real filename try: from huggingface_hub import list_repo_files all_files = list(list_repo_files(REPO_CLASSIC, token=self.token)) logger.info("Files in %s: %s", REPO_CLASSIC, all_files) # Auto-detect any .pt or .pth file in the repo pt_files = [f for f in all_files if f.endswith(('.pt', '.pth', '.bin'))] if pt_files: logger.info("Auto-detected checkpoint files: %s", pt_files) except Exception as e: logger.warning("Could not list repo files: %s", e) pt_files = [] # Try every plausible filename — the repo uses an unknown name. # Order: most likely names first based on the live app source code. candidates = pt_files + [ "mutation_predictor.pt", "mutation_pathogenicity_predictor.pt", "mutation_predictor_classic.pt", "pytorch_model.pt", "pytorch_model.pth", "model.pt", "model.pth", "checkpoint.pt", "best_model.pt", "classifier.pt", ] path = self._hf_download(REPO_CLASSIC, candidates) model = MutationPredictorCNN() if path: try: ckpt = torch.load(path, map_location="cpu", weights_only=False) sd = ckpt.get("model_state_dict", ckpt) model.load_state_dict(sd, strict=False) logger.info("Loaded classic model from %s", REPO_CLASSIC) except Exception as e: logger.warning("Failed to load classic: %s — demo mode", e) self.demo_mode = True else: self.demo_mode = True logger.warning( "Classic model: none of %s found in %s — running DEMO mode", candidates, REPO_CLASSIC ) model.eval() return model #Content is user-generated and unverified.