Praneshrajan15's picture
feat: initial playground deployment
5143557 verified
"""Multi-provider LLM client for DataForge.
Reads ``DATAFORGE_LLM_PROVIDER`` from the environment and dispatches to the
matching provider. Week 1 implements **groq** and **gemini** only; other
providers raise ``NotImplementedError``.
No LLM calls are made by detectors β€” this module is for the agent loop
(Week 2+) and is stubbed here to establish the interface.
The interface is:
``async def complete(messages, model, temperature) -> str``
"""
from __future__ import annotations
import os
from typing import Literal, TypedDict
import httpx
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
# ── Message type ──────────────────────────────────────────────────────────
class Message(TypedDict):
"""A single chat message.
Args:
role: The speaker role β€” ``"system"``, ``"user"``, or ``"assistant"``.
content: The text content of the message.
"""
role: Literal["system", "user", "assistant"]
content: str
# ── Exceptions ────────────────────────────────────────────────────────────
class ProviderError(Exception):
"""Raised when an LLM provider call fails after retries.
Args:
provider: The provider name that failed.
message: Description of the failure.
"""
def __init__(self, provider: str, message: str) -> None:
self.provider = provider
super().__init__(f"[{provider}] {message}")
# ── Provider dispatch ─────────────────────────────────────────────────────
_SUPPORTED_PROVIDERS = frozenset({"groq", "gemini", "cerebras", "openrouter", "hf", "cloudflare"})
def get_provider_name() -> str:
"""Read the active provider from the environment.
Returns:
The lowercased provider name from ``DATAFORGE_LLM_PROVIDER``,
defaulting to ``"groq"`` if not set.
Example:
>>> import os
>>> os.environ["DATAFORGE_LLM_PROVIDER"] = "gemini"
>>> get_provider_name()
'gemini'
"""
return os.environ.get("DATAFORGE_LLM_PROVIDER", "groq").lower()
async def complete(
messages: list[Message],
*,
model: str | None = None,
temperature: float = 0.0,
) -> str:
"""Send a chat completion request to the active LLM provider.
Args:
messages: List of chat messages forming the conversation.
model: Optional model override. If None, uses the provider default.
temperature: Sampling temperature (0.0 = deterministic).
Returns:
The assistant's response text.
Raises:
NotImplementedError: If the provider is not yet implemented.
ProviderError: If the API call fails after retries.
Example:
>>> import asyncio
>>> msgs = [{"role": "user", "content": "What is 2+2?"}]
>>> # result = asyncio.run(complete(msgs)) # requires API key
"""
provider = get_provider_name()
if provider == "groq":
return await _complete_groq(messages, model=model, temperature=temperature)
if provider == "gemini":
return await _complete_gemini(messages, model=model, temperature=temperature)
if provider in _SUPPORTED_PROVIDERS:
raise NotImplementedError(
f"Provider '{provider}' is planned but not yet implemented. "
f"Use 'groq' or 'gemini' for Week 1."
)
raise NotImplementedError(
f"Unknown provider '{provider}'. Supported: {sorted(_SUPPORTED_PROVIDERS)}"
)
# ── Groq provider ────────────────────────────────────────────────────────
_GROQ_URL = "https://api.groq.com/openai/v1/chat/completions"
_GROQ_DEFAULT_MODEL = "llama-3.1-70b-versatile"
@retry(
retry=retry_if_exception_type(httpx.HTTPStatusError),
wait=wait_exponential(multiplier=1, min=1, max=30),
stop=stop_after_attempt(3),
reraise=True,
)
async def _complete_groq(
messages: list[Message],
*,
model: str | None = None,
temperature: float = 0.0,
) -> str:
"""Call Groq's OpenAI-compatible chat completions API.
Args:
messages: Chat messages.
model: Model name (defaults to llama-3.1-70b-versatile).
temperature: Sampling temperature.
Returns:
The assistant's response text.
Raises:
ProviderError: If the response is malformed.
"""
api_key = os.environ.get("GROQ_API_KEY", "")
if not api_key:
raise ProviderError("groq", "GROQ_API_KEY environment variable not set")
payload = {
"model": model or _GROQ_DEFAULT_MODEL,
"messages": [dict(m) for m in messages],
"temperature": temperature,
}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
_GROQ_URL,
json=payload,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
data = response.json()
try:
return str(data["choices"][0]["message"]["content"])
except (KeyError, IndexError) as exc:
raise ProviderError("groq", f"Unexpected response format: {data}") from exc
# ── Gemini provider ──────────────────────────────────────────────────────
_GEMINI_URL_TEMPLATE = (
"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent"
)
_GEMINI_DEFAULT_MODEL = "gemini-2.0-flash"
@retry(
retry=retry_if_exception_type(httpx.HTTPStatusError),
wait=wait_exponential(multiplier=1, min=1, max=30),
stop=stop_after_attempt(3),
reraise=True,
)
async def _complete_gemini(
messages: list[Message],
*,
model: str | None = None,
temperature: float = 0.0,
) -> str:
"""Call Google's Gemini generativeLanguage API.
Args:
messages: Chat messages (converted to Gemini's content format).
model: Model name (defaults to gemini-2.0-flash).
temperature: Sampling temperature.
Returns:
The assistant's response text.
Raises:
ProviderError: If the response is malformed.
"""
api_key = os.environ.get("GEMINI_API_KEY", "")
if not api_key:
raise ProviderError("gemini", "GEMINI_API_KEY environment variable not set")
model_name = model or _GEMINI_DEFAULT_MODEL
url = _GEMINI_URL_TEMPLATE.format(model=model_name)
# Convert OpenAI-style messages to Gemini format.
contents: list[dict[str, object]] = []
system_instruction: str | None = None
for msg in messages:
if msg["role"] == "system":
system_instruction = msg["content"]
else:
role = "user" if msg["role"] == "user" else "model"
contents.append(
{
"role": role,
"parts": [{"text": msg["content"]}],
}
)
payload: dict[str, object] = {
"contents": contents,
"generationConfig": {"temperature": temperature},
}
if system_instruction:
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
url,
json=payload,
params={"key": api_key},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
data = response.json()
try:
return str(data["candidates"][0]["content"]["parts"][0]["text"])
except (KeyError, IndexError) as exc:
raise ProviderError("gemini", f"Unexpected response format: {data}") from exc