import pytest from omniff.models.base import OmniModel from omniff.plugins import ModelPlugin, PluginRegistry class DummyModel(OmniModel): def __init__(self, model_id="dummy", device="cpu"): self.model_id = model_id self.device = device self._loaded = False @property def is_loaded(self): return self._loaded def load(self): self._loaded = True def unload(self): self._loaded = False def infer(self, inputs): return {"text": "dummy"} def test_register_and_get(): reg = PluginRegistry() plugin = ModelPlugin("test", DummyModel, "TEXT_SIMPLE") reg.register(plugin) assert reg.has("test") assert reg.get("test") is plugin def test_get_missing(): reg = PluginRegistry() with pytest.raises(KeyError, match="not registered"): reg.get("nope") def test_list_plugins(): reg = PluginRegistry() reg.register(ModelPlugin("a", DummyModel, "TEXT_SIMPLE")) reg.register(ModelPlugin("b", DummyModel, "IMAGE_CAPTION")) assert sorted(reg.list()) == ["a", "b"] def test_create_model(): reg = PluginRegistry() reg.register(ModelPlugin("test", DummyModel, "TEXT_SIMPLE", {"model_id": "default"})) model = reg.create_model("test") assert model.model_id == "default" def test_create_model_with_overrides(): reg = PluginRegistry() reg.register(ModelPlugin("test", DummyModel, "TEXT_SIMPLE", {"model_id": "default"})) model = reg.create_model("test", model_id="custom") assert model.model_id == "custom"