Mutation_XAI / model_loader.py
nileshhanotia's picture
Upload model_loader.py
1bf9b9d verified
"""
model_loader.py — PeVe Unified Space Model Loading Module
Loading logic adapted from:
- nileshhanotia/mutation-predictor-splice-app (app.py)
- nileshhanotia/mutation-pathogenicity-app (app.py)
- nileshhanotia/mutation-explainable-v6 (model_v6.pkl)
Provides:
load_splice_model() → (model, status_dict)
load_context_model() → (model, status_dict)
load_protein_model() → (model, status_dict)
get_model_status() → combined status dict
"""
import os
import traceback
import pickle
import torch
import torch.nn as nn
# ── Optional: set HF token for private repos ───────────────────────────────
# Either set the environment variable HF_TOKEN before running, or hard-code
# a token here (not recommended for public repos).
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# ══════════════════════════════════════════════════════════════════════════════
# MODULE-LEVEL MODEL HANDLES
# These are populated by the load_*() functions below.
# ══════════════════════════════════════════════════════════════════════════════
_splice_model = None
_context_model = None
_protein_model = None
# ══════════════════════════════════════════════════════════════════════════════
# ARCHITECTURE — Splice Model
# Adapted from: nileshhanotia/mutation-predictor-splice-app app.py
# ══════════════════════════════════════════════════════════════════════════════
def _get_mutation_position_from_input(x_flat):
"""Internal helper used by MutationPredictorCNN_v2.forward()."""
return x_flat[:, 990:1089].argmax(dim=1)
class MutationPredictorCNN_v2(nn.Module):
"""
Splice-aware mutation predictor.
Architecture copied verbatim from mutation-predictor-splice-app/app.py
to guarantee weight compatibility.
"""
def __init__(self, fc_region_out=8, splice_fc_out=16):
super().__init__()
fc1_in = 256 + 32 + fc_region_out + splice_fc_out
self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3)
self.bn1 = nn.BatchNorm1d(64)
self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
self.bn2 = nn.BatchNorm1d(128)
self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm1d(256)
self.global_pool = nn.AdaptiveAvgPool1d(1)
self.mut_fc = nn.Linear(12, 32)
self.importance_head = nn.Linear(256, 1)
self.region_importance_head = nn.Linear(256, 2)
self.fc_region = nn.Linear(2, fc_region_out)
self.splice_fc = nn.Linear(3, splice_fc_out)
self.splice_importance_head = nn.Linear(256, 3)
self.fc1 = nn.Linear(fc1_in, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 1)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.4)
def forward(self, x, mutation_positions=None):
bs = x.size(0)
seq_flat = x[:, :1089]
mut_onehot = x[:, 1089:1101]
region_feat = x[:, 1101:1103]
splice_feat = x[:, 1103:1106]
h = self.relu(self.bn1(self.conv1(seq_flat.view(bs, 11, 99))))
h = self.relu(self.bn2(self.conv2(h)))
conv_out = self.relu(self.bn3(self.conv3(h)))
if mutation_positions is None:
mutation_positions = _get_mutation_position_from_input(x)
pos_idx = mutation_positions.clamp(0, 98).long()
pe = pos_idx.view(bs, 1, 1).expand(bs, 256, 1)
mut_feat = conv_out.gather(2, pe).squeeze(2)
imp_score = torch.sigmoid(self.importance_head(mut_feat))
pooled = self.global_pool(conv_out).squeeze(-1)
r_imp = torch.sigmoid(self.region_importance_head(pooled))
s_imp = torch.sigmoid(self.splice_importance_head(pooled))
m = self.relu(self.mut_fc(mut_onehot))
r = self.relu(self.fc_region(region_feat))
s = self.relu(self.splice_fc(splice_feat))
fused = torch.cat([pooled, m, r, s], dim=1)
out = self.dropout(self.relu(self.fc1(fused)))
out = self.dropout(self.relu(self.fc2(out)))
logit = self.fc3(out)
return logit, imp_score, r_imp, s_imp
# ══════════════════════════════════════════════════════════════════════════════
# ARCHITECTURE — Context (401 bp CNN) Model
# Adapted from: nileshhanotia/mutation-predictor-v4
# ══════════════════════════════════════════════════════════════════════════════
class MutationContextCNN(nn.Module):
"""
401 bp context window CNN for mutation pathogenicity.
Architecture mirrors the v4 space model; weights loaded from state dict.
If the actual v4 architecture differs, the load_state_dict call will raise
a descriptive KeyError that will be captured in the status dict.
"""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv1d(5, 64, kernel_size=11, padding=5)
self.bn1 = nn.BatchNorm1d(64)
self.conv2 = nn.Conv1d(64, 128, kernel_size=7, padding=3)
self.bn2 = nn.BatchNorm1d(128)
self.conv3 = nn.Conv1d(128, 256, kernel_size=5, padding=2)
self.bn3 = nn.BatchNorm1d(256)
self.pool = nn.AdaptiveAvgPool1d(1)
self.fc1 = nn.Linear(256, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 1)
self.relu = nn.ReLU()
self.drop = nn.Dropout(0.3)
def forward(self, x):
# x: (batch, seq_len, channels) → permute → (batch, channels, seq_len)
h = x.permute(0, 2, 1)
h = self.relu(self.bn1(self.conv1(h)))
h = self.relu(self.bn2(self.conv2(h)))
h = self.relu(self.bn3(self.conv3(h)))
h = self.pool(h).squeeze(-1)
h = self.drop(self.relu(self.fc1(h)))
h = self.drop(self.relu(self.fc2(h)))
return self.fc3(h)
# ══════════════════════════════════════════════════════════════════════════════
# LOADER — Splice Model
# ══════════════════════════════════════════════════════════════════════════════
def load_splice_model():
"""
Load MutationPredictorCNN_v2 from nileshhanotia/mutation-predictor-splice.
Loading logic adapted from:
nileshhanotia/mutation-predictor-splice-app app.py
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
sd = ckpt["model_state_dict"]
Returns
-------
(model | None, {"loaded": bool, "error_message": str})
"""
global _splice_model
status = {"loaded": False, "error_message": ""}
try:
from huggingface_hub import hf_hub_download # local import for clarity
MODEL_REPO = "nileshhanotia/mutation-predictor-splice"
MODEL_FILENAME = "mutation_predictor_splice.pt"
print(f"[splice] Downloading {MODEL_FILENAME} from {MODEL_REPO} …")
ckpt_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILENAME,
token=HF_TOKEN,
)
print(f"[splice] Loading checkpoint from {ckpt_path} …")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
sd = ckpt["model_state_dict"]
# Infer architecture hyper-params from the state dict (exact pattern from app.py)
fc_region_out = sd["fc_region.weight"].shape[0]
splice_fc_out = sd["splice_fc.weight"].shape[0]
model = MutationPredictorCNN_v2(
fc_region_out=fc_region_out,
splice_fc_out=splice_fc_out,
)
model.load_state_dict(sd)
model.eval()
val_acc = ckpt.get("val_accuracy", float("nan"))
print(f"[splice] ✓ Loaded. val_accuracy={val_acc:.4f} | "
f"fc_region_out={fc_region_out} | splice_fc_out={splice_fc_out}")
_splice_model = model
status["loaded"] = True
except Exception:
tb = traceback.format_exc()
print(f"[splice] ✗ FAILED to load:\n{tb}")
status["error_message"] = tb
_splice_model = None
return _splice_model, status
# ══════════════════════════════════════════════════════════════════════════════
# LOADER — Context Model (401 bp CNN, mutation-predictor-v4)
# ══════════════════════════════════════════════════════════════════════════════
def load_context_model():
"""
Load the 401 bp context CNN from nileshhanotia/mutation-predictor-v4.
Loading logic adapted from:
nileshhanotia/mutation-pathogenicity-app app.py
checkpoint = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
Returns
-------
(model | None, {"loaded": bool, "error_message": str})
"""
global _context_model
status = {"loaded": False, "error_message": ""}
try:
from huggingface_hub import hf_hub_download
MODEL_REPO = "nileshhanotia/mutation-predictor-v4"
# Try common checkpoint filenames used in HF spaces
CANDIDATE_FILENAMES = [
"pytorch_model.pth",
"mutation_predictor_v4.pt",
"model.pt",
"model.pth",
"checkpoint.pth",
]
ckpt_path = None
last_error = ""
for fname in CANDIDATE_FILENAMES:
try:
print(f"[context] Trying {fname} from {MODEL_REPO} …")
ckpt_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=fname,
token=HF_TOKEN,
)
print(f"[context] Found: {fname}")
break
except Exception as e:
last_error = str(e)
continue
if ckpt_path is None:
raise FileNotFoundError(
f"None of the candidate filenames found in {MODEL_REPO}. "
f"Last error: {last_error}"
)
print(f"[context] Loading checkpoint from {ckpt_path} …")
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
# Support both raw state-dict and wrapped checkpoint
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
sd = checkpoint["model_state_dict"]
elif isinstance(checkpoint, dict) and "state_dict" in checkpoint:
sd = checkpoint["state_dict"]
else:
sd = checkpoint # assume it IS the state dict
model = MutationContextCNN()
model.load_state_dict(sd, strict=False) # strict=False tolerates minor arch diffs
model.eval()
print("[context] ✓ Loaded MutationContextCNN (401 bp).")
_context_model = model
status["loaded"] = True
except Exception:
tb = traceback.format_exc()
print(f"[context] ✗ FAILED to load:\n{tb}")
status["error_message"] = tb
_context_model = None
return _context_model, status
# ══════════════════════════════════════════════════════════════════════════════
# LOADER — Protein Model (XGBoost .pkl from mutation-explainable-v6)
# ══════════════════════════════════════════════════════════════════════════════
def load_protein_model():
"""
Load the pickled XGBoost model from nileshhanotia/mutation-explainable-v6.
Loading logic adapted from:
nileshhanotia/mutation-explainable-v6 (model_v6.pkl)
Uses Python pickle / joblib — NOT XGBoost Booster.load_model().
The model is already stored as a complete trained sklearn-compatible object.
Returns
-------
(model | None, {"loaded": bool, "error_message": str})
"""
global _protein_model
status = {"loaded": False, "error_message": ""}
try:
from huggingface_hub import hf_hub_download
MODEL_REPO = "nileshhanotia/mutation-explainable-v6"
MODEL_FILENAME = "model_v6.pkl"
print(f"[protein] Downloading {MODEL_FILENAME} from {MODEL_REPO} …")
pkl_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILENAME,
token=HF_TOKEN,
)
print(f"[protein] Loading pickle from {pkl_path} …")
# Try joblib first (common for sklearn/xgboost pipelines), fall back to pickle
try:
import joblib
model = joblib.load(pkl_path)
print("[protein] Loaded via joblib.")
except Exception:
with open(pkl_path, "rb") as f:
model = pickle.load(f)
print("[protein] Loaded via pickle.")
print(f"[protein] ✓ Loaded protein model: {type(model).__name__}")
_protein_model = model
status["loaded"] = True
except Exception:
tb = traceback.format_exc()
print(f"[protein] ✗ FAILED to load:\n{tb}")
status["error_message"] = tb
_protein_model = None
return _protein_model, status
# ══════════════════════════════════════════════════════════════════════════════
# STATUS AGGREGATOR
# ══════════════════════════════════════════════════════════════════════════════
def get_model_status() -> dict:
"""
Load all three models and return a unified status dictionary.
Returns
-------
{
"splice": {"loaded": bool, "error_message": str},
"context": {"loaded": bool, "error_message": str},
"protein": {"loaded": bool, "error_message": str},
}
"""
print("=" * 60)
print("PeVe — starting unified model loading")
print("=" * 60)
_, splice_status = load_splice_model()
_, context_status = load_context_model()
_, protein_status = load_protein_model()
status = {
"splice": splice_status,
"context": context_status,
"protein": protein_status,
}
# Summary report
print("\n" + "=" * 60)
print("PeVe — model loading complete")
print("=" * 60)
for name, s in status.items():
icon = "✓" if s["loaded"] else "✗"
print(f" [{icon}] {name:10s} loaded={s['loaded']}")
print("=" * 60 + "\n")
return status
# ══════════════════════════════════════════════════════════════════════════════
# PUBLIC ACCESSORS
# ══════════════════════════════════════════════════════════════════════════════
def get_splice_model():
"""Return the loaded splice model handle (None if not loaded)."""
return _splice_model
def get_context_model():
"""Return the loaded context model handle (None if not loaded)."""
return _context_model
def get_protein_model():
"""Return the loaded protein model handle (None if not loaded)."""
return _protein_model
# ══════════════════════════════════════════════════════════════════════════════
# SELF-TEST
# ══════════════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
print("Testing model loading...")
status = get_model_status()
print(status)