File size: 1,107 Bytes
88e3f4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import pytest

from omniff.models.base import OmniModel
from omniff.models.registry import ModelRegistry


class FakeModel(OmniModel):
    def __init__(self):
        self.loaded = False

    def load(self):
        self.loaded = True

    def unload(self):
        self.loaded = False

    def infer(self, inputs):
        return {"echo": inputs}


def test_register_and_get():
    reg = ModelRegistry()
    model = FakeModel()
    reg.register("test_model", model)
    assert reg.get("test_model") is model


def test_get_missing():
    reg = ModelRegistry()
    with pytest.raises(KeyError):
        reg.get("nonexistent")


def test_load_model():
    reg = ModelRegistry()
    model = FakeModel()
    reg.register("m", model)
    reg.load("m")
    assert model.loaded


def test_unload_model():
    reg = ModelRegistry()
    model = FakeModel()
    reg.register("m", model)
    reg.load("m")
    reg.unload("m")
    assert not model.loaded


def test_list_models():
    reg = ModelRegistry()
    reg.register("a", FakeModel())
    reg.register("b", FakeModel())
    assert set(reg.list()) == {"a", "b"}