File size: 1,107 Bytes
88e3f4a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | import pytest
from omniff.models.base import OmniModel
from omniff.models.registry import ModelRegistry
class FakeModel(OmniModel):
def __init__(self):
self.loaded = False
def load(self):
self.loaded = True
def unload(self):
self.loaded = False
def infer(self, inputs):
return {"echo": inputs}
def test_register_and_get():
reg = ModelRegistry()
model = FakeModel()
reg.register("test_model", model)
assert reg.get("test_model") is model
def test_get_missing():
reg = ModelRegistry()
with pytest.raises(KeyError):
reg.get("nonexistent")
def test_load_model():
reg = ModelRegistry()
model = FakeModel()
reg.register("m", model)
reg.load("m")
assert model.loaded
def test_unload_model():
reg = ModelRegistry()
model = FakeModel()
reg.register("m", model)
reg.load("m")
reg.unload("m")
assert not model.loaded
def test_list_models():
reg = ModelRegistry()
reg.register("a", FakeModel())
reg.register("b", FakeModel())
assert set(reg.list()) == {"a", "b"}
|