siRBench-predictor / predictor /artifacts.py
dimostzim's picture
Initial siRBench predictor Space
5eda974
from __future__ import annotations
import os
from pathlib import Path
from huggingface_hub import hf_hub_download
DEFAULT_MODEL_REPO_ID = "dimostzim/siRBench-model"
MODEL_REPO_ENV = "SIRBENCH_MODEL_REPO"
ARTIFACT_FILENAMES = (
"xgb_model.json",
"lgbm_model.txt",
"calibrator.joblib",
"feature_artifacts.json",
)
def get_model_repo_id() -> str:
return os.getenv(MODEL_REPO_ENV, DEFAULT_MODEL_REPO_ID)
def resolve_artifact_dir(local_dir: str | Path | None = None) -> Path:
if local_dir is not None:
return Path(local_dir)
return Path(__file__).resolve().parents[1] / "artifacts"
def ensure_artifact_file(filename: str, repo_id: str | None = None, local_dir: str | Path | None = None) -> Path:
repo_id = repo_id or get_model_repo_id()
artifact_dir = resolve_artifact_dir(local_dir)
local_path = artifact_dir / filename
if local_path.exists():
return local_path
artifact_dir.mkdir(parents=True, exist_ok=True)
downloaded = hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type="model",
)
return Path(downloaded)
def ensure_artifacts(repo_id: str | None = None, local_dir: str | Path | None = None) -> dict[str, Path]:
repo_id = repo_id or get_model_repo_id()
return {
filename: ensure_artifact_file(filename, repo_id=repo_id, local_dir=local_dir)
for filename in ARTIFACT_FILENAMES
}