Spaces:
Running
Running
File size: 4,976 Bytes
7db5adc 1e436e0 7db5adc 1e436e0 7db5adc 1e436e0 7db5adc 1d0171b 7db5adc 12364c0 7db5adc 12364c0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """HF Inference API client wrapper for Prisma.
Provides PrismaInferenceClient, a small wrapper around huggingface_hub's
InferenceClient that:
- Forces JSON output via response_format={"type": "json_object"}.
This is required for reliable structured output with Llama 3.3 70B,
which otherwise produces conversational text before/instead of JSON.
- Parses and validates the response via src.evaluation.
- Raises typed errors for API failures (InferenceError) and parse
failures (EvaluationParseError, propagated from evaluation).
The wrapper is initialized once per session with an HF token; each
generate() call sends a full message history (system + conversation)
and returns a validated ParsedTurn.
"""
from __future__ import annotations
from typing import Sequence
from huggingface_hub import InferenceClient
from huggingface_hub.utils import HfHubHTTPError
from .config import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE, MODEL_ID
from .evaluation import ParsedTurn, parse_model_output
# Single chat message in OpenAI format. Kept loose for v1; can tighten to
# a TypedDict later if message shapes diversify.
ChatMessage = dict[str, str]
class InferenceError(Exception):
"""Raised when the inference API call fails or returns malformed data.
Wraps network errors, authentication failures, rate-limit errors,
and missing-field errors in the API response. Parse errors on the
model's content are *not* wrapped here — they surface as
EvaluationParseError so the app layer can distinguish them.
"""
class PrismaInferenceClient:
"""Wrapper around huggingface_hub.InferenceClient configured for Prisma.
Holds a single InferenceClient instance and exposes a ``generate()``
method that takes a full message history and returns a validated
``ParsedTurn``.
JSON output is forced unconditionally via the ``response_format``
parameter. This is required for Llama 3.3 70B and harmless on models
that already comply with prompt-level JSON instructions, so we apply
it uniformly for consistency across model families.
Args:
token: HuggingFace access token with inference permissions.
model_id: Model to call. Defaults to ``MODEL_ID`` from config.
temperature: Sampling temperature.
max_tokens: Maximum tokens per response.
Raises:
ValueError: If ``token`` is empty.
"""
def __init__(
self,
token: str,
model_id: str = MODEL_ID,
temperature: float = DEFAULT_TEMPERATURE,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> None:
if not token:
raise ValueError("token must be a non-empty string")
self._client = InferenceClient(token=token)
self._model_id = model_id
self._temperature = temperature
self._max_tokens = max_tokens
@property
def model_id(self) -> str:
"""The model ID this client is configured to use."""
return self._model_id
def generate(self, messages: Sequence[ChatMessage]) -> ParsedTurn:
"""Send a chat completion request and return a parsed turn.
Args:
messages: Full chat history including the system message as the
first entry. Each message is a dict with ``role`` and
``content`` keys (OpenAI format).
Returns:
A ``ParsedTurn`` with the response text and validated
evaluation scores.
Raises:
ValueError: If ``messages`` is empty.
InferenceError: If the API call itself fails (auth, rate limit,
network, malformed response envelope).
EvaluationParseError: If the model's content cannot be parsed
or validated against the expected attribute schema.
"""
if not messages:
raise ValueError("messages must not be empty")
try:
completion = self._client.chat_completion(
model=self._model_id,
messages=list(messages),
max_tokens=self._max_tokens,
temperature=self._temperature,
response_format={"type": "json_object"},
)
except HfHubHTTPError as exc:
raise InferenceError(
f"HF Inference API request failed: {exc}"
) from exc
except Exception as exc:
raise InferenceError(
f"Unexpected error during inference call: {exc}"
) from exc
try:
raw = completion.choices[0].message.content
except (AttributeError, IndexError, TypeError) as exc:
raise InferenceError(
f"Inference response missing expected fields: {exc}"
) from exc
if not isinstance(raw, str) or not raw.strip():
raise InferenceError(
"Inference response content is empty or non-text"
)
return parse_model_output(raw)
|