kuechenpassagent / src /runtime.py
lederyou's picture
Upload folder using huggingface_hub
db662ea verified
Raw
History Blame Contribute Delete
2.34 kB
"""Runtime defaults and model bootstrap helpers."""
from __future__ import annotations
import os
import shutil
from pathlib import Path
from src.config import MODELS_DIR
def configure_runtime() -> None:
"""Set conservative thread defaults unless user overrides them."""
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
def _resolve_hf_model_repo_id() -> str | None:
"""Resolve model repo id from environment variables."""
repo_id = os.getenv("HF_MODEL_REPO_ID")
if repo_id:
return repo_id
user = os.getenv("HF_USERNAME")
repo = os.getenv("HF_MODEL_REPO", "kuechenpassagent-models")
if user:
return f"{user}/{repo}"
return None
def ensure_model_artifacts(
required_files: tuple[str, ...] = ("prep_time_pipeline.joblib", "food_classifier.pth"),
) -> None:
"""Download missing model files from HuggingFace model repo if configured."""
MODELS_DIR.mkdir(parents=True, exist_ok=True)
missing = [name for name in required_files if not (MODELS_DIR / name).exists()]
if not missing:
return
repo_id = _resolve_hf_model_repo_id()
if not repo_id:
print(
"[runtime] missing model artifacts and no HF repo configured. "
"Set HF_MODEL_REPO_ID or HF_USERNAME/HF_MODEL_REPO."
)
return
try:
from huggingface_hub import hf_hub_download # type: ignore
except ImportError:
print("[runtime] huggingface_hub not installed, cannot auto-download models.")
return
token = os.getenv("HF_TOKEN")
for filename in missing:
target = MODELS_DIR / filename
try:
cached_path = Path(
hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type="model",
token=token,
)
)
if cached_path.resolve() != target.resolve():
shutil.copy2(cached_path, target)
print(f"[runtime] model ready: {target.name}")
except Exception as exc: # noqa: BLE001
print(f"[runtime] could not download {filename} from {repo_id}: {exc}")