File size: 5,154 Bytes
f381be8
 
 
 
1552b5a
 
 
 
 
 
 
 
 
 
 
 
 
 
f381be8
 
 
 
 
 
 
 
 
 
 
1552b5a
f381be8
 
 
1552b5a
 
 
f381be8
1552b5a
4229df6
f381be8
 
 
 
1552b5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f381be8
1552b5a
d3996f2
1552b5a
 
 
 
 
 
 
 
 
d3996f2
1552b5a
 
f381be8
 
1552b5a
 
f381be8
 
 
 
1552b5a
f381be8
1552b5a
f381be8
 
1552b5a
 
f381be8
 
1552b5a
 
 
 
 
 
 
 
 
 
 
 
 
 
d3996f2
1552b5a
 
 
 
 
 
 
f381be8
 
1552b5a
 
 
 
 
 
f381be8
1552b5a
 
 
 
 
 
f381be8
d3996f2
 
f381be8
 
 
1552b5a
f381be8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Download model artifacts from Hugging Face Hub at container startup.

Called automatically by the Docker entrypoint before uvicorn starts.
Can also download a specific version on-demand (e.g. from the API).

HF model repo layout  (v1/ and v2/ at repo root):
    v1/models/classical/*.joblib
    v1/models/deep/*.pt  *.keras
    v1/scalers/*.joblib
    v2/models/classical/*.joblib
    v2/models/deep/*.pt  *.keras
    v2/scalers/*.joblib
    v2/results/*.json

Local layout after download  (local_dir = ARTIFACTS_DIR):
    artifacts/v1/...
    artifacts/v2/...
"""

import os
import sys
from pathlib import Path

# ──────────────────────────────────────────────────────────────────────────────
# Config
# ──────────────────────────────────────────────────────────────────────────────
REPO_ID   = "NeerajCodz/aiBatteryLifeCycle"
REPO_TYPE = "model"
# Token read from the HF_TOKEN Space Secret (set in Space Settings -> Secrets)
# For local use: set HF_TOKEN in your shell or .env before running
HF_TOKEN  = os.getenv("HF_TOKEN", "")

# HF repo stores v1/ and v2/ at root β†’ local_dir=ARTIFACTS_DIR maps them to
#   artifacts/v1/...  and  artifacts/v2/...
ARTIFACTS_DIR = Path(__file__).resolve().parent.parent / "artifacts"

# Sentinel file β€” written after a successful full download
SENTINEL = ARTIFACTS_DIR / ".hf_downloaded"

# ──────────────────────────────────────────────────────────────────────────────


def _hf_kwargs(allow_patterns: list | None = None,
               ignore_patterns: list | None = None) -> dict:
    """Build kwargs for snapshot_download; inject token only when non-empty."""
    kwargs: dict = dict(
        repo_id=REPO_ID,
        repo_type=REPO_TYPE,
        local_dir=str(ARTIFACTS_DIR),
    )
    if allow_patterns:
        kwargs["allow_patterns"] = allow_patterns
    if ignore_patterns:
        kwargs["ignore_patterns"] = ignore_patterns
    if HF_TOKEN:
        kwargs["token"] = HF_TOKEN
    return kwargs


def _key_models(version: str = "v3") -> list:
    base = ARTIFACTS_DIR / version / "models" / "classical"
    return [base / f"{m}.joblib" for m in ("random_forest", "xgboost", "lightgbm")]


def version_loaded(version: str) -> bool:
    """Return True when the given version's key models exist on disk."""
    return all(p.exists() for p in _key_models(version))


def already_downloaded(version: str = "v3") -> bool:
    """Return True only when all three BestEnsemble component models are present."""
    missing = [p for p in _key_models(version) if not p.exists()]
    if missing:
        if SENTINEL.exists():
            SENTINEL.unlink()
            print(f"[download_models] Sentinel stale ({len(missing)} key models missing) β€” will re-download")
        return False
    return True


def _ensure_hub():
    try:
        from huggingface_hub import snapshot_download  # noqa: F401
    except ImportError:
        import subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install",
                               "huggingface_hub>=0.23", "-q"])


def download_version(version: str) -> None:
    """Download a single version (e.g. 'v1' or 'v2') from HF Hub into artifacts/."""
    _ensure_hub()
    from huggingface_hub import snapshot_download
    ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
    print(f"[download_models] Downloading {version}/ from {REPO_ID} -> {ARTIFACTS_DIR}")
    snapshot_download(**_hf_kwargs(
        allow_patterns=[f"{version}/**"],
        ignore_patterns=["*.log"],
    ))
    print(f"[download_models] {version}/ ready")


def download_all() -> None:
    """Download all versions (v1 + v2 + v3) from HF Hub."""
    _ensure_hub()
    from huggingface_hub import snapshot_download
    ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
    print(f"[download_models] Downloading all versions from {REPO_ID} -> {ARTIFACTS_DIR}")
    snapshot_download(**_hf_kwargs(ignore_patterns=["*.log"]))
    SENTINEL.write_text("downloaded\n")
    print("[download_models] Artifacts ready")


def main() -> None:
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--version", default=None,
                        help="Download only this version, e.g. v1 or v2")
    args = parser.parse_args()

    if args.version:
        if version_loaded(args.version):
            print(f"[download_models] {args.version} already present β€” skipping")
        else:
            download_version(args.version)
        return

    # Default: ensure v3 (latest) is present
    if already_downloaded("v3"):
        print("[download_models] Artifacts already present β€” skipping download")
        return

    download_all()


if __name__ == "__main__":
    main()