"""LLM client — provider-agnostic wrapper for OpenAI and Gemini. Why a wrapper: - Two-tier model selection (reasoning vs bulk) without scattering model names - Two-provider support (OpenAI / Gemini), switchable via LLM_PROVIDER env var - Built-in retry on transient errors - Pydantic-validated structured outputs - Single chokepoint for logging / token accounting The provider is chosen at construction time from settings.llm_provider: - 'openai' (default) → gpt-4o + gpt-4o-mini via langchain-openai - 'gemini' → gemini-2.5-flash + gemini-2.5-flash-lite via langchain-google-genai Both providers share the same interface, so calling code never needs to care which one is active. Usage: llm = LLMClient() answer = llm.complete("Why is the sky blue?", model="bulk") parsed = llm.structured(prompt, ReviewOutput, model="reasoning") """ from __future__ import annotations import logging import time from typing import Any, Type, TypeVar from langchain_core.language_models import BaseChatModel from langchain_core.output_parsers import PydanticOutputParser from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel from tenacity import (retry, stop_after_attempt, wait_exponential, retry_if_exception) from core.config import settings log = logging.getLogger(__name__) T = TypeVar("T", bound=BaseModel) def _build_openai_models(temp_reasoning: float, temp_bulk: float) -> tuple[BaseChatModel, BaseChatModel]: """Construct OpenAI reasoning + bulk models.""" from langchain_openai import ChatOpenAI if not settings.openai_api_key: raise RuntimeError( "LLM_PROVIDER=openai but OPENAI_API_KEY not set. " "Add it to .env or switch LLM_PROVIDER to 'gemini'." ) reasoning = ChatOpenAI( model=settings.openai_reasoning_model, temperature=temp_reasoning, api_key=settings.openai_api_key, ) bulk = ChatOpenAI( model=settings.openai_bulk_model, temperature=temp_bulk, api_key=settings.openai_api_key, ) return reasoning, bulk def _build_gemini_models(temp_reasoning: float, temp_bulk: float) -> tuple[BaseChatModel, BaseChatModel]: """Construct Gemini reasoning + bulk models.""" try: from langchain_google_genai import ChatGoogleGenerativeAI except ImportError as e: raise ImportError( "LLM_PROVIDER=gemini but langchain-google-genai is not installed. " "Run: pip install langchain-google-genai" ) from e if not settings.gemini_api_key: raise RuntimeError( "LLM_PROVIDER=gemini but GEMINI_API_KEY not set. " "Get a key at https://aistudio.google.com/apikey and add it to .env." ) reasoning = ChatGoogleGenerativeAI( model=settings.gemini_reasoning_model, temperature=temp_reasoning, google_api_key=settings.gemini_api_key, ) bulk = ChatGoogleGenerativeAI( model=settings.gemini_bulk_model, temperature=temp_bulk, google_api_key=settings.gemini_api_key, ) return reasoning, bulk def _should_failover(exc: Exception) -> bool: """Decide whether an exception warrants trying the fallback provider. Triggers on quota / rate-limit errors AND on transient service errors (5xx, timeouts, connection failures) — i.e. any sign the primary provider is currently unable to serve the request. Does NOT trigger on clear client-side mistakes (bad request, malformed schema), which the fallback could not fix either. """ text = f"{type(exc).__name__} {exc}".lower() quota = ("429", "quota", "rate limit", "ratelimit", "resource exhausted", "resource_exhausted", "exceeded", "too many requests") transient = ("500", "502", "503", "504", "overloaded", "unavailable", "timeout", "timed out", "connection", "internal error", "service") return any(s in text for s in quota + transient) def _is_quota_error(exc: Exception) -> bool: """True only for rate-limit / quota-exhausted errors (used to skip the slow retry-backoff so failover happens fast on quota limits).""" text = f"{type(exc).__name__} {exc}".lower() signals = ("429", "quota", "rate limit", "ratelimit", "resource exhausted", "resource_exhausted", "exceeded", "too many requests") return any(s in text for s in signals) class LLMClient: """Two-tier, two-provider LLM client with automatic failover. Tier 'reasoning' → flagship model (gpt-4o / gemini-2.5-flash). Tier 'bulk' → cheap/fast model (gpt-4o-mini / gemini-2.5-flash-lite). Failover: the primary provider is chosen from settings.llm_provider. If the other provider's API key is also present, it is built as a fallback. When a call to the primary fails with a quota / rate-limit error, the identical call is retried on the fallback provider — so a judge hitting the free Gemini tier's limit mid-demo never sees an error. If no fallback key is configured, the client behaves exactly as a single-provider client. """ def __init__(self, temperature_reasoning: float = 0.7, temperature_bulk: float = 0.3, provider: str | None = None): self.provider = (provider or settings.llm_provider).lower() log.info(f"LLMClient initializing with primary provider={self.provider!r}") if self.provider == "openai": self._reasoning, self._bulk = _build_openai_models( temperature_reasoning, temperature_bulk) elif self.provider == "gemini": self._reasoning, self._bulk = _build_gemini_models( temperature_reasoning, temperature_bulk) else: raise ValueError( f"Unknown LLM_PROVIDER={self.provider!r}; expected 'openai' or 'gemini'") # Build the OTHER provider as a fallback, if its key is available. self.fallback_provider: str | None = None self._fb_reasoning: BaseChatModel | None = None self._fb_bulk: BaseChatModel | None = None try: if self.provider == "gemini" and settings.openai_api_key: self._fb_reasoning, self._fb_bulk = _build_openai_models( temperature_reasoning, temperature_bulk) self.fallback_provider = "openai" elif self.provider == "openai" and settings.gemini_api_key: self._fb_reasoning, self._fb_bulk = _build_gemini_models( temperature_reasoning, temperature_bulk) self.fallback_provider = "gemini" except Exception as e: # fallback is best-effort; never block startup log.warning(f"Fallback provider unavailable, continuing without it: {e}") self.fallback_provider = None if self.fallback_provider: log.info(f"Failover enabled: {self.provider} → {self.fallback_provider} " f"on quota errors") else: log.info("No fallback provider configured; running single-provider") def _model(self, tier: str) -> BaseChatModel: if tier == "reasoning": return self._reasoning if tier == "bulk": return self._bulk raise ValueError(f"Unknown tier {tier!r}; expected 'reasoning' or 'bulk'") def _fb_model(self, tier: str) -> BaseChatModel | None: if tier == "reasoning": return self._fb_reasoning if tier == "bulk": return self._fb_bulk return None # ────────────────────────────────────────────────────────────────── # Free-form completion # ────────────────────────────────────────────────────────────────── @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10), retry=retry_if_exception(lambda e: not _is_quota_error(e))) def complete(self, prompt: str, model: str = "bulk", system: str | None = None) -> str: messages: list[Any] = [] if system: messages.append(("system", system)) messages.append(("human", "{input}")) template = ChatPromptTemplate.from_messages(messages) def _run(model_obj: BaseChatModel) -> str: t0 = time.time() result = (template | model_obj).invoke({"input": prompt}) content = result.content if isinstance(content, list): content = "".join( p.get("text", "") if isinstance(p, dict) else str(p) for p in content) log.info(f"LLM complete [{model}] {time.time() - t0:.2f}s · " f"prompt {len(prompt)} chars · output {len(content)} chars") return content try: return _run(self._model(model)) except Exception as e: fb = self._fb_model(model) if fb is not None and _should_failover(e): log.warning(f"Primary provider {self.provider} failed " f"({type(e).__name__}); failing over to " f"{self.fallback_provider}") return _run(fb) raise # ────────────────────────────────────────────────────────────────── # Structured output — pydantic-validated # ────────────────────────────────────────────────────────────────── @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10), retry=retry_if_exception(lambda e: not _is_quota_error(e))) def structured(self, prompt: str, schema: Type[T], model: str = "reasoning", system: str | None = None) -> T: """Run prompt, parse output into the given Pydantic schema. Uses LangChain's PydanticOutputParser. On a quota / rate-limit error from the primary provider, the same call is retried on the fallback. """ parser = PydanticOutputParser(pydantic_object=schema) format_instructions = parser.get_format_instructions() messages: list[Any] = [] if system: messages.append(("system", system)) messages.append(("human", "{input}\n\n{format_instructions}")) template = ChatPromptTemplate.from_messages(messages) def _run(model_obj: BaseChatModel) -> T: t0 = time.time() chain = template | model_obj | parser out = chain.invoke({ "input": prompt, "format_instructions": format_instructions, }) log.info(f"LLM structured [{model}] {time.time() - t0:.2f}s · " f"schema {schema.__name__} · prompt {len(prompt)} chars") return out try: return _run(self._model(model)) except Exception as e: fb = self._fb_model(model) if fb is not None and _should_failover(e): log.warning(f"Primary provider {self.provider} failed " f"({type(e).__name__}); failing over to " f"{self.fallback_provider}") return _run(fb) raise