Spaces:
Running
Running
| """LLM client β provider-agnostic wrapper for OpenAI and Gemini. | |
| Why a wrapper: | |
| - Two-tier model selection (reasoning vs bulk) without scattering model names | |
| - Two-provider support (OpenAI / Gemini), switchable via LLM_PROVIDER env var | |
| - Built-in retry on transient errors | |
| - Pydantic-validated structured outputs | |
| - Single chokepoint for logging / token accounting | |
| The provider is chosen at construction time from settings.llm_provider: | |
| - 'openai' (default) β gpt-4o + gpt-4o-mini via langchain-openai | |
| - 'gemini' β gemini-2.5-flash + gemini-2.5-flash-lite via langchain-google-genai | |
| Both providers share the same interface, so calling code never needs to | |
| care which one is active. | |
| Usage: | |
| llm = LLMClient() | |
| answer = llm.complete("Why is the sky blue?", model="bulk") | |
| parsed = llm.structured(prompt, ReviewOutput, model="reasoning") | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import time | |
| from typing import Any, Type, TypeVar | |
| from langchain_core.language_models import BaseChatModel | |
| from langchain_core.output_parsers import PydanticOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from pydantic import BaseModel | |
| from tenacity import (retry, stop_after_attempt, wait_exponential, | |
| retry_if_exception) | |
| from core.config import settings | |
| log = logging.getLogger(__name__) | |
| T = TypeVar("T", bound=BaseModel) | |
| def _build_openai_models(temp_reasoning: float, temp_bulk: float) -> tuple[BaseChatModel, BaseChatModel]: | |
| """Construct OpenAI reasoning + bulk models.""" | |
| from langchain_openai import ChatOpenAI | |
| if not settings.openai_api_key: | |
| raise RuntimeError( | |
| "LLM_PROVIDER=openai but OPENAI_API_KEY not set. " | |
| "Add it to .env or switch LLM_PROVIDER to 'gemini'." | |
| ) | |
| reasoning = ChatOpenAI( | |
| model=settings.openai_reasoning_model, | |
| temperature=temp_reasoning, | |
| api_key=settings.openai_api_key, | |
| ) | |
| bulk = ChatOpenAI( | |
| model=settings.openai_bulk_model, | |
| temperature=temp_bulk, | |
| api_key=settings.openai_api_key, | |
| ) | |
| return reasoning, bulk | |
| def _build_gemini_models(temp_reasoning: float, temp_bulk: float) -> tuple[BaseChatModel, BaseChatModel]: | |
| """Construct Gemini reasoning + bulk models.""" | |
| try: | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| except ImportError as e: | |
| raise ImportError( | |
| "LLM_PROVIDER=gemini but langchain-google-genai is not installed. " | |
| "Run: pip install langchain-google-genai" | |
| ) from e | |
| if not settings.gemini_api_key: | |
| raise RuntimeError( | |
| "LLM_PROVIDER=gemini but GEMINI_API_KEY not set. " | |
| "Get a key at https://aistudio.google.com/apikey and add it to .env." | |
| ) | |
| reasoning = ChatGoogleGenerativeAI( | |
| model=settings.gemini_reasoning_model, | |
| temperature=temp_reasoning, | |
| google_api_key=settings.gemini_api_key, | |
| ) | |
| bulk = ChatGoogleGenerativeAI( | |
| model=settings.gemini_bulk_model, | |
| temperature=temp_bulk, | |
| google_api_key=settings.gemini_api_key, | |
| ) | |
| return reasoning, bulk | |
| def _should_failover(exc: Exception) -> bool: | |
| """Decide whether an exception warrants trying the fallback provider. | |
| Triggers on quota / rate-limit errors AND on transient service errors | |
| (5xx, timeouts, connection failures) β i.e. any sign the primary | |
| provider is currently unable to serve the request. Does NOT trigger on | |
| clear client-side mistakes (bad request, malformed schema), which the | |
| fallback could not fix either. | |
| """ | |
| text = f"{type(exc).__name__} {exc}".lower() | |
| quota = ("429", "quota", "rate limit", "ratelimit", "resource exhausted", | |
| "resource_exhausted", "exceeded", "too many requests") | |
| transient = ("500", "502", "503", "504", "overloaded", "unavailable", | |
| "timeout", "timed out", "connection", "internal error", | |
| "service") | |
| return any(s in text for s in quota + transient) | |
| def _is_quota_error(exc: Exception) -> bool: | |
| """True only for rate-limit / quota-exhausted errors (used to skip the | |
| slow retry-backoff so failover happens fast on quota limits).""" | |
| text = f"{type(exc).__name__} {exc}".lower() | |
| signals = ("429", "quota", "rate limit", "ratelimit", | |
| "resource exhausted", "resource_exhausted", | |
| "exceeded", "too many requests") | |
| return any(s in text for s in signals) | |
| class LLMClient: | |
| """Two-tier, two-provider LLM client with automatic failover. | |
| Tier 'reasoning' β flagship model (gpt-4o / gemini-2.5-flash). | |
| Tier 'bulk' β cheap/fast model (gpt-4o-mini / gemini-2.5-flash-lite). | |
| Failover: the primary provider is chosen from settings.llm_provider. | |
| If the other provider's API key is also present, it is built as a | |
| fallback. When a call to the primary fails with a quota / rate-limit | |
| error, the identical call is retried on the fallback provider β so a | |
| judge hitting the free Gemini tier's limit mid-demo never sees an | |
| error. If no fallback key is configured, the client behaves exactly | |
| as a single-provider client. | |
| """ | |
| def __init__(self, temperature_reasoning: float = 0.7, | |
| temperature_bulk: float = 0.3, | |
| provider: str | None = None): | |
| self.provider = (provider or settings.llm_provider).lower() | |
| log.info(f"LLMClient initializing with primary provider={self.provider!r}") | |
| if self.provider == "openai": | |
| self._reasoning, self._bulk = _build_openai_models( | |
| temperature_reasoning, temperature_bulk) | |
| elif self.provider == "gemini": | |
| self._reasoning, self._bulk = _build_gemini_models( | |
| temperature_reasoning, temperature_bulk) | |
| else: | |
| raise ValueError( | |
| f"Unknown LLM_PROVIDER={self.provider!r}; expected 'openai' or 'gemini'") | |
| # Build the OTHER provider as a fallback, if its key is available. | |
| self.fallback_provider: str | None = None | |
| self._fb_reasoning: BaseChatModel | None = None | |
| self._fb_bulk: BaseChatModel | None = None | |
| try: | |
| if self.provider == "gemini" and settings.openai_api_key: | |
| self._fb_reasoning, self._fb_bulk = _build_openai_models( | |
| temperature_reasoning, temperature_bulk) | |
| self.fallback_provider = "openai" | |
| elif self.provider == "openai" and settings.gemini_api_key: | |
| self._fb_reasoning, self._fb_bulk = _build_gemini_models( | |
| temperature_reasoning, temperature_bulk) | |
| self.fallback_provider = "gemini" | |
| except Exception as e: # fallback is best-effort; never block startup | |
| log.warning(f"Fallback provider unavailable, continuing without it: {e}") | |
| self.fallback_provider = None | |
| if self.fallback_provider: | |
| log.info(f"Failover enabled: {self.provider} β {self.fallback_provider} " | |
| f"on quota errors") | |
| else: | |
| log.info("No fallback provider configured; running single-provider") | |
| def _model(self, tier: str) -> BaseChatModel: | |
| if tier == "reasoning": | |
| return self._reasoning | |
| if tier == "bulk": | |
| return self._bulk | |
| raise ValueError(f"Unknown tier {tier!r}; expected 'reasoning' or 'bulk'") | |
| def _fb_model(self, tier: str) -> BaseChatModel | None: | |
| if tier == "reasoning": | |
| return self._fb_reasoning | |
| if tier == "bulk": | |
| return self._fb_bulk | |
| return None | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Free-form completion | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def complete(self, prompt: str, model: str = "bulk", | |
| system: str | None = None) -> str: | |
| messages: list[Any] = [] | |
| if system: | |
| messages.append(("system", system)) | |
| messages.append(("human", "{input}")) | |
| template = ChatPromptTemplate.from_messages(messages) | |
| def _run(model_obj: BaseChatModel) -> str: | |
| t0 = time.time() | |
| result = (template | model_obj).invoke({"input": prompt}) | |
| content = result.content | |
| if isinstance(content, list): | |
| content = "".join( | |
| p.get("text", "") if isinstance(p, dict) else str(p) | |
| for p in content) | |
| log.info(f"LLM complete [{model}] {time.time() - t0:.2f}s Β· " | |
| f"prompt {len(prompt)} chars Β· output {len(content)} chars") | |
| return content | |
| try: | |
| return _run(self._model(model)) | |
| except Exception as e: | |
| fb = self._fb_model(model) | |
| if fb is not None and _should_failover(e): | |
| log.warning(f"Primary provider {self.provider} failed " | |
| f"({type(e).__name__}); failing over to " | |
| f"{self.fallback_provider}") | |
| return _run(fb) | |
| raise | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Structured output β pydantic-validated | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def structured(self, prompt: str, schema: Type[T], model: str = "reasoning", | |
| system: str | None = None) -> T: | |
| """Run prompt, parse output into the given Pydantic schema. | |
| Uses LangChain's PydanticOutputParser. On a quota / rate-limit error | |
| from the primary provider, the same call is retried on the fallback. | |
| """ | |
| parser = PydanticOutputParser(pydantic_object=schema) | |
| format_instructions = parser.get_format_instructions() | |
| messages: list[Any] = [] | |
| if system: | |
| messages.append(("system", system)) | |
| messages.append(("human", "{input}\n\n{format_instructions}")) | |
| template = ChatPromptTemplate.from_messages(messages) | |
| def _run(model_obj: BaseChatModel) -> T: | |
| t0 = time.time() | |
| chain = template | model_obj | parser | |
| out = chain.invoke({ | |
| "input": prompt, | |
| "format_instructions": format_instructions, | |
| }) | |
| log.info(f"LLM structured [{model}] {time.time() - t0:.2f}s Β· " | |
| f"schema {schema.__name__} Β· prompt {len(prompt)} chars") | |
| return out | |
| try: | |
| return _run(self._model(model)) | |
| except Exception as e: | |
| fb = self._fb_model(model) | |
| if fb is not None and _should_failover(e): | |
| log.warning(f"Primary provider {self.provider} failed " | |
| f"({type(e).__name__}); failing over to " | |
| f"{self.fallback_provider}") | |
| return _run(fb) | |
| raise | |