| """Minimal OpenAI-compatible clients for benchmark-only LLM baselines.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import time |
| from dataclasses import dataclass |
| from typing import cast |
|
|
| import httpx |
|
|
|
|
| class ProviderRequestError(RuntimeError): |
| """Raised when a provider rejects a benchmark request payload.""" |
|
|
|
|
| class ProviderRateLimitError(ProviderRequestError): |
| """Raised when a provider asks us to wait longer than the configured cap.""" |
|
|
|
|
| def _is_rate_limit_error(exc: BaseException) -> bool: |
| """Return whether an exception is an HTTP 429 response.""" |
| return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429 |
|
|
|
|
| def _is_retryable_provider_error(exc: BaseException) -> bool: |
| """Return whether an HTTP error is worth retrying for teacher collection.""" |
| return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code in {429, 503} |
|
|
|
|
| def _retry_after_s(exc: httpx.HTTPStatusError, *, fallback_s: float) -> float: |
| """Return provider retry-after delay when present.""" |
| raw_retry_after = exc.response.headers.get("retry-after") |
| if raw_retry_after is None: |
| return fallback_s |
| try: |
| return max(float(raw_retry_after), fallback_s) |
| except ValueError: |
| return fallback_s |
|
|
|
|
| @dataclass(frozen=True, kw_only=True) |
| class GroqCompletion: |
| """Completion payload plus conservative usage accounting.""" |
|
|
| text: str |
| prompt_tokens: int |
| completion_tokens: int |
| warnings: tuple[str, ...] |
|
|
|
|
| class OpenAICompatBenchClient: |
| """Sequential OpenAI-compatible client with fixed 429 retry and spacing.""" |
|
|
| def __init__( |
| self, |
| *, |
| api_key: str, |
| model: str, |
| endpoint: str, |
| provider: str, |
| min_interval_s: float = 2.0, |
| max_tokens: int = 512, |
| max_retries: int = 5, |
| max_retry_after_s: float = 120.0, |
| timeout_s: float = 60.0, |
| ) -> None: |
| self._api_key = api_key |
| self._model = model |
| self._endpoint = endpoint |
| self._provider = provider |
| self._min_interval_s = min_interval_s |
| self._max_tokens = max_tokens |
| self._max_retries = max_retries |
| self._max_retry_after_s = max_retry_after_s |
| self._timeout_s = timeout_s |
| self._last_success_at: float | None = None |
| self._client = httpx.Client( |
| timeout=self._timeout_s, |
| headers={ |
| "Authorization": f"Bearer {self._api_key}", |
| "Content-Type": "application/json", |
| }, |
| ) |
|
|
| @property |
| def model(self) -> str: |
| """Return the configured provider model name.""" |
| return self._model |
|
|
| @property |
| def provider(self) -> str: |
| """Return the configured provider identifier.""" |
| return self._provider |
|
|
| def _respect_spacing(self) -> None: |
| """Sleep long enough to keep requests sequential with a fixed gap.""" |
| if self._last_success_at is None: |
| return |
| elapsed = time.monotonic() - self._last_success_at |
| remaining = self._min_interval_s - elapsed |
| if remaining > 0: |
| time.sleep(remaining) |
|
|
| def _post(self, messages: list[dict[str, str]]) -> dict[str, object]: |
| """Issue the underlying chat-completions request.""" |
| payload = { |
| "model": self._model, |
| "messages": messages, |
| "temperature": 0.0, |
| "max_tokens": self._max_tokens, |
| } |
| last_rate_limit_error: httpx.HTTPStatusError | None = None |
| for attempt in range(self._max_retries): |
| response: httpx.Response | None = None |
| try: |
| response = self._client.post( |
| self._endpoint, |
| json=payload, |
| ) |
| response.raise_for_status() |
| except httpx.HTTPStatusError as exc: |
| if not _is_retryable_provider_error(exc) or attempt == self._max_retries - 1: |
| body = exc.response.text[:500].replace("\n", " ") |
| raise ProviderRequestError( |
| f"{self._provider} request rejected with HTTP " |
| f"{exc.response.status_code}: {body}" |
| ) from exc |
| last_rate_limit_error = exc |
| retry_s = _retry_after_s(exc, fallback_s=2.0 * (attempt + 1)) |
| if retry_s > self._max_retry_after_s: |
| body = exc.response.text[:500].replace("\n", " ") |
| raise ProviderRateLimitError( |
| f"{self._provider} rate limit retry-after {retry_s:.2f}s " |
| f"exceeds cap {self._max_retry_after_s:.2f}s: {body}" |
| ) from exc |
| logging.getLogger("dataforge.bench.groq_client").warning( |
| "%s_rate_limit attempt=%d retry_after_s=%.2f", |
| self._provider, |
| attempt + 1, |
| retry_s, |
| ) |
| time.sleep(retry_s) |
| continue |
| except httpx.TimeoutException as exc: |
| raise TimeoutError( |
| f"{self._provider} request timed out after {self._timeout_s:.1f} seconds." |
| ) from exc |
| return dict(response.json()) |
| if last_rate_limit_error is not None: |
| raise last_rate_limit_error |
| raise RuntimeError(f"{self._provider} request failed without a response.") |
|
|
| def complete(self, messages: list[dict[str, str]]) -> GroqCompletion: |
| """Send one benchmark completion request to the configured provider.""" |
| self._respect_spacing() |
| payload = self._post(messages) |
| self._last_success_at = time.monotonic() |
|
|
| warnings: list[str] = [] |
| usage = payload.get("usage", {}) |
| prompt_tokens = int(usage.get("prompt_tokens", 0)) if isinstance(usage, dict) else 0 |
| completion_tokens = int(usage.get("completion_tokens", 0)) if isinstance(usage, dict) else 0 |
| if not usage: |
| warnings.append("missing_usage_payload") |
| logging.getLogger("dataforge.bench.groq_client").warning( |
| "%s_missing_usage_payload", self._provider |
| ) |
|
|
| try: |
| choices = cast(list[dict[str, object]], payload["choices"]) |
| message = cast(dict[str, object], choices[0]["message"]) |
| content = str(message["content"]) |
| except (KeyError, IndexError, TypeError) as exc: |
| raise ValueError( |
| f"Unexpected {self._provider} response payload: {json.dumps(payload)}" |
| ) from exc |
| return GroqCompletion( |
| text=content, |
| prompt_tokens=prompt_tokens, |
| completion_tokens=completion_tokens, |
| warnings=tuple(warnings), |
| ) |
|
|
|
|
| class GroqBenchClient(OpenAICompatBenchClient): |
| """Sequential Groq client with fixed 429 retry and spacing.""" |
|
|
| def __init__( |
| self, |
| *, |
| api_key: str, |
| model: str = "llama-3.3-70b-versatile", |
| min_interval_s: float = 2.0, |
| max_tokens: int = 512, |
| max_retries: int = 5, |
| max_retry_after_s: float = 120.0, |
| timeout_s: float = 60.0, |
| ) -> None: |
| super().__init__( |
| api_key=api_key, |
| model=model, |
| endpoint="https://api.groq.com/openai/v1/chat/completions", |
| provider="groq", |
| min_interval_s=min_interval_s, |
| max_tokens=max_tokens, |
| max_retries=max_retries, |
| max_retry_after_s=max_retry_after_s, |
| timeout_s=timeout_s, |
| ) |
|
|
|
|
| class CerebrasBenchClient(OpenAICompatBenchClient): |
| """Sequential Cerebras client with fixed 429 retry and spacing.""" |
|
|
| def __init__( |
| self, |
| *, |
| api_key: str, |
| model: str = "qwen-3-235b-a22b-instruct-2507", |
| min_interval_s: float = 0.5, |
| max_tokens: int = 512, |
| max_retries: int = 5, |
| max_retry_after_s: float = 120.0, |
| timeout_s: float = 60.0, |
| ) -> None: |
| super().__init__( |
| api_key=api_key, |
| model=model, |
| endpoint="https://api.cerebras.ai/v1/chat/completions", |
| provider="cerebras", |
| min_interval_s=min_interval_s, |
| max_tokens=max_tokens, |
| max_retries=max_retries, |
| max_retry_after_s=max_retry_after_s, |
| timeout_s=timeout_s, |
| ) |
|
|
|
|
| class GeminiBenchClient: |
| """Sequential Gemini client adapted to the benchmark completion interface.""" |
|
|
| def __init__( |
| self, |
| *, |
| api_key: str, |
| model: str = "gemini-3.1-pro-preview", |
| min_interval_s: float = 2.0, |
| max_tokens: int = 512, |
| max_retries: int = 5, |
| max_retry_after_s: float = 120.0, |
| timeout_s: float = 60.0, |
| ) -> None: |
| self._api_key = api_key |
| self._model = model.removeprefix("models/") |
| self._min_interval_s = min_interval_s |
| self._max_tokens = max_tokens |
| self._max_retries = max_retries |
| self._max_retry_after_s = max_retry_after_s |
| self._timeout_s = timeout_s |
| self._last_success_at: float | None = None |
| self._client = httpx.Client( |
| timeout=self._timeout_s, |
| headers={"Content-Type": "application/json"}, |
| ) |
|
|
| @property |
| def model(self) -> str: |
| """Return the configured Gemini model name.""" |
| return self._model |
|
|
| @property |
| def provider(self) -> str: |
| """Return the provider identifier.""" |
| return "gemini" |
|
|
| def _respect_spacing(self) -> None: |
| """Sleep long enough to keep requests sequential with a fixed gap.""" |
| if self._last_success_at is None: |
| return |
| elapsed = time.monotonic() - self._last_success_at |
| remaining = self._min_interval_s - elapsed |
| if remaining > 0: |
| time.sleep(remaining) |
|
|
| def _payload(self, messages: list[dict[str, str]]) -> dict[str, object]: |
| """Convert OpenAI-style chat messages to Gemini generateContent payload.""" |
| system_texts: list[str] = [] |
| contents: list[dict[str, object]] = [] |
| for message in messages: |
| role = message.get("role", "user") |
| content = message.get("content", "") |
| if role == "system": |
| system_texts.append(content) |
| continue |
| gemini_role = "model" if role == "assistant" else "user" |
| contents.append({"role": gemini_role, "parts": [{"text": content}]}) |
|
|
| payload: dict[str, object] = { |
| "contents": contents, |
| "generationConfig": { |
| "temperature": 0.0, |
| "maxOutputTokens": self._max_tokens, |
| }, |
| } |
| if system_texts: |
| payload["systemInstruction"] = { |
| "parts": [{"text": "\n\n".join(system_texts)}], |
| } |
| return payload |
|
|
| def _post(self, messages: list[dict[str, str]]) -> dict[str, object]: |
| """Issue the underlying Gemini generateContent request.""" |
| endpoint = ( |
| f"https://generativelanguage.googleapis.com/v1beta/models/{self._model}:generateContent" |
| ) |
| last_rate_limit_error: httpx.HTTPStatusError | None = None |
| for attempt in range(self._max_retries): |
| response: httpx.Response | None = None |
| try: |
| response = self._client.post( |
| endpoint, |
| params={"key": self._api_key}, |
| json=self._payload(messages), |
| ) |
| response.raise_for_status() |
| except httpx.HTTPStatusError as exc: |
| if not _is_retryable_provider_error(exc) or attempt == self._max_retries - 1: |
| body = exc.response.text[:500].replace("\n", " ") |
| raise ProviderRequestError( |
| f"gemini request rejected with HTTP {exc.response.status_code}: {body}" |
| ) from exc |
| last_rate_limit_error = exc |
| retry_s = _retry_after_s(exc, fallback_s=2.0 * (attempt + 1)) |
| if retry_s > self._max_retry_after_s: |
| body = exc.response.text[:500].replace("\n", " ") |
| raise ProviderRateLimitError( |
| f"gemini rate limit retry-after {retry_s:.2f}s " |
| f"exceeds cap {self._max_retry_after_s:.2f}s: {body}" |
| ) from exc |
| logging.getLogger("dataforge.bench.groq_client").warning( |
| "gemini_rate_limit attempt=%d retry_after_s=%.2f", |
| attempt + 1, |
| retry_s, |
| ) |
| time.sleep(retry_s) |
| continue |
| except httpx.TimeoutException as exc: |
| raise TimeoutError( |
| f"gemini request timed out after {self._timeout_s:.1f} seconds." |
| ) from exc |
| return dict(response.json()) |
| if last_rate_limit_error is not None: |
| raise last_rate_limit_error |
| raise RuntimeError("gemini request failed without a response.") |
|
|
| def complete(self, messages: list[dict[str, str]]) -> GroqCompletion: |
| """Send one benchmark completion request to Gemini.""" |
| self._respect_spacing() |
| payload = self._post(messages) |
| self._last_success_at = time.monotonic() |
|
|
| warnings: list[str] = [] |
| usage = payload.get("usageMetadata", {}) |
| prompt_tokens = int(usage.get("promptTokenCount", 0)) if isinstance(usage, dict) else 0 |
| completion_tokens = ( |
| int(usage.get("candidatesTokenCount", 0)) if isinstance(usage, dict) else 0 |
| ) |
| if not usage: |
| warnings.append("missing_usage_payload") |
| logging.getLogger("dataforge.bench.groq_client").warning("gemini_missing_usage_payload") |
|
|
| try: |
| candidates = cast(list[dict[str, object]], payload["candidates"]) |
| content = cast(dict[str, object], candidates[0]["content"]) |
| parts = cast(list[dict[str, object]], content["parts"]) |
| text = "".join(str(part.get("text", "")) for part in parts) |
| except (KeyError, IndexError, TypeError) as exc: |
| raise ValueError(f"Unexpected gemini response payload: {json.dumps(payload)}") from exc |
| return GroqCompletion( |
| text=text, |
| prompt_tokens=prompt_tokens, |
| completion_tokens=completion_tokens, |
| warnings=tuple(warnings), |
| ) |
|
|