Spaces:
Running
Running
| """HF Inference API client wrapper for Prisma. | |
| Provides PrismaInferenceClient, a small wrapper around huggingface_hub's | |
| InferenceClient that: | |
| - Forces JSON output via response_format={"type": "json_object"}. | |
| This is required for reliable structured output with Llama 3.3 70B, | |
| which otherwise produces conversational text before/instead of JSON. | |
| - Parses and validates the response via src.evaluation. | |
| - Raises typed errors for API failures (InferenceError) and parse | |
| failures (EvaluationParseError, propagated from evaluation). | |
| The wrapper is initialized once per session with an HF token; each | |
| generate() call sends a full message history (system + conversation) | |
| and returns a validated ParsedTurn. | |
| """ | |
| from __future__ import annotations | |
| from typing import Sequence | |
| from huggingface_hub import InferenceClient | |
| from huggingface_hub.utils import HfHubHTTPError | |
| from .config import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE, MODEL_ID | |
| from .evaluation import ParsedTurn, parse_model_output | |
| # Single chat message in OpenAI format. Kept loose for v1; can tighten to | |
| # a TypedDict later if message shapes diversify. | |
| ChatMessage = dict[str, str] | |
| class InferenceError(Exception): | |
| """Raised when the inference API call fails or returns malformed data. | |
| Wraps network errors, authentication failures, rate-limit errors, | |
| and missing-field errors in the API response. Parse errors on the | |
| model's content are *not* wrapped here — they surface as | |
| EvaluationParseError so the app layer can distinguish them. | |
| """ | |
| class PrismaInferenceClient: | |
| """Wrapper around huggingface_hub.InferenceClient configured for Prisma. | |
| Holds a single InferenceClient instance and exposes a ``generate()`` | |
| method that takes a full message history and returns a validated | |
| ``ParsedTurn``. | |
| JSON output is forced unconditionally via the ``response_format`` | |
| parameter. This is required for Llama 3.3 70B and harmless on models | |
| that already comply with prompt-level JSON instructions, so we apply | |
| it uniformly for consistency across model families. | |
| Args: | |
| token: HuggingFace access token with inference permissions. | |
| model_id: Model to call. Defaults to ``MODEL_ID`` from config. | |
| temperature: Sampling temperature. | |
| max_tokens: Maximum tokens per response. | |
| Raises: | |
| ValueError: If ``token`` is empty. | |
| """ | |
| def __init__( | |
| self, | |
| token: str, | |
| model_id: str = MODEL_ID, | |
| temperature: float = DEFAULT_TEMPERATURE, | |
| max_tokens: int = DEFAULT_MAX_TOKENS, | |
| ) -> None: | |
| if not token: | |
| raise ValueError("token must be a non-empty string") | |
| self._client = InferenceClient(token=token) | |
| self._model_id = model_id | |
| self._temperature = temperature | |
| self._max_tokens = max_tokens | |
| def model_id(self) -> str: | |
| """The model ID this client is configured to use.""" | |
| return self._model_id | |
| def generate(self, messages: Sequence[ChatMessage]) -> ParsedTurn: | |
| """Send a chat completion request and return a parsed turn. | |
| Args: | |
| messages: Full chat history including the system message as the | |
| first entry. Each message is a dict with ``role`` and | |
| ``content`` keys (OpenAI format). | |
| Returns: | |
| A ``ParsedTurn`` with the response text and validated | |
| evaluation scores. | |
| Raises: | |
| ValueError: If ``messages`` is empty. | |
| InferenceError: If the API call itself fails (auth, rate limit, | |
| network, malformed response envelope). | |
| EvaluationParseError: If the model's content cannot be parsed | |
| or validated against the expected attribute schema. | |
| """ | |
| if not messages: | |
| raise ValueError("messages must not be empty") | |
| try: | |
| completion = self._client.chat_completion( | |
| model=self._model_id, | |
| messages=list(messages), | |
| max_tokens=self._max_tokens, | |
| temperature=self._temperature, | |
| response_format={"type": "json_object"}, | |
| ) | |
| except HfHubHTTPError as exc: | |
| raise InferenceError( | |
| f"HF Inference API request failed: {exc}" | |
| ) from exc | |
| except Exception as exc: | |
| raise InferenceError( | |
| f"Unexpected error during inference call: {exc}" | |
| ) from exc | |
| try: | |
| raw = completion.choices[0].message.content | |
| except (AttributeError, IndexError, TypeError) as exc: | |
| raise InferenceError( | |
| f"Inference response missing expected fields: {exc}" | |
| ) from exc | |
| if not isinstance(raw, str) or not raw.strip(): | |
| raise InferenceError( | |
| "Inference response content is empty or non-text" | |
| ) | |
| return parse_model_output(raw) | |