File size: 1,560 Bytes
88e3f4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import pytest

from omniff.models.base import OmniModel
from omniff.plugins import ModelPlugin, PluginRegistry


class DummyModel(OmniModel):
    def __init__(self, model_id="dummy", device="cpu"):
        self.model_id = model_id
        self.device = device
        self._loaded = False

    @property
    def is_loaded(self):
        return self._loaded

    def load(self):
        self._loaded = True

    def unload(self):
        self._loaded = False

    def infer(self, inputs):
        return {"text": "dummy"}


def test_register_and_get():
    reg = PluginRegistry()
    plugin = ModelPlugin("test", DummyModel, "TEXT_SIMPLE")
    reg.register(plugin)
    assert reg.has("test")
    assert reg.get("test") is plugin


def test_get_missing():
    reg = PluginRegistry()
    with pytest.raises(KeyError, match="not registered"):
        reg.get("nope")


def test_list_plugins():
    reg = PluginRegistry()
    reg.register(ModelPlugin("a", DummyModel, "TEXT_SIMPLE"))
    reg.register(ModelPlugin("b", DummyModel, "IMAGE_CAPTION"))
    assert sorted(reg.list()) == ["a", "b"]


def test_create_model():
    reg = PluginRegistry()
    reg.register(ModelPlugin("test", DummyModel, "TEXT_SIMPLE", {"model_id": "default"}))
    model = reg.create_model("test")
    assert model.model_id == "default"


def test_create_model_with_overrides():
    reg = PluginRegistry()
    reg.register(ModelPlugin("test", DummyModel, "TEXT_SIMPLE", {"model_id": "default"}))
    model = reg.create_model("test", model_id="custom")
    assert model.model_id == "custom"