"""Resolve Hugging Face Hub checkpoint references to local cached paths. Any `checkpointing.pretrained_*` config value may be given as a Hugging Face reference instead of a local path: hf:/// # latest revision on main hf:///@ # pinned branch/tag/commit Example: checkpointing.pretrained_model=hf://autonomousvision/learn2splat/model.ckpt The file is downloaded once into ``./checkpoints`` (``HF_CACHE_DIR`` below, relative to the working directory), laid out by its in-repo path, and the local path is returned, so all downstream ``torch.load`` calls keep working unchanged. Gated/private repos (e.g. ``autonomousvision/learn2splat``) require authentication: run ``huggingface-cli login`` or set the ``HF_TOKEN`` environment variable. """ from __future__ import annotations from .io import cyan HF_PREFIX = "hf://" # hf:// checkpoints (and their sibling config.yaml) are downloaded here on # first access — relative to the working directory — as plain files laid out # by their in-repo path (e.g. ./checkpoints/dense/checkpoints/model.ckpt), # instead of the global HF cache's models--*/snapshots// structure. # huggingface_hub still skips the download when the local copy is current. HF_CACHE_DIR = "checkpoints" def is_hf_ref(path: str | None) -> bool: return isinstance(path, str) and path.startswith(HF_PREFIX) def resolve_hf_ref(ref: str) -> str: """Download an ``hf://`` reference and return the local cached file path.""" try: from huggingface_hub import hf_hub_download except ImportError as e: # pragma: no cover - depends on env raise ImportError( "huggingface_hub is required to load 'hf://' checkpoints. " "Install it with `pip install huggingface_hub`." ) from e body = ref[len(HF_PREFIX):] revision = None if "@" in body: body, revision = body.rsplit("@", 1) parts = body.split("/") if len(parts) < 3: raise ValueError( f"Invalid HF checkpoint reference {ref!r}. Expected " f"'hf:////[@]'." ) repo_id = "/".join(parts[:2]) filename = "/".join(parts[2:]) print(cyan(f"Resolving HF checkpoint {ref} (repo={repo_id}, " f"file={filename}, revision={revision or 'main'})")) local_path = hf_hub_download( repo_id=repo_id, filename=filename, revision=revision, local_dir=HF_CACHE_DIR, ) print(cyan(f"Downloaded to {local_path}")) return local_path def maybe_resolve_hf_ref(path: str | None) -> str | None: """Resolve `path` if it is an `hf://` reference, otherwise return it as-is.""" if is_hf_ref(path): return resolve_hf_ref(path) return path def hf_sibling_config(ref: str) -> str | None: """Download the ``config.yaml`` that sits next to an ``hf://`` checkpoint. Released checkpoints are laid out as ``/checkpoints/.ckpt`` with the training config at ``/config.yaml`` (the same `/../../` relation `_find_config_for_checkpoint` expects). ``hf_hub_download`` only fetches the requested file, so the sibling config must be fetched explicitly; pulling it into the same repo/revision snapshot makes it discoverable. Returns the local path, or ``None`` if ``ref`` is not an ``hf://`` reference / the sibling does not exist. """ if not is_hf_ref(ref): return None from pathlib import PurePosixPath from huggingface_hub import hf_hub_download body = ref[len(HF_PREFIX):] revision = None if "@" in body: body, revision = body.rsplit("@", 1) parts = body.split("/") if len(parts) < 3: return None repo_id = "/".join(parts[:2]) file_in_repo = "/".join(parts[2:]) cfg_in_repo = str(PurePosixPath(file_in_repo).parent.parent / "config.yaml") try: local = hf_hub_download( repo_id=repo_id, filename=cfg_in_repo, revision=revision, local_dir=HF_CACHE_DIR, ) print(cyan(f"Fetched sibling config {cfg_in_repo} -> {local}")) return local except Exception as e: # sibling may not exist for non-standard layouts print(cyan(f"No sibling config.yaml for {ref} ({type(e).__name__}); " f"will fall back to local config discovery.")) return None