from __future__ import annotations import os from pathlib import Path from huggingface_hub import hf_hub_download DEFAULT_MODEL_REPO_ID = "dimostzim/siRBench-model" MODEL_REPO_ENV = "SIRBENCH_MODEL_REPO" ARTIFACT_FILENAMES = ( "xgb_model.json", "lgbm_model.txt", "calibrator.joblib", "feature_artifacts.json", ) def get_model_repo_id() -> str: return os.getenv(MODEL_REPO_ENV, DEFAULT_MODEL_REPO_ID) def resolve_artifact_dir(local_dir: str | Path | None = None) -> Path: if local_dir is not None: return Path(local_dir) return Path(__file__).resolve().parents[1] / "artifacts" def ensure_artifact_file(filename: str, repo_id: str | None = None, local_dir: str | Path | None = None) -> Path: repo_id = repo_id or get_model_repo_id() artifact_dir = resolve_artifact_dir(local_dir) local_path = artifact_dir / filename if local_path.exists(): return local_path artifact_dir.mkdir(parents=True, exist_ok=True) downloaded = hf_hub_download( repo_id=repo_id, filename=filename, repo_type="model", ) return Path(downloaded) def ensure_artifacts(repo_id: str | None = None, local_dir: str | Path | None = None) -> dict[str, Path]: repo_id = repo_id or get_model_repo_id() return { filename: ensure_artifact_file(filename, repo_id=repo_id, local_dir=local_dir) for filename in ARTIFACT_FILENAMES }