Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List | |
| from .paths import CHECKPOINT_DIR, LOCAL_AVRA_WEIGHTS, PROJECT_ROOT | |
| AVRA_SUBMODELS = ["mta", "pa", "gca-f"] | |
| PIPELINE_FIGURE_PATH = PROJECT_ROOT / "src" / "assets" / "Hyperclinical_Pipeline.jpg" | |
| def required_asset_files() -> List[str]: | |
| files = [ | |
| "checkpoints/neurofusion/best_model.pt", | |
| "checkpoints/neurofusion/preprocessing_stats.json", | |
| ] | |
| for sub in AVRA_SUBMODELS: | |
| for idx in range(1, 6): | |
| files.append(f"src/inference_core/weights/{sub}/model_{idx}.pth.tar") | |
| return files | |
| def _resolve_token(token: str | None) -> str | None: | |
| return token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| def inspect_assets_repo( | |
| repo_id: str, | |
| revision: str = "main", | |
| token: str | None = None, | |
| ) -> Dict[str, object]: | |
| """Inspect remote HF assets repo and report missing required files.""" | |
| from huggingface_hub import HfApi | |
| repo_id = repo_id.strip() | |
| revision = (revision or "main").strip() or "main" | |
| token = _resolve_token(token) | |
| api = HfApi(token=token) | |
| remote_files = set(api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision)) | |
| required = required_asset_files() | |
| missing_required = [rel_path for rel_path in required if rel_path not in remote_files] | |
| return { | |
| "repo_id": repo_id, | |
| "revision": revision, | |
| "required_count": len(required), | |
| "remote_file_count": len(remote_files), | |
| "missing_required": missing_required, | |
| "has_pipeline_figure": "src/assets/Hyperclinical_Pipeline.jpg" in remote_files, | |
| } | |
| def local_asset_status() -> Dict[str, bool]: | |
| best = CHECKPOINT_DIR / "best_model.pt" | |
| stats = CHECKPOINT_DIR / "preprocessing_stats.json" | |
| avra_ok = all( | |
| (LOCAL_AVRA_WEIGHTS / sub / f"model_{idx}.pth.tar").exists() | |
| for sub in AVRA_SUBMODELS | |
| for idx in range(1, 6) | |
| ) | |
| return { | |
| "checkpoint": best.exists(), | |
| "stats": stats.exists(), | |
| "avra_weights": avra_ok, | |
| "pipeline_figure": PIPELINE_FIGURE_PATH.exists(), | |
| } | |
| def ensure_assets_from_hub( | |
| repo_id: str, | |
| revision: str = "main", | |
| token: str | None = None, | |
| force_download: bool = False, | |
| ) -> Dict[str, str]: | |
| """Download required inference assets from a Hugging Face repo. | |
| Expected repository layout mirrors this Space repo layout: | |
| checkpoints/neurofusion/best_model.pt | |
| checkpoints/neurofusion/preprocessing_stats.json | |
| src/inference_core/weights/{mta,pa,gca-f}/model_{1..5}.pth.tar | |
| src/assets/Hyperclinical_Pipeline.jpg (optional) | |
| """ | |
| from huggingface_hub import hf_hub_download | |
| repo_id = repo_id.strip() | |
| revision = (revision or "main").strip() or "main" | |
| token = _resolve_token(token) | |
| try: | |
| remote_status = inspect_assets_repo(repo_id=repo_id, revision=revision, token=token) | |
| except Exception as exc: | |
| raise RuntimeError( | |
| f"Unable to access HF assets repo '{repo_id}' (revision '{revision}'): {exc}. " | |
| "If the repo is private, set HF_TOKEN in Space Secrets." | |
| ) from exc | |
| missing_required = remote_status["missing_required"] | |
| if missing_required: | |
| preview = ", ".join(missing_required[:5]) | |
| if len(missing_required) > 5: | |
| preview += f", ... (+{len(missing_required) - 5} more)" | |
| raise FileNotFoundError( | |
| f"HF assets repo '{repo_id}' is missing required files: {preview}" | |
| ) | |
| downloaded: Dict[str, str] = {} | |
| for rel_path in required_asset_files(): | |
| local_path = PROJECT_ROOT / rel_path | |
| local_path.parent.mkdir(parents=True, exist_ok=True) | |
| if local_path.exists() and not force_download: | |
| downloaded[rel_path] = str(local_path) | |
| continue | |
| target = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=rel_path, | |
| revision=revision, | |
| token=token, | |
| local_dir=str(PROJECT_ROOT), | |
| local_dir_use_symlinks=False, | |
| force_download=force_download, | |
| ) | |
| downloaded[rel_path] = target | |
| # Optional figure download (non-fatal if absent) | |
| optional_figure = "src/assets/Hyperclinical_Pipeline.jpg" | |
| try: | |
| local_figure = PROJECT_ROOT / optional_figure | |
| local_figure.parent.mkdir(parents=True, exist_ok=True) | |
| if not local_figure.exists() or force_download: | |
| target = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=optional_figure, | |
| revision=revision, | |
| token=token, | |
| local_dir=str(PROJECT_ROOT), | |
| local_dir_use_symlinks=False, | |
| force_download=force_download, | |
| ) | |
| downloaded[optional_figure] = target | |
| except Exception: | |
| pass | |
| return downloaded | |