"""model_loader.py — PeVe v1.1""" from __future__ import annotations import os, pickle, warnings from pathlib import Path import numpy as np from config import MODELS _splice_model = _context_model = _protein_model = None _splice_tok = _context_tok = None def get_splice_model(): global _splice_model, _splice_tok if _splice_model is None: _splice_model, _splice_tok = _load_torch(MODELS["splice"], "splice") return _splice_model, _splice_tok def get_context_model(): global _context_model, _context_tok if _context_model is None: _context_model, _context_tok = _load_torch(MODELS["context"], "context") return _context_model, _context_tok def get_protein_model(): global _protein_model if _protein_model is None: _protein_model = _load_protein(MODELS["protein"]) return _protein_model def _load_torch(repo_id, key): import torch from huggingface_hub import snapshot_download print(f"[PeVe] Loading {key} model from {repo_id}") try: from transformers import AutoModel, AutoTokenizer model = AutoModel.from_pretrained(repo_id) model.eval() try: tok = AutoTokenizer.from_pretrained(repo_id) except Exception: tok = None print(f"[PeVe] {key}: loaded via AutoModel") return model, tok except Exception as e1: warnings.warn(f"AutoModel failed ({e1}), trying direct load") try: local = snapshot_download(repo_id=repo_id) candidates = list(Path(local).glob("*.pt")) + list(Path(local).glob("*.pth")) + list(Path(local).glob("*.bin")) if not candidates: raise FileNotFoundError("No model file found") obj = torch.load(candidates[0], map_location="cpu", weights_only=False) model = obj.get("model", obj) if isinstance(obj, dict) else obj print(f"[PeVe] {key}: loaded via torch.load") return model, None except Exception as e2: warnings.warn(f"Direct load failed ({e2}) — {key} will use fallback") return None, None def _load_protein(repo_id): import xgboost as xgb from huggingface_hub import snapshot_download print(f"[PeVe] Loading protein model from {repo_id}") try: local = snapshot_download(repo_id=repo_id) for ext in ["*.pkl","*.json","*.ubj","*.bin","*.model"]: for p in Path(local).glob(ext): if p.suffix == ".pkl": with open(p,"rb") as f: return pickle.load(f) m = xgb.Booster(); m.load_model(str(p)); return m raise FileNotFoundError("No XGBoost file found") except Exception as exc: warnings.warn(f"Protein model load failed: {exc}") return None