Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, Iterable, List | |
| APP_TMP = Path("/tmp/bila-space-demo") | |
| def _writable_data_dir() -> Path: | |
| data = Path("/data") | |
| if data.exists() and os.access(data, os.W_OK): | |
| return data / "bila-space-demo" | |
| return APP_TMP | |
| def configure_runtime_cache() -> Path: | |
| base = _writable_data_dir() | |
| hf_home = Path(os.environ.get("HF_HOME", str(base / "hf-home"))) | |
| torch_home = Path(os.environ.get("TORCH_HOME", str(base / "torch-home"))) | |
| gradio_tmp = Path(os.environ.get("GRADIO_TEMP_DIR", str(base / "gradio-tmp"))) | |
| os.environ.setdefault("HF_HOME", str(hf_home)) | |
| os.environ.setdefault("TORCH_HOME", str(torch_home)) | |
| os.environ.setdefault("GRADIO_TEMP_DIR", str(gradio_tmp)) | |
| for path in (hf_home, torch_home, gradio_tmp): | |
| path.mkdir(parents=True, exist_ok=True) | |
| return base | |
| def _allow_patterns_for_model(model_cfg: Dict) -> List[str]: | |
| patterns = [] | |
| for rel_path in model_cfg["weights"].values(): | |
| if rel_path.endswith((".pth", ".bin", ".safetensors", ".json")): | |
| patterns.append(rel_path) | |
| else: | |
| patterns.append(rel_path.rstrip("/") + "/**") | |
| metric_file = model_cfg.get("evidence", {}).get("metric_file") | |
| if metric_file: | |
| patterns.append(metric_file) | |
| return patterns | |
| def resolve_model_root(model_key: str, model_cfg: Dict) -> Path: | |
| local_root = os.environ.get("BILA_MODEL_ROOT") | |
| if local_root: | |
| return Path(local_root).expanduser().resolve() | |
| repo_id = os.environ.get("BILA_WEIGHTS_REPO") | |
| if not repo_id: | |
| raise RuntimeError( | |
| "Set BILA_WEIGHTS_REPO to the Hugging Face model repo containing demo weights, " | |
| "or set BILA_MODEL_ROOT to a local directory with the same layout." | |
| ) | |
| from huggingface_hub import snapshot_download | |
| cache_dir = Path(os.environ.get("BILA_MODEL_CACHE", str(_writable_data_dir() / "hf-cache"))) | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| return Path( | |
| snapshot_download( | |
| repo_id=repo_id, | |
| repo_type=os.environ.get("BILA_WEIGHTS_REPO_TYPE", "model"), | |
| cache_dir=str(cache_dir), | |
| allow_patterns=_allow_patterns_for_model(model_cfg), | |
| token=os.environ.get("HF_TOKEN"), | |
| ) | |
| ) | |
| def require_paths(root: Path, rel_paths: Iterable[str]) -> Dict[str, Path]: | |
| resolved = {} | |
| missing = [] | |
| for rel_path in rel_paths: | |
| path = root / rel_path | |
| resolved[rel_path] = path | |
| if not path.exists(): | |
| missing.append(str(path)) | |
| if missing: | |
| raise FileNotFoundError("Missing required weight paths:\n" + "\n".join(missing)) | |
| return resolved | |