""" registry.py — Auto-discovers all Sniper model repos on HuggingFace and manages artifact loading (model PKLs, calibrators, regime models, metadata). """ import os import re import json import time import logging from pathlib import Path from typing import Optional from dataclasses import dataclass, field import joblib import requests logger = logging.getLogger("SniperRegistry") # Patch __main__ immediately so joblib can resolve SniperModel from # any pickle that was trained with __main__.SniperModel from src.sniper_model import patch_main as _patch_main, safe_load as _safe_load _patch_main() CACHE_DIR = Path("/tmp/sniper_cache") CACHE_DIR.mkdir(parents=True, exist_ok=True) HF_AUTHOR = "Arkm20" MODEL_PREFIX = "sniper-model" # --------------------------------------------------------------------------- # Data structures # --------------------------------------------------------------------------- @dataclass class ArtifactBundle: """One timestamped run within a model repo.""" repo_id: str timestamp: str main_model: object = None stage1_model: object = None calibrator: object = None regime_models: dict = field(default_factory=dict) metadata: dict = field(default_factory=dict) feature_list: list = field(default_factory=list) optimal_threshold: float = 0.5 @property def label(self) -> str: repo_short = self.repo_id.split("/")[-1] return f"{repo_short} · run {self.timestamp}" @property def model_version(self) -> str: return self.metadata.get("model_version", "?") @property def has_regime_models(self) -> bool: return len(self.regime_models) > 0 @property def has_calibrator(self) -> bool: return self.calibrator is not None @property def has_two_stage(self) -> bool: return self.stage1_model is not None @dataclass class ModelRepo: repo_id: str bundles: list # list[ArtifactBundle], ordered newest first @property def latest(self) -> Optional[ArtifactBundle]: return self.bundles[0] if self.bundles else None @property def short_name(self) -> str: return self.repo_id.split("/")[-1] # --------------------------------------------------------------------------- # HuggingFace API helpers # --------------------------------------------------------------------------- def _hf_api_get(url: str, token: str = None, retries: int = 3) -> dict | list | None: headers = {} if token: headers["Authorization"] = f"Bearer {token}" for attempt in range(retries): try: resp = requests.get(url, headers=headers, timeout=30) if resp.status_code == 200: return resp.json() if resp.status_code == 404: return None except Exception as e: logger.warning(f"HF API attempt {attempt+1} failed: {e}") if attempt < retries - 1: time.sleep(2 ** attempt) return None def _list_sniper_repos(token: str = None) -> list[str]: """Return all repo_ids matching Arkm20/sniper-model-*""" url = f"https://huggingface.co/api/models?author={HF_AUTHOR}&search={MODEL_PREFIX}&limit=50" data = _hf_api_get(url, token) if not data: return [] repos = [m.get("modelId") or m.get("id") for m in data] return [r for r in repos if r and MODEL_PREFIX in r] def _list_repo_files(repo_id: str, token: str = None) -> list[str]: url = f"https://huggingface.co/api/models/{repo_id}" data = _hf_api_get(url, token) if not data: return [] siblings = data.get("siblings", []) return [s["rfilename"] for s in siblings if "rfilename" in s] def _group_by_timestamp(files: list[str]) -> dict[str, dict]: """ Group artifact files by their timestamp suffix. e.g. lgb_model_20250312_1430.pkl -> ts=20250312_1430 Returns {timestamp: {"main": path, "stage1": path, "calibrator": path, "regime": {name: path}, "metadata": path}} """ ts_pattern = re.compile(r"(\d{8}_\d{4})") groups: dict[str, dict] = {} for f in files: m = ts_pattern.search(f) if not m: continue ts = m.group(1) if ts not in groups: groups[ts] = {"main": None, "stage1": None, "calibrator": None, "regime": {}, "metadata": None} fname = Path(f).name if fname.startswith("lgb_model_"): groups[ts]["main"] = f elif fname.startswith("stage1_"): groups[ts]["stage1"] = f elif fname.startswith("calibrator_"): groups[ts]["calibrator"] = f elif fname.startswith("regime_"): # e.g. regime_mkt0_vix1_20250312_1430.pkl regime_name = fname.replace(f"_{ts}.pkl", "").replace("regime_", "") groups[ts]["regime"][regime_name] = f elif fname.startswith("metadata_final_"): groups[ts]["metadata"] = f return groups # --------------------------------------------------------------------------- # Artifact download + caching # --------------------------------------------------------------------------- def _download_file(repo_id: str, path_in_repo: str, token: str = None) -> Path | None: safe_name = path_in_repo.replace("/", "_") cache_path = CACHE_DIR / repo_id.replace("/", "__") / safe_name cache_path.parent.mkdir(parents=True, exist_ok=True) if cache_path.exists(): return cache_path url = f"https://huggingface.co/{repo_id}/resolve/main/{path_in_repo}" headers = {} if token: headers["Authorization"] = f"Bearer {token}" for attempt in range(3): try: resp = requests.get(url, headers=headers, timeout=120, stream=True) if resp.status_code == 200: with open(cache_path, "wb") as fh: for chunk in resp.iter_content(chunk_size=8192): fh.write(chunk) return cache_path logger.warning(f"Download {url} returned {resp.status_code}") except Exception as e: logger.warning(f"Download attempt {attempt+1} failed: {e}") if attempt < 2: time.sleep(2 ** attempt) return None def _load_pkl(repo_id: str, path_in_repo: str, token: str = None): if not path_in_repo: return None local = _download_file(repo_id, path_in_repo, token) if local is None: logger.warning(f"Could not download {path_in_repo}") return None try: return _safe_load(local) except Exception as e: logger.error(f"Failed to load {local}: {e}") return None def _load_metadata(repo_id: str, path_in_repo: str, token: str = None) -> dict: if not path_in_repo: return {} local = _download_file(repo_id, path_in_repo, token) if local is None: return {} try: with open(local) as fh: return json.load(fh) except Exception as e: logger.error(f"Failed to parse metadata {local}: {e}") return {} # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def discover_repos(token: str = None, progress_cb=None) -> list[ModelRepo]: """ Discover all Arkm20/sniper-model-* repos and return ModelRepo objects. progress_cb(message: str) is called at each step for UI updates. """ def _cb(msg): if progress_cb: progress_cb(msg) logger.info(msg) _cb("Scanning HuggingFace for Sniper model repositories...") repo_ids = _list_sniper_repos(token) if not repo_ids: _cb("No sniper-model-* repositories found under Arkm20.") return [] _cb(f"Found {len(repo_ids)} repo(s): {', '.join(r.split('/')[-1] for r in repo_ids)}") model_repos = [] for repo_id in sorted(repo_ids, reverse=True): _cb(f"Listing files in {repo_id}...") files = _list_repo_files(repo_id, token) groups = _group_by_timestamp(files) if not groups: _cb(f" No timestamped artifacts found in {repo_id}, skipping.") continue bundles = [] for ts in sorted(groups.keys(), reverse=True): g = groups[ts] meta = _load_metadata(repo_id, g["metadata"], token) if g["metadata"] else {} bundle = ArtifactBundle( repo_id=repo_id, timestamp=ts, metadata=meta, feature_list=meta.get("feature_list", []), optimal_threshold=meta.get("optimal_threshold", 0.5), ) bundles.append(bundle) model_repos.append(ModelRepo(repo_id=repo_id, bundles=bundles)) _cb(f" {repo_id}: {len(bundles)} run(s) found.") return model_repos def load_bundle(bundle: ArtifactBundle, token: str = None, progress_cb=None) -> ArtifactBundle: """ Download and load all PKL artifacts for a bundle (lazy, called on first use). """ def _cb(msg): if progress_cb: progress_cb(msg) logger.info(msg) # Find the group again from metadata to get file paths repo_id = bundle.repo_id files = _list_repo_files(repo_id, token) groups = _group_by_timestamp(files) g = groups.get(bundle.timestamp, {}) _cb(f"Loading main model ({bundle.timestamp})...") bundle.main_model = _load_pkl(repo_id, g.get("main"), token) if g.get("stage1"): _cb("Loading Stage 1 model...") bundle.stage1_model = _load_pkl(repo_id, g["stage1"], token) if g.get("calibrator"): _cb("Loading calibrator...") bundle.calibrator = _load_pkl(repo_id, g["calibrator"], token) for rname, rpath in g.get("regime", {}).items(): _cb(f"Loading regime model: {rname}...") bundle.regime_models[rname] = _load_pkl(repo_id, rpath, token) _cb("All artifacts loaded.") return bundle def get_all_bundle_labels(repos: list[ModelRepo]) -> list[str]: """Return flat list of all bundle labels for dropdown population.""" labels = [] for repo in repos: for bundle in repo.bundles: labels.append(bundle.label) return labels def find_bundle_by_label(repos: list[ModelRepo], label: str) -> Optional[ArtifactBundle]: """Reverse-lookup a bundle from its display label.""" for repo in repos: for bundle in repo.bundles: if bundle.label == label: return bundle return None # --------------------------------------------------------------------------- # Inference helpers # --------------------------------------------------------------------------- def predict_proba(bundle: ArtifactBundle, X, use_regime: bool = True, sp500_above_sma: bool = True, vix_high: bool = False) -> "np.ndarray": """ Run inference: optionally route to regime model, apply calibrator. Returns calibrated probabilities (N,). """ import numpy as np from scipy.special import expit if bundle.main_model is None: raise RuntimeError("Bundle not loaded — call load_bundle() first.") # Regime routing model = bundle.main_model if use_regime and bundle.has_regime_models: mkt = 1 if sp500_above_sma else 0 vix = 1 if vix_high else 0 rkey = f"mkt{mkt}_vix{vix}" if rkey in bundle.regime_models and bundle.regime_models[rkey] is not None: model = bundle.regime_models[rkey] raw_proba = model.predict_proba(X)[:, 1] # Calibrate if bundle.has_calibrator: raw_proba = bundle.calibrator.predict(raw_proba) return raw_proba