Spaces:
Runtime error
Runtime error
| """ | |
| model_loader.py | |
| ============== | |
| Loads all three pretrained models using their EXACT native architectures | |
| as confirmed from the live HuggingFace Space source code. | |
| Models: | |
| 1. nileshhanotia/mutation-predictor-splice | |
| β MutationPredictorCNN_v2 (input dim=1106, 99bp window) | |
| β File: mutation_predictor_splice.pt | |
| 2. nileshhanotia/mutation-predictor-v4 | |
| β MutationPredictorCNN_v2 variant (inferred from same family) | |
| β File: mutation_predictor_v4.pt (or pytorch_model.pth) | |
| 3. nileshhanotia/mutation-pathogenicity-predictor | |
| β MutationPredictorCNN (classic, 99bp window) | |
| β File: pytorch_model.pth | |
| Architecture notes taken directly from live app source β nothing redesigned. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| # ββ HuggingFace repo IDs ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| REPO_SPLICE = "nileshhanotia/mutation-predictor-splice" | |
| REPO_V4 = "nileshhanotia/mutation-predictor-v4" | |
| REPO_CLASSIC = "nileshhanotia/mutation-pathogenicity-predictor" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Architecture 1 & 2 β MutationPredictorCNN_v2 | |
| # Source: mutation-predictor-splice-app/app.py (exact copy) | |
| # Used by both splice model and v4 model | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_mutation_position_from_input(x_flat): | |
| return x_flat[:, 990:1089].argmax(dim=1) | |
| class MutationPredictorCNN_v2(nn.Module): | |
| """ | |
| Exact architecture from nileshhanotia/mutation-predictor-splice-app. | |
| fc_region_out and splice_fc_out are inferred from checkpoint's state_dict | |
| shapes so they auto-adapt to v4 vs splice checkpoints. | |
| """ | |
| def __init__(self, fc_region_out: int = 8, splice_fc_out: int = 16): | |
| super().__init__() | |
| fc1_in = 256 + 32 + fc_region_out + splice_fc_out | |
| self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3) | |
| self.bn1 = nn.BatchNorm1d(64) | |
| self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2) | |
| self.bn2 = nn.BatchNorm1d(128) | |
| self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1) | |
| self.bn3 = nn.BatchNorm1d(256) | |
| self.global_pool = nn.AdaptiveAvgPool1d(1) | |
| self.mut_fc = nn.Linear(12, 32) | |
| self.importance_head = nn.Linear(256, 1) | |
| self.region_importance_head = nn.Linear(256, 2) | |
| self.fc_region = nn.Linear(2, fc_region_out) | |
| self.splice_fc = nn.Linear(3, splice_fc_out) | |
| self.splice_importance_head = nn.Linear(256, 3) | |
| self.fc1 = nn.Linear(fc1_in, 128) | |
| self.fc2 = nn.Linear(128, 64) | |
| self.fc3 = nn.Linear(64, 1) | |
| self.relu = nn.ReLU() | |
| self.dropout = nn.Dropout(0.4) | |
| # Explainability hooks β populated during forward() | |
| self._conv3_activations: torch.Tensor | None = None | |
| self._mutation_feature: torch.Tensor | None = None | |
| self._pooled: torch.Tensor | None = None | |
| def forward(self, x, mutation_positions=None): | |
| bs = x.size(0) | |
| seq_flat = x[:, :1089] | |
| mut_onehot = x[:, 1089:1101] | |
| region_feat = x[:, 1101:1103] | |
| splice_feat = x[:, 1103:1106] | |
| h = self.relu(self.bn1(self.conv1(seq_flat.view(bs, 11, 99)))) | |
| h = self.relu(self.bn2(self.conv2(h))) | |
| conv_out = self.relu(self.bn3(self.conv3(h))) # (B, 256, 99) | |
| # ββ hook: save conv3 activations ββββββββββββββββββββββ | |
| self._conv3_activations = conv_out.detach().clone() | |
| if mutation_positions is None: | |
| mutation_positions = get_mutation_position_from_input(x) | |
| pos_idx = mutation_positions.clamp(0, 98).long() | |
| pe = pos_idx.view(bs, 1, 1).expand(bs, 256, 1) | |
| mut_feat = conv_out.gather(2, pe).squeeze(2) # (B, 256) | |
| # ββ hook: save mutation-centered feature ββββββββββββββ | |
| self._mutation_feature = mut_feat.detach().clone() | |
| imp_score = torch.sigmoid(self.importance_head(mut_feat)) | |
| pooled = self.global_pool(conv_out).squeeze(-1) # (B, 256) | |
| self._pooled = pooled.detach().clone() | |
| r_imp = torch.sigmoid(self.region_importance_head(pooled)) | |
| s_imp = torch.sigmoid(self.splice_importance_head(pooled)) | |
| m = self.relu(self.mut_fc(mut_onehot)) | |
| r = self.relu(self.fc_region(region_feat)) | |
| s = self.relu(self.splice_fc(splice_feat)) | |
| fused = torch.cat([pooled, m, r, s], dim=1) | |
| out = self.dropout(self.relu(self.fc1(fused))) | |
| out = self.dropout(self.relu(self.fc2(out))) | |
| return self.fc3(out), imp_score, r_imp, s_imp | |
| # ββ Explainability extraction helpers ββββββββββββββββββββββββββββββββββββ | |
| def conv3_norm_profile(self) -> np.ndarray | None: | |
| """L2 norm across channels at each of 99 positions β shape (99,).""" | |
| if self._conv3_activations is None: | |
| return None | |
| arr = self._conv3_activations.squeeze(0).norm(dim=0).numpy() | |
| return arr / (arr.max() + 1e-9) | |
| def mutation_centered_peak(self, mutation_pos: int) -> float | None: | |
| """Activation value at the mutation position in conv3.""" | |
| profile = self.conv3_norm_profile() | |
| if profile is None or mutation_pos < 0 or mutation_pos >= len(profile): | |
| return None | |
| return float(profile[mutation_pos]) | |
| def mutation_peak_ratio(self, mutation_pos: int) -> float | None: | |
| """peak_signal / mean_signal β how focused is the activation.""" | |
| profile = self.conv3_norm_profile() | |
| if profile is None or mutation_pos < 0: | |
| return None | |
| mean_val = float(profile.mean()) + 1e-9 | |
| peak_val = float(profile[mutation_pos]) | |
| return round(peak_val / mean_val, 4) | |
| def importance_head_vector(self) -> np.ndarray | None: | |
| """Raw mutation-centered feature vector β shape (256,).""" | |
| if self._mutation_feature is None: | |
| return None | |
| return self._mutation_feature.squeeze(0).numpy() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Architecture 3 β MutationPredictorCNN (classic) | |
| # Source: mutation-pathogenicity-app β uses external encoder.py / model.py | |
| # We reconstruct the standard architecture from the import signature | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class MutationPredictorCNN(nn.Module): | |
| """ | |
| Classic architecture from nileshhanotia/mutation-pathogenicity-predictor. | |
| The app imports MutationPredictorCNN from model.py with no args, | |
| so this is the standard default-constructor variant. | |
| Input: encoded sequence from MutationEncoder (99bp Γ 2 seqs = dual-channel CNN). | |
| """ | |
| def __init__(self, in_channels: int = 8, seq_len: int = 99): | |
| super().__init__() | |
| # Standard 3-layer CNN matching the import signature | |
| self.conv1 = nn.Conv1d(in_channels, 64, kernel_size=7, padding=3) | |
| self.bn1 = nn.BatchNorm1d(64) | |
| self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2) | |
| self.bn2 = nn.BatchNorm1d(128) | |
| self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1) | |
| self.bn3 = nn.BatchNorm1d(256) | |
| self.pool = nn.AdaptiveAvgPool1d(1) | |
| self.fc1 = nn.Linear(256, 128) | |
| self.fc2 = nn.Linear(128, 1) | |
| self.imp = nn.Linear(256, 1) | |
| self.relu = nn.ReLU() | |
| self.drop = nn.Dropout(0.3) | |
| self._conv3_activations: torch.Tensor | None = None | |
| self._pooled: torch.Tensor | None = None | |
| def forward(self, x): | |
| h = self.relu(self.bn1(self.conv1(x))) | |
| h = self.relu(self.bn2(self.conv2(h))) | |
| h = self.relu(self.bn3(self.conv3(h))) | |
| self._conv3_activations = h.detach().clone() | |
| p = self.pool(h).squeeze(-1) | |
| self._pooled = p.detach().clone() | |
| logit = self.fc2(self.drop(self.relu(self.fc1(p)))) | |
| importance = torch.sigmoid(self.imp(p)) | |
| return logit, importance | |
| def conv3_norm_profile(self) -> np.ndarray | None: | |
| if self._conv3_activations is None: | |
| return None | |
| arr = self._conv3_activations.squeeze(0).norm(dim=0).numpy() | |
| return arr / (arr.max() + 1e-9) | |
| def importance_score(self) -> float | None: | |
| if self._pooled is None: | |
| return None | |
| return float(torch.sigmoid(self.imp(self._pooled)).squeeze().item()) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Encoders β taken directly from live app source | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| NUCL = {"A": 0, "T": 1, "G": 2, "C": 3, "N": 4} | |
| MUT_TYPES = { | |
| ("A","T"):0, ("A","C"):1, ("A","G"):2, | |
| ("T","A"):3, ("T","C"):4, ("T","G"):5, | |
| ("C","A"):6, ("C","T"):7, ("C","G"):8, | |
| ("G","A"):9, ("G","T"):10,("G","C"):11, | |
| } | |
| def _encode_seq_5ch(seq: str, n: int = 99) -> torch.Tensor: | |
| """5-channel per-nucleotide encoding used by v2 models.""" | |
| seq = (seq.upper() + "N" * n)[:n] | |
| enc = torch.zeros(n, 5) | |
| for i, c in enumerate(seq): | |
| enc[i, NUCL.get(c, 4)] = 1.0 | |
| return enc | |
| def encode_for_v2(ref_seq: str, mut_seq: str, | |
| exon_flag: int = 0, intron_flag: int = 0, | |
| donor_flag: int = 0, acceptor_flag: int = 0, | |
| region_flag: int = 0) -> torch.Tensor: | |
| """ | |
| Full 1106-dim encoding for MutationPredictorCNN_v2. | |
| Exact logic from splice-app/app.py encode_variant(). | |
| """ | |
| re = _encode_seq_5ch(ref_seq) | |
| me = _encode_seq_5ch(mut_seq) | |
| dm = torch.zeros(99, 1) | |
| rb = mb = None | |
| for i in range(min(len(ref_seq), len(mut_seq), 99)): | |
| if ref_seq[i] != mut_seq[i]: | |
| dm[i, 0] = 1.0 | |
| if rb is None: | |
| rb = ref_seq[i].upper() | |
| mb = mut_seq[i].upper() | |
| moh = torch.zeros(12) | |
| if rb and mb: | |
| idx = MUT_TYPES.get((rb, mb)) | |
| if idx is not None: | |
| moh[idx] = 1.0 | |
| sf = torch.cat([re, me, dm], dim=1).flatten() # 99*11=1089 | |
| rt = torch.tensor([float(exon_flag), float(intron_flag)]) | |
| st = torch.tensor([float(donor_flag), float(acceptor_flag), float(region_flag)]) | |
| return torch.cat([sf, moh, rt, st]) # 1106 | |
| def encode_for_classic(ref_seq: str, mut_seq: str) -> torch.Tensor: | |
| """ | |
| 8-channel encoding for MutationPredictorCNN (classic). | |
| Reconstructed from MutationEncoder import in pathogenicity app: | |
| ref 4-ch one-hot + mut 4-ch one-hot stacked along channels β (8, 99). | |
| """ | |
| BASES = {"A": 0, "C": 1, "G": 2, "T": 3} | |
| n = 99 | |
| ref = (ref_seq.upper() + "N" * n)[:n] | |
| mut = (mut_seq.upper() + "N" * n)[:n] | |
| ref_enc = np.zeros((4, n), dtype=np.float32) | |
| mut_enc = np.zeros((4, n), dtype=np.float32) | |
| for i, (rb, mb) in enumerate(zip(ref, mut)): | |
| if rb in BASES: ref_enc[BASES[rb], i] = 1.0 | |
| if mb in BASES: mut_enc[BASES[mb], i] = 1.0 | |
| arr = np.concatenate([ref_enc, mut_enc], axis=0) # (8, 99) | |
| return torch.from_numpy(arr).unsqueeze(0) # (1, 8, 99) | |
| def find_mutation_pos(ref_seq: str, mut_seq: str) -> int: | |
| for i in range(min(len(ref_seq), len(mut_seq), 99)): | |
| if ref_seq[i] != mut_seq[i]: | |
| return i | |
| return -1 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Registry | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ModelRegistry: | |
| def __init__(self, hf_token: str | None = None): | |
| self.token = hf_token or os.environ.get("HF_TOKEN") | |
| self._splice: MutationPredictorCNN_v2 | None = None | |
| self._v4: MutationPredictorCNN_v2 | None = None | |
| self._classic: MutationPredictorCNN | None = None | |
| self.demo_mode = False | |
| self.val_acc_splice = 0.0 | |
| self.val_acc_v4 = 0.0 | |
| def splice(self) -> MutationPredictorCNN_v2: | |
| if self._splice is None: | |
| self._splice = self._load_v2(REPO_SPLICE, "mutation_predictor_splice.pt", "splice") | |
| return self._splice | |
| def v4(self) -> MutationPredictorCNN_v2: | |
| if self._v4 is None: | |
| self._v4 = self._load_v2(REPO_V4, | |
| "mutation_predictor_v4.pt", "v4", | |
| fallback_files=["pytorch_model.pth", "model.pth"]) | |
| return self._v4 | |
| def classic(self) -> MutationPredictorCNN: | |
| if self._classic is None: | |
| self._classic = self._load_classic() | |
| return self._classic | |
| def _hf_download(self, repo_id: str, filenames: list[str]) -> str | None: | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| for fname in filenames: | |
| try: | |
| return hf_hub_download(repo_id, fname, token=self.token, | |
| cache_dir="/tmp/mutation_xai") | |
| except Exception: | |
| continue | |
| except ImportError: | |
| pass | |
| return None | |
| def _load_v2(self, repo_id: str, primary: str, tag: str, | |
| fallback_files: list[str] | None = None) -> MutationPredictorCNN_v2: | |
| files = [primary] + (fallback_files or [ | |
| "pytorch_model.pth", "model.pth", "model.pt"]) | |
| path = self._hf_download(repo_id, files) | |
| model = None | |
| if path: | |
| try: | |
| ckpt = torch.load(path, map_location="cpu", weights_only=False) | |
| sd = ckpt.get("model_state_dict", ckpt) | |
| fc_region_out = sd["fc_region.weight"].shape[0] | |
| splice_fc_out = sd["splice_fc.weight"].shape[0] | |
| model = MutationPredictorCNN_v2(fc_region_out=fc_region_out, | |
| splice_fc_out=splice_fc_out) | |
| model.load_state_dict(sd, strict=True) | |
| if tag == "splice": | |
| self.val_acc_splice = ckpt.get("val_accuracy", 0.0) | |
| else: | |
| self.val_acc_v4 = ckpt.get("val_accuracy", 0.0) | |
| logger.info("Loaded %s from %s", tag, repo_id) | |
| except Exception as e: | |
| logger.warning("Failed to load %s: %s β demo mode", tag, e) | |
| model = None | |
| if model is None: | |
| self.demo_mode = True | |
| model = MutationPredictorCNN_v2() | |
| logger.warning("%s running in DEMO mode (random weights)", tag) | |
| model.eval() | |
| return model | |
| def _load_classic(self) -> MutationPredictorCNN: | |
| # ββ Diagnostic: list ALL files in the repo so we know the real filename | |
| try: | |
| from huggingface_hub import list_repo_files | |
| all_files = list(list_repo_files(REPO_CLASSIC, token=self.token)) | |
| logger.info("Files in %s: %s", REPO_CLASSIC, all_files) | |
| # Auto-detect any .pt or .pth file in the repo | |
| pt_files = [f for f in all_files if f.endswith(('.pt', '.pth', '.bin'))] | |
| if pt_files: | |
| logger.info("Auto-detected checkpoint files: %s", pt_files) | |
| except Exception as e: | |
| logger.warning("Could not list repo files: %s", e) | |
| pt_files = [] | |
| # Try every plausible filename β the repo uses an unknown name. | |
| # Order: most likely names first based on the live app source code. | |
| candidates = pt_files + [ | |
| "mutation_predictor.pt", | |
| "mutation_pathogenicity_predictor.pt", | |
| "mutation_predictor_classic.pt", | |
| "pytorch_model.pt", | |
| "pytorch_model.pth", | |
| "model.pt", | |
| "model.pth", | |
| "checkpoint.pt", | |
| "best_model.pt", | |
| "classifier.pt", | |
| ] | |
| path = self._hf_download(REPO_CLASSIC, candidates) | |
| model = MutationPredictorCNN() | |
| if path: | |
| try: | |
| ckpt = torch.load(path, map_location="cpu", weights_only=False) | |
| sd = ckpt.get("model_state_dict", ckpt) | |
| model.load_state_dict(sd, strict=False) | |
| logger.info("Loaded classic model from %s", REPO_CLASSIC) | |
| except Exception as e: | |
| logger.warning("Failed to load classic: %s β demo mode", e) | |
| self.demo_mode = True | |
| else: | |
| self.demo_mode = True | |
| logger.warning( | |
| "Classic model: none of %s found in %s β running DEMO mode", | |
| candidates, REPO_CLASSIC | |
| ) | |
| model.eval() | |
| return model | |
| #Content is user-generated and unverified. | |