from __future__ import annotations import os from pathlib import Path from typing import Dict, Iterable, List APP_TMP = Path("/tmp/bila-space-demo") def _writable_data_dir() -> Path: data = Path("/data") if data.exists() and os.access(data, os.W_OK): return data / "bila-space-demo" return APP_TMP def configure_runtime_cache() -> Path: base = _writable_data_dir() hf_home = Path(os.environ.get("HF_HOME", str(base / "hf-home"))) torch_home = Path(os.environ.get("TORCH_HOME", str(base / "torch-home"))) gradio_tmp = Path(os.environ.get("GRADIO_TEMP_DIR", str(base / "gradio-tmp"))) os.environ.setdefault("HF_HOME", str(hf_home)) os.environ.setdefault("TORCH_HOME", str(torch_home)) os.environ.setdefault("GRADIO_TEMP_DIR", str(gradio_tmp)) for path in (hf_home, torch_home, gradio_tmp): path.mkdir(parents=True, exist_ok=True) return base def _allow_patterns_for_model(model_cfg: Dict) -> List[str]: patterns = [] for rel_path in model_cfg["weights"].values(): if rel_path.endswith((".pth", ".bin", ".safetensors", ".json")): patterns.append(rel_path) else: patterns.append(rel_path.rstrip("/") + "/**") metric_file = model_cfg.get("evidence", {}).get("metric_file") if metric_file: patterns.append(metric_file) return patterns def resolve_model_root(model_key: str, model_cfg: Dict) -> Path: local_root = os.environ.get("BILA_MODEL_ROOT") if local_root: return Path(local_root).expanduser().resolve() repo_id = os.environ.get("BILA_WEIGHTS_REPO") if not repo_id: raise RuntimeError( "Set BILA_WEIGHTS_REPO to the Hugging Face model repo containing demo weights, " "or set BILA_MODEL_ROOT to a local directory with the same layout." ) from huggingface_hub import snapshot_download cache_dir = Path(os.environ.get("BILA_MODEL_CACHE", str(_writable_data_dir() / "hf-cache"))) cache_dir.mkdir(parents=True, exist_ok=True) return Path( snapshot_download( repo_id=repo_id, repo_type=os.environ.get("BILA_WEIGHTS_REPO_TYPE", "model"), cache_dir=str(cache_dir), allow_patterns=_allow_patterns_for_model(model_cfg), token=os.environ.get("HF_TOKEN"), ) ) def require_paths(root: Path, rel_paths: Iterable[str]) -> Dict[str, Path]: resolved = {} missing = [] for rel_path in rel_paths: path = root / rel_path resolved[rel_path] = path if not path.exists(): missing.append(str(path)) if missing: raise FileNotFoundError("Missing required weight paths:\n" + "\n".join(missing)) return resolved