Spaces:
Sleeping
Sleeping
| """ | |
| model_loader.py — PeVe Unified Space Model Loading Module | |
| Loading logic adapted from: | |
| - nileshhanotia/mutation-predictor-splice-app (app.py) | |
| - nileshhanotia/mutation-pathogenicity-app (app.py) | |
| - nileshhanotia/mutation-explainable-v6 (model_v6.pkl) | |
| Provides: | |
| load_splice_model() → (model, status_dict) | |
| load_context_model() → (model, status_dict) | |
| load_protein_model() → (model, status_dict) | |
| get_model_status() → combined status dict | |
| """ | |
| import os | |
| import traceback | |
| import pickle | |
| import torch | |
| import torch.nn as nn | |
| # ── Optional: set HF token for private repos ─────────────────────────────── | |
| # Either set the environment variable HF_TOKEN before running, or hard-code | |
| # a token here (not recommended for public repos). | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # MODULE-LEVEL MODEL HANDLES | |
| # These are populated by the load_*() functions below. | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| _splice_model = None | |
| _context_model = None | |
| _protein_model = None | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # ARCHITECTURE — Splice Model | |
| # Adapted from: nileshhanotia/mutation-predictor-splice-app app.py | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def _get_mutation_position_from_input(x_flat): | |
| """Internal helper used by MutationPredictorCNN_v2.forward().""" | |
| return x_flat[:, 990:1089].argmax(dim=1) | |
| class MutationPredictorCNN_v2(nn.Module): | |
| """ | |
| Splice-aware mutation predictor. | |
| Architecture copied verbatim from mutation-predictor-splice-app/app.py | |
| to guarantee weight compatibility. | |
| """ | |
| def __init__(self, fc_region_out=8, splice_fc_out=16): | |
| super().__init__() | |
| fc1_in = 256 + 32 + fc_region_out + splice_fc_out | |
| self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3) | |
| self.bn1 = nn.BatchNorm1d(64) | |
| self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2) | |
| self.bn2 = nn.BatchNorm1d(128) | |
| self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1) | |
| self.bn3 = nn.BatchNorm1d(256) | |
| self.global_pool = nn.AdaptiveAvgPool1d(1) | |
| self.mut_fc = nn.Linear(12, 32) | |
| self.importance_head = nn.Linear(256, 1) | |
| self.region_importance_head = nn.Linear(256, 2) | |
| self.fc_region = nn.Linear(2, fc_region_out) | |
| self.splice_fc = nn.Linear(3, splice_fc_out) | |
| self.splice_importance_head = nn.Linear(256, 3) | |
| self.fc1 = nn.Linear(fc1_in, 128) | |
| self.fc2 = nn.Linear(128, 64) | |
| self.fc3 = nn.Linear(64, 1) | |
| self.relu = nn.ReLU() | |
| self.dropout = nn.Dropout(0.4) | |
| def forward(self, x, mutation_positions=None): | |
| bs = x.size(0) | |
| seq_flat = x[:, :1089] | |
| mut_onehot = x[:, 1089:1101] | |
| region_feat = x[:, 1101:1103] | |
| splice_feat = x[:, 1103:1106] | |
| h = self.relu(self.bn1(self.conv1(seq_flat.view(bs, 11, 99)))) | |
| h = self.relu(self.bn2(self.conv2(h))) | |
| conv_out = self.relu(self.bn3(self.conv3(h))) | |
| if mutation_positions is None: | |
| mutation_positions = _get_mutation_position_from_input(x) | |
| pos_idx = mutation_positions.clamp(0, 98).long() | |
| pe = pos_idx.view(bs, 1, 1).expand(bs, 256, 1) | |
| mut_feat = conv_out.gather(2, pe).squeeze(2) | |
| imp_score = torch.sigmoid(self.importance_head(mut_feat)) | |
| pooled = self.global_pool(conv_out).squeeze(-1) | |
| r_imp = torch.sigmoid(self.region_importance_head(pooled)) | |
| s_imp = torch.sigmoid(self.splice_importance_head(pooled)) | |
| m = self.relu(self.mut_fc(mut_onehot)) | |
| r = self.relu(self.fc_region(region_feat)) | |
| s = self.relu(self.splice_fc(splice_feat)) | |
| fused = torch.cat([pooled, m, r, s], dim=1) | |
| out = self.dropout(self.relu(self.fc1(fused))) | |
| out = self.dropout(self.relu(self.fc2(out))) | |
| logit = self.fc3(out) | |
| return logit, imp_score, r_imp, s_imp | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # ARCHITECTURE — Context (401 bp CNN) Model | |
| # Adapted from: nileshhanotia/mutation-predictor-v4 | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| class MutationContextCNN(nn.Module): | |
| """ | |
| 401 bp context window CNN for mutation pathogenicity. | |
| Architecture mirrors the v4 space model; weights loaded from state dict. | |
| If the actual v4 architecture differs, the load_state_dict call will raise | |
| a descriptive KeyError that will be captured in the status dict. | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = nn.Conv1d(5, 64, kernel_size=11, padding=5) | |
| self.bn1 = nn.BatchNorm1d(64) | |
| self.conv2 = nn.Conv1d(64, 128, kernel_size=7, padding=3) | |
| self.bn2 = nn.BatchNorm1d(128) | |
| self.conv3 = nn.Conv1d(128, 256, kernel_size=5, padding=2) | |
| self.bn3 = nn.BatchNorm1d(256) | |
| self.pool = nn.AdaptiveAvgPool1d(1) | |
| self.fc1 = nn.Linear(256, 128) | |
| self.fc2 = nn.Linear(128, 64) | |
| self.fc3 = nn.Linear(64, 1) | |
| self.relu = nn.ReLU() | |
| self.drop = nn.Dropout(0.3) | |
| def forward(self, x): | |
| # x: (batch, seq_len, channels) → permute → (batch, channels, seq_len) | |
| h = x.permute(0, 2, 1) | |
| h = self.relu(self.bn1(self.conv1(h))) | |
| h = self.relu(self.bn2(self.conv2(h))) | |
| h = self.relu(self.bn3(self.conv3(h))) | |
| h = self.pool(h).squeeze(-1) | |
| h = self.drop(self.relu(self.fc1(h))) | |
| h = self.drop(self.relu(self.fc2(h))) | |
| return self.fc3(h) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # LOADER — Splice Model | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def load_splice_model(): | |
| """ | |
| Load MutationPredictorCNN_v2 from nileshhanotia/mutation-predictor-splice. | |
| Loading logic adapted from: | |
| nileshhanotia/mutation-predictor-splice-app app.py | |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| sd = ckpt["model_state_dict"] | |
| Returns | |
| ------- | |
| (model | None, {"loaded": bool, "error_message": str}) | |
| """ | |
| global _splice_model | |
| status = {"loaded": False, "error_message": ""} | |
| try: | |
| from huggingface_hub import hf_hub_download # local import for clarity | |
| MODEL_REPO = "nileshhanotia/mutation-predictor-splice" | |
| MODEL_FILENAME = "mutation_predictor_splice.pt" | |
| print(f"[splice] Downloading {MODEL_FILENAME} from {MODEL_REPO} …") | |
| ckpt_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=MODEL_FILENAME, | |
| token=HF_TOKEN, | |
| ) | |
| print(f"[splice] Loading checkpoint from {ckpt_path} …") | |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| sd = ckpt["model_state_dict"] | |
| # Infer architecture hyper-params from the state dict (exact pattern from app.py) | |
| fc_region_out = sd["fc_region.weight"].shape[0] | |
| splice_fc_out = sd["splice_fc.weight"].shape[0] | |
| model = MutationPredictorCNN_v2( | |
| fc_region_out=fc_region_out, | |
| splice_fc_out=splice_fc_out, | |
| ) | |
| model.load_state_dict(sd) | |
| model.eval() | |
| val_acc = ckpt.get("val_accuracy", float("nan")) | |
| print(f"[splice] ✓ Loaded. val_accuracy={val_acc:.4f} | " | |
| f"fc_region_out={fc_region_out} | splice_fc_out={splice_fc_out}") | |
| _splice_model = model | |
| status["loaded"] = True | |
| except Exception: | |
| tb = traceback.format_exc() | |
| print(f"[splice] ✗ FAILED to load:\n{tb}") | |
| status["error_message"] = tb | |
| _splice_model = None | |
| return _splice_model, status | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # LOADER — Context Model (401 bp CNN, mutation-predictor-v4) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def load_context_model(): | |
| """ | |
| Load the 401 bp context CNN from nileshhanotia/mutation-predictor-v4. | |
| Loading logic adapted from: | |
| nileshhanotia/mutation-pathogenicity-app app.py | |
| checkpoint = torch.load(MODEL_PATH, map_location=device) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| Returns | |
| ------- | |
| (model | None, {"loaded": bool, "error_message": str}) | |
| """ | |
| global _context_model | |
| status = {"loaded": False, "error_message": ""} | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| MODEL_REPO = "nileshhanotia/mutation-predictor-v4" | |
| # Try common checkpoint filenames used in HF spaces | |
| CANDIDATE_FILENAMES = [ | |
| "pytorch_model.pth", | |
| "mutation_predictor_v4.pt", | |
| "model.pt", | |
| "model.pth", | |
| "checkpoint.pth", | |
| ] | |
| ckpt_path = None | |
| last_error = "" | |
| for fname in CANDIDATE_FILENAMES: | |
| try: | |
| print(f"[context] Trying {fname} from {MODEL_REPO} …") | |
| ckpt_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=fname, | |
| token=HF_TOKEN, | |
| ) | |
| print(f"[context] Found: {fname}") | |
| break | |
| except Exception as e: | |
| last_error = str(e) | |
| continue | |
| if ckpt_path is None: | |
| raise FileNotFoundError( | |
| f"None of the candidate filenames found in {MODEL_REPO}. " | |
| f"Last error: {last_error}" | |
| ) | |
| print(f"[context] Loading checkpoint from {ckpt_path} …") | |
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| # Support both raw state-dict and wrapped checkpoint | |
| if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: | |
| sd = checkpoint["model_state_dict"] | |
| elif isinstance(checkpoint, dict) and "state_dict" in checkpoint: | |
| sd = checkpoint["state_dict"] | |
| else: | |
| sd = checkpoint # assume it IS the state dict | |
| model = MutationContextCNN() | |
| model.load_state_dict(sd, strict=False) # strict=False tolerates minor arch diffs | |
| model.eval() | |
| print("[context] ✓ Loaded MutationContextCNN (401 bp).") | |
| _context_model = model | |
| status["loaded"] = True | |
| except Exception: | |
| tb = traceback.format_exc() | |
| print(f"[context] ✗ FAILED to load:\n{tb}") | |
| status["error_message"] = tb | |
| _context_model = None | |
| return _context_model, status | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # LOADER — Protein Model (XGBoost .pkl from mutation-explainable-v6) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def load_protein_model(): | |
| """ | |
| Load the pickled XGBoost model from nileshhanotia/mutation-explainable-v6. | |
| Loading logic adapted from: | |
| nileshhanotia/mutation-explainable-v6 (model_v6.pkl) | |
| Uses Python pickle / joblib — NOT XGBoost Booster.load_model(). | |
| The model is already stored as a complete trained sklearn-compatible object. | |
| Returns | |
| ------- | |
| (model | None, {"loaded": bool, "error_message": str}) | |
| """ | |
| global _protein_model | |
| status = {"loaded": False, "error_message": ""} | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| MODEL_REPO = "nileshhanotia/mutation-explainable-v6" | |
| MODEL_FILENAME = "model_v6.pkl" | |
| print(f"[protein] Downloading {MODEL_FILENAME} from {MODEL_REPO} …") | |
| pkl_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=MODEL_FILENAME, | |
| token=HF_TOKEN, | |
| ) | |
| print(f"[protein] Loading pickle from {pkl_path} …") | |
| # Try joblib first (common for sklearn/xgboost pipelines), fall back to pickle | |
| try: | |
| import joblib | |
| model = joblib.load(pkl_path) | |
| print("[protein] Loaded via joblib.") | |
| except Exception: | |
| with open(pkl_path, "rb") as f: | |
| model = pickle.load(f) | |
| print("[protein] Loaded via pickle.") | |
| print(f"[protein] ✓ Loaded protein model: {type(model).__name__}") | |
| _protein_model = model | |
| status["loaded"] = True | |
| except Exception: | |
| tb = traceback.format_exc() | |
| print(f"[protein] ✗ FAILED to load:\n{tb}") | |
| status["error_message"] = tb | |
| _protein_model = None | |
| return _protein_model, status | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # STATUS AGGREGATOR | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def get_model_status() -> dict: | |
| """ | |
| Load all three models and return a unified status dictionary. | |
| Returns | |
| ------- | |
| { | |
| "splice": {"loaded": bool, "error_message": str}, | |
| "context": {"loaded": bool, "error_message": str}, | |
| "protein": {"loaded": bool, "error_message": str}, | |
| } | |
| """ | |
| print("=" * 60) | |
| print("PeVe — starting unified model loading") | |
| print("=" * 60) | |
| _, splice_status = load_splice_model() | |
| _, context_status = load_context_model() | |
| _, protein_status = load_protein_model() | |
| status = { | |
| "splice": splice_status, | |
| "context": context_status, | |
| "protein": protein_status, | |
| } | |
| # Summary report | |
| print("\n" + "=" * 60) | |
| print("PeVe — model loading complete") | |
| print("=" * 60) | |
| for name, s in status.items(): | |
| icon = "✓" if s["loaded"] else "✗" | |
| print(f" [{icon}] {name:10s} loaded={s['loaded']}") | |
| print("=" * 60 + "\n") | |
| return status | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # PUBLIC ACCESSORS | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def get_splice_model(): | |
| """Return the loaded splice model handle (None if not loaded).""" | |
| return _splice_model | |
| def get_context_model(): | |
| """Return the loaded context model handle (None if not loaded).""" | |
| return _context_model | |
| def get_protein_model(): | |
| """Return the loaded protein model handle (None if not loaded).""" | |
| return _protein_model | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # SELF-TEST | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| if __name__ == "__main__": | |
| print("Testing model loading...") | |
| status = get_model_status() | |
| print(status) | |