image-captioning-api / backend /app /tests /test_weights_loader.py
apoorvrajdev's picture
feat(backend): pull weights from HuggingFace Hub at startup when BACKEND_WEIGHTS_HUB_REPO is set
2461f82
"""Unit tests for ``app.services.weights_loader.resolve_weights``.
These tests never hit the network — the downloader is injected as a stub.
"""
from __future__ import annotations
from pathlib import Path
from app.core.config import BackendSettings
from app.services.weights_loader import resolve_weights
def test_resolve_weights_local_mode_returns_settings_paths_verbatim() -> None:
settings = BackendSettings(
weights_path=Path("models/v1.0.0/model.h5"),
tokenizer_dir=Path("models/v1.0.0"),
weights_hub_repo=None,
)
weights_path, tokenizer_dir = resolve_weights(settings, downloader=None)
assert weights_path == Path("models/v1.0.0/model.h5")
assert tokenizer_dir == Path("models/v1.0.0")
def test_resolve_weights_hub_mode_calls_downloader_with_expected_args(tmp_path: Path) -> None:
fake_snapshot = tmp_path / "snapshots" / "abc123"
fake_snapshot.mkdir(parents=True)
calls: list[dict[str, object]] = []
def fake_downloader(*, repo_id: str, revision: str, cache_dir: str | None) -> str:
calls.append({"repo_id": repo_id, "revision": revision, "cache_dir": cache_dir})
return str(fake_snapshot)
settings = BackendSettings(
weights_hub_repo="user/captioning-weights",
weights_hub_revision="v1.0.0",
weights_hub_filename="model.h5",
weights_cache_dir=tmp_path / "cache",
)
weights_path, tokenizer_dir = resolve_weights(settings, downloader=fake_downloader)
assert calls == [
{
"repo_id": "user/captioning-weights",
"revision": "v1.0.0",
"cache_dir": str(tmp_path / "cache"),
}
]
assert weights_path == fake_snapshot / "model.h5"
assert tokenizer_dir == fake_snapshot
def test_resolve_weights_hub_mode_passes_none_cache_dir_when_unset(tmp_path: Path) -> None:
fake_snapshot = tmp_path / "snap"
fake_snapshot.mkdir()
seen_cache_dir: list[str | None] = []
def fake_downloader(*, repo_id: str, revision: str, cache_dir: str | None) -> str:
seen_cache_dir.append(cache_dir)
return str(fake_snapshot)
settings = BackendSettings(
weights_hub_repo="user/captioning-weights",
weights_cache_dir=None,
)
resolve_weights(settings, downloader=fake_downloader)
assert seen_cache_dir == [None]
def test_resolve_weights_hub_mode_honors_custom_weights_filename(tmp_path: Path) -> None:
fake_snapshot = tmp_path / "snap"
fake_snapshot.mkdir()
def fake_downloader(*, repo_id: str, revision: str, cache_dir: str | None) -> str:
return str(fake_snapshot)
settings = BackendSettings(
weights_hub_repo="user/captioning-weights",
weights_hub_filename="captioning.weights.h5",
)
weights_path, _ = resolve_weights(settings, downloader=fake_downloader)
assert weights_path == fake_snapshot / "captioning.weights.h5"