salmasoma
Improve HF asset sync diagnostics and validation
c024608
from __future__ import annotations
import os
from pathlib import Path
from typing import Dict, List
from .paths import CHECKPOINT_DIR, LOCAL_AVRA_WEIGHTS, PROJECT_ROOT
AVRA_SUBMODELS = ["mta", "pa", "gca-f"]
PIPELINE_FIGURE_PATH = PROJECT_ROOT / "src" / "assets" / "Hyperclinical_Pipeline.jpg"
def required_asset_files() -> List[str]:
files = [
"checkpoints/neurofusion/best_model.pt",
"checkpoints/neurofusion/preprocessing_stats.json",
]
for sub in AVRA_SUBMODELS:
for idx in range(1, 6):
files.append(f"src/inference_core/weights/{sub}/model_{idx}.pth.tar")
return files
def _resolve_token(token: str | None) -> str | None:
return token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
def inspect_assets_repo(
repo_id: str,
revision: str = "main",
token: str | None = None,
) -> Dict[str, object]:
"""Inspect remote HF assets repo and report missing required files."""
from huggingface_hub import HfApi
repo_id = repo_id.strip()
revision = (revision or "main").strip() or "main"
token = _resolve_token(token)
api = HfApi(token=token)
remote_files = set(api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision))
required = required_asset_files()
missing_required = [rel_path for rel_path in required if rel_path not in remote_files]
return {
"repo_id": repo_id,
"revision": revision,
"required_count": len(required),
"remote_file_count": len(remote_files),
"missing_required": missing_required,
"has_pipeline_figure": "src/assets/Hyperclinical_Pipeline.jpg" in remote_files,
}
def local_asset_status() -> Dict[str, bool]:
best = CHECKPOINT_DIR / "best_model.pt"
stats = CHECKPOINT_DIR / "preprocessing_stats.json"
avra_ok = all(
(LOCAL_AVRA_WEIGHTS / sub / f"model_{idx}.pth.tar").exists()
for sub in AVRA_SUBMODELS
for idx in range(1, 6)
)
return {
"checkpoint": best.exists(),
"stats": stats.exists(),
"avra_weights": avra_ok,
"pipeline_figure": PIPELINE_FIGURE_PATH.exists(),
}
def ensure_assets_from_hub(
repo_id: str,
revision: str = "main",
token: str | None = None,
force_download: bool = False,
) -> Dict[str, str]:
"""Download required inference assets from a Hugging Face repo.
Expected repository layout mirrors this Space repo layout:
checkpoints/neurofusion/best_model.pt
checkpoints/neurofusion/preprocessing_stats.json
src/inference_core/weights/{mta,pa,gca-f}/model_{1..5}.pth.tar
src/assets/Hyperclinical_Pipeline.jpg (optional)
"""
from huggingface_hub import hf_hub_download
repo_id = repo_id.strip()
revision = (revision or "main").strip() or "main"
token = _resolve_token(token)
try:
remote_status = inspect_assets_repo(repo_id=repo_id, revision=revision, token=token)
except Exception as exc:
raise RuntimeError(
f"Unable to access HF assets repo '{repo_id}' (revision '{revision}'): {exc}. "
"If the repo is private, set HF_TOKEN in Space Secrets."
) from exc
missing_required = remote_status["missing_required"]
if missing_required:
preview = ", ".join(missing_required[:5])
if len(missing_required) > 5:
preview += f", ... (+{len(missing_required) - 5} more)"
raise FileNotFoundError(
f"HF assets repo '{repo_id}' is missing required files: {preview}"
)
downloaded: Dict[str, str] = {}
for rel_path in required_asset_files():
local_path = PROJECT_ROOT / rel_path
local_path.parent.mkdir(parents=True, exist_ok=True)
if local_path.exists() and not force_download:
downloaded[rel_path] = str(local_path)
continue
target = hf_hub_download(
repo_id=repo_id,
filename=rel_path,
revision=revision,
token=token,
local_dir=str(PROJECT_ROOT),
local_dir_use_symlinks=False,
force_download=force_download,
)
downloaded[rel_path] = target
# Optional figure download (non-fatal if absent)
optional_figure = "src/assets/Hyperclinical_Pipeline.jpg"
try:
local_figure = PROJECT_ROOT / optional_figure
local_figure.parent.mkdir(parents=True, exist_ok=True)
if not local_figure.exists() or force_download:
target = hf_hub_download(
repo_id=repo_id,
filename=optional_figure,
revision=revision,
token=token,
local_dir=str(PROJECT_ROOT),
local_dir_use_symlinks=False,
force_download=force_download,
)
downloaded[optional_figure] = target
except Exception:
pass
return downloaded