Spaces:
Sleeping
Sleeping
File size: 6,709 Bytes
536ba3d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | from __future__ import annotations
import pytest
from fastapi import BackgroundTasks, HTTPException
import api.main as api_main
class _StubRegistry:
def __init__(self, model_count=0, meta=None, catalog=None, models=None):
self.model_count = model_count
self._version_meta = meta or {}
self._catalog = catalog or {}
self.models = models or {}
self.loaded_models: list[str] = []
def ensure_metadata_loaded(self):
return None
def refresh_metadata(self):
return None
def load_all(self, only_models=None):
if only_models:
for name in only_models:
self.models[name] = object()
self.loaded_models.append(name)
self.model_count = max(self.model_count, 1)
def model_on_disk(self, model_name):
return model_name in self._catalog and self._catalog[model_name].get("on_disk", False)
def load_model(self, model_name):
if model_name in self._catalog:
self.models[model_name] = object()
self.loaded_models.append(model_name)
return True
return False
def test_model_status_key():
assert api_main._model_status_key("v3", "xgboost") == "v3:xgboost"
def test_version_loaded_false_when_missing(monkeypatch, tmp_path):
monkeypatch.setattr(api_main, "_artifacts_dir", lambda: tmp_path)
assert api_main._version_loaded("v3") is False
def test_version_loaded_true_with_joblib(monkeypatch, tmp_path):
models_dir = tmp_path / "v3" / "models" / "classical"
models_dir.mkdir(parents=True)
(models_dir / "x.joblib").write_bytes(b"x")
monkeypatch.setattr(api_main, "_artifacts_dir", lambda: tmp_path)
assert api_main._version_loaded("v3") is True
def test_list_versions(run_async, monkeypatch):
regs = {
"v1": _StubRegistry(1, {"display": "v1.0", "description": "d1", "features": 12, "champion": "rf"}, {"rf": {}}),
"v2": _StubRegistry(0, {"display": "v2.0", "description": "d2", "features": 12, "champion": "et"}, {"et": {}}),
"v3": _StubRegistry(2, {"display": "v3.0", "description": "d3", "features": 18, "champion": "xgb"}, {"xgb": {}}),
}
monkeypatch.setattr(api_main, "_REGISTRIES", regs)
monkeypatch.setattr(api_main, "_version_status", {"v3": "ready"})
monkeypatch.setattr(api_main, "_version_loaded", lambda v: v != "v2")
out = run_async(api_main.list_versions())
assert len(out) == 3
assert out[0]["id"] == "v3"
assert out[0]["status"] == "ready"
def test_health(run_async, monkeypatch):
monkeypatch.setattr(api_main, "registry_v1", _StubRegistry(model_count=1))
monkeypatch.setattr(api_main, "registry_v2", _StubRegistry(model_count=2))
monkeypatch.setattr(api_main, "registry_v3", _StubRegistry(model_count=3))
monkeypatch.setattr(api_main, "registry", type("R", (), {"device": "cpu"})())
resp = run_async(api_main.health())
assert resp.models_loaded == 6
assert resp.status == "ok"
def test_load_version_unknown(run_async):
with pytest.raises(HTTPException) as ex:
run_async(api_main.load_version("v9", None))
assert ex.value.status_code == 400
def test_load_version_returns_downloading_when_already_running(run_async, monkeypatch):
reg = _StubRegistry()
monkeypatch.setattr(api_main, "_REGISTRIES", {"v1": reg, "v2": reg, "v3": reg})
monkeypatch.setattr(api_main, "_version_status", {"v1": "downloading"})
monkeypatch.setattr(api_main, "ensure_metadata_first", lambda _v: None)
out = run_async(api_main.load_version("v1", BackgroundTasks()))
assert out["status"] == "downloading"
def test_load_version_loads_from_disk_when_available(run_async, monkeypatch):
reg = _StubRegistry(model_count=0)
monkeypatch.setattr(api_main, "_REGISTRIES", {"v1": reg, "v2": reg, "v3": reg})
monkeypatch.setattr(api_main, "_version_status", {})
monkeypatch.setattr(api_main, "_version_loaded", lambda _v: True)
monkeypatch.setattr(api_main, "ensure_metadata_first", lambda _v: None)
out = run_async(api_main.load_version("v1", BackgroundTasks()))
assert out["status"] == "ready"
assert reg.model_count >= 1
def test_get_version_models_meta_and_datamap(run_async, monkeypatch, tmp_path):
v1 = tmp_path / "v1"
v1.mkdir(parents=True)
(v1 / "models.json").write_text('{"models":{"rf":{"family":"classical"}}}', encoding="utf-8")
(v1 / "datamap.json").write_text('{"files":[{"path":"figures/a.png"}]}', encoding="utf-8")
reg = _StubRegistry()
monkeypatch.setattr(api_main, "_REGISTRIES", {"v1": reg, "v2": reg, "v3": reg})
monkeypatch.setattr(api_main, "_artifacts_dir", lambda: tmp_path)
monkeypatch.setattr(api_main, "ensure_metadata_first", lambda _v: None)
out = run_async(api_main.get_version_models_meta("v1"))
assert "models_meta" in out and "datamap" in out
assert "rf" in out["models_meta"]["models"]
assert out["datamap"]["files"][0]["path"] == "figures/a.png"
def test_get_version_datamap_generates_when_missing(run_async, monkeypatch, tmp_path):
v1 = tmp_path / "v1"
v1.mkdir(parents=True)
reg = _StubRegistry()
def _fake_write_datamap(version):
(tmp_path / version / "datamap.json").write_text('{"files":[]}', encoding="utf-8")
monkeypatch.setattr(api_main, "_REGISTRIES", {"v1": reg, "v2": reg, "v3": reg})
monkeypatch.setattr(api_main, "_artifacts_dir", lambda: tmp_path)
monkeypatch.setattr(api_main, "ensure_metadata_first", lambda _v: None)
monkeypatch.setattr(api_main, "write_datamap", _fake_write_datamap)
out = run_async(api_main.get_version_datamap("v1"))
assert out["files"] == []
def test_list_version_models_and_load_single_model(run_async, monkeypatch):
catalog = {
"xgboost": {"family": "classical", "r2": 0.9, "file": "models/classical/xgboost.joblib", "on_disk": True},
"best_ensemble": {"family": "ensemble", "r2": 0.93, "file": None},
}
reg = _StubRegistry(catalog=catalog)
monkeypatch.setattr(api_main, "_REGISTRIES", {"v1": reg, "v2": reg, "v3": reg})
monkeypatch.setattr(api_main, "_model_status", {})
monkeypatch.setattr(api_main, "ensure_metadata_first", lambda _v: None)
rows = run_async(api_main.list_version_models("v1"))
assert {r["name"] for r in rows} == {"xgboost", "best_ensemble"}
assert any(r["status"] in ("on_disk", "ready", "not_downloaded") for r in rows)
out = run_async(api_main.load_single_model("v1", "xgboost", BackgroundTasks()))
assert out["status"] == "ready"
out_virtual = run_async(api_main.load_single_model("v1", "best_ensemble", BackgroundTasks()))
assert out_virtual["status"] == "ready"
|