Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| DEFAULT_MODEL_REPO_ID = "dimostzim/siRBench-model" | |
| MODEL_REPO_ENV = "SIRBENCH_MODEL_REPO" | |
| ARTIFACT_FILENAMES = ( | |
| "xgb_model.json", | |
| "lgbm_model.txt", | |
| "calibrator.joblib", | |
| "feature_artifacts.json", | |
| ) | |
| def get_model_repo_id() -> str: | |
| return os.getenv(MODEL_REPO_ENV, DEFAULT_MODEL_REPO_ID) | |
| def resolve_artifact_dir(local_dir: str | Path | None = None) -> Path: | |
| if local_dir is not None: | |
| return Path(local_dir) | |
| return Path(__file__).resolve().parents[1] / "artifacts" | |
| def ensure_artifact_file(filename: str, repo_id: str | None = None, local_dir: str | Path | None = None) -> Path: | |
| repo_id = repo_id or get_model_repo_id() | |
| artifact_dir = resolve_artifact_dir(local_dir) | |
| local_path = artifact_dir / filename | |
| if local_path.exists(): | |
| return local_path | |
| artifact_dir.mkdir(parents=True, exist_ok=True) | |
| downloaded = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| repo_type="model", | |
| ) | |
| return Path(downloaded) | |
| def ensure_artifacts(repo_id: str | None = None, local_dir: str | Path | None = None) -> dict[str, Path]: | |
| repo_id = repo_id or get_model_repo_id() | |
| return { | |
| filename: ensure_artifact_file(filename, repo_id=repo_id, local_dir=local_dir) | |
| for filename in ARTIFACT_FILENAMES | |
| } | |