File size: 2,921 Bytes
2461f82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Unit tests for ``app.services.weights_loader.resolve_weights``.

These tests never hit the network — the downloader is injected as a stub.
"""

from __future__ import annotations

from pathlib import Path

from app.core.config import BackendSettings
from app.services.weights_loader import resolve_weights


def test_resolve_weights_local_mode_returns_settings_paths_verbatim() -> None:
    settings = BackendSettings(
        weights_path=Path("models/v1.0.0/model.h5"),
        tokenizer_dir=Path("models/v1.0.0"),
        weights_hub_repo=None,
    )

    weights_path, tokenizer_dir = resolve_weights(settings, downloader=None)

    assert weights_path == Path("models/v1.0.0/model.h5")
    assert tokenizer_dir == Path("models/v1.0.0")


def test_resolve_weights_hub_mode_calls_downloader_with_expected_args(tmp_path: Path) -> None:
    fake_snapshot = tmp_path / "snapshots" / "abc123"
    fake_snapshot.mkdir(parents=True)
    calls: list[dict[str, object]] = []

    def fake_downloader(*, repo_id: str, revision: str, cache_dir: str | None) -> str:
        calls.append({"repo_id": repo_id, "revision": revision, "cache_dir": cache_dir})
        return str(fake_snapshot)

    settings = BackendSettings(
        weights_hub_repo="user/captioning-weights",
        weights_hub_revision="v1.0.0",
        weights_hub_filename="model.h5",
        weights_cache_dir=tmp_path / "cache",
    )

    weights_path, tokenizer_dir = resolve_weights(settings, downloader=fake_downloader)

    assert calls == [
        {
            "repo_id": "user/captioning-weights",
            "revision": "v1.0.0",
            "cache_dir": str(tmp_path / "cache"),
        }
    ]
    assert weights_path == fake_snapshot / "model.h5"
    assert tokenizer_dir == fake_snapshot


def test_resolve_weights_hub_mode_passes_none_cache_dir_when_unset(tmp_path: Path) -> None:
    fake_snapshot = tmp_path / "snap"
    fake_snapshot.mkdir()
    seen_cache_dir: list[str | None] = []

    def fake_downloader(*, repo_id: str, revision: str, cache_dir: str | None) -> str:
        seen_cache_dir.append(cache_dir)
        return str(fake_snapshot)

    settings = BackendSettings(
        weights_hub_repo="user/captioning-weights",
        weights_cache_dir=None,
    )

    resolve_weights(settings, downloader=fake_downloader)

    assert seen_cache_dir == [None]


def test_resolve_weights_hub_mode_honors_custom_weights_filename(tmp_path: Path) -> None:
    fake_snapshot = tmp_path / "snap"
    fake_snapshot.mkdir()

    def fake_downloader(*, repo_id: str, revision: str, cache_dir: str | None) -> str:
        return str(fake_snapshot)

    settings = BackendSettings(
        weights_hub_repo="user/captioning-weights",
        weights_hub_filename="captioning.weights.h5",
    )

    weights_path, _ = resolve_weights(settings, downloader=fake_downloader)

    assert weights_path == fake_snapshot / "captioning.weights.h5"