| | |
| | import pytest |
| | from inference import chat_completion, stream_chat_completion |
| |
|
| | class DummyStream: |
| | def __init__(self, chunks): |
| | self._chunks = chunks |
| | def __iter__(self): |
| | return iter(self._chunks) |
| |
|
| | class DummyClient: |
| | def __init__(self, response): |
| | self.response = response |
| | self.chat = self |
| | def completions(self, **kwargs): |
| | return self |
| | def create(self, **kwargs): |
| | |
| | if kwargs.get("stream"): |
| | from types import SimpleNamespace |
| | chunks = [ |
| | SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="h"))]), |
| | SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="i"))]) |
| | ] |
| | return DummyStream(chunks) |
| | |
| | from types import SimpleNamespace |
| | return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content=self.response))]) |
| |
|
| | @pytest.fixture(autouse=True) |
| | def patch_client(monkeypatch): |
| | from hf_client import get_inference_client |
| | def fake(model_id, provider): |
| | return DummyClient("hello") |
| | monkeypatch.setattr('hf_client.get_inference_client', fake) |
| |
|
| | def test_chat_completion(): |
| | out = chat_completion("any-model", [{"role":"user","content":"hi"}]) |
| | assert out == "hello" |
| |
|
| | def test_stream_chat_completion(): |
| | chunks = list(stream_chat_completion("any-model", [{"role":"user","content":"stream"}])) |
| | assert "".join(chunks) == "hi" |
| |
|