Israelbliz's picture
Upload llm.py
ad45209 verified
"""LLM client β€” provider-agnostic wrapper for OpenAI and Gemini.
Why a wrapper:
- Two-tier model selection (reasoning vs bulk) without scattering model names
- Two-provider support (OpenAI / Gemini), switchable via LLM_PROVIDER env var
- Built-in retry on transient errors
- Pydantic-validated structured outputs
- Single chokepoint for logging / token accounting
The provider is chosen at construction time from settings.llm_provider:
- 'openai' (default) β†’ gpt-4o + gpt-4o-mini via langchain-openai
- 'gemini' β†’ gemini-2.5-flash + gemini-2.5-flash-lite via langchain-google-genai
Both providers share the same interface, so calling code never needs to
care which one is active.
Usage:
llm = LLMClient()
answer = llm.complete("Why is the sky blue?", model="bulk")
parsed = llm.structured(prompt, ReviewOutput, model="reasoning")
"""
from __future__ import annotations
import logging
import time
from typing import Any, Type, TypeVar
from langchain_core.language_models import BaseChatModel
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel
from tenacity import (retry, stop_after_attempt, wait_exponential,
retry_if_exception)
from core.config import settings
log = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseModel)
def _build_openai_models(temp_reasoning: float, temp_bulk: float) -> tuple[BaseChatModel, BaseChatModel]:
"""Construct OpenAI reasoning + bulk models."""
from langchain_openai import ChatOpenAI
if not settings.openai_api_key:
raise RuntimeError(
"LLM_PROVIDER=openai but OPENAI_API_KEY not set. "
"Add it to .env or switch LLM_PROVIDER to 'gemini'."
)
reasoning = ChatOpenAI(
model=settings.openai_reasoning_model,
temperature=temp_reasoning,
api_key=settings.openai_api_key,
)
bulk = ChatOpenAI(
model=settings.openai_bulk_model,
temperature=temp_bulk,
api_key=settings.openai_api_key,
)
return reasoning, bulk
def _build_gemini_models(temp_reasoning: float, temp_bulk: float) -> tuple[BaseChatModel, BaseChatModel]:
"""Construct Gemini reasoning + bulk models."""
try:
from langchain_google_genai import ChatGoogleGenerativeAI
except ImportError as e:
raise ImportError(
"LLM_PROVIDER=gemini but langchain-google-genai is not installed. "
"Run: pip install langchain-google-genai"
) from e
if not settings.gemini_api_key:
raise RuntimeError(
"LLM_PROVIDER=gemini but GEMINI_API_KEY not set. "
"Get a key at https://aistudio.google.com/apikey and add it to .env."
)
reasoning = ChatGoogleGenerativeAI(
model=settings.gemini_reasoning_model,
temperature=temp_reasoning,
google_api_key=settings.gemini_api_key,
)
bulk = ChatGoogleGenerativeAI(
model=settings.gemini_bulk_model,
temperature=temp_bulk,
google_api_key=settings.gemini_api_key,
)
return reasoning, bulk
def _should_failover(exc: Exception) -> bool:
"""Decide whether an exception warrants trying the fallback provider.
Triggers on quota / rate-limit errors AND on transient service errors
(5xx, timeouts, connection failures) β€” i.e. any sign the primary
provider is currently unable to serve the request. Does NOT trigger on
clear client-side mistakes (bad request, malformed schema), which the
fallback could not fix either.
"""
text = f"{type(exc).__name__} {exc}".lower()
quota = ("429", "quota", "rate limit", "ratelimit", "resource exhausted",
"resource_exhausted", "exceeded", "too many requests")
transient = ("500", "502", "503", "504", "overloaded", "unavailable",
"timeout", "timed out", "connection", "internal error",
"service")
return any(s in text for s in quota + transient)
def _is_quota_error(exc: Exception) -> bool:
"""True only for rate-limit / quota-exhausted errors (used to skip the
slow retry-backoff so failover happens fast on quota limits)."""
text = f"{type(exc).__name__} {exc}".lower()
signals = ("429", "quota", "rate limit", "ratelimit",
"resource exhausted", "resource_exhausted",
"exceeded", "too many requests")
return any(s in text for s in signals)
class LLMClient:
"""Two-tier, two-provider LLM client with automatic failover.
Tier 'reasoning' β†’ flagship model (gpt-4o / gemini-2.5-flash).
Tier 'bulk' β†’ cheap/fast model (gpt-4o-mini / gemini-2.5-flash-lite).
Failover: the primary provider is chosen from settings.llm_provider.
If the other provider's API key is also present, it is built as a
fallback. When a call to the primary fails with a quota / rate-limit
error, the identical call is retried on the fallback provider β€” so a
judge hitting the free Gemini tier's limit mid-demo never sees an
error. If no fallback key is configured, the client behaves exactly
as a single-provider client.
"""
def __init__(self, temperature_reasoning: float = 0.7,
temperature_bulk: float = 0.3,
provider: str | None = None):
self.provider = (provider or settings.llm_provider).lower()
log.info(f"LLMClient initializing with primary provider={self.provider!r}")
if self.provider == "openai":
self._reasoning, self._bulk = _build_openai_models(
temperature_reasoning, temperature_bulk)
elif self.provider == "gemini":
self._reasoning, self._bulk = _build_gemini_models(
temperature_reasoning, temperature_bulk)
else:
raise ValueError(
f"Unknown LLM_PROVIDER={self.provider!r}; expected 'openai' or 'gemini'")
# Build the OTHER provider as a fallback, if its key is available.
self.fallback_provider: str | None = None
self._fb_reasoning: BaseChatModel | None = None
self._fb_bulk: BaseChatModel | None = None
try:
if self.provider == "gemini" and settings.openai_api_key:
self._fb_reasoning, self._fb_bulk = _build_openai_models(
temperature_reasoning, temperature_bulk)
self.fallback_provider = "openai"
elif self.provider == "openai" and settings.gemini_api_key:
self._fb_reasoning, self._fb_bulk = _build_gemini_models(
temperature_reasoning, temperature_bulk)
self.fallback_provider = "gemini"
except Exception as e: # fallback is best-effort; never block startup
log.warning(f"Fallback provider unavailable, continuing without it: {e}")
self.fallback_provider = None
if self.fallback_provider:
log.info(f"Failover enabled: {self.provider} β†’ {self.fallback_provider} "
f"on quota errors")
else:
log.info("No fallback provider configured; running single-provider")
def _model(self, tier: str) -> BaseChatModel:
if tier == "reasoning":
return self._reasoning
if tier == "bulk":
return self._bulk
raise ValueError(f"Unknown tier {tier!r}; expected 'reasoning' or 'bulk'")
def _fb_model(self, tier: str) -> BaseChatModel | None:
if tier == "reasoning":
return self._fb_reasoning
if tier == "bulk":
return self._fb_bulk
return None
# ──────────────────────────────────────────────────────────────────
# Free-form completion
# ──────────────────────────────────────────────────────────────────
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10),
retry=retry_if_exception(lambda e: not _is_quota_error(e)))
def complete(self, prompt: str, model: str = "bulk",
system: str | None = None) -> str:
messages: list[Any] = []
if system:
messages.append(("system", system))
messages.append(("human", "{input}"))
template = ChatPromptTemplate.from_messages(messages)
def _run(model_obj: BaseChatModel) -> str:
t0 = time.time()
result = (template | model_obj).invoke({"input": prompt})
content = result.content
if isinstance(content, list):
content = "".join(
p.get("text", "") if isinstance(p, dict) else str(p)
for p in content)
log.info(f"LLM complete [{model}] {time.time() - t0:.2f}s Β· "
f"prompt {len(prompt)} chars Β· output {len(content)} chars")
return content
try:
return _run(self._model(model))
except Exception as e:
fb = self._fb_model(model)
if fb is not None and _should_failover(e):
log.warning(f"Primary provider {self.provider} failed "
f"({type(e).__name__}); failing over to "
f"{self.fallback_provider}")
return _run(fb)
raise
# ──────────────────────────────────────────────────────────────────
# Structured output β€” pydantic-validated
# ──────────────────────────────────────────────────────────────────
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10),
retry=retry_if_exception(lambda e: not _is_quota_error(e)))
def structured(self, prompt: str, schema: Type[T], model: str = "reasoning",
system: str | None = None) -> T:
"""Run prompt, parse output into the given Pydantic schema.
Uses LangChain's PydanticOutputParser. On a quota / rate-limit error
from the primary provider, the same call is retried on the fallback.
"""
parser = PydanticOutputParser(pydantic_object=schema)
format_instructions = parser.get_format_instructions()
messages: list[Any] = []
if system:
messages.append(("system", system))
messages.append(("human", "{input}\n\n{format_instructions}"))
template = ChatPromptTemplate.from_messages(messages)
def _run(model_obj: BaseChatModel) -> T:
t0 = time.time()
chain = template | model_obj | parser
out = chain.invoke({
"input": prompt,
"format_instructions": format_instructions,
})
log.info(f"LLM structured [{model}] {time.time() - t0:.2f}s Β· "
f"schema {schema.__name__} Β· prompt {len(prompt)} chars")
return out
try:
return _run(self._model(model))
except Exception as e:
fb = self._fb_model(model)
if fb is not None and _should_failover(e):
log.warning(f"Primary provider {self.provider} failed "
f"({type(e).__name__}); failing over to "
f"{self.fallback_provider}")
return _run(fb)
raise