Spaces:
Running
Running
| """Unit tests for src.inference.""" | |
| from __future__ import annotations | |
| import json | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| from src.evaluation import EvaluationParseError, ParsedTurn | |
| from src.inference import InferenceError, PrismaInferenceClient | |
| VALID_PAYLOAD = json.dumps({ | |
| "response": "Hi there!", | |
| "evaluation": { | |
| "competent": 5, | |
| "likeable": 5, | |
| "considerate": 5, | |
| "polite": 5, | |
| "formal": 5, | |
| "demanding": 3, | |
| }, | |
| }) | |
| def _mock_completion(content: str) -> MagicMock: | |
| """Build a MagicMock mimicking the HF chat_completion return shape.""" | |
| completion = MagicMock() | |
| completion.choices = [MagicMock()] | |
| completion.choices[0].message.content = content | |
| return completion | |
| # ---- Construction ---- | |
| def test_rejects_empty_token(): | |
| with pytest.raises(ValueError, match="token"): | |
| PrismaInferenceClient(token="") | |
| def test_exposes_model_id(): | |
| client = PrismaInferenceClient(token="hf_test", model_id="some/model") | |
| assert client.model_id == "some/model" | |
| # ---- generate(): happy paths ---- | |
| def test_generate_returns_parsed_turn(): | |
| client = PrismaInferenceClient(token="hf_test") | |
| with patch.object(client, "_client") as mock_inner: | |
| mock_inner.chat_completion.return_value = _mock_completion(VALID_PAYLOAD) | |
| result = client.generate([{"role": "user", "content": "hi"}]) | |
| assert isinstance(result, ParsedTurn) | |
| assert result.response == "Hi there!" | |
| assert result.evaluation["competent"] == 5 | |
| def test_generate_forces_json_response_format(): | |
| """The wrapper must always pass response_format={'type': 'json_object'}.""" | |
| client = PrismaInferenceClient(token="hf_test") | |
| with patch.object(client, "_client") as mock_inner: | |
| mock_inner.chat_completion.return_value = _mock_completion(VALID_PAYLOAD) | |
| client.generate([{"role": "user", "content": "hi"}]) | |
| call = mock_inner.chat_completion.call_args | |
| assert call.kwargs["response_format"] == {"type": "json_object"} | |
| def test_generate_passes_messages_and_model(): | |
| client = PrismaInferenceClient(token="hf_test", model_id="custom/model") | |
| messages = [ | |
| {"role": "system", "content": "sys"}, | |
| {"role": "user", "content": "hi"}, | |
| ] | |
| with patch.object(client, "_client") as mock_inner: | |
| mock_inner.chat_completion.return_value = _mock_completion(VALID_PAYLOAD) | |
| client.generate(messages) | |
| call = mock_inner.chat_completion.call_args | |
| assert call.kwargs["model"] == "custom/model" | |
| assert call.kwargs["messages"] == messages | |
| # ---- generate(): error paths ---- | |
| def test_generate_rejects_empty_messages(): | |
| client = PrismaInferenceClient(token="hf_test") | |
| with pytest.raises(ValueError, match="messages"): | |
| client.generate([]) | |
| def test_generate_wraps_unexpected_exception(): | |
| client = PrismaInferenceClient(token="hf_test") | |
| with patch.object(client, "_client") as mock_inner: | |
| mock_inner.chat_completion.side_effect = RuntimeError("boom") | |
| with pytest.raises(InferenceError, match="boom"): | |
| client.generate([{"role": "user", "content": "hi"}]) | |
| def test_generate_rejects_empty_content(): | |
| client = PrismaInferenceClient(token="hf_test") | |
| with patch.object(client, "_client") as mock_inner: | |
| mock_inner.chat_completion.return_value = _mock_completion("") | |
| with pytest.raises(InferenceError, match="empty"): | |
| client.generate([{"role": "user", "content": "hi"}]) | |
| def test_generate_rejects_missing_choices(): | |
| client = PrismaInferenceClient(token="hf_test") | |
| with patch.object(client, "_client") as mock_inner: | |
| bad = MagicMock() | |
| bad.choices = [] | |
| mock_inner.chat_completion.return_value = bad | |
| with pytest.raises(InferenceError, match="missing expected fields"): | |
| client.generate([{"role": "user", "content": "hi"}]) | |
| def test_generate_propagates_parse_errors(): | |
| """Parse failures bubble up as EvaluationParseError, not InferenceError.""" | |
| client = PrismaInferenceClient(token="hf_test") | |
| with patch.object(client, "_client") as mock_inner: | |
| mock_inner.chat_completion.return_value = _mock_completion("not json") | |
| with pytest.raises(EvaluationParseError): | |
| client.generate([{"role": "user", "content": "hi"}]) |