"""Minimal OpenAI-compatible clients for benchmark-only LLM baselines.""" from __future__ import annotations import json import logging import time from dataclasses import dataclass from typing import cast import httpx class ProviderRequestError(RuntimeError): """Raised when a provider rejects a benchmark request payload.""" class ProviderRateLimitError(ProviderRequestError): """Raised when a provider asks us to wait longer than the configured cap.""" def _is_rate_limit_error(exc: BaseException) -> bool: """Return whether an exception is an HTTP 429 response.""" return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429 def _is_retryable_provider_error(exc: BaseException) -> bool: """Return whether an HTTP error is worth retrying for teacher collection.""" return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code in {429, 503} def _retry_after_s(exc: httpx.HTTPStatusError, *, fallback_s: float) -> float: """Return provider retry-after delay when present.""" raw_retry_after = exc.response.headers.get("retry-after") if raw_retry_after is None: return fallback_s try: return max(float(raw_retry_after), fallback_s) except ValueError: return fallback_s @dataclass(frozen=True, kw_only=True) class GroqCompletion: """Completion payload plus conservative usage accounting.""" text: str prompt_tokens: int completion_tokens: int warnings: tuple[str, ...] class OpenAICompatBenchClient: """Sequential OpenAI-compatible client with fixed 429 retry and spacing.""" def __init__( self, *, api_key: str, model: str, endpoint: str, provider: str, min_interval_s: float = 2.0, max_tokens: int = 512, max_retries: int = 5, max_retry_after_s: float = 120.0, timeout_s: float = 60.0, ) -> None: self._api_key = api_key self._model = model self._endpoint = endpoint self._provider = provider self._min_interval_s = min_interval_s self._max_tokens = max_tokens self._max_retries = max_retries self._max_retry_after_s = max_retry_after_s self._timeout_s = timeout_s self._last_success_at: float | None = None self._client = httpx.Client( timeout=self._timeout_s, headers={ "Authorization": f"Bearer {self._api_key}", "Content-Type": "application/json", }, ) @property def model(self) -> str: """Return the configured provider model name.""" return self._model @property def provider(self) -> str: """Return the configured provider identifier.""" return self._provider def _respect_spacing(self) -> None: """Sleep long enough to keep requests sequential with a fixed gap.""" if self._last_success_at is None: return elapsed = time.monotonic() - self._last_success_at remaining = self._min_interval_s - elapsed if remaining > 0: time.sleep(remaining) def _post(self, messages: list[dict[str, str]]) -> dict[str, object]: """Issue the underlying chat-completions request.""" payload = { "model": self._model, "messages": messages, "temperature": 0.0, "max_tokens": self._max_tokens, } last_rate_limit_error: httpx.HTTPStatusError | None = None for attempt in range(self._max_retries): response: httpx.Response | None = None try: response = self._client.post( self._endpoint, json=payload, ) response.raise_for_status() except httpx.HTTPStatusError as exc: if not _is_retryable_provider_error(exc) or attempt == self._max_retries - 1: body = exc.response.text[:500].replace("\n", " ") raise ProviderRequestError( f"{self._provider} request rejected with HTTP " f"{exc.response.status_code}: {body}" ) from exc last_rate_limit_error = exc retry_s = _retry_after_s(exc, fallback_s=2.0 * (attempt + 1)) if retry_s > self._max_retry_after_s: body = exc.response.text[:500].replace("\n", " ") raise ProviderRateLimitError( f"{self._provider} rate limit retry-after {retry_s:.2f}s " f"exceeds cap {self._max_retry_after_s:.2f}s: {body}" ) from exc logging.getLogger("dataforge.bench.groq_client").warning( "%s_rate_limit attempt=%d retry_after_s=%.2f", self._provider, attempt + 1, retry_s, ) time.sleep(retry_s) continue except httpx.TimeoutException as exc: raise TimeoutError( f"{self._provider} request timed out after {self._timeout_s:.1f} seconds." ) from exc return dict(response.json()) if last_rate_limit_error is not None: raise last_rate_limit_error raise RuntimeError(f"{self._provider} request failed without a response.") def complete(self, messages: list[dict[str, str]]) -> GroqCompletion: """Send one benchmark completion request to the configured provider.""" self._respect_spacing() payload = self._post(messages) self._last_success_at = time.monotonic() warnings: list[str] = [] usage = payload.get("usage", {}) prompt_tokens = int(usage.get("prompt_tokens", 0)) if isinstance(usage, dict) else 0 completion_tokens = int(usage.get("completion_tokens", 0)) if isinstance(usage, dict) else 0 if not usage: warnings.append("missing_usage_payload") logging.getLogger("dataforge.bench.groq_client").warning( "%s_missing_usage_payload", self._provider ) try: choices = cast(list[dict[str, object]], payload["choices"]) message = cast(dict[str, object], choices[0]["message"]) content = str(message["content"]) except (KeyError, IndexError, TypeError) as exc: raise ValueError( f"Unexpected {self._provider} response payload: {json.dumps(payload)}" ) from exc return GroqCompletion( text=content, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, warnings=tuple(warnings), ) class GroqBenchClient(OpenAICompatBenchClient): """Sequential Groq client with fixed 429 retry and spacing.""" def __init__( self, *, api_key: str, model: str = "llama-3.3-70b-versatile", min_interval_s: float = 2.0, max_tokens: int = 512, max_retries: int = 5, max_retry_after_s: float = 120.0, timeout_s: float = 60.0, ) -> None: super().__init__( api_key=api_key, model=model, endpoint="https://api.groq.com/openai/v1/chat/completions", provider="groq", min_interval_s=min_interval_s, max_tokens=max_tokens, max_retries=max_retries, max_retry_after_s=max_retry_after_s, timeout_s=timeout_s, ) class CerebrasBenchClient(OpenAICompatBenchClient): """Sequential Cerebras client with fixed 429 retry and spacing.""" def __init__( self, *, api_key: str, model: str = "qwen-3-235b-a22b-instruct-2507", min_interval_s: float = 0.5, max_tokens: int = 512, max_retries: int = 5, max_retry_after_s: float = 120.0, timeout_s: float = 60.0, ) -> None: super().__init__( api_key=api_key, model=model, endpoint="https://api.cerebras.ai/v1/chat/completions", provider="cerebras", min_interval_s=min_interval_s, max_tokens=max_tokens, max_retries=max_retries, max_retry_after_s=max_retry_after_s, timeout_s=timeout_s, ) class GeminiBenchClient: """Sequential Gemini client adapted to the benchmark completion interface.""" def __init__( self, *, api_key: str, model: str = "gemini-3.1-pro-preview", min_interval_s: float = 2.0, max_tokens: int = 512, max_retries: int = 5, max_retry_after_s: float = 120.0, timeout_s: float = 60.0, ) -> None: self._api_key = api_key self._model = model.removeprefix("models/") self._min_interval_s = min_interval_s self._max_tokens = max_tokens self._max_retries = max_retries self._max_retry_after_s = max_retry_after_s self._timeout_s = timeout_s self._last_success_at: float | None = None self._client = httpx.Client( timeout=self._timeout_s, headers={"Content-Type": "application/json"}, ) @property def model(self) -> str: """Return the configured Gemini model name.""" return self._model @property def provider(self) -> str: """Return the provider identifier.""" return "gemini" def _respect_spacing(self) -> None: """Sleep long enough to keep requests sequential with a fixed gap.""" if self._last_success_at is None: return elapsed = time.monotonic() - self._last_success_at remaining = self._min_interval_s - elapsed if remaining > 0: time.sleep(remaining) def _payload(self, messages: list[dict[str, str]]) -> dict[str, object]: """Convert OpenAI-style chat messages to Gemini generateContent payload.""" system_texts: list[str] = [] contents: list[dict[str, object]] = [] for message in messages: role = message.get("role", "user") content = message.get("content", "") if role == "system": system_texts.append(content) continue gemini_role = "model" if role == "assistant" else "user" contents.append({"role": gemini_role, "parts": [{"text": content}]}) payload: dict[str, object] = { "contents": contents, "generationConfig": { "temperature": 0.0, "maxOutputTokens": self._max_tokens, }, } if system_texts: payload["systemInstruction"] = { "parts": [{"text": "\n\n".join(system_texts)}], } return payload def _post(self, messages: list[dict[str, str]]) -> dict[str, object]: """Issue the underlying Gemini generateContent request.""" endpoint = ( f"https://generativelanguage.googleapis.com/v1beta/models/{self._model}:generateContent" ) last_rate_limit_error: httpx.HTTPStatusError | None = None for attempt in range(self._max_retries): response: httpx.Response | None = None try: response = self._client.post( endpoint, params={"key": self._api_key}, json=self._payload(messages), ) response.raise_for_status() except httpx.HTTPStatusError as exc: if not _is_retryable_provider_error(exc) or attempt == self._max_retries - 1: body = exc.response.text[:500].replace("\n", " ") raise ProviderRequestError( f"gemini request rejected with HTTP {exc.response.status_code}: {body}" ) from exc last_rate_limit_error = exc retry_s = _retry_after_s(exc, fallback_s=2.0 * (attempt + 1)) if retry_s > self._max_retry_after_s: body = exc.response.text[:500].replace("\n", " ") raise ProviderRateLimitError( f"gemini rate limit retry-after {retry_s:.2f}s " f"exceeds cap {self._max_retry_after_s:.2f}s: {body}" ) from exc logging.getLogger("dataforge.bench.groq_client").warning( "gemini_rate_limit attempt=%d retry_after_s=%.2f", attempt + 1, retry_s, ) time.sleep(retry_s) continue except httpx.TimeoutException as exc: raise TimeoutError( f"gemini request timed out after {self._timeout_s:.1f} seconds." ) from exc return dict(response.json()) if last_rate_limit_error is not None: raise last_rate_limit_error raise RuntimeError("gemini request failed without a response.") def complete(self, messages: list[dict[str, str]]) -> GroqCompletion: """Send one benchmark completion request to Gemini.""" self._respect_spacing() payload = self._post(messages) self._last_success_at = time.monotonic() warnings: list[str] = [] usage = payload.get("usageMetadata", {}) prompt_tokens = int(usage.get("promptTokenCount", 0)) if isinstance(usage, dict) else 0 completion_tokens = ( int(usage.get("candidatesTokenCount", 0)) if isinstance(usage, dict) else 0 ) if not usage: warnings.append("missing_usage_payload") logging.getLogger("dataforge.bench.groq_client").warning("gemini_missing_usage_payload") try: candidates = cast(list[dict[str, object]], payload["candidates"]) content = cast(dict[str, object], candidates[0]["content"]) parts = cast(list[dict[str, object]], content["parts"]) text = "".join(str(part.get("text", "")) for part in parts) except (KeyError, IndexError, TypeError) as exc: raise ValueError(f"Unexpected gemini response payload: {json.dumps(payload)}") from exc return GroqCompletion( text=text, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, warnings=tuple(warnings), )