Spaces:
Sleeping
Sleeping
| """ | |
| registry.py — Auto-discovers all Sniper model repos on HuggingFace and | |
| manages artifact loading (model PKLs, calibrators, regime models, metadata). | |
| """ | |
| import os | |
| import re | |
| import json | |
| import time | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional | |
| from dataclasses import dataclass, field | |
| import joblib | |
| import requests | |
| logger = logging.getLogger("SniperRegistry") | |
| # Patch __main__ immediately so joblib can resolve SniperModel from | |
| # any pickle that was trained with __main__.SniperModel | |
| from src.sniper_model import patch_main as _patch_main, safe_load as _safe_load | |
| _patch_main() | |
| CACHE_DIR = Path("/tmp/sniper_cache") | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| HF_AUTHOR = "Arkm20" | |
| MODEL_PREFIX = "sniper-model" | |
| # --------------------------------------------------------------------------- | |
| # Data structures | |
| # --------------------------------------------------------------------------- | |
| class ArtifactBundle: | |
| """One timestamped run within a model repo.""" | |
| repo_id: str | |
| timestamp: str | |
| main_model: object = None | |
| stage1_model: object = None | |
| calibrator: object = None | |
| regime_models: dict = field(default_factory=dict) | |
| metadata: dict = field(default_factory=dict) | |
| feature_list: list = field(default_factory=list) | |
| optimal_threshold: float = 0.5 | |
| def label(self) -> str: | |
| repo_short = self.repo_id.split("/")[-1] | |
| return f"{repo_short} · run {self.timestamp}" | |
| def model_version(self) -> str: | |
| return self.metadata.get("model_version", "?") | |
| def has_regime_models(self) -> bool: | |
| return len(self.regime_models) > 0 | |
| def has_calibrator(self) -> bool: | |
| return self.calibrator is not None | |
| def has_two_stage(self) -> bool: | |
| return self.stage1_model is not None | |
| class ModelRepo: | |
| repo_id: str | |
| bundles: list # list[ArtifactBundle], ordered newest first | |
| def latest(self) -> Optional[ArtifactBundle]: | |
| return self.bundles[0] if self.bundles else None | |
| def short_name(self) -> str: | |
| return self.repo_id.split("/")[-1] | |
| # --------------------------------------------------------------------------- | |
| # HuggingFace API helpers | |
| # --------------------------------------------------------------------------- | |
| def _hf_api_get(url: str, token: str = None, retries: int = 3) -> dict | list | None: | |
| headers = {} | |
| if token: | |
| headers["Authorization"] = f"Bearer {token}" | |
| for attempt in range(retries): | |
| try: | |
| resp = requests.get(url, headers=headers, timeout=30) | |
| if resp.status_code == 200: | |
| return resp.json() | |
| if resp.status_code == 404: | |
| return None | |
| except Exception as e: | |
| logger.warning(f"HF API attempt {attempt+1} failed: {e}") | |
| if attempt < retries - 1: | |
| time.sleep(2 ** attempt) | |
| return None | |
| def _list_sniper_repos(token: str = None) -> list[str]: | |
| """Return all repo_ids matching Arkm20/sniper-model-*""" | |
| url = f"https://huggingface.co/api/models?author={HF_AUTHOR}&search={MODEL_PREFIX}&limit=50" | |
| data = _hf_api_get(url, token) | |
| if not data: | |
| return [] | |
| repos = [m.get("modelId") or m.get("id") for m in data] | |
| return [r for r in repos if r and MODEL_PREFIX in r] | |
| def _list_repo_files(repo_id: str, token: str = None) -> list[str]: | |
| url = f"https://huggingface.co/api/models/{repo_id}" | |
| data = _hf_api_get(url, token) | |
| if not data: | |
| return [] | |
| siblings = data.get("siblings", []) | |
| return [s["rfilename"] for s in siblings if "rfilename" in s] | |
| def _group_by_timestamp(files: list[str]) -> dict[str, dict]: | |
| """ | |
| Group artifact files by their timestamp suffix. | |
| e.g. lgb_model_20250312_1430.pkl -> ts=20250312_1430 | |
| Returns {timestamp: {"main": path, "stage1": path, "calibrator": path, | |
| "regime": {name: path}, "metadata": path}} | |
| """ | |
| ts_pattern = re.compile(r"(\d{8}_\d{4})") | |
| groups: dict[str, dict] = {} | |
| for f in files: | |
| m = ts_pattern.search(f) | |
| if not m: | |
| continue | |
| ts = m.group(1) | |
| if ts not in groups: | |
| groups[ts] = {"main": None, "stage1": None, "calibrator": None, | |
| "regime": {}, "metadata": None} | |
| fname = Path(f).name | |
| if fname.startswith("lgb_model_"): | |
| groups[ts]["main"] = f | |
| elif fname.startswith("stage1_"): | |
| groups[ts]["stage1"] = f | |
| elif fname.startswith("calibrator_"): | |
| groups[ts]["calibrator"] = f | |
| elif fname.startswith("regime_"): | |
| # e.g. regime_mkt0_vix1_20250312_1430.pkl | |
| regime_name = fname.replace(f"_{ts}.pkl", "").replace("regime_", "") | |
| groups[ts]["regime"][regime_name] = f | |
| elif fname.startswith("metadata_final_"): | |
| groups[ts]["metadata"] = f | |
| return groups | |
| # --------------------------------------------------------------------------- | |
| # Artifact download + caching | |
| # --------------------------------------------------------------------------- | |
| def _download_file(repo_id: str, path_in_repo: str, token: str = None) -> Path | None: | |
| safe_name = path_in_repo.replace("/", "_") | |
| cache_path = CACHE_DIR / repo_id.replace("/", "__") / safe_name | |
| cache_path.parent.mkdir(parents=True, exist_ok=True) | |
| if cache_path.exists(): | |
| return cache_path | |
| url = f"https://huggingface.co/{repo_id}/resolve/main/{path_in_repo}" | |
| headers = {} | |
| if token: | |
| headers["Authorization"] = f"Bearer {token}" | |
| for attempt in range(3): | |
| try: | |
| resp = requests.get(url, headers=headers, timeout=120, stream=True) | |
| if resp.status_code == 200: | |
| with open(cache_path, "wb") as fh: | |
| for chunk in resp.iter_content(chunk_size=8192): | |
| fh.write(chunk) | |
| return cache_path | |
| logger.warning(f"Download {url} returned {resp.status_code}") | |
| except Exception as e: | |
| logger.warning(f"Download attempt {attempt+1} failed: {e}") | |
| if attempt < 2: | |
| time.sleep(2 ** attempt) | |
| return None | |
| def _load_pkl(repo_id: str, path_in_repo: str, token: str = None): | |
| if not path_in_repo: | |
| return None | |
| local = _download_file(repo_id, path_in_repo, token) | |
| if local is None: | |
| logger.warning(f"Could not download {path_in_repo}") | |
| return None | |
| try: | |
| return _safe_load(local) | |
| except Exception as e: | |
| logger.error(f"Failed to load {local}: {e}") | |
| return None | |
| def _load_metadata(repo_id: str, path_in_repo: str, token: str = None) -> dict: | |
| if not path_in_repo: | |
| return {} | |
| local = _download_file(repo_id, path_in_repo, token) | |
| if local is None: | |
| return {} | |
| try: | |
| with open(local) as fh: | |
| return json.load(fh) | |
| except Exception as e: | |
| logger.error(f"Failed to parse metadata {local}: {e}") | |
| return {} | |
| # --------------------------------------------------------------------------- | |
| # Public API | |
| # --------------------------------------------------------------------------- | |
| def discover_repos(token: str = None, progress_cb=None) -> list[ModelRepo]: | |
| """ | |
| Discover all Arkm20/sniper-model-* repos and return ModelRepo objects. | |
| progress_cb(message: str) is called at each step for UI updates. | |
| """ | |
| def _cb(msg): | |
| if progress_cb: | |
| progress_cb(msg) | |
| logger.info(msg) | |
| _cb("Scanning HuggingFace for Sniper model repositories...") | |
| repo_ids = _list_sniper_repos(token) | |
| if not repo_ids: | |
| _cb("No sniper-model-* repositories found under Arkm20.") | |
| return [] | |
| _cb(f"Found {len(repo_ids)} repo(s): {', '.join(r.split('/')[-1] for r in repo_ids)}") | |
| model_repos = [] | |
| for repo_id in sorted(repo_ids, reverse=True): | |
| _cb(f"Listing files in {repo_id}...") | |
| files = _list_repo_files(repo_id, token) | |
| groups = _group_by_timestamp(files) | |
| if not groups: | |
| _cb(f" No timestamped artifacts found in {repo_id}, skipping.") | |
| continue | |
| bundles = [] | |
| for ts in sorted(groups.keys(), reverse=True): | |
| g = groups[ts] | |
| meta = _load_metadata(repo_id, g["metadata"], token) if g["metadata"] else {} | |
| bundle = ArtifactBundle( | |
| repo_id=repo_id, | |
| timestamp=ts, | |
| metadata=meta, | |
| feature_list=meta.get("feature_list", []), | |
| optimal_threshold=meta.get("optimal_threshold", 0.5), | |
| ) | |
| bundles.append(bundle) | |
| model_repos.append(ModelRepo(repo_id=repo_id, bundles=bundles)) | |
| _cb(f" {repo_id}: {len(bundles)} run(s) found.") | |
| return model_repos | |
| def load_bundle(bundle: ArtifactBundle, token: str = None, progress_cb=None) -> ArtifactBundle: | |
| """ | |
| Download and load all PKL artifacts for a bundle (lazy, called on first use). | |
| """ | |
| def _cb(msg): | |
| if progress_cb: | |
| progress_cb(msg) | |
| logger.info(msg) | |
| # Find the group again from metadata to get file paths | |
| repo_id = bundle.repo_id | |
| files = _list_repo_files(repo_id, token) | |
| groups = _group_by_timestamp(files) | |
| g = groups.get(bundle.timestamp, {}) | |
| _cb(f"Loading main model ({bundle.timestamp})...") | |
| bundle.main_model = _load_pkl(repo_id, g.get("main"), token) | |
| if g.get("stage1"): | |
| _cb("Loading Stage 1 model...") | |
| bundle.stage1_model = _load_pkl(repo_id, g["stage1"], token) | |
| if g.get("calibrator"): | |
| _cb("Loading calibrator...") | |
| bundle.calibrator = _load_pkl(repo_id, g["calibrator"], token) | |
| for rname, rpath in g.get("regime", {}).items(): | |
| _cb(f"Loading regime model: {rname}...") | |
| bundle.regime_models[rname] = _load_pkl(repo_id, rpath, token) | |
| _cb("All artifacts loaded.") | |
| return bundle | |
| def get_all_bundle_labels(repos: list[ModelRepo]) -> list[str]: | |
| """Return flat list of all bundle labels for dropdown population.""" | |
| labels = [] | |
| for repo in repos: | |
| for bundle in repo.bundles: | |
| labels.append(bundle.label) | |
| return labels | |
| def find_bundle_by_label(repos: list[ModelRepo], label: str) -> Optional[ArtifactBundle]: | |
| """Reverse-lookup a bundle from its display label.""" | |
| for repo in repos: | |
| for bundle in repo.bundles: | |
| if bundle.label == label: | |
| return bundle | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Inference helpers | |
| # --------------------------------------------------------------------------- | |
| def predict_proba(bundle: ArtifactBundle, X, use_regime: bool = True, | |
| sp500_above_sma: bool = True, vix_high: bool = False) -> "np.ndarray": | |
| """ | |
| Run inference: optionally route to regime model, apply calibrator. | |
| Returns calibrated probabilities (N,). | |
| """ | |
| import numpy as np | |
| from scipy.special import expit | |
| if bundle.main_model is None: | |
| raise RuntimeError("Bundle not loaded — call load_bundle() first.") | |
| # Regime routing | |
| model = bundle.main_model | |
| if use_regime and bundle.has_regime_models: | |
| mkt = 1 if sp500_above_sma else 0 | |
| vix = 1 if vix_high else 0 | |
| rkey = f"mkt{mkt}_vix{vix}" | |
| if rkey in bundle.regime_models and bundle.regime_models[rkey] is not None: | |
| model = bundle.regime_models[rkey] | |
| raw_proba = model.predict_proba(X)[:, 1] | |
| # Calibrate | |
| if bundle.has_calibrator: | |
| raw_proba = bundle.calibrator.predict(raw_proba) | |
| return raw_proba |