Mutation_XAI / model_loader_w.py
nileshhanotia's picture
Rename model_loader.py to model_loader_w.py
a530b25 verified
"""model_loader.py — PeVe v1.1"""
from __future__ import annotations
import os, pickle, warnings
from pathlib import Path
import numpy as np
from config import MODELS
_splice_model = _context_model = _protein_model = None
_splice_tok = _context_tok = None
def get_splice_model():
global _splice_model, _splice_tok
if _splice_model is None:
_splice_model, _splice_tok = _load_torch(MODELS["splice"], "splice")
return _splice_model, _splice_tok
def get_context_model():
global _context_model, _context_tok
if _context_model is None:
_context_model, _context_tok = _load_torch(MODELS["context"], "context")
return _context_model, _context_tok
def get_protein_model():
global _protein_model
if _protein_model is None:
_protein_model = _load_protein(MODELS["protein"])
return _protein_model
def _load_torch(repo_id, key):
import torch
from huggingface_hub import snapshot_download
print(f"[PeVe] Loading {key} model from {repo_id}")
try:
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained(repo_id)
model.eval()
try:
tok = AutoTokenizer.from_pretrained(repo_id)
except Exception:
tok = None
print(f"[PeVe] {key}: loaded via AutoModel")
return model, tok
except Exception as e1:
warnings.warn(f"AutoModel failed ({e1}), trying direct load")
try:
local = snapshot_download(repo_id=repo_id)
candidates = list(Path(local).glob("*.pt")) + list(Path(local).glob("*.pth")) + list(Path(local).glob("*.bin"))
if not candidates:
raise FileNotFoundError("No model file found")
obj = torch.load(candidates[0], map_location="cpu", weights_only=False)
model = obj.get("model", obj) if isinstance(obj, dict) else obj
print(f"[PeVe] {key}: loaded via torch.load")
return model, None
except Exception as e2:
warnings.warn(f"Direct load failed ({e2}) — {key} will use fallback")
return None, None
def _load_protein(repo_id):
import xgboost as xgb
from huggingface_hub import snapshot_download
print(f"[PeVe] Loading protein model from {repo_id}")
try:
local = snapshot_download(repo_id=repo_id)
for ext in ["*.pkl","*.json","*.ubj","*.bin","*.model"]:
for p in Path(local).glob(ext):
if p.suffix == ".pkl":
with open(p,"rb") as f: return pickle.load(f)
m = xgb.Booster(); m.load_model(str(p)); return m
raise FileNotFoundError("No XGBoost file found")
except Exception as exc:
warnings.warn(f"Protein model load failed: {exc}")
return None