Kintsugi-Garden / tests /test_backend_dispatch.py
AI-Sherpa
Initial commit: Kintsugi Garden — Build Small Hackathon submission
4f04d05
"""
Tests for the model-backend dispatcher (run_model).
These tests mock the inner _run_* functions so they don't load any real
model — only the routing logic is exercised.
"""
import os
import sys
import unittest
from unittest import mock
HERE = os.path.dirname(os.path.abspath(__file__))
ROOT = os.path.dirname(HERE)
sys.path.insert(0, ROOT)
# Force a backend value before importing app so module-level reads don't
# trigger a real model load.
os.environ.setdefault("KINTSUGI_BACKEND", "ollama")
import app # noqa: E402
class DispatcherRoutingTests(unittest.TestCase):
"""run_model must route based on BACKEND."""
SAMPLE_ARGS = (
"I dreamt of a river.", # text
"Dream", # entry_type
2, # depth (DEPTH_LABELS key)
[], # symbol_matches
False, # grounded_jungian
True, # include_question
)
def test_routes_to_ollama_when_backend_is_ollama(self):
with mock.patch.object(app, "BACKEND", "ollama"), \
mock.patch.object(app, "_run_ollama", return_value=("ok", None)) as m:
text, err = app.run_model(*self.SAMPLE_ARGS)
m.assert_called_once_with(*self.SAMPLE_ARGS)
self.assertEqual(text, "ok")
self.assertIsNone(err)
def test_routes_to_transformers_when_backend_is_transformers(self):
with mock.patch.object(app, "BACKEND", "transformers"), \
mock.patch.object(app, "_run_transformers", return_value=("tx", None)) as m:
text, err = app.run_model(*self.SAMPLE_ARGS)
m.assert_called_once_with(*self.SAMPLE_ARGS)
self.assertEqual(text, "tx")
self.assertIsNone(err)
def test_routes_to_llama_cpp_when_backend_is_llama_cpp(self):
with mock.patch.object(app, "BACKEND", "llama_cpp"), \
mock.patch.object(app, "_run_llama_cpp", return_value=("lc", None)) as m:
text, err = app.run_model(*self.SAMPLE_ARGS)
m.assert_called_once_with(*self.SAMPLE_ARGS)
self.assertEqual(text, "lc")
self.assertIsNone(err)
def test_routes_to_llama_cpp_for_unknown_backend(self):
# Unknown values fall through to the default (llama_cpp).
with mock.patch.object(app, "BACKEND", "something-unrecognised"), \
mock.patch.object(app, "_run_llama_cpp", return_value=("lc", None)) as m:
app.run_model(*self.SAMPLE_ARGS)
m.assert_called_once_with(*self.SAMPLE_ARGS)
class LlamaCppLoaderTests(unittest.TestCase):
"""_load_llama_cpp_model must lazy-load once and cache the instance."""
def setUp(self):
# Reset the module-level cache between tests.
app._LLAMA_CPP_MODEL = None
app._LLAMA_CPP_ERROR = None
def test_returns_cached_instance_on_second_call(self):
fake_llama = mock.MagicMock(name="FakeLlama")
with mock.patch("llama_cpp.Llama.from_pretrained", return_value=fake_llama) as ctor:
first, err1 = app._load_llama_cpp_model()
second, err2 = app._load_llama_cpp_model()
self.assertIs(first, fake_llama)
self.assertIs(second, fake_llama)
self.assertIsNone(err1)
self.assertIsNone(err2)
ctor.assert_called_once()
# And the call used our configured repo / file.
kwargs = ctor.call_args.kwargs
self.assertEqual(kwargs.get("repo_id"), app.LLAMA_REPO)
self.assertEqual(kwargs.get("filename"), app.LLAMA_FILE)
self.assertEqual(kwargs.get("n_ctx"), app.LLAMA_CTX)
def test_returns_error_when_loader_raises(self):
with mock.patch(
"llama_cpp.Llama.from_pretrained",
side_effect=RuntimeError("download failed"),
):
instance, err = app._load_llama_cpp_model()
self.assertIsNone(instance)
self.assertIsNotNone(err)
self.assertIn("download failed", err)
class RunLlamaCppTests(unittest.TestCase):
"""_run_llama_cpp must format messages correctly and extract output."""
SAMPLE_ARGS = (
"I dreamt of a river.",
"Dream",
2,
[],
False,
True,
)
def _fake_llama(self, content="A symbolic reading."):
fake = mock.MagicMock(name="FakeLlama")
fake.create_chat_completion.return_value = {
"choices": [{"message": {"role": "assistant", "content": content}}]
}
return fake
def test_returns_extracted_content(self):
fake = self._fake_llama("River as threshold symbol.")
with mock.patch.object(app, "_load_llama_cpp_model",
return_value=(fake, None)):
text, err = app._run_llama_cpp(*self.SAMPLE_ARGS)
self.assertEqual(text, "River as threshold symbol.")
self.assertIsNone(err)
def test_appends_no_think_to_user_message(self):
fake = self._fake_llama()
with mock.patch.object(app, "_load_llama_cpp_model",
return_value=(fake, None)):
app._run_llama_cpp(*self.SAMPLE_ARGS)
call_kwargs = fake.create_chat_completion.call_args.kwargs
messages = call_kwargs["messages"]
user_msg = next(m for m in messages if m["role"] == "user")
self.assertIn("/no_think", user_msg["content"])
def test_passes_system_prompt(self):
fake = self._fake_llama()
with mock.patch.object(app, "_load_llama_cpp_model",
return_value=(fake, None)):
app._run_llama_cpp(*self.SAMPLE_ARGS)
messages = fake.create_chat_completion.call_args.kwargs["messages"]
system_msg = next(m for m in messages if m["role"] == "system")
self.assertEqual(system_msg["content"], app.SYSTEM_PROMPT)
def test_passes_generation_params(self):
fake = self._fake_llama()
with mock.patch.object(app, "_load_llama_cpp_model",
return_value=(fake, None)):
app._run_llama_cpp(*self.SAMPLE_ARGS)
kwargs = fake.create_chat_completion.call_args.kwargs
self.assertEqual(kwargs["temperature"], app.GEN_CONFIG["temperature"])
self.assertEqual(kwargs["top_p"], app.GEN_CONFIG["top_p"])
self.assertEqual(kwargs["max_tokens"], app.GEN_CONFIG["max_new_tokens"])
self.assertEqual(kwargs["repeat_penalty"], app.GEN_CONFIG["repetition_penalty"])
def test_returns_error_when_loader_fails(self):
with mock.patch.object(app, "_load_llama_cpp_model",
return_value=(None, "boom")):
text, err = app._run_llama_cpp(*self.SAMPLE_ARGS)
self.assertEqual(text, "")
self.assertEqual(err, "boom")
def test_returns_error_on_empty_output(self):
fake = self._fake_llama(content=" ")
with mock.patch.object(app, "_load_llama_cpp_model",
return_value=(fake, None)):
text, err = app._run_llama_cpp(*self.SAMPLE_ARGS)
self.assertEqual(text, "")
self.assertIsNotNone(err)
if __name__ == "__main__":
unittest.main()