Praneshrajan15's picture
Deploy DataForge playground API
eed1cab verified
"""Minimal OpenAI-compatible clients for benchmark-only LLM baselines."""
from __future__ import annotations
import json
import logging
import time
from dataclasses import dataclass
from typing import cast
import httpx
class ProviderRequestError(RuntimeError):
"""Raised when a provider rejects a benchmark request payload."""
class ProviderRateLimitError(ProviderRequestError):
"""Raised when a provider asks us to wait longer than the configured cap."""
def _is_rate_limit_error(exc: BaseException) -> bool:
"""Return whether an exception is an HTTP 429 response."""
return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429
def _is_retryable_provider_error(exc: BaseException) -> bool:
"""Return whether an HTTP error is worth retrying for teacher collection."""
return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code in {429, 503}
def _retry_after_s(exc: httpx.HTTPStatusError, *, fallback_s: float) -> float:
"""Return provider retry-after delay when present."""
raw_retry_after = exc.response.headers.get("retry-after")
if raw_retry_after is None:
return fallback_s
try:
return max(float(raw_retry_after), fallback_s)
except ValueError:
return fallback_s
@dataclass(frozen=True, kw_only=True)
class GroqCompletion:
"""Completion payload plus conservative usage accounting."""
text: str
prompt_tokens: int
completion_tokens: int
warnings: tuple[str, ...]
class OpenAICompatBenchClient:
"""Sequential OpenAI-compatible client with fixed 429 retry and spacing."""
def __init__(
self,
*,
api_key: str,
model: str,
endpoint: str,
provider: str,
min_interval_s: float = 2.0,
max_tokens: int = 512,
max_retries: int = 5,
max_retry_after_s: float = 120.0,
timeout_s: float = 60.0,
) -> None:
self._api_key = api_key
self._model = model
self._endpoint = endpoint
self._provider = provider
self._min_interval_s = min_interval_s
self._max_tokens = max_tokens
self._max_retries = max_retries
self._max_retry_after_s = max_retry_after_s
self._timeout_s = timeout_s
self._last_success_at: float | None = None
self._client = httpx.Client(
timeout=self._timeout_s,
headers={
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
},
)
@property
def model(self) -> str:
"""Return the configured provider model name."""
return self._model
@property
def provider(self) -> str:
"""Return the configured provider identifier."""
return self._provider
def _respect_spacing(self) -> None:
"""Sleep long enough to keep requests sequential with a fixed gap."""
if self._last_success_at is None:
return
elapsed = time.monotonic() - self._last_success_at
remaining = self._min_interval_s - elapsed
if remaining > 0:
time.sleep(remaining)
def _post(self, messages: list[dict[str, str]]) -> dict[str, object]:
"""Issue the underlying chat-completions request."""
payload = {
"model": self._model,
"messages": messages,
"temperature": 0.0,
"max_tokens": self._max_tokens,
}
last_rate_limit_error: httpx.HTTPStatusError | None = None
for attempt in range(self._max_retries):
response: httpx.Response | None = None
try:
response = self._client.post(
self._endpoint,
json=payload,
)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
if not _is_retryable_provider_error(exc) or attempt == self._max_retries - 1:
body = exc.response.text[:500].replace("\n", " ")
raise ProviderRequestError(
f"{self._provider} request rejected with HTTP "
f"{exc.response.status_code}: {body}"
) from exc
last_rate_limit_error = exc
retry_s = _retry_after_s(exc, fallback_s=2.0 * (attempt + 1))
if retry_s > self._max_retry_after_s:
body = exc.response.text[:500].replace("\n", " ")
raise ProviderRateLimitError(
f"{self._provider} rate limit retry-after {retry_s:.2f}s "
f"exceeds cap {self._max_retry_after_s:.2f}s: {body}"
) from exc
logging.getLogger("dataforge.bench.groq_client").warning(
"%s_rate_limit attempt=%d retry_after_s=%.2f",
self._provider,
attempt + 1,
retry_s,
)
time.sleep(retry_s)
continue
except httpx.TimeoutException as exc:
raise TimeoutError(
f"{self._provider} request timed out after {self._timeout_s:.1f} seconds."
) from exc
return dict(response.json())
if last_rate_limit_error is not None:
raise last_rate_limit_error
raise RuntimeError(f"{self._provider} request failed without a response.")
def complete(self, messages: list[dict[str, str]]) -> GroqCompletion:
"""Send one benchmark completion request to the configured provider."""
self._respect_spacing()
payload = self._post(messages)
self._last_success_at = time.monotonic()
warnings: list[str] = []
usage = payload.get("usage", {})
prompt_tokens = int(usage.get("prompt_tokens", 0)) if isinstance(usage, dict) else 0
completion_tokens = int(usage.get("completion_tokens", 0)) if isinstance(usage, dict) else 0
if not usage:
warnings.append("missing_usage_payload")
logging.getLogger("dataforge.bench.groq_client").warning(
"%s_missing_usage_payload", self._provider
)
try:
choices = cast(list[dict[str, object]], payload["choices"])
message = cast(dict[str, object], choices[0]["message"])
content = str(message["content"])
except (KeyError, IndexError, TypeError) as exc:
raise ValueError(
f"Unexpected {self._provider} response payload: {json.dumps(payload)}"
) from exc
return GroqCompletion(
text=content,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
warnings=tuple(warnings),
)
class GroqBenchClient(OpenAICompatBenchClient):
"""Sequential Groq client with fixed 429 retry and spacing."""
def __init__(
self,
*,
api_key: str,
model: str = "llama-3.3-70b-versatile",
min_interval_s: float = 2.0,
max_tokens: int = 512,
max_retries: int = 5,
max_retry_after_s: float = 120.0,
timeout_s: float = 60.0,
) -> None:
super().__init__(
api_key=api_key,
model=model,
endpoint="https://api.groq.com/openai/v1/chat/completions",
provider="groq",
min_interval_s=min_interval_s,
max_tokens=max_tokens,
max_retries=max_retries,
max_retry_after_s=max_retry_after_s,
timeout_s=timeout_s,
)
class CerebrasBenchClient(OpenAICompatBenchClient):
"""Sequential Cerebras client with fixed 429 retry and spacing."""
def __init__(
self,
*,
api_key: str,
model: str = "qwen-3-235b-a22b-instruct-2507",
min_interval_s: float = 0.5,
max_tokens: int = 512,
max_retries: int = 5,
max_retry_after_s: float = 120.0,
timeout_s: float = 60.0,
) -> None:
super().__init__(
api_key=api_key,
model=model,
endpoint="https://api.cerebras.ai/v1/chat/completions",
provider="cerebras",
min_interval_s=min_interval_s,
max_tokens=max_tokens,
max_retries=max_retries,
max_retry_after_s=max_retry_after_s,
timeout_s=timeout_s,
)
class GeminiBenchClient:
"""Sequential Gemini client adapted to the benchmark completion interface."""
def __init__(
self,
*,
api_key: str,
model: str = "gemini-3.1-pro-preview",
min_interval_s: float = 2.0,
max_tokens: int = 512,
max_retries: int = 5,
max_retry_after_s: float = 120.0,
timeout_s: float = 60.0,
) -> None:
self._api_key = api_key
self._model = model.removeprefix("models/")
self._min_interval_s = min_interval_s
self._max_tokens = max_tokens
self._max_retries = max_retries
self._max_retry_after_s = max_retry_after_s
self._timeout_s = timeout_s
self._last_success_at: float | None = None
self._client = httpx.Client(
timeout=self._timeout_s,
headers={"Content-Type": "application/json"},
)
@property
def model(self) -> str:
"""Return the configured Gemini model name."""
return self._model
@property
def provider(self) -> str:
"""Return the provider identifier."""
return "gemini"
def _respect_spacing(self) -> None:
"""Sleep long enough to keep requests sequential with a fixed gap."""
if self._last_success_at is None:
return
elapsed = time.monotonic() - self._last_success_at
remaining = self._min_interval_s - elapsed
if remaining > 0:
time.sleep(remaining)
def _payload(self, messages: list[dict[str, str]]) -> dict[str, object]:
"""Convert OpenAI-style chat messages to Gemini generateContent payload."""
system_texts: list[str] = []
contents: list[dict[str, object]] = []
for message in messages:
role = message.get("role", "user")
content = message.get("content", "")
if role == "system":
system_texts.append(content)
continue
gemini_role = "model" if role == "assistant" else "user"
contents.append({"role": gemini_role, "parts": [{"text": content}]})
payload: dict[str, object] = {
"contents": contents,
"generationConfig": {
"temperature": 0.0,
"maxOutputTokens": self._max_tokens,
},
}
if system_texts:
payload["systemInstruction"] = {
"parts": [{"text": "\n\n".join(system_texts)}],
}
return payload
def _post(self, messages: list[dict[str, str]]) -> dict[str, object]:
"""Issue the underlying Gemini generateContent request."""
endpoint = (
f"https://generativelanguage.googleapis.com/v1beta/models/{self._model}:generateContent"
)
last_rate_limit_error: httpx.HTTPStatusError | None = None
for attempt in range(self._max_retries):
response: httpx.Response | None = None
try:
response = self._client.post(
endpoint,
params={"key": self._api_key},
json=self._payload(messages),
)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
if not _is_retryable_provider_error(exc) or attempt == self._max_retries - 1:
body = exc.response.text[:500].replace("\n", " ")
raise ProviderRequestError(
f"gemini request rejected with HTTP {exc.response.status_code}: {body}"
) from exc
last_rate_limit_error = exc
retry_s = _retry_after_s(exc, fallback_s=2.0 * (attempt + 1))
if retry_s > self._max_retry_after_s:
body = exc.response.text[:500].replace("\n", " ")
raise ProviderRateLimitError(
f"gemini rate limit retry-after {retry_s:.2f}s "
f"exceeds cap {self._max_retry_after_s:.2f}s: {body}"
) from exc
logging.getLogger("dataforge.bench.groq_client").warning(
"gemini_rate_limit attempt=%d retry_after_s=%.2f",
attempt + 1,
retry_s,
)
time.sleep(retry_s)
continue
except httpx.TimeoutException as exc:
raise TimeoutError(
f"gemini request timed out after {self._timeout_s:.1f} seconds."
) from exc
return dict(response.json())
if last_rate_limit_error is not None:
raise last_rate_limit_error
raise RuntimeError("gemini request failed without a response.")
def complete(self, messages: list[dict[str, str]]) -> GroqCompletion:
"""Send one benchmark completion request to Gemini."""
self._respect_spacing()
payload = self._post(messages)
self._last_success_at = time.monotonic()
warnings: list[str] = []
usage = payload.get("usageMetadata", {})
prompt_tokens = int(usage.get("promptTokenCount", 0)) if isinstance(usage, dict) else 0
completion_tokens = (
int(usage.get("candidatesTokenCount", 0)) if isinstance(usage, dict) else 0
)
if not usage:
warnings.append("missing_usage_payload")
logging.getLogger("dataforge.bench.groq_client").warning("gemini_missing_usage_payload")
try:
candidates = cast(list[dict[str, object]], payload["candidates"])
content = cast(dict[str, object], candidates[0]["content"])
parts = cast(list[dict[str, object]], content["parts"])
text = "".join(str(part.get("text", "")) for part in parts)
except (KeyError, IndexError, TypeError) as exc:
raise ValueError(f"Unexpected gemini response payload: {json.dumps(payload)}") from exc
return GroqCompletion(
text=text,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
warnings=tuple(warnings),
)