Mutation-XAI / model_loader.py
nileshhanotia's picture
Create model_loader.py
ccc687d verified
"""
model_loader.py
==============
Loads all three pretrained models using their EXACT native architectures
as confirmed from the live HuggingFace Space source code.
Models:
1. nileshhanotia/mutation-predictor-splice
β†’ MutationPredictorCNN_v2 (input dim=1106, 99bp window)
β†’ File: mutation_predictor_splice.pt
2. nileshhanotia/mutation-predictor-v4
β†’ MutationPredictorCNN_v2 variant (inferred from same family)
β†’ File: mutation_predictor_v4.pt (or pytorch_model.pth)
3. nileshhanotia/mutation-pathogenicity-predictor
β†’ MutationPredictorCNN (classic, 99bp window)
β†’ File: pytorch_model.pth
Architecture notes taken directly from live app source β€” nothing redesigned.
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
logger = logging.getLogger(__name__)
# ── HuggingFace repo IDs ──────────────────────────────────────────────────────
REPO_SPLICE = "nileshhanotia/mutation-predictor-splice"
REPO_V4 = "nileshhanotia/mutation-predictor-v4"
REPO_CLASSIC = "nileshhanotia/mutation-pathogenicity-predictor"
# ═══════════════════════════════════════════════════════════════════════════════
# Architecture 1 & 2 β€” MutationPredictorCNN_v2
# Source: mutation-predictor-splice-app/app.py (exact copy)
# Used by both splice model and v4 model
# ═══════════════════════════════════════════════════════════════════════════════
def get_mutation_position_from_input(x_flat):
return x_flat[:, 990:1089].argmax(dim=1)
class MutationPredictorCNN_v2(nn.Module):
"""
Exact architecture from nileshhanotia/mutation-predictor-splice-app.
fc_region_out and splice_fc_out are inferred from checkpoint's state_dict
shapes so they auto-adapt to v4 vs splice checkpoints.
"""
def __init__(self, fc_region_out: int = 8, splice_fc_out: int = 16):
super().__init__()
fc1_in = 256 + 32 + fc_region_out + splice_fc_out
self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3)
self.bn1 = nn.BatchNorm1d(64)
self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
self.bn2 = nn.BatchNorm1d(128)
self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm1d(256)
self.global_pool = nn.AdaptiveAvgPool1d(1)
self.mut_fc = nn.Linear(12, 32)
self.importance_head = nn.Linear(256, 1)
self.region_importance_head = nn.Linear(256, 2)
self.fc_region = nn.Linear(2, fc_region_out)
self.splice_fc = nn.Linear(3, splice_fc_out)
self.splice_importance_head = nn.Linear(256, 3)
self.fc1 = nn.Linear(fc1_in, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 1)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.4)
# Explainability hooks β€” populated during forward()
self._conv3_activations: torch.Tensor | None = None
self._mutation_feature: torch.Tensor | None = None
self._pooled: torch.Tensor | None = None
def forward(self, x, mutation_positions=None):
bs = x.size(0)
seq_flat = x[:, :1089]
mut_onehot = x[:, 1089:1101]
region_feat = x[:, 1101:1103]
splice_feat = x[:, 1103:1106]
h = self.relu(self.bn1(self.conv1(seq_flat.view(bs, 11, 99))))
h = self.relu(self.bn2(self.conv2(h)))
conv_out = self.relu(self.bn3(self.conv3(h))) # (B, 256, 99)
# ── hook: save conv3 activations ──────────────────────
self._conv3_activations = conv_out.detach().clone()
if mutation_positions is None:
mutation_positions = get_mutation_position_from_input(x)
pos_idx = mutation_positions.clamp(0, 98).long()
pe = pos_idx.view(bs, 1, 1).expand(bs, 256, 1)
mut_feat = conv_out.gather(2, pe).squeeze(2) # (B, 256)
# ── hook: save mutation-centered feature ──────────────
self._mutation_feature = mut_feat.detach().clone()
imp_score = torch.sigmoid(self.importance_head(mut_feat))
pooled = self.global_pool(conv_out).squeeze(-1) # (B, 256)
self._pooled = pooled.detach().clone()
r_imp = torch.sigmoid(self.region_importance_head(pooled))
s_imp = torch.sigmoid(self.splice_importance_head(pooled))
m = self.relu(self.mut_fc(mut_onehot))
r = self.relu(self.fc_region(region_feat))
s = self.relu(self.splice_fc(splice_feat))
fused = torch.cat([pooled, m, r, s], dim=1)
out = self.dropout(self.relu(self.fc1(fused)))
out = self.dropout(self.relu(self.fc2(out)))
return self.fc3(out), imp_score, r_imp, s_imp
# ── Explainability extraction helpers ────────────────────────────────────
def conv3_norm_profile(self) -> np.ndarray | None:
"""L2 norm across channels at each of 99 positions β€” shape (99,)."""
if self._conv3_activations is None:
return None
arr = self._conv3_activations.squeeze(0).norm(dim=0).numpy()
return arr / (arr.max() + 1e-9)
def mutation_centered_peak(self, mutation_pos: int) -> float | None:
"""Activation value at the mutation position in conv3."""
profile = self.conv3_norm_profile()
if profile is None or mutation_pos < 0 or mutation_pos >= len(profile):
return None
return float(profile[mutation_pos])
def mutation_peak_ratio(self, mutation_pos: int) -> float | None:
"""peak_signal / mean_signal β€” how focused is the activation."""
profile = self.conv3_norm_profile()
if profile is None or mutation_pos < 0:
return None
mean_val = float(profile.mean()) + 1e-9
peak_val = float(profile[mutation_pos])
return round(peak_val / mean_val, 4)
def importance_head_vector(self) -> np.ndarray | None:
"""Raw mutation-centered feature vector β€” shape (256,)."""
if self._mutation_feature is None:
return None
return self._mutation_feature.squeeze(0).numpy()
# ═══════════════════════════════════════════════════════════════════════════════
# Architecture 3 β€” MutationPredictorCNN (classic)
# Source: mutation-pathogenicity-app β€” uses external encoder.py / model.py
# We reconstruct the standard architecture from the import signature
# ═══════════════════════════════════════════════════════════════════════════════
class MutationPredictorCNN(nn.Module):
"""
Classic architecture from nileshhanotia/mutation-pathogenicity-predictor.
The app imports MutationPredictorCNN from model.py with no args,
so this is the standard default-constructor variant.
Input: encoded sequence from MutationEncoder (99bp Γ— 2 seqs = dual-channel CNN).
"""
def __init__(self, in_channels: int = 8, seq_len: int = 99):
super().__init__()
# Standard 3-layer CNN matching the import signature
self.conv1 = nn.Conv1d(in_channels, 64, kernel_size=7, padding=3)
self.bn1 = nn.BatchNorm1d(64)
self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
self.bn2 = nn.BatchNorm1d(128)
self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm1d(256)
self.pool = nn.AdaptiveAvgPool1d(1)
self.fc1 = nn.Linear(256, 128)
self.fc2 = nn.Linear(128, 1)
self.imp = nn.Linear(256, 1)
self.relu = nn.ReLU()
self.drop = nn.Dropout(0.3)
self._conv3_activations: torch.Tensor | None = None
self._pooled: torch.Tensor | None = None
def forward(self, x):
h = self.relu(self.bn1(self.conv1(x)))
h = self.relu(self.bn2(self.conv2(h)))
h = self.relu(self.bn3(self.conv3(h)))
self._conv3_activations = h.detach().clone()
p = self.pool(h).squeeze(-1)
self._pooled = p.detach().clone()
logit = self.fc2(self.drop(self.relu(self.fc1(p))))
importance = torch.sigmoid(self.imp(p))
return logit, importance
def conv3_norm_profile(self) -> np.ndarray | None:
if self._conv3_activations is None:
return None
arr = self._conv3_activations.squeeze(0).norm(dim=0).numpy()
return arr / (arr.max() + 1e-9)
def importance_score(self) -> float | None:
if self._pooled is None:
return None
return float(torch.sigmoid(self.imp(self._pooled)).squeeze().item())
# ═══════════════════════════════════════════════════════════════════════════════
# Encoders β€” taken directly from live app source
# ═══════════════════════════════════════════════════════════════════════════════
NUCL = {"A": 0, "T": 1, "G": 2, "C": 3, "N": 4}
MUT_TYPES = {
("A","T"):0, ("A","C"):1, ("A","G"):2,
("T","A"):3, ("T","C"):4, ("T","G"):5,
("C","A"):6, ("C","T"):7, ("C","G"):8,
("G","A"):9, ("G","T"):10,("G","C"):11,
}
def _encode_seq_5ch(seq: str, n: int = 99) -> torch.Tensor:
"""5-channel per-nucleotide encoding used by v2 models."""
seq = (seq.upper() + "N" * n)[:n]
enc = torch.zeros(n, 5)
for i, c in enumerate(seq):
enc[i, NUCL.get(c, 4)] = 1.0
return enc
def encode_for_v2(ref_seq: str, mut_seq: str,
exon_flag: int = 0, intron_flag: int = 0,
donor_flag: int = 0, acceptor_flag: int = 0,
region_flag: int = 0) -> torch.Tensor:
"""
Full 1106-dim encoding for MutationPredictorCNN_v2.
Exact logic from splice-app/app.py encode_variant().
"""
re = _encode_seq_5ch(ref_seq)
me = _encode_seq_5ch(mut_seq)
dm = torch.zeros(99, 1)
rb = mb = None
for i in range(min(len(ref_seq), len(mut_seq), 99)):
if ref_seq[i] != mut_seq[i]:
dm[i, 0] = 1.0
if rb is None:
rb = ref_seq[i].upper()
mb = mut_seq[i].upper()
moh = torch.zeros(12)
if rb and mb:
idx = MUT_TYPES.get((rb, mb))
if idx is not None:
moh[idx] = 1.0
sf = torch.cat([re, me, dm], dim=1).flatten() # 99*11=1089
rt = torch.tensor([float(exon_flag), float(intron_flag)])
st = torch.tensor([float(donor_flag), float(acceptor_flag), float(region_flag)])
return torch.cat([sf, moh, rt, st]) # 1106
def encode_for_classic(ref_seq: str, mut_seq: str) -> torch.Tensor:
"""
8-channel encoding for MutationPredictorCNN (classic).
Reconstructed from MutationEncoder import in pathogenicity app:
ref 4-ch one-hot + mut 4-ch one-hot stacked along channels β†’ (8, 99).
"""
BASES = {"A": 0, "C": 1, "G": 2, "T": 3}
n = 99
ref = (ref_seq.upper() + "N" * n)[:n]
mut = (mut_seq.upper() + "N" * n)[:n]
ref_enc = np.zeros((4, n), dtype=np.float32)
mut_enc = np.zeros((4, n), dtype=np.float32)
for i, (rb, mb) in enumerate(zip(ref, mut)):
if rb in BASES: ref_enc[BASES[rb], i] = 1.0
if mb in BASES: mut_enc[BASES[mb], i] = 1.0
arr = np.concatenate([ref_enc, mut_enc], axis=0) # (8, 99)
return torch.from_numpy(arr).unsqueeze(0) # (1, 8, 99)
def find_mutation_pos(ref_seq: str, mut_seq: str) -> int:
for i in range(min(len(ref_seq), len(mut_seq), 99)):
if ref_seq[i] != mut_seq[i]:
return i
return -1
# ═══════════════════════════════════════════════════════════════════════════════
# Registry
# ═══════════════════════════════════════════════════════════════════════════════
class ModelRegistry:
def __init__(self, hf_token: str | None = None):
self.token = hf_token or os.environ.get("HF_TOKEN")
self._splice: MutationPredictorCNN_v2 | None = None
self._v4: MutationPredictorCNN_v2 | None = None
self._classic: MutationPredictorCNN | None = None
self.demo_mode = False
self.val_acc_splice = 0.0
self.val_acc_v4 = 0.0
@property
def splice(self) -> MutationPredictorCNN_v2:
if self._splice is None:
self._splice = self._load_v2(REPO_SPLICE, "mutation_predictor_splice.pt", "splice")
return self._splice
@property
def v4(self) -> MutationPredictorCNN_v2:
if self._v4 is None:
self._v4 = self._load_v2(REPO_V4,
"mutation_predictor_v4.pt", "v4",
fallback_files=["pytorch_model.pth", "model.pth"])
return self._v4
@property
def classic(self) -> MutationPredictorCNN:
if self._classic is None:
self._classic = self._load_classic()
return self._classic
def _hf_download(self, repo_id: str, filenames: list[str]) -> str | None:
try:
from huggingface_hub import hf_hub_download
for fname in filenames:
try:
return hf_hub_download(repo_id, fname, token=self.token,
cache_dir="/tmp/mutation_xai")
except Exception:
continue
except ImportError:
pass
return None
def _load_v2(self, repo_id: str, primary: str, tag: str,
fallback_files: list[str] | None = None) -> MutationPredictorCNN_v2:
files = [primary] + (fallback_files or [
"pytorch_model.pth", "model.pth", "model.pt"])
path = self._hf_download(repo_id, files)
model = None
if path:
try:
ckpt = torch.load(path, map_location="cpu", weights_only=False)
sd = ckpt.get("model_state_dict", ckpt)
fc_region_out = sd["fc_region.weight"].shape[0]
splice_fc_out = sd["splice_fc.weight"].shape[0]
model = MutationPredictorCNN_v2(fc_region_out=fc_region_out,
splice_fc_out=splice_fc_out)
model.load_state_dict(sd, strict=True)
if tag == "splice":
self.val_acc_splice = ckpt.get("val_accuracy", 0.0)
else:
self.val_acc_v4 = ckpt.get("val_accuracy", 0.0)
logger.info("Loaded %s from %s", tag, repo_id)
except Exception as e:
logger.warning("Failed to load %s: %s β€” demo mode", tag, e)
model = None
if model is None:
self.demo_mode = True
model = MutationPredictorCNN_v2()
logger.warning("%s running in DEMO mode (random weights)", tag)
model.eval()
return model
def _load_classic(self) -> MutationPredictorCNN:
# ── Diagnostic: list ALL files in the repo so we know the real filename
try:
from huggingface_hub import list_repo_files
all_files = list(list_repo_files(REPO_CLASSIC, token=self.token))
logger.info("Files in %s: %s", REPO_CLASSIC, all_files)
# Auto-detect any .pt or .pth file in the repo
pt_files = [f for f in all_files if f.endswith(('.pt', '.pth', '.bin'))]
if pt_files:
logger.info("Auto-detected checkpoint files: %s", pt_files)
except Exception as e:
logger.warning("Could not list repo files: %s", e)
pt_files = []
# Try every plausible filename β€” the repo uses an unknown name.
# Order: most likely names first based on the live app source code.
candidates = pt_files + [
"mutation_predictor.pt",
"mutation_pathogenicity_predictor.pt",
"mutation_predictor_classic.pt",
"pytorch_model.pt",
"pytorch_model.pth",
"model.pt",
"model.pth",
"checkpoint.pt",
"best_model.pt",
"classifier.pt",
]
path = self._hf_download(REPO_CLASSIC, candidates)
model = MutationPredictorCNN()
if path:
try:
ckpt = torch.load(path, map_location="cpu", weights_only=False)
sd = ckpt.get("model_state_dict", ckpt)
model.load_state_dict(sd, strict=False)
logger.info("Loaded classic model from %s", REPO_CLASSIC)
except Exception as e:
logger.warning("Failed to load classic: %s β€” demo mode", e)
self.demo_mode = True
else:
self.demo_mode = True
logger.warning(
"Classic model: none of %s found in %s β€” running DEMO mode",
candidates, REPO_CLASSIC
)
model.eval()
return model
#Content is user-generated and unverified.