File size: 6,709 Bytes
536ba3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from __future__ import annotations

import pytest
from fastapi import BackgroundTasks, HTTPException

import api.main as api_main


class _StubRegistry:
    def __init__(self, model_count=0, meta=None, catalog=None, models=None):
        self.model_count = model_count
        self._version_meta = meta or {}
        self._catalog = catalog or {}
        self.models = models or {}
        self.loaded_models: list[str] = []

    def ensure_metadata_loaded(self):
        return None

    def refresh_metadata(self):
        return None

    def load_all(self, only_models=None):
        if only_models:
            for name in only_models:
                self.models[name] = object()
                self.loaded_models.append(name)
        self.model_count = max(self.model_count, 1)

    def model_on_disk(self, model_name):
        return model_name in self._catalog and self._catalog[model_name].get("on_disk", False)

    def load_model(self, model_name):
        if model_name in self._catalog:
            self.models[model_name] = object()
            self.loaded_models.append(model_name)
            return True
        return False


def test_model_status_key():
    assert api_main._model_status_key("v3", "xgboost") == "v3:xgboost"


def test_version_loaded_false_when_missing(monkeypatch, tmp_path):
    monkeypatch.setattr(api_main, "_artifacts_dir", lambda: tmp_path)
    assert api_main._version_loaded("v3") is False


def test_version_loaded_true_with_joblib(monkeypatch, tmp_path):
    models_dir = tmp_path / "v3" / "models" / "classical"
    models_dir.mkdir(parents=True)
    (models_dir / "x.joblib").write_bytes(b"x")
    monkeypatch.setattr(api_main, "_artifacts_dir", lambda: tmp_path)
    assert api_main._version_loaded("v3") is True


def test_list_versions(run_async, monkeypatch):
    regs = {
        "v1": _StubRegistry(1, {"display": "v1.0", "description": "d1", "features": 12, "champion": "rf"}, {"rf": {}}),
        "v2": _StubRegistry(0, {"display": "v2.0", "description": "d2", "features": 12, "champion": "et"}, {"et": {}}),
        "v3": _StubRegistry(2, {"display": "v3.0", "description": "d3", "features": 18, "champion": "xgb"}, {"xgb": {}}),
    }
    monkeypatch.setattr(api_main, "_REGISTRIES", regs)
    monkeypatch.setattr(api_main, "_version_status", {"v3": "ready"})
    monkeypatch.setattr(api_main, "_version_loaded", lambda v: v != "v2")
    out = run_async(api_main.list_versions())
    assert len(out) == 3
    assert out[0]["id"] == "v3"
    assert out[0]["status"] == "ready"


def test_health(run_async, monkeypatch):
    monkeypatch.setattr(api_main, "registry_v1", _StubRegistry(model_count=1))
    monkeypatch.setattr(api_main, "registry_v2", _StubRegistry(model_count=2))
    monkeypatch.setattr(api_main, "registry_v3", _StubRegistry(model_count=3))
    monkeypatch.setattr(api_main, "registry", type("R", (), {"device": "cpu"})())
    resp = run_async(api_main.health())
    assert resp.models_loaded == 6
    assert resp.status == "ok"


def test_load_version_unknown(run_async):
    with pytest.raises(HTTPException) as ex:
        run_async(api_main.load_version("v9", None))
    assert ex.value.status_code == 400


def test_load_version_returns_downloading_when_already_running(run_async, monkeypatch):
    reg = _StubRegistry()
    monkeypatch.setattr(api_main, "_REGISTRIES", {"v1": reg, "v2": reg, "v3": reg})
    monkeypatch.setattr(api_main, "_version_status", {"v1": "downloading"})
    monkeypatch.setattr(api_main, "ensure_metadata_first", lambda _v: None)
    out = run_async(api_main.load_version("v1", BackgroundTasks()))
    assert out["status"] == "downloading"


def test_load_version_loads_from_disk_when_available(run_async, monkeypatch):
    reg = _StubRegistry(model_count=0)
    monkeypatch.setattr(api_main, "_REGISTRIES", {"v1": reg, "v2": reg, "v3": reg})
    monkeypatch.setattr(api_main, "_version_status", {})
    monkeypatch.setattr(api_main, "_version_loaded", lambda _v: True)
    monkeypatch.setattr(api_main, "ensure_metadata_first", lambda _v: None)
    out = run_async(api_main.load_version("v1", BackgroundTasks()))
    assert out["status"] == "ready"
    assert reg.model_count >= 1


def test_get_version_models_meta_and_datamap(run_async, monkeypatch, tmp_path):
    v1 = tmp_path / "v1"
    v1.mkdir(parents=True)
    (v1 / "models.json").write_text('{"models":{"rf":{"family":"classical"}}}', encoding="utf-8")
    (v1 / "datamap.json").write_text('{"files":[{"path":"figures/a.png"}]}', encoding="utf-8")

    reg = _StubRegistry()
    monkeypatch.setattr(api_main, "_REGISTRIES", {"v1": reg, "v2": reg, "v3": reg})
    monkeypatch.setattr(api_main, "_artifacts_dir", lambda: tmp_path)
    monkeypatch.setattr(api_main, "ensure_metadata_first", lambda _v: None)

    out = run_async(api_main.get_version_models_meta("v1"))
    assert "models_meta" in out and "datamap" in out
    assert "rf" in out["models_meta"]["models"]
    assert out["datamap"]["files"][0]["path"] == "figures/a.png"


def test_get_version_datamap_generates_when_missing(run_async, monkeypatch, tmp_path):
    v1 = tmp_path / "v1"
    v1.mkdir(parents=True)

    reg = _StubRegistry()

    def _fake_write_datamap(version):
        (tmp_path / version / "datamap.json").write_text('{"files":[]}', encoding="utf-8")

    monkeypatch.setattr(api_main, "_REGISTRIES", {"v1": reg, "v2": reg, "v3": reg})
    monkeypatch.setattr(api_main, "_artifacts_dir", lambda: tmp_path)
    monkeypatch.setattr(api_main, "ensure_metadata_first", lambda _v: None)
    monkeypatch.setattr(api_main, "write_datamap", _fake_write_datamap)

    out = run_async(api_main.get_version_datamap("v1"))
    assert out["files"] == []


def test_list_version_models_and_load_single_model(run_async, monkeypatch):
    catalog = {
        "xgboost": {"family": "classical", "r2": 0.9, "file": "models/classical/xgboost.joblib", "on_disk": True},
        "best_ensemble": {"family": "ensemble", "r2": 0.93, "file": None},
    }
    reg = _StubRegistry(catalog=catalog)
    monkeypatch.setattr(api_main, "_REGISTRIES", {"v1": reg, "v2": reg, "v3": reg})
    monkeypatch.setattr(api_main, "_model_status", {})
    monkeypatch.setattr(api_main, "ensure_metadata_first", lambda _v: None)

    rows = run_async(api_main.list_version_models("v1"))
    assert {r["name"] for r in rows} == {"xgboost", "best_ensemble"}
    assert any(r["status"] in ("on_disk", "ready", "not_downloaded") for r in rows)

    out = run_async(api_main.load_single_model("v1", "xgboost", BackgroundTasks()))
    assert out["status"] == "ready"

    out_virtual = run_async(api_main.load_single_model("v1", "best_ensemble", BackgroundTasks()))
    assert out_virtual["status"] == "ready"