Spaces:
Running
Running
| """ | |
| Download model artifacts from Hugging Face Hub at container startup. | |
| Called automatically by the Docker entrypoint before uvicorn starts. | |
| Can also download a specific version on-demand (e.g. from the API). | |
| HF model repo layout (v1/ and v2/ at repo root): | |
| v1/models/classical/*.joblib | |
| v1/models/deep/*.pt *.keras | |
| v1/scalers/*.joblib | |
| v2/models/classical/*.joblib | |
| v2/models/deep/*.pt *.keras | |
| v2/scalers/*.joblib | |
| v2/results/*.json | |
| Local layout after download (local_dir = ARTIFACTS_DIR): | |
| artifacts/v1/... | |
| artifacts/v2/... | |
| """ | |
| import os | |
| import sys | |
| from pathlib import Path | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Config | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| REPO_ID = "NeerajCodz/aiBatteryLifeCycle" | |
| REPO_TYPE = "model" | |
| # Token read from the HF_TOKEN Space Secret (set in Space Settings -> Secrets) | |
| # For local use: set HF_TOKEN in your shell or .env before running | |
| HF_TOKEN = os.getenv("HF_TOKEN", "") | |
| # HF repo stores v1/ and v2/ at root β local_dir=ARTIFACTS_DIR maps them to | |
| # artifacts/v1/... and artifacts/v2/... | |
| ARTIFACTS_DIR = Path(__file__).resolve().parent.parent / "artifacts" | |
| # Sentinel file β written after a successful full download | |
| SENTINEL = ARTIFACTS_DIR / ".hf_downloaded" | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _hf_kwargs(allow_patterns: list | None = None, | |
| ignore_patterns: list | None = None) -> dict: | |
| """Build kwargs for snapshot_download; inject token only when non-empty.""" | |
| kwargs: dict = dict( | |
| repo_id=REPO_ID, | |
| repo_type=REPO_TYPE, | |
| local_dir=str(ARTIFACTS_DIR), | |
| ) | |
| if allow_patterns: | |
| kwargs["allow_patterns"] = allow_patterns | |
| if ignore_patterns: | |
| kwargs["ignore_patterns"] = ignore_patterns | |
| if HF_TOKEN: | |
| kwargs["token"] = HF_TOKEN | |
| return kwargs | |
| def _key_models(version: str = "v3") -> list: | |
| base = ARTIFACTS_DIR / version / "models" / "classical" | |
| return [base / f"{m}.joblib" for m in ("random_forest", "xgboost", "lightgbm")] | |
| def version_loaded(version: str) -> bool: | |
| """Return True when the given version's key models exist on disk.""" | |
| return all(p.exists() for p in _key_models(version)) | |
| def already_downloaded(version: str = "v3") -> bool: | |
| """Return True only when all three BestEnsemble component models are present.""" | |
| missing = [p for p in _key_models(version) if not p.exists()] | |
| if missing: | |
| if SENTINEL.exists(): | |
| SENTINEL.unlink() | |
| print(f"[download_models] Sentinel stale ({len(missing)} key models missing) β will re-download") | |
| return False | |
| return True | |
| def _ensure_hub(): | |
| try: | |
| from huggingface_hub import snapshot_download # noqa: F401 | |
| except ImportError: | |
| import subprocess | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", | |
| "huggingface_hub>=0.23", "-q"]) | |
| def download_version(version: str) -> None: | |
| """Download a single version (e.g. 'v1' or 'v2') from HF Hub into artifacts/.""" | |
| _ensure_hub() | |
| from huggingface_hub import snapshot_download | |
| ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True) | |
| print(f"[download_models] Downloading {version}/ from {REPO_ID} -> {ARTIFACTS_DIR}") | |
| snapshot_download(**_hf_kwargs( | |
| allow_patterns=[f"{version}/**"], | |
| ignore_patterns=["*.log"], | |
| )) | |
| print(f"[download_models] {version}/ ready") | |
| def download_all() -> None: | |
| """Download all versions (v1 + v2 + v3) from HF Hub.""" | |
| _ensure_hub() | |
| from huggingface_hub import snapshot_download | |
| ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True) | |
| print(f"[download_models] Downloading all versions from {REPO_ID} -> {ARTIFACTS_DIR}") | |
| snapshot_download(**_hf_kwargs(ignore_patterns=["*.log"])) | |
| SENTINEL.write_text("downloaded\n") | |
| print("[download_models] Artifacts ready") | |
| def main() -> None: | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--version", default=None, | |
| help="Download only this version, e.g. v1 or v2") | |
| args = parser.parse_args() | |
| if args.version: | |
| if version_loaded(args.version): | |
| print(f"[download_models] {args.version} already present β skipping") | |
| else: | |
| download_version(args.version) | |
| return | |
| # Default: ensure v3 (latest) is present | |
| if already_downloaded("v3"): | |
| print("[download_models] Artifacts already present β skipping download") | |
| return | |
| download_all() | |
| if __name__ == "__main__": | |
| main() | |