File size: 4,336 Bytes
7db5adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Unit tests for src.inference."""

from __future__ import annotations

import json
from unittest.mock import MagicMock, patch

import pytest

from src.evaluation import EvaluationParseError, ParsedTurn
from src.inference import InferenceError, PrismaInferenceClient


VALID_PAYLOAD = json.dumps({
    "response": "Hi there!",
    "evaluation": {
        "competent": 5,
        "likeable": 5,
        "considerate": 5,
        "polite": 5,
        "formal": 5,
        "demanding": 3,
    },
})


def _mock_completion(content: str) -> MagicMock:
    """Build a MagicMock mimicking the HF chat_completion return shape."""
    completion = MagicMock()
    completion.choices = [MagicMock()]
    completion.choices[0].message.content = content
    return completion


# ---- Construction ----

def test_rejects_empty_token():
    with pytest.raises(ValueError, match="token"):
        PrismaInferenceClient(token="")


def test_exposes_model_id():
    client = PrismaInferenceClient(token="hf_test", model_id="some/model")
    assert client.model_id == "some/model"


# ---- generate(): happy paths ----

def test_generate_returns_parsed_turn():
    client = PrismaInferenceClient(token="hf_test")
    with patch.object(client, "_client") as mock_inner:
        mock_inner.chat_completion.return_value = _mock_completion(VALID_PAYLOAD)
        result = client.generate([{"role": "user", "content": "hi"}])
    assert isinstance(result, ParsedTurn)
    assert result.response == "Hi there!"
    assert result.evaluation["competent"] == 5


def test_generate_forces_json_response_format():
    """The wrapper must always pass response_format={'type': 'json_object'}."""
    client = PrismaInferenceClient(token="hf_test")
    with patch.object(client, "_client") as mock_inner:
        mock_inner.chat_completion.return_value = _mock_completion(VALID_PAYLOAD)
        client.generate([{"role": "user", "content": "hi"}])
    call = mock_inner.chat_completion.call_args
    assert call.kwargs["response_format"] == {"type": "json_object"}


def test_generate_passes_messages_and_model():
    client = PrismaInferenceClient(token="hf_test", model_id="custom/model")
    messages = [
        {"role": "system", "content": "sys"},
        {"role": "user", "content": "hi"},
    ]
    with patch.object(client, "_client") as mock_inner:
        mock_inner.chat_completion.return_value = _mock_completion(VALID_PAYLOAD)
        client.generate(messages)
    call = mock_inner.chat_completion.call_args
    assert call.kwargs["model"] == "custom/model"
    assert call.kwargs["messages"] == messages


# ---- generate(): error paths ----

def test_generate_rejects_empty_messages():
    client = PrismaInferenceClient(token="hf_test")
    with pytest.raises(ValueError, match="messages"):
        client.generate([])


def test_generate_wraps_unexpected_exception():
    client = PrismaInferenceClient(token="hf_test")
    with patch.object(client, "_client") as mock_inner:
        mock_inner.chat_completion.side_effect = RuntimeError("boom")
        with pytest.raises(InferenceError, match="boom"):
            client.generate([{"role": "user", "content": "hi"}])


def test_generate_rejects_empty_content():
    client = PrismaInferenceClient(token="hf_test")
    with patch.object(client, "_client") as mock_inner:
        mock_inner.chat_completion.return_value = _mock_completion("")
        with pytest.raises(InferenceError, match="empty"):
            client.generate([{"role": "user", "content": "hi"}])


def test_generate_rejects_missing_choices():
    client = PrismaInferenceClient(token="hf_test")
    with patch.object(client, "_client") as mock_inner:
        bad = MagicMock()
        bad.choices = []
        mock_inner.chat_completion.return_value = bad
        with pytest.raises(InferenceError, match="missing expected fields"):
            client.generate([{"role": "user", "content": "hi"}])


def test_generate_propagates_parse_errors():
    """Parse failures bubble up as EvaluationParseError, not InferenceError."""
    client = PrismaInferenceClient(token="hf_test")
    with patch.object(client, "_client") as mock_inner:
        mock_inner.chat_completion.return_value = _mock_completion("not json")
        with pytest.raises(EvaluationParseError):
            client.generate([{"role": "user", "content": "hi"}])