Spaces:
Sleeping
Sleeping
| """model_loader.py — PeVe v1.1""" | |
| from __future__ import annotations | |
| import os, pickle, warnings | |
| from pathlib import Path | |
| import numpy as np | |
| from config import MODELS | |
| _splice_model = _context_model = _protein_model = None | |
| _splice_tok = _context_tok = None | |
| def get_splice_model(): | |
| global _splice_model, _splice_tok | |
| if _splice_model is None: | |
| _splice_model, _splice_tok = _load_torch(MODELS["splice"], "splice") | |
| return _splice_model, _splice_tok | |
| def get_context_model(): | |
| global _context_model, _context_tok | |
| if _context_model is None: | |
| _context_model, _context_tok = _load_torch(MODELS["context"], "context") | |
| return _context_model, _context_tok | |
| def get_protein_model(): | |
| global _protein_model | |
| if _protein_model is None: | |
| _protein_model = _load_protein(MODELS["protein"]) | |
| return _protein_model | |
| def _load_torch(repo_id, key): | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| print(f"[PeVe] Loading {key} model from {repo_id}") | |
| try: | |
| from transformers import AutoModel, AutoTokenizer | |
| model = AutoModel.from_pretrained(repo_id) | |
| model.eval() | |
| try: | |
| tok = AutoTokenizer.from_pretrained(repo_id) | |
| except Exception: | |
| tok = None | |
| print(f"[PeVe] {key}: loaded via AutoModel") | |
| return model, tok | |
| except Exception as e1: | |
| warnings.warn(f"AutoModel failed ({e1}), trying direct load") | |
| try: | |
| local = snapshot_download(repo_id=repo_id) | |
| candidates = list(Path(local).glob("*.pt")) + list(Path(local).glob("*.pth")) + list(Path(local).glob("*.bin")) | |
| if not candidates: | |
| raise FileNotFoundError("No model file found") | |
| obj = torch.load(candidates[0], map_location="cpu", weights_only=False) | |
| model = obj.get("model", obj) if isinstance(obj, dict) else obj | |
| print(f"[PeVe] {key}: loaded via torch.load") | |
| return model, None | |
| except Exception as e2: | |
| warnings.warn(f"Direct load failed ({e2}) — {key} will use fallback") | |
| return None, None | |
| def _load_protein(repo_id): | |
| import xgboost as xgb | |
| from huggingface_hub import snapshot_download | |
| print(f"[PeVe] Loading protein model from {repo_id}") | |
| try: | |
| local = snapshot_download(repo_id=repo_id) | |
| for ext in ["*.pkl","*.json","*.ubj","*.bin","*.model"]: | |
| for p in Path(local).glob(ext): | |
| if p.suffix == ".pkl": | |
| with open(p,"rb") as f: return pickle.load(f) | |
| m = xgb.Booster(); m.load_model(str(p)); return m | |
| raise FileNotFoundError("No XGBoost file found") | |
| except Exception as exc: | |
| warnings.warn(f"Protein model load failed: {exc}") | |
| return None | |