Spaces:
Configuration error
Configuration error
File size: 2,693 Bytes
2461f82 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | """Resolve weights and tokenizer paths, optionally pulling from HuggingFace Hub.
In production we don't want to bake a 158 MB ``.h5`` file into the Docker
image — it makes builds slow and couples weight rotation to image rebuilds.
Instead, the image carries only code, and the runtime pulls the snapshot
from a public Hub repo (pinned to a revision) the first time the container
boots. On HuggingFace Spaces the cache persists across restarts.
When ``BackendSettings.weights_hub_repo`` is unset we fall back to the
local paths declared in settings, which is what unit tests and `make serve`
use today.
"""
from __future__ import annotations
from pathlib import Path
from typing import Protocol
from app.core.config import BackendSettings
from captioning.utils import get_logger
log = get_logger(__name__)
class SnapshotDownloader(Protocol):
"""Minimal callable shape we need from ``huggingface_hub.snapshot_download``."""
def __call__(
self,
*,
repo_id: str,
revision: str,
cache_dir: str | None,
) -> str: ...
def resolve_weights(
settings: BackendSettings,
downloader: SnapshotDownloader | None = None,
) -> tuple[Path, Path]:
"""Return ``(weights_path, tokenizer_dir)`` for the predictor to load.
Local mode (``weights_hub_repo`` is None): returns the paths verbatim.
Hub mode: calls the downloader, then returns paths inside the snapshot
directory. The downloader is injectable so tests can substitute a stub
instead of hitting the network.
"""
if not settings.weights_hub_repo:
log.info(
"weights_source_local",
weights=str(settings.weights_path),
tokenizer_dir=str(settings.tokenizer_dir),
)
return settings.weights_path, settings.tokenizer_dir
if downloader is None:
from huggingface_hub import snapshot_download as _snapshot_download
downloader = _snapshot_download
assert downloader is not None # for type-checker
cache_dir = str(settings.weights_cache_dir) if settings.weights_cache_dir else None
log.info(
"weights_source_hub",
repo=settings.weights_hub_repo,
revision=settings.weights_hub_revision,
cache_dir=cache_dir,
)
snapshot_dir = Path(
downloader(
repo_id=settings.weights_hub_repo,
revision=settings.weights_hub_revision,
cache_dir=cache_dir,
)
)
weights_path = snapshot_dir / settings.weights_hub_filename
log.info(
"weights_downloaded",
snapshot_dir=str(snapshot_dir),
weights=str(weights_path),
)
return weights_path, snapshot_dir
|