openLLMbenchmark / tests /test_engine.py
hf-space-deployer
HF Space deploy from main - 0b1e82967585f1407bf51086f2e5a962f178218a
371efe0
from __future__ import annotations
import pytest
from engine import ChatStreamEvent, get_cloud_client, get_local_client, list_models, stream_chat, stream_chat_events
class _Message:
def __init__(self, content: str) -> None:
self.content = content
class _Chunk:
def __init__(
self,
content: str = "",
response: str = "",
*,
done: bool = False,
eval_count: int | None = None,
prompt_eval_count: int | None = None,
) -> None:
self.message = _Message(content) if content else None
self.response = response
self.done = done
self.eval_count = eval_count
self.prompt_eval_count = prompt_eval_count
class _ModelItem:
def __init__(self, model: str) -> None:
self.model = model
class _ListPayload:
def __init__(self, models: list[_ModelItem]) -> None:
self.models = models
class _ClientForStream:
def chat(self, **_: object): # type: ignore[no-untyped-def]
return iter(
[
_Chunk(content="Merhaba "),
_Chunk(content="dünya"),
{"message": {"content": "!"}},
{"done": True, "eval_count": 3, "prompt_eval_count": 5},
]
)
class _ClientForModels:
def list(self): # type: ignore[no-untyped-def]
return _ListPayload([_ModelItem("llama3"), _ModelItem("qwen3")])
class _FailingClientForModels:
def list(self): # type: ignore[no-untyped-def]
raise RuntimeError("offline")
def test_stream_chat_handles_object_and_dict_chunks() -> None:
client = _ClientForStream()
parts = list(stream_chat(client=client, model="x", prompt="p", system_prompt="s")) # type: ignore[arg-type]
assert "".join(parts) == "Merhaba dünya!"
def test_stream_chat_events_preserve_final_usage_metadata() -> None:
client = _ClientForStream()
events = list(stream_chat_events(client=client, model="x", prompt="p", system_prompt="s")) # type: ignore[arg-type]
assert events[-1] == ChatStreamEvent(content="", done=True, generated_tokens=3, prompt_tokens=5)
assert "".join(event.content for event in events) == "Merhaba dünya!"
def test_list_models_handles_object_payload() -> None:
client = _ClientForModels()
names = list_models(client=client) # type: ignore[arg-type]
assert names == ["llama3", "qwen3"]
def test_get_cloud_client_requires_api_key(monkeypatch) -> None:
monkeypatch.delenv("OLLAMA_API_KEY", raising=False)
with pytest.raises(RuntimeError):
get_cloud_client()
def test_get_cloud_client_accepts_explicit_api_key(monkeypatch) -> None:
monkeypatch.delenv("OLLAMA_API_KEY", raising=False)
captured: dict[str, object] = {}
class _FakeClient:
def __init__(self, **kwargs): # type: ignore[no-untyped-def]
captured.update(kwargs)
monkeypatch.setattr("engine.Client", _FakeClient)
get_cloud_client(api_key="session-key")
assert captured["headers"] == {"Authorization": "Bearer session-key"}
def test_get_local_client_uses_default_host_without_auth(monkeypatch) -> None:
captured: dict[str, object] = {}
class _FakeClient:
def __init__(self, **kwargs): # type: ignore[no-untyped-def]
captured.update(kwargs)
monkeypatch.setattr("engine.Client", _FakeClient)
get_local_client()
assert captured["host"] == "http://localhost:11434"
assert "headers" not in captured
def test_list_models_tolerates_local_errors() -> None:
client = _FailingClientForModels()
assert list_models(client=client, source="local") == [] # type: ignore[arg-type]
def test_list_models_raises_for_cloud_errors() -> None:
client = _FailingClientForModels()
with pytest.raises(RuntimeError):
list_models(client=client, source="cloud") # type: ignore[arg-type]