File size: 4,976 Bytes
7db5adc
1e436e0
7db5adc
 
 
 
 
 
 
 
1e436e0
7db5adc
 
 
1e436e0
7db5adc
 
 
 
 
 
 
 
 
1d0171b
7db5adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12364c0
7db5adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12364c0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""HF Inference API client wrapper for Prisma.

Provides PrismaInferenceClient, a small wrapper around huggingface_hub's
InferenceClient that:
- Forces JSON output via response_format={"type": "json_object"}.
  This is required for reliable structured output with Llama 3.3 70B,
  which otherwise produces conversational text before/instead of JSON.
- Parses and validates the response via src.evaluation.
- Raises typed errors for API failures (InferenceError) and parse
  failures (EvaluationParseError, propagated from evaluation).

The wrapper is initialized once per session with an HF token; each
generate() call sends a full message history (system + conversation)
and returns a validated ParsedTurn.
"""

from __future__ import annotations

from typing import Sequence

from huggingface_hub import InferenceClient
from huggingface_hub.utils import HfHubHTTPError

from .config import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE, MODEL_ID
from .evaluation import ParsedTurn, parse_model_output

# Single chat message in OpenAI format. Kept loose for v1; can tighten to
# a TypedDict later if message shapes diversify.
ChatMessage = dict[str, str]


class InferenceError(Exception):
    """Raised when the inference API call fails or returns malformed data.

    Wraps network errors, authentication failures, rate-limit errors,
    and missing-field errors in the API response. Parse errors on the
    model's content are *not* wrapped here — they surface as
    EvaluationParseError so the app layer can distinguish them.
    """


class PrismaInferenceClient:
    """Wrapper around huggingface_hub.InferenceClient configured for Prisma.

    Holds a single InferenceClient instance and exposes a ``generate()``
    method that takes a full message history and returns a validated
    ``ParsedTurn``.

    JSON output is forced unconditionally via the ``response_format``
    parameter. This is required for Llama 3.3 70B and harmless on models
    that already comply with prompt-level JSON instructions, so we apply
    it uniformly for consistency across model families.

    Args:
        token: HuggingFace access token with inference permissions.
        model_id: Model to call. Defaults to ``MODEL_ID`` from config.
        temperature: Sampling temperature.
        max_tokens: Maximum tokens per response.

    Raises:
        ValueError: If ``token`` is empty.
    """

    def __init__(
        self,
        token: str,
        model_id: str = MODEL_ID,
        temperature: float = DEFAULT_TEMPERATURE,
        max_tokens: int = DEFAULT_MAX_TOKENS,
    ) -> None:
        if not token:
            raise ValueError("token must be a non-empty string")
        self._client = InferenceClient(token=token)
        self._model_id = model_id
        self._temperature = temperature
        self._max_tokens = max_tokens

    @property
    def model_id(self) -> str:
        """The model ID this client is configured to use."""
        return self._model_id

    def generate(self, messages: Sequence[ChatMessage]) -> ParsedTurn:
        """Send a chat completion request and return a parsed turn.

        Args:
            messages: Full chat history including the system message as the
                first entry. Each message is a dict with ``role`` and
                ``content`` keys (OpenAI format).

        Returns:
            A ``ParsedTurn`` with the response text and validated
            evaluation scores.

        Raises:
            ValueError: If ``messages`` is empty.
            InferenceError: If the API call itself fails (auth, rate limit,
                network, malformed response envelope).
            EvaluationParseError: If the model's content cannot be parsed
                or validated against the expected attribute schema.
        """
        if not messages:
            raise ValueError("messages must not be empty")

        
        try:
            completion = self._client.chat_completion(
                model=self._model_id,
                messages=list(messages),
                max_tokens=self._max_tokens,
                temperature=self._temperature,
                response_format={"type": "json_object"},
            )
        except HfHubHTTPError as exc:
            raise InferenceError(
                f"HF Inference API request failed: {exc}"
            ) from exc
        except Exception as exc:
            raise InferenceError(
                f"Unexpected error during inference call: {exc}"
            ) from exc

        try:
            raw = completion.choices[0].message.content
        except (AttributeError, IndexError, TypeError) as exc:
            raise InferenceError(
                f"Inference response missing expected fields: {exc}"
            ) from exc

        if not isinstance(raw, str) or not raw.strip():
            raise InferenceError(
                "Inference response content is empty or non-text"
            )

        return parse_model_output(raw)