RolandM commited on
Commit
7db5adc
·
1 Parent(s): ba58dc4

Add inference module with HF API wrapper

Browse files

- src/inference.py: PrismaInferenceClient wraps huggingface_hub's
InferenceClient with forced JSON mode (required for Llama 3.3 70B)
- Typed errors: InferenceError for API issues, EvaluationParseError
bubbles up from src.evaluation for parse issues
- src/config.py: MODEL_ID, DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS
- tests/test_inference.py: 10 tests with mocked InferenceClient,
no real API calls

Files changed (3) hide show
  1. src/config.py +5 -0
  2. src/inference.py +132 -6
  3. tests/test_inference.py +123 -0
src/config.py CHANGED
@@ -11,6 +11,11 @@ MIN_SCORE: int = 1
11
  MAX_SCORE: int = 7
12
  SESSION_TURN_CAP: int = 12
13
 
 
 
 
 
 
14
  DEFAULT_ATTRIBUTES: list[str] = [
15
  "competent",
16
  "likeable",
 
11
  MAX_SCORE: int = 7
12
  SESSION_TURN_CAP: int = 12
13
 
14
+ MODEL_ID: str = "meta-llama/Llama-3.3-70B-Instruct"
15
+ DEFAULT_TEMPERATURE: float = 0.7
16
+ DEFAULT_MAX_TOKENS: int = 600
17
+
18
+
19
  DEFAULT_ATTRIBUTES: list[str] = [
20
  "competent",
21
  "likeable",
src/inference.py CHANGED
@@ -1,9 +1,135 @@
1
- """Hugging Face Inference API client wrapper.
2
 
3
- Thin wrapper around `huggingface_hub`'s inference client that issues a
4
- single LLM call per turn and returns the raw model output. Keeps API
5
- concerns (auth, model selection, retries) isolated from prompt and
6
- evaluation logic.
 
 
 
 
7
 
8
- Implementation pending scaffolding only.
 
 
9
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HF Inference API client wrapper for Prisma.
2
 
3
+ Provides PrismaInferenceClient, a small wrapper around huggingface_hub's
4
+ InferenceClient that:
5
+ - Forces JSON output via response_format={"type": "json_object"}.
6
+ This is required for reliable structured output with Llama 3.3 70B,
7
+ which otherwise produces conversational text before/instead of JSON.
8
+ - Parses and validates the response via src.evaluation.
9
+ - Raises typed errors for API failures (InferenceError) and parse
10
+ failures (EvaluationParseError, propagated from evaluation).
11
 
12
+ The wrapper is initialized once per session with an HF token; each
13
+ generate() call sends a full message history (system + conversation)
14
+ and returns a validated ParsedTurn.
15
  """
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Sequence
20
+
21
+ from huggingface_hub import InferenceClient
22
+ from huggingface_hub.utils import HfHubHTTPError
23
+
24
+ from .config import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE, MODEL_ID
25
+ from .evaluation import ParsedTurn, parse_model_output
26
+
27
+
28
+ # Single chat message in OpenAI format. Kept loose for v1; can tighten to
29
+ # a TypedDict later if message shapes diversify.
30
+ ChatMessage = dict[str, str]
31
+
32
+
33
+ class InferenceError(Exception):
34
+ """Raised when the inference API call fails or returns malformed data.
35
+
36
+ Wraps network errors, authentication failures, rate-limit errors,
37
+ and missing-field errors in the API response. Parse errors on the
38
+ model's content are *not* wrapped here — they surface as
39
+ EvaluationParseError so the app layer can distinguish them.
40
+ """
41
+
42
+
43
+ class PrismaInferenceClient:
44
+ """Wrapper around huggingface_hub.InferenceClient configured for Prisma.
45
+
46
+ Holds a single InferenceClient instance and exposes a ``generate()``
47
+ method that takes a full message history and returns a validated
48
+ ``ParsedTurn``.
49
+
50
+ JSON output is forced unconditionally via the ``response_format``
51
+ parameter. This is required for Llama 3.3 70B and harmless on models
52
+ that already comply with prompt-level JSON instructions, so we apply
53
+ it uniformly for consistency across model families.
54
+
55
+ Args:
56
+ token: HuggingFace access token with inference permissions.
57
+ model_id: Model to call. Defaults to ``MODEL_ID`` from config.
58
+ temperature: Sampling temperature.
59
+ max_tokens: Maximum tokens per response.
60
+
61
+ Raises:
62
+ ValueError: If ``token`` is empty.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ token: str,
68
+ model_id: str = MODEL_ID,
69
+ temperature: float = DEFAULT_TEMPERATURE,
70
+ max_tokens: int = DEFAULT_MAX_TOKENS,
71
+ ) -> None:
72
+ if not token:
73
+ raise ValueError("token must be a non-empty string")
74
+ self._client = InferenceClient(token=token)
75
+ self._model_id = model_id
76
+ self._temperature = temperature
77
+ self._max_tokens = max_tokens
78
+
79
+ @property
80
+ def model_id(self) -> str:
81
+ """The model ID this client is configured to use."""
82
+ return self._model_id
83
+
84
+ def generate(self, messages: Sequence[ChatMessage]) -> ParsedTurn:
85
+ """Send a chat completion request and return a parsed turn.
86
+
87
+ Args:
88
+ messages: Full chat history including the system message as the
89
+ first entry. Each message is a dict with ``role`` and
90
+ ``content`` keys (OpenAI format).
91
+
92
+ Returns:
93
+ A ``ParsedTurn`` with the response text and validated
94
+ evaluation scores.
95
+
96
+ Raises:
97
+ ValueError: If ``messages`` is empty.
98
+ InferenceError: If the API call itself fails (auth, rate limit,
99
+ network, malformed response envelope).
100
+ EvaluationParseError: If the model's content cannot be parsed
101
+ or validated against the expected attribute schema.
102
+ """
103
+ if not messages:
104
+ raise ValueError("messages must not be empty")
105
+
106
+ try:
107
+ completion = self._client.chat_completion(
108
+ model=self._model_id,
109
+ messages=list(messages),
110
+ max_tokens=self._max_tokens,
111
+ temperature=self._temperature,
112
+ response_format={"type": "json_object"},
113
+ )
114
+ except HfHubHTTPError as exc:
115
+ raise InferenceError(
116
+ f"HF Inference API request failed: {exc}"
117
+ ) from exc
118
+ except Exception as exc:
119
+ raise InferenceError(
120
+ f"Unexpected error during inference call: {exc}"
121
+ ) from exc
122
+
123
+ try:
124
+ raw = completion.choices[0].message.content
125
+ except (AttributeError, IndexError, TypeError) as exc:
126
+ raise InferenceError(
127
+ f"Inference response missing expected fields: {exc}"
128
+ ) from exc
129
+
130
+ if not isinstance(raw, str) or not raw.strip():
131
+ raise InferenceError(
132
+ "Inference response content is empty or non-text"
133
+ )
134
+
135
+ return parse_model_output(raw)
tests/test_inference.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for src.inference."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from unittest.mock import MagicMock, patch
7
+
8
+ import pytest
9
+
10
+ from src.evaluation import EvaluationParseError, ParsedTurn
11
+ from src.inference import InferenceError, PrismaInferenceClient
12
+
13
+
14
+ VALID_PAYLOAD = json.dumps({
15
+ "response": "Hi there!",
16
+ "evaluation": {
17
+ "competent": 5,
18
+ "likeable": 5,
19
+ "considerate": 5,
20
+ "polite": 5,
21
+ "formal": 5,
22
+ "demanding": 3,
23
+ },
24
+ })
25
+
26
+
27
+ def _mock_completion(content: str) -> MagicMock:
28
+ """Build a MagicMock mimicking the HF chat_completion return shape."""
29
+ completion = MagicMock()
30
+ completion.choices = [MagicMock()]
31
+ completion.choices[0].message.content = content
32
+ return completion
33
+
34
+
35
+ # ---- Construction ----
36
+
37
+ def test_rejects_empty_token():
38
+ with pytest.raises(ValueError, match="token"):
39
+ PrismaInferenceClient(token="")
40
+
41
+
42
+ def test_exposes_model_id():
43
+ client = PrismaInferenceClient(token="hf_test", model_id="some/model")
44
+ assert client.model_id == "some/model"
45
+
46
+
47
+ # ---- generate(): happy paths ----
48
+
49
+ def test_generate_returns_parsed_turn():
50
+ client = PrismaInferenceClient(token="hf_test")
51
+ with patch.object(client, "_client") as mock_inner:
52
+ mock_inner.chat_completion.return_value = _mock_completion(VALID_PAYLOAD)
53
+ result = client.generate([{"role": "user", "content": "hi"}])
54
+ assert isinstance(result, ParsedTurn)
55
+ assert result.response == "Hi there!"
56
+ assert result.evaluation["competent"] == 5
57
+
58
+
59
+ def test_generate_forces_json_response_format():
60
+ """The wrapper must always pass response_format={'type': 'json_object'}."""
61
+ client = PrismaInferenceClient(token="hf_test")
62
+ with patch.object(client, "_client") as mock_inner:
63
+ mock_inner.chat_completion.return_value = _mock_completion(VALID_PAYLOAD)
64
+ client.generate([{"role": "user", "content": "hi"}])
65
+ call = mock_inner.chat_completion.call_args
66
+ assert call.kwargs["response_format"] == {"type": "json_object"}
67
+
68
+
69
+ def test_generate_passes_messages_and_model():
70
+ client = PrismaInferenceClient(token="hf_test", model_id="custom/model")
71
+ messages = [
72
+ {"role": "system", "content": "sys"},
73
+ {"role": "user", "content": "hi"},
74
+ ]
75
+ with patch.object(client, "_client") as mock_inner:
76
+ mock_inner.chat_completion.return_value = _mock_completion(VALID_PAYLOAD)
77
+ client.generate(messages)
78
+ call = mock_inner.chat_completion.call_args
79
+ assert call.kwargs["model"] == "custom/model"
80
+ assert call.kwargs["messages"] == messages
81
+
82
+
83
+ # ---- generate(): error paths ----
84
+
85
+ def test_generate_rejects_empty_messages():
86
+ client = PrismaInferenceClient(token="hf_test")
87
+ with pytest.raises(ValueError, match="messages"):
88
+ client.generate([])
89
+
90
+
91
+ def test_generate_wraps_unexpected_exception():
92
+ client = PrismaInferenceClient(token="hf_test")
93
+ with patch.object(client, "_client") as mock_inner:
94
+ mock_inner.chat_completion.side_effect = RuntimeError("boom")
95
+ with pytest.raises(InferenceError, match="boom"):
96
+ client.generate([{"role": "user", "content": "hi"}])
97
+
98
+
99
+ def test_generate_rejects_empty_content():
100
+ client = PrismaInferenceClient(token="hf_test")
101
+ with patch.object(client, "_client") as mock_inner:
102
+ mock_inner.chat_completion.return_value = _mock_completion("")
103
+ with pytest.raises(InferenceError, match="empty"):
104
+ client.generate([{"role": "user", "content": "hi"}])
105
+
106
+
107
+ def test_generate_rejects_missing_choices():
108
+ client = PrismaInferenceClient(token="hf_test")
109
+ with patch.object(client, "_client") as mock_inner:
110
+ bad = MagicMock()
111
+ bad.choices = []
112
+ mock_inner.chat_completion.return_value = bad
113
+ with pytest.raises(InferenceError, match="missing expected fields"):
114
+ client.generate([{"role": "user", "content": "hi"}])
115
+
116
+
117
+ def test_generate_propagates_parse_errors():
118
+ """Parse failures bubble up as EvaluationParseError, not InferenceError."""
119
+ client = PrismaInferenceClient(token="hf_test")
120
+ with patch.object(client, "_client") as mock_inner:
121
+ mock_inner.chat_completion.return_value = _mock_completion("not json")
122
+ with pytest.raises(EvaluationParseError):
123
+ client.generate([{"role": "user", "content": "hi"}])