""" model_loader.py — PeVe Unified Space Model Loading Module Loading logic adapted from: - nileshhanotia/mutation-predictor-splice-app (app.py) - nileshhanotia/mutation-pathogenicity-app (app.py) - nileshhanotia/mutation-explainable-v6 (model_v6.pkl) Provides: load_splice_model() → (model, status_dict) load_context_model() → (model, status_dict) load_protein_model() → (model, status_dict) get_model_status() → combined status dict """ import os import traceback import pickle import torch import torch.nn as nn # ── Optional: set HF token for private repos ─────────────────────────────── # Either set the environment variable HF_TOKEN before running, or hard-code # a token here (not recommended for public repos). HF_TOKEN = os.environ.get("HF_TOKEN", None) # ══════════════════════════════════════════════════════════════════════════════ # MODULE-LEVEL MODEL HANDLES # These are populated by the load_*() functions below. # ══════════════════════════════════════════════════════════════════════════════ _splice_model = None _context_model = None _protein_model = None # ══════════════════════════════════════════════════════════════════════════════ # ARCHITECTURE — Splice Model # Adapted from: nileshhanotia/mutation-predictor-splice-app app.py # ══════════════════════════════════════════════════════════════════════════════ def _get_mutation_position_from_input(x_flat): """Internal helper used by MutationPredictorCNN_v2.forward().""" return x_flat[:, 990:1089].argmax(dim=1) class MutationPredictorCNN_v2(nn.Module): """ Splice-aware mutation predictor. Architecture copied verbatim from mutation-predictor-splice-app/app.py to guarantee weight compatibility. """ def __init__(self, fc_region_out=8, splice_fc_out=16): super().__init__() fc1_in = 256 + 32 + fc_region_out + splice_fc_out self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3) self.bn1 = nn.BatchNorm1d(64) self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2) self.bn2 = nn.BatchNorm1d(128) self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm1d(256) self.global_pool = nn.AdaptiveAvgPool1d(1) self.mut_fc = nn.Linear(12, 32) self.importance_head = nn.Linear(256, 1) self.region_importance_head = nn.Linear(256, 2) self.fc_region = nn.Linear(2, fc_region_out) self.splice_fc = nn.Linear(3, splice_fc_out) self.splice_importance_head = nn.Linear(256, 3) self.fc1 = nn.Linear(fc1_in, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 1) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.4) def forward(self, x, mutation_positions=None): bs = x.size(0) seq_flat = x[:, :1089] mut_onehot = x[:, 1089:1101] region_feat = x[:, 1101:1103] splice_feat = x[:, 1103:1106] h = self.relu(self.bn1(self.conv1(seq_flat.view(bs, 11, 99)))) h = self.relu(self.bn2(self.conv2(h))) conv_out = self.relu(self.bn3(self.conv3(h))) if mutation_positions is None: mutation_positions = _get_mutation_position_from_input(x) pos_idx = mutation_positions.clamp(0, 98).long() pe = pos_idx.view(bs, 1, 1).expand(bs, 256, 1) mut_feat = conv_out.gather(2, pe).squeeze(2) imp_score = torch.sigmoid(self.importance_head(mut_feat)) pooled = self.global_pool(conv_out).squeeze(-1) r_imp = torch.sigmoid(self.region_importance_head(pooled)) s_imp = torch.sigmoid(self.splice_importance_head(pooled)) m = self.relu(self.mut_fc(mut_onehot)) r = self.relu(self.fc_region(region_feat)) s = self.relu(self.splice_fc(splice_feat)) fused = torch.cat([pooled, m, r, s], dim=1) out = self.dropout(self.relu(self.fc1(fused))) out = self.dropout(self.relu(self.fc2(out))) logit = self.fc3(out) return logit, imp_score, r_imp, s_imp # ══════════════════════════════════════════════════════════════════════════════ # ARCHITECTURE — Context (401 bp CNN) Model # Adapted from: nileshhanotia/mutation-predictor-v4 # ══════════════════════════════════════════════════════════════════════════════ class MutationContextCNN(nn.Module): """ 401 bp context window CNN for mutation pathogenicity. Architecture mirrors the v4 space model; weights loaded from state dict. If the actual v4 architecture differs, the load_state_dict call will raise a descriptive KeyError that will be captured in the status dict. """ def __init__(self): super().__init__() self.conv1 = nn.Conv1d(5, 64, kernel_size=11, padding=5) self.bn1 = nn.BatchNorm1d(64) self.conv2 = nn.Conv1d(64, 128, kernel_size=7, padding=3) self.bn2 = nn.BatchNorm1d(128) self.conv3 = nn.Conv1d(128, 256, kernel_size=5, padding=2) self.bn3 = nn.BatchNorm1d(256) self.pool = nn.AdaptiveAvgPool1d(1) self.fc1 = nn.Linear(256, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 1) self.relu = nn.ReLU() self.drop = nn.Dropout(0.3) def forward(self, x): # x: (batch, seq_len, channels) → permute → (batch, channels, seq_len) h = x.permute(0, 2, 1) h = self.relu(self.bn1(self.conv1(h))) h = self.relu(self.bn2(self.conv2(h))) h = self.relu(self.bn3(self.conv3(h))) h = self.pool(h).squeeze(-1) h = self.drop(self.relu(self.fc1(h))) h = self.drop(self.relu(self.fc2(h))) return self.fc3(h) # ══════════════════════════════════════════════════════════════════════════════ # LOADER — Splice Model # ══════════════════════════════════════════════════════════════════════════════ def load_splice_model(): """ Load MutationPredictorCNN_v2 from nileshhanotia/mutation-predictor-splice. Loading logic adapted from: nileshhanotia/mutation-predictor-splice-app app.py ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) sd = ckpt["model_state_dict"] Returns ------- (model | None, {"loaded": bool, "error_message": str}) """ global _splice_model status = {"loaded": False, "error_message": ""} try: from huggingface_hub import hf_hub_download # local import for clarity MODEL_REPO = "nileshhanotia/mutation-predictor-splice" MODEL_FILENAME = "mutation_predictor_splice.pt" print(f"[splice] Downloading {MODEL_FILENAME} from {MODEL_REPO} …") ckpt_path = hf_hub_download( repo_id=MODEL_REPO, filename=MODEL_FILENAME, token=HF_TOKEN, ) print(f"[splice] Loading checkpoint from {ckpt_path} …") ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) sd = ckpt["model_state_dict"] # Infer architecture hyper-params from the state dict (exact pattern from app.py) fc_region_out = sd["fc_region.weight"].shape[0] splice_fc_out = sd["splice_fc.weight"].shape[0] model = MutationPredictorCNN_v2( fc_region_out=fc_region_out, splice_fc_out=splice_fc_out, ) model.load_state_dict(sd) model.eval() val_acc = ckpt.get("val_accuracy", float("nan")) print(f"[splice] ✓ Loaded. val_accuracy={val_acc:.4f} | " f"fc_region_out={fc_region_out} | splice_fc_out={splice_fc_out}") _splice_model = model status["loaded"] = True except Exception: tb = traceback.format_exc() print(f"[splice] ✗ FAILED to load:\n{tb}") status["error_message"] = tb _splice_model = None return _splice_model, status # ══════════════════════════════════════════════════════════════════════════════ # LOADER — Context Model (401 bp CNN, mutation-predictor-v4) # ══════════════════════════════════════════════════════════════════════════════ def load_context_model(): """ Load the 401 bp context CNN from nileshhanotia/mutation-predictor-v4. Loading logic adapted from: nileshhanotia/mutation-pathogenicity-app app.py checkpoint = torch.load(MODEL_PATH, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) Returns ------- (model | None, {"loaded": bool, "error_message": str}) """ global _context_model status = {"loaded": False, "error_message": ""} try: from huggingface_hub import hf_hub_download MODEL_REPO = "nileshhanotia/mutation-predictor-v4" # Try common checkpoint filenames used in HF spaces CANDIDATE_FILENAMES = [ "pytorch_model.pth", "mutation_predictor_v4.pt", "model.pt", "model.pth", "checkpoint.pth", ] ckpt_path = None last_error = "" for fname in CANDIDATE_FILENAMES: try: print(f"[context] Trying {fname} from {MODEL_REPO} …") ckpt_path = hf_hub_download( repo_id=MODEL_REPO, filename=fname, token=HF_TOKEN, ) print(f"[context] Found: {fname}") break except Exception as e: last_error = str(e) continue if ckpt_path is None: raise FileNotFoundError( f"None of the candidate filenames found in {MODEL_REPO}. " f"Last error: {last_error}" ) print(f"[context] Loading checkpoint from {ckpt_path} …") checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) # Support both raw state-dict and wrapped checkpoint if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: sd = checkpoint["model_state_dict"] elif isinstance(checkpoint, dict) and "state_dict" in checkpoint: sd = checkpoint["state_dict"] else: sd = checkpoint # assume it IS the state dict model = MutationContextCNN() model.load_state_dict(sd, strict=False) # strict=False tolerates minor arch diffs model.eval() print("[context] ✓ Loaded MutationContextCNN (401 bp).") _context_model = model status["loaded"] = True except Exception: tb = traceback.format_exc() print(f"[context] ✗ FAILED to load:\n{tb}") status["error_message"] = tb _context_model = None return _context_model, status # ══════════════════════════════════════════════════════════════════════════════ # LOADER — Protein Model (XGBoost .pkl from mutation-explainable-v6) # ══════════════════════════════════════════════════════════════════════════════ def load_protein_model(): """ Load the pickled XGBoost model from nileshhanotia/mutation-explainable-v6. Loading logic adapted from: nileshhanotia/mutation-explainable-v6 (model_v6.pkl) Uses Python pickle / joblib — NOT XGBoost Booster.load_model(). The model is already stored as a complete trained sklearn-compatible object. Returns ------- (model | None, {"loaded": bool, "error_message": str}) """ global _protein_model status = {"loaded": False, "error_message": ""} try: from huggingface_hub import hf_hub_download MODEL_REPO = "nileshhanotia/mutation-explainable-v6" MODEL_FILENAME = "model_v6.pkl" print(f"[protein] Downloading {MODEL_FILENAME} from {MODEL_REPO} …") pkl_path = hf_hub_download( repo_id=MODEL_REPO, filename=MODEL_FILENAME, token=HF_TOKEN, ) print(f"[protein] Loading pickle from {pkl_path} …") # Try joblib first (common for sklearn/xgboost pipelines), fall back to pickle try: import joblib model = joblib.load(pkl_path) print("[protein] Loaded via joblib.") except Exception: with open(pkl_path, "rb") as f: model = pickle.load(f) print("[protein] Loaded via pickle.") print(f"[protein] ✓ Loaded protein model: {type(model).__name__}") _protein_model = model status["loaded"] = True except Exception: tb = traceback.format_exc() print(f"[protein] ✗ FAILED to load:\n{tb}") status["error_message"] = tb _protein_model = None return _protein_model, status # ══════════════════════════════════════════════════════════════════════════════ # STATUS AGGREGATOR # ══════════════════════════════════════════════════════════════════════════════ def get_model_status() -> dict: """ Load all three models and return a unified status dictionary. Returns ------- { "splice": {"loaded": bool, "error_message": str}, "context": {"loaded": bool, "error_message": str}, "protein": {"loaded": bool, "error_message": str}, } """ print("=" * 60) print("PeVe — starting unified model loading") print("=" * 60) _, splice_status = load_splice_model() _, context_status = load_context_model() _, protein_status = load_protein_model() status = { "splice": splice_status, "context": context_status, "protein": protein_status, } # Summary report print("\n" + "=" * 60) print("PeVe — model loading complete") print("=" * 60) for name, s in status.items(): icon = "✓" if s["loaded"] else "✗" print(f" [{icon}] {name:10s} loaded={s['loaded']}") print("=" * 60 + "\n") return status # ══════════════════════════════════════════════════════════════════════════════ # PUBLIC ACCESSORS # ══════════════════════════════════════════════════════════════════════════════ def get_splice_model(): """Return the loaded splice model handle (None if not loaded).""" return _splice_model def get_context_model(): """Return the loaded context model handle (None if not loaded).""" return _context_model def get_protein_model(): """Return the loaded protein model handle (None if not loaded).""" return _protein_model # ══════════════════════════════════════════════════════════════════════════════ # SELF-TEST # ══════════════════════════════════════════════════════════════════════════════ if __name__ == "__main__": print("Testing model loading...") status = get_model_status() print(status)