| """Multi-provider LLM client for DataForge. |
| |
| Reads ``DATAFORGE_LLM_PROVIDER`` from the environment and dispatches to the |
| matching provider. Week 1 implements **groq** and **gemini** only; other |
| providers raise ``NotImplementedError``. |
| |
| No LLM calls are made by detectors β this module is for the agent loop |
| (Week 2+) and is stubbed here to establish the interface. |
| |
| The interface is: |
| ``async def complete(messages, model, temperature) -> str`` |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| from typing import Literal, TypedDict |
|
|
| import httpx |
| from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential |
|
|
| |
|
|
|
|
| class Message(TypedDict): |
| """A single chat message. |
| |
| Args: |
| role: The speaker role β ``"system"``, ``"user"``, or ``"assistant"``. |
| content: The text content of the message. |
| """ |
|
|
| role: Literal["system", "user", "assistant"] |
| content: str |
|
|
|
|
| |
|
|
|
|
| class ProviderError(Exception): |
| """Raised when an LLM provider call fails after retries. |
| |
| Args: |
| provider: The provider name that failed. |
| message: Description of the failure. |
| """ |
|
|
| def __init__(self, provider: str, message: str) -> None: |
| self.provider = provider |
| super().__init__(f"[{provider}] {message}") |
|
|
|
|
| |
|
|
| _SUPPORTED_PROVIDERS = frozenset({"groq", "gemini", "cerebras", "openrouter", "hf", "cloudflare"}) |
|
|
|
|
| def get_provider_name() -> str: |
| """Read the active provider from the environment. |
| |
| Returns: |
| The lowercased provider name from ``DATAFORGE_LLM_PROVIDER``, |
| defaulting to ``"groq"`` if not set. |
| |
| Example: |
| >>> import os |
| >>> os.environ["DATAFORGE_LLM_PROVIDER"] = "gemini" |
| >>> get_provider_name() |
| 'gemini' |
| """ |
| return os.environ.get("DATAFORGE_LLM_PROVIDER", "groq").lower() |
|
|
|
|
| async def complete( |
| messages: list[Message], |
| *, |
| model: str | None = None, |
| temperature: float = 0.0, |
| ) -> str: |
| """Send a chat completion request to the active LLM provider. |
| |
| Args: |
| messages: List of chat messages forming the conversation. |
| model: Optional model override. If None, uses the provider default. |
| temperature: Sampling temperature (0.0 = deterministic). |
| |
| Returns: |
| The assistant's response text. |
| |
| Raises: |
| NotImplementedError: If the provider is not yet implemented. |
| ProviderError: If the API call fails after retries. |
| |
| Example: |
| >>> import asyncio |
| >>> msgs = [{"role": "user", "content": "What is 2+2?"}] |
| >>> # result = asyncio.run(complete(msgs)) # requires API key |
| """ |
| provider = get_provider_name() |
|
|
| if provider == "groq": |
| return await _complete_groq(messages, model=model, temperature=temperature) |
| if provider == "gemini": |
| return await _complete_gemini(messages, model=model, temperature=temperature) |
|
|
| if provider in _SUPPORTED_PROVIDERS: |
| raise NotImplementedError( |
| f"Provider '{provider}' is planned but not yet implemented. " |
| f"Use 'groq' or 'gemini' for Week 1." |
| ) |
|
|
| raise NotImplementedError( |
| f"Unknown provider '{provider}'. Supported: {sorted(_SUPPORTED_PROVIDERS)}" |
| ) |
|
|
|
|
| |
|
|
| _GROQ_URL = "https://api.groq.com/openai/v1/chat/completions" |
| _GROQ_DEFAULT_MODEL = "llama-3.1-70b-versatile" |
|
|
|
|
| @retry( |
| retry=retry_if_exception_type(httpx.HTTPStatusError), |
| wait=wait_exponential(multiplier=1, min=1, max=30), |
| stop=stop_after_attempt(3), |
| reraise=True, |
| ) |
| async def _complete_groq( |
| messages: list[Message], |
| *, |
| model: str | None = None, |
| temperature: float = 0.0, |
| ) -> str: |
| """Call Groq's OpenAI-compatible chat completions API. |
| |
| Args: |
| messages: Chat messages. |
| model: Model name (defaults to llama-3.1-70b-versatile). |
| temperature: Sampling temperature. |
| |
| Returns: |
| The assistant's response text. |
| |
| Raises: |
| ProviderError: If the response is malformed. |
| """ |
| api_key = os.environ.get("GROQ_API_KEY", "") |
| if not api_key: |
| raise ProviderError("groq", "GROQ_API_KEY environment variable not set") |
|
|
| payload = { |
| "model": model or _GROQ_DEFAULT_MODEL, |
| "messages": [dict(m) for m in messages], |
| "temperature": temperature, |
| } |
|
|
| async with httpx.AsyncClient(timeout=60.0) as client: |
| response = await client.post( |
| _GROQ_URL, |
| json=payload, |
| headers={ |
| "Authorization": f"Bearer {api_key}", |
| "Content-Type": "application/json", |
| }, |
| ) |
| response.raise_for_status() |
|
|
| data = response.json() |
| try: |
| return str(data["choices"][0]["message"]["content"]) |
| except (KeyError, IndexError) as exc: |
| raise ProviderError("groq", f"Unexpected response format: {data}") from exc |
|
|
|
|
| |
|
|
| _GEMINI_URL_TEMPLATE = ( |
| "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent" |
| ) |
| _GEMINI_DEFAULT_MODEL = "gemini-2.0-flash" |
|
|
|
|
| @retry( |
| retry=retry_if_exception_type(httpx.HTTPStatusError), |
| wait=wait_exponential(multiplier=1, min=1, max=30), |
| stop=stop_after_attempt(3), |
| reraise=True, |
| ) |
| async def _complete_gemini( |
| messages: list[Message], |
| *, |
| model: str | None = None, |
| temperature: float = 0.0, |
| ) -> str: |
| """Call Google's Gemini generativeLanguage API. |
| |
| Args: |
| messages: Chat messages (converted to Gemini's content format). |
| model: Model name (defaults to gemini-2.0-flash). |
| temperature: Sampling temperature. |
| |
| Returns: |
| The assistant's response text. |
| |
| Raises: |
| ProviderError: If the response is malformed. |
| """ |
| api_key = os.environ.get("GEMINI_API_KEY", "") |
| if not api_key: |
| raise ProviderError("gemini", "GEMINI_API_KEY environment variable not set") |
|
|
| model_name = model or _GEMINI_DEFAULT_MODEL |
| url = _GEMINI_URL_TEMPLATE.format(model=model_name) |
|
|
| |
| contents: list[dict[str, object]] = [] |
| system_instruction: str | None = None |
| for msg in messages: |
| if msg["role"] == "system": |
| system_instruction = msg["content"] |
| else: |
| role = "user" if msg["role"] == "user" else "model" |
| contents.append( |
| { |
| "role": role, |
| "parts": [{"text": msg["content"]}], |
| } |
| ) |
|
|
| payload: dict[str, object] = { |
| "contents": contents, |
| "generationConfig": {"temperature": temperature}, |
| } |
| if system_instruction: |
| payload["systemInstruction"] = {"parts": [{"text": system_instruction}]} |
|
|
| async with httpx.AsyncClient(timeout=60.0) as client: |
| response = await client.post( |
| url, |
| json=payload, |
| params={"key": api_key}, |
| headers={"Content-Type": "application/json"}, |
| ) |
| response.raise_for_status() |
|
|
| data = response.json() |
| try: |
| return str(data["candidates"][0]["content"]["parts"][0]["text"]) |
| except (KeyError, IndexError) as exc: |
| raise ProviderError("gemini", f"Unexpected response format: {data}") from exc |
|
|