Spaces:
Sleeping
Sleeping
| """Runtime defaults and model bootstrap helpers.""" | |
| from __future__ import annotations | |
| import os | |
| import shutil | |
| from pathlib import Path | |
| from src.config import MODELS_DIR | |
| def configure_runtime() -> None: | |
| """Set conservative thread defaults unless user overrides them.""" | |
| os.environ.setdefault("OMP_NUM_THREADS", "1") | |
| os.environ.setdefault("MKL_NUM_THREADS", "1") | |
| os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1") | |
| os.environ.setdefault("NUMEXPR_NUM_THREADS", "1") | |
| def _resolve_hf_model_repo_id() -> str | None: | |
| """Resolve model repo id from environment variables.""" | |
| repo_id = os.getenv("HF_MODEL_REPO_ID") | |
| if repo_id: | |
| return repo_id | |
| user = os.getenv("HF_USERNAME") | |
| repo = os.getenv("HF_MODEL_REPO", "kuechenpassagent-models") | |
| if user: | |
| return f"{user}/{repo}" | |
| return None | |
| def ensure_model_artifacts( | |
| required_files: tuple[str, ...] = ("prep_time_pipeline.joblib", "food_classifier.pth"), | |
| ) -> None: | |
| """Download missing model files from HuggingFace model repo if configured.""" | |
| MODELS_DIR.mkdir(parents=True, exist_ok=True) | |
| missing = [name for name in required_files if not (MODELS_DIR / name).exists()] | |
| if not missing: | |
| return | |
| repo_id = _resolve_hf_model_repo_id() | |
| if not repo_id: | |
| print( | |
| "[runtime] missing model artifacts and no HF repo configured. " | |
| "Set HF_MODEL_REPO_ID or HF_USERNAME/HF_MODEL_REPO." | |
| ) | |
| return | |
| try: | |
| from huggingface_hub import hf_hub_download # type: ignore | |
| except ImportError: | |
| print("[runtime] huggingface_hub not installed, cannot auto-download models.") | |
| return | |
| token = os.getenv("HF_TOKEN") | |
| for filename in missing: | |
| target = MODELS_DIR / filename | |
| try: | |
| cached_path = Path( | |
| hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| repo_type="model", | |
| token=token, | |
| ) | |
| ) | |
| if cached_path.resolve() != target.resolve(): | |
| shutil.copy2(cached_path, target) | |
| print(f"[runtime] model ready: {target.name}") | |
| except Exception as exc: # noqa: BLE001 | |
| print(f"[runtime] could not download {filename} from {repo_id}: {exc}") | |