from __future__ import annotations import pytest from fastapi import BackgroundTasks, HTTPException import api.main as api_main class _StubRegistry: def __init__(self, model_count=0, meta=None, catalog=None, models=None): self.model_count = model_count self._version_meta = meta or {} self._catalog = catalog or {} self.models = models or {} self.loaded_models: list[str] = [] def ensure_metadata_loaded(self): return None def refresh_metadata(self): return None def load_all(self, only_models=None): if only_models: for name in only_models: self.models[name] = object() self.loaded_models.append(name) self.model_count = max(self.model_count, 1) def model_on_disk(self, model_name): return model_name in self._catalog and self._catalog[model_name].get("on_disk", False) def load_model(self, model_name): if model_name in self._catalog: self.models[model_name] = object() self.loaded_models.append(model_name) return True return False def test_model_status_key(): assert api_main._model_status_key("v3", "xgboost") == "v3:xgboost" def test_version_loaded_false_when_missing(monkeypatch, tmp_path): monkeypatch.setattr(api_main, "_artifacts_dir", lambda: tmp_path) assert api_main._version_loaded("v3") is False def test_version_loaded_true_with_joblib(monkeypatch, tmp_path): models_dir = tmp_path / "v3" / "models" / "classical" models_dir.mkdir(parents=True) (models_dir / "x.joblib").write_bytes(b"x") monkeypatch.setattr(api_main, "_artifacts_dir", lambda: tmp_path) assert api_main._version_loaded("v3") is True def test_list_versions(run_async, monkeypatch): regs = { "v1": _StubRegistry(1, {"display": "v1.0", "description": "d1", "features": 12, "champion": "rf"}, {"rf": {}}), "v2": _StubRegistry(0, {"display": "v2.0", "description": "d2", "features": 12, "champion": "et"}, {"et": {}}), "v3": _StubRegistry(2, {"display": "v3.0", "description": "d3", "features": 18, "champion": "xgb"}, {"xgb": {}}), } monkeypatch.setattr(api_main, "_REGISTRIES", regs) monkeypatch.setattr(api_main, "_version_status", {"v3": "ready"}) monkeypatch.setattr(api_main, "_version_loaded", lambda v: v != "v2") out = run_async(api_main.list_versions()) assert len(out) == 3 assert out[0]["id"] == "v3" assert out[0]["status"] == "ready" def test_health(run_async, monkeypatch): monkeypatch.setattr(api_main, "registry_v1", _StubRegistry(model_count=1)) monkeypatch.setattr(api_main, "registry_v2", _StubRegistry(model_count=2)) monkeypatch.setattr(api_main, "registry_v3", _StubRegistry(model_count=3)) monkeypatch.setattr(api_main, "registry", type("R", (), {"device": "cpu"})()) resp = run_async(api_main.health()) assert resp.models_loaded == 6 assert resp.status == "ok" def test_load_version_unknown(run_async): with pytest.raises(HTTPException) as ex: run_async(api_main.load_version("v9", None)) assert ex.value.status_code == 400 def test_load_version_returns_downloading_when_already_running(run_async, monkeypatch): reg = _StubRegistry() monkeypatch.setattr(api_main, "_REGISTRIES", {"v1": reg, "v2": reg, "v3": reg}) monkeypatch.setattr(api_main, "_version_status", {"v1": "downloading"}) monkeypatch.setattr(api_main, "ensure_metadata_first", lambda _v: None) out = run_async(api_main.load_version("v1", BackgroundTasks())) assert out["status"] == "downloading" def test_load_version_loads_from_disk_when_available(run_async, monkeypatch): reg = _StubRegistry(model_count=0) monkeypatch.setattr(api_main, "_REGISTRIES", {"v1": reg, "v2": reg, "v3": reg}) monkeypatch.setattr(api_main, "_version_status", {}) monkeypatch.setattr(api_main, "_version_loaded", lambda _v: True) monkeypatch.setattr(api_main, "ensure_metadata_first", lambda _v: None) out = run_async(api_main.load_version("v1", BackgroundTasks())) assert out["status"] == "ready" assert reg.model_count >= 1 def test_get_version_models_meta_and_datamap(run_async, monkeypatch, tmp_path): v1 = tmp_path / "v1" v1.mkdir(parents=True) (v1 / "models.json").write_text('{"models":{"rf":{"family":"classical"}}}', encoding="utf-8") (v1 / "datamap.json").write_text('{"files":[{"path":"figures/a.png"}]}', encoding="utf-8") reg = _StubRegistry() monkeypatch.setattr(api_main, "_REGISTRIES", {"v1": reg, "v2": reg, "v3": reg}) monkeypatch.setattr(api_main, "_artifacts_dir", lambda: tmp_path) monkeypatch.setattr(api_main, "ensure_metadata_first", lambda _v: None) out = run_async(api_main.get_version_models_meta("v1")) assert "models_meta" in out and "datamap" in out assert "rf" in out["models_meta"]["models"] assert out["datamap"]["files"][0]["path"] == "figures/a.png" def test_get_version_datamap_generates_when_missing(run_async, monkeypatch, tmp_path): v1 = tmp_path / "v1" v1.mkdir(parents=True) reg = _StubRegistry() def _fake_write_datamap(version): (tmp_path / version / "datamap.json").write_text('{"files":[]}', encoding="utf-8") monkeypatch.setattr(api_main, "_REGISTRIES", {"v1": reg, "v2": reg, "v3": reg}) monkeypatch.setattr(api_main, "_artifacts_dir", lambda: tmp_path) monkeypatch.setattr(api_main, "ensure_metadata_first", lambda _v: None) monkeypatch.setattr(api_main, "write_datamap", _fake_write_datamap) out = run_async(api_main.get_version_datamap("v1")) assert out["files"] == [] def test_list_version_models_and_load_single_model(run_async, monkeypatch): catalog = { "xgboost": {"family": "classical", "r2": 0.9, "file": "models/classical/xgboost.joblib", "on_disk": True}, "best_ensemble": {"family": "ensemble", "r2": 0.93, "file": None}, } reg = _StubRegistry(catalog=catalog) monkeypatch.setattr(api_main, "_REGISTRIES", {"v1": reg, "v2": reg, "v3": reg}) monkeypatch.setattr(api_main, "_model_status", {}) monkeypatch.setattr(api_main, "ensure_metadata_first", lambda _v: None) rows = run_async(api_main.list_version_models("v1")) assert {r["name"] for r in rows} == {"xgboost", "best_ensemble"} assert any(r["status"] in ("on_disk", "ready", "not_downloaded") for r in rows) out = run_async(api_main.load_single_model("v1", "xgboost", BackgroundTasks())) assert out["status"] == "ready" out_virtual = run_async(api_main.load_single_model("v1", "best_ensemble", BackgroundTasks())) assert out_virtual["status"] == "ready"