model-tester / src /registry.py
Arkm20's picture
Update src/registry.py
fb20032 verified
"""
registry.py — Auto-discovers all Sniper model repos on HuggingFace and
manages artifact loading (model PKLs, calibrators, regime models, metadata).
"""
import os
import re
import json
import time
import logging
from pathlib import Path
from typing import Optional
from dataclasses import dataclass, field
import joblib
import requests
logger = logging.getLogger("SniperRegistry")
# Patch __main__ immediately so joblib can resolve SniperModel from
# any pickle that was trained with __main__.SniperModel
from src.sniper_model import patch_main as _patch_main, safe_load as _safe_load
_patch_main()
CACHE_DIR = Path("/tmp/sniper_cache")
CACHE_DIR.mkdir(parents=True, exist_ok=True)
HF_AUTHOR = "Arkm20"
MODEL_PREFIX = "sniper-model"
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class ArtifactBundle:
"""One timestamped run within a model repo."""
repo_id: str
timestamp: str
main_model: object = None
stage1_model: object = None
calibrator: object = None
regime_models: dict = field(default_factory=dict)
metadata: dict = field(default_factory=dict)
feature_list: list = field(default_factory=list)
optimal_threshold: float = 0.5
@property
def label(self) -> str:
repo_short = self.repo_id.split("/")[-1]
return f"{repo_short} · run {self.timestamp}"
@property
def model_version(self) -> str:
return self.metadata.get("model_version", "?")
@property
def has_regime_models(self) -> bool:
return len(self.regime_models) > 0
@property
def has_calibrator(self) -> bool:
return self.calibrator is not None
@property
def has_two_stage(self) -> bool:
return self.stage1_model is not None
@dataclass
class ModelRepo:
repo_id: str
bundles: list # list[ArtifactBundle], ordered newest first
@property
def latest(self) -> Optional[ArtifactBundle]:
return self.bundles[0] if self.bundles else None
@property
def short_name(self) -> str:
return self.repo_id.split("/")[-1]
# ---------------------------------------------------------------------------
# HuggingFace API helpers
# ---------------------------------------------------------------------------
def _hf_api_get(url: str, token: str = None, retries: int = 3) -> dict | list | None:
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
for attempt in range(retries):
try:
resp = requests.get(url, headers=headers, timeout=30)
if resp.status_code == 200:
return resp.json()
if resp.status_code == 404:
return None
except Exception as e:
logger.warning(f"HF API attempt {attempt+1} failed: {e}")
if attempt < retries - 1:
time.sleep(2 ** attempt)
return None
def _list_sniper_repos(token: str = None) -> list[str]:
"""Return all repo_ids matching Arkm20/sniper-model-*"""
url = f"https://huggingface.co/api/models?author={HF_AUTHOR}&search={MODEL_PREFIX}&limit=50"
data = _hf_api_get(url, token)
if not data:
return []
repos = [m.get("modelId") or m.get("id") for m in data]
return [r for r in repos if r and MODEL_PREFIX in r]
def _list_repo_files(repo_id: str, token: str = None) -> list[str]:
url = f"https://huggingface.co/api/models/{repo_id}"
data = _hf_api_get(url, token)
if not data:
return []
siblings = data.get("siblings", [])
return [s["rfilename"] for s in siblings if "rfilename" in s]
def _group_by_timestamp(files: list[str]) -> dict[str, dict]:
"""
Group artifact files by their timestamp suffix.
e.g. lgb_model_20250312_1430.pkl -> ts=20250312_1430
Returns {timestamp: {"main": path, "stage1": path, "calibrator": path,
"regime": {name: path}, "metadata": path}}
"""
ts_pattern = re.compile(r"(\d{8}_\d{4})")
groups: dict[str, dict] = {}
for f in files:
m = ts_pattern.search(f)
if not m:
continue
ts = m.group(1)
if ts not in groups:
groups[ts] = {"main": None, "stage1": None, "calibrator": None,
"regime": {}, "metadata": None}
fname = Path(f).name
if fname.startswith("lgb_model_"):
groups[ts]["main"] = f
elif fname.startswith("stage1_"):
groups[ts]["stage1"] = f
elif fname.startswith("calibrator_"):
groups[ts]["calibrator"] = f
elif fname.startswith("regime_"):
# e.g. regime_mkt0_vix1_20250312_1430.pkl
regime_name = fname.replace(f"_{ts}.pkl", "").replace("regime_", "")
groups[ts]["regime"][regime_name] = f
elif fname.startswith("metadata_final_"):
groups[ts]["metadata"] = f
return groups
# ---------------------------------------------------------------------------
# Artifact download + caching
# ---------------------------------------------------------------------------
def _download_file(repo_id: str, path_in_repo: str, token: str = None) -> Path | None:
safe_name = path_in_repo.replace("/", "_")
cache_path = CACHE_DIR / repo_id.replace("/", "__") / safe_name
cache_path.parent.mkdir(parents=True, exist_ok=True)
if cache_path.exists():
return cache_path
url = f"https://huggingface.co/{repo_id}/resolve/main/{path_in_repo}"
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
for attempt in range(3):
try:
resp = requests.get(url, headers=headers, timeout=120, stream=True)
if resp.status_code == 200:
with open(cache_path, "wb") as fh:
for chunk in resp.iter_content(chunk_size=8192):
fh.write(chunk)
return cache_path
logger.warning(f"Download {url} returned {resp.status_code}")
except Exception as e:
logger.warning(f"Download attempt {attempt+1} failed: {e}")
if attempt < 2:
time.sleep(2 ** attempt)
return None
def _load_pkl(repo_id: str, path_in_repo: str, token: str = None):
if not path_in_repo:
return None
local = _download_file(repo_id, path_in_repo, token)
if local is None:
logger.warning(f"Could not download {path_in_repo}")
return None
try:
return _safe_load(local)
except Exception as e:
logger.error(f"Failed to load {local}: {e}")
return None
def _load_metadata(repo_id: str, path_in_repo: str, token: str = None) -> dict:
if not path_in_repo:
return {}
local = _download_file(repo_id, path_in_repo, token)
if local is None:
return {}
try:
with open(local) as fh:
return json.load(fh)
except Exception as e:
logger.error(f"Failed to parse metadata {local}: {e}")
return {}
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def discover_repos(token: str = None, progress_cb=None) -> list[ModelRepo]:
"""
Discover all Arkm20/sniper-model-* repos and return ModelRepo objects.
progress_cb(message: str) is called at each step for UI updates.
"""
def _cb(msg):
if progress_cb:
progress_cb(msg)
logger.info(msg)
_cb("Scanning HuggingFace for Sniper model repositories...")
repo_ids = _list_sniper_repos(token)
if not repo_ids:
_cb("No sniper-model-* repositories found under Arkm20.")
return []
_cb(f"Found {len(repo_ids)} repo(s): {', '.join(r.split('/')[-1] for r in repo_ids)}")
model_repos = []
for repo_id in sorted(repo_ids, reverse=True):
_cb(f"Listing files in {repo_id}...")
files = _list_repo_files(repo_id, token)
groups = _group_by_timestamp(files)
if not groups:
_cb(f" No timestamped artifacts found in {repo_id}, skipping.")
continue
bundles = []
for ts in sorted(groups.keys(), reverse=True):
g = groups[ts]
meta = _load_metadata(repo_id, g["metadata"], token) if g["metadata"] else {}
bundle = ArtifactBundle(
repo_id=repo_id,
timestamp=ts,
metadata=meta,
feature_list=meta.get("feature_list", []),
optimal_threshold=meta.get("optimal_threshold", 0.5),
)
bundles.append(bundle)
model_repos.append(ModelRepo(repo_id=repo_id, bundles=bundles))
_cb(f" {repo_id}: {len(bundles)} run(s) found.")
return model_repos
def load_bundle(bundle: ArtifactBundle, token: str = None, progress_cb=None) -> ArtifactBundle:
"""
Download and load all PKL artifacts for a bundle (lazy, called on first use).
"""
def _cb(msg):
if progress_cb:
progress_cb(msg)
logger.info(msg)
# Find the group again from metadata to get file paths
repo_id = bundle.repo_id
files = _list_repo_files(repo_id, token)
groups = _group_by_timestamp(files)
g = groups.get(bundle.timestamp, {})
_cb(f"Loading main model ({bundle.timestamp})...")
bundle.main_model = _load_pkl(repo_id, g.get("main"), token)
if g.get("stage1"):
_cb("Loading Stage 1 model...")
bundle.stage1_model = _load_pkl(repo_id, g["stage1"], token)
if g.get("calibrator"):
_cb("Loading calibrator...")
bundle.calibrator = _load_pkl(repo_id, g["calibrator"], token)
for rname, rpath in g.get("regime", {}).items():
_cb(f"Loading regime model: {rname}...")
bundle.regime_models[rname] = _load_pkl(repo_id, rpath, token)
_cb("All artifacts loaded.")
return bundle
def get_all_bundle_labels(repos: list[ModelRepo]) -> list[str]:
"""Return flat list of all bundle labels for dropdown population."""
labels = []
for repo in repos:
for bundle in repo.bundles:
labels.append(bundle.label)
return labels
def find_bundle_by_label(repos: list[ModelRepo], label: str) -> Optional[ArtifactBundle]:
"""Reverse-lookup a bundle from its display label."""
for repo in repos:
for bundle in repo.bundles:
if bundle.label == label:
return bundle
return None
# ---------------------------------------------------------------------------
# Inference helpers
# ---------------------------------------------------------------------------
def predict_proba(bundle: ArtifactBundle, X, use_regime: bool = True,
sp500_above_sma: bool = True, vix_high: bool = False) -> "np.ndarray":
"""
Run inference: optionally route to regime model, apply calibrator.
Returns calibrated probabilities (N,).
"""
import numpy as np
from scipy.special import expit
if bundle.main_model is None:
raise RuntimeError("Bundle not loaded — call load_bundle() first.")
# Regime routing
model = bundle.main_model
if use_regime and bundle.has_regime_models:
mkt = 1 if sp500_above_sma else 0
vix = 1 if vix_high else 0
rkey = f"mkt{mkt}_vix{vix}"
if rkey in bundle.regime_models and bundle.regime_models[rkey] is not None:
model = bundle.regime_models[rkey]
raw_proba = model.predict_proba(X)[:, 1]
# Calibrate
if bundle.has_calibrator:
raw_proba = bundle.calibrator.predict(raw_proba)
return raw_proba