Spaces:
Running
Running
File size: 11,766 Bytes
ad45209 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 | """LLM client β provider-agnostic wrapper for OpenAI and Gemini.
Why a wrapper:
- Two-tier model selection (reasoning vs bulk) without scattering model names
- Two-provider support (OpenAI / Gemini), switchable via LLM_PROVIDER env var
- Built-in retry on transient errors
- Pydantic-validated structured outputs
- Single chokepoint for logging / token accounting
The provider is chosen at construction time from settings.llm_provider:
- 'openai' (default) β gpt-4o + gpt-4o-mini via langchain-openai
- 'gemini' β gemini-2.5-flash + gemini-2.5-flash-lite via langchain-google-genai
Both providers share the same interface, so calling code never needs to
care which one is active.
Usage:
llm = LLMClient()
answer = llm.complete("Why is the sky blue?", model="bulk")
parsed = llm.structured(prompt, ReviewOutput, model="reasoning")
"""
from __future__ import annotations
import logging
import time
from typing import Any, Type, TypeVar
from langchain_core.language_models import BaseChatModel
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel
from tenacity import (retry, stop_after_attempt, wait_exponential,
retry_if_exception)
from core.config import settings
log = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseModel)
def _build_openai_models(temp_reasoning: float, temp_bulk: float) -> tuple[BaseChatModel, BaseChatModel]:
"""Construct OpenAI reasoning + bulk models."""
from langchain_openai import ChatOpenAI
if not settings.openai_api_key:
raise RuntimeError(
"LLM_PROVIDER=openai but OPENAI_API_KEY not set. "
"Add it to .env or switch LLM_PROVIDER to 'gemini'."
)
reasoning = ChatOpenAI(
model=settings.openai_reasoning_model,
temperature=temp_reasoning,
api_key=settings.openai_api_key,
)
bulk = ChatOpenAI(
model=settings.openai_bulk_model,
temperature=temp_bulk,
api_key=settings.openai_api_key,
)
return reasoning, bulk
def _build_gemini_models(temp_reasoning: float, temp_bulk: float) -> tuple[BaseChatModel, BaseChatModel]:
"""Construct Gemini reasoning + bulk models."""
try:
from langchain_google_genai import ChatGoogleGenerativeAI
except ImportError as e:
raise ImportError(
"LLM_PROVIDER=gemini but langchain-google-genai is not installed. "
"Run: pip install langchain-google-genai"
) from e
if not settings.gemini_api_key:
raise RuntimeError(
"LLM_PROVIDER=gemini but GEMINI_API_KEY not set. "
"Get a key at https://aistudio.google.com/apikey and add it to .env."
)
reasoning = ChatGoogleGenerativeAI(
model=settings.gemini_reasoning_model,
temperature=temp_reasoning,
google_api_key=settings.gemini_api_key,
)
bulk = ChatGoogleGenerativeAI(
model=settings.gemini_bulk_model,
temperature=temp_bulk,
google_api_key=settings.gemini_api_key,
)
return reasoning, bulk
def _should_failover(exc: Exception) -> bool:
"""Decide whether an exception warrants trying the fallback provider.
Triggers on quota / rate-limit errors AND on transient service errors
(5xx, timeouts, connection failures) β i.e. any sign the primary
provider is currently unable to serve the request. Does NOT trigger on
clear client-side mistakes (bad request, malformed schema), which the
fallback could not fix either.
"""
text = f"{type(exc).__name__} {exc}".lower()
quota = ("429", "quota", "rate limit", "ratelimit", "resource exhausted",
"resource_exhausted", "exceeded", "too many requests")
transient = ("500", "502", "503", "504", "overloaded", "unavailable",
"timeout", "timed out", "connection", "internal error",
"service")
return any(s in text for s in quota + transient)
def _is_quota_error(exc: Exception) -> bool:
"""True only for rate-limit / quota-exhausted errors (used to skip the
slow retry-backoff so failover happens fast on quota limits)."""
text = f"{type(exc).__name__} {exc}".lower()
signals = ("429", "quota", "rate limit", "ratelimit",
"resource exhausted", "resource_exhausted",
"exceeded", "too many requests")
return any(s in text for s in signals)
class LLMClient:
"""Two-tier, two-provider LLM client with automatic failover.
Tier 'reasoning' β flagship model (gpt-4o / gemini-2.5-flash).
Tier 'bulk' β cheap/fast model (gpt-4o-mini / gemini-2.5-flash-lite).
Failover: the primary provider is chosen from settings.llm_provider.
If the other provider's API key is also present, it is built as a
fallback. When a call to the primary fails with a quota / rate-limit
error, the identical call is retried on the fallback provider β so a
judge hitting the free Gemini tier's limit mid-demo never sees an
error. If no fallback key is configured, the client behaves exactly
as a single-provider client.
"""
def __init__(self, temperature_reasoning: float = 0.7,
temperature_bulk: float = 0.3,
provider: str | None = None):
self.provider = (provider or settings.llm_provider).lower()
log.info(f"LLMClient initializing with primary provider={self.provider!r}")
if self.provider == "openai":
self._reasoning, self._bulk = _build_openai_models(
temperature_reasoning, temperature_bulk)
elif self.provider == "gemini":
self._reasoning, self._bulk = _build_gemini_models(
temperature_reasoning, temperature_bulk)
else:
raise ValueError(
f"Unknown LLM_PROVIDER={self.provider!r}; expected 'openai' or 'gemini'")
# Build the OTHER provider as a fallback, if its key is available.
self.fallback_provider: str | None = None
self._fb_reasoning: BaseChatModel | None = None
self._fb_bulk: BaseChatModel | None = None
try:
if self.provider == "gemini" and settings.openai_api_key:
self._fb_reasoning, self._fb_bulk = _build_openai_models(
temperature_reasoning, temperature_bulk)
self.fallback_provider = "openai"
elif self.provider == "openai" and settings.gemini_api_key:
self._fb_reasoning, self._fb_bulk = _build_gemini_models(
temperature_reasoning, temperature_bulk)
self.fallback_provider = "gemini"
except Exception as e: # fallback is best-effort; never block startup
log.warning(f"Fallback provider unavailable, continuing without it: {e}")
self.fallback_provider = None
if self.fallback_provider:
log.info(f"Failover enabled: {self.provider} β {self.fallback_provider} "
f"on quota errors")
else:
log.info("No fallback provider configured; running single-provider")
def _model(self, tier: str) -> BaseChatModel:
if tier == "reasoning":
return self._reasoning
if tier == "bulk":
return self._bulk
raise ValueError(f"Unknown tier {tier!r}; expected 'reasoning' or 'bulk'")
def _fb_model(self, tier: str) -> BaseChatModel | None:
if tier == "reasoning":
return self._fb_reasoning
if tier == "bulk":
return self._fb_bulk
return None
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Free-form completion
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10),
retry=retry_if_exception(lambda e: not _is_quota_error(e)))
def complete(self, prompt: str, model: str = "bulk",
system: str | None = None) -> str:
messages: list[Any] = []
if system:
messages.append(("system", system))
messages.append(("human", "{input}"))
template = ChatPromptTemplate.from_messages(messages)
def _run(model_obj: BaseChatModel) -> str:
t0 = time.time()
result = (template | model_obj).invoke({"input": prompt})
content = result.content
if isinstance(content, list):
content = "".join(
p.get("text", "") if isinstance(p, dict) else str(p)
for p in content)
log.info(f"LLM complete [{model}] {time.time() - t0:.2f}s Β· "
f"prompt {len(prompt)} chars Β· output {len(content)} chars")
return content
try:
return _run(self._model(model))
except Exception as e:
fb = self._fb_model(model)
if fb is not None and _should_failover(e):
log.warning(f"Primary provider {self.provider} failed "
f"({type(e).__name__}); failing over to "
f"{self.fallback_provider}")
return _run(fb)
raise
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Structured output β pydantic-validated
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10),
retry=retry_if_exception(lambda e: not _is_quota_error(e)))
def structured(self, prompt: str, schema: Type[T], model: str = "reasoning",
system: str | None = None) -> T:
"""Run prompt, parse output into the given Pydantic schema.
Uses LangChain's PydanticOutputParser. On a quota / rate-limit error
from the primary provider, the same call is retried on the fallback.
"""
parser = PydanticOutputParser(pydantic_object=schema)
format_instructions = parser.get_format_instructions()
messages: list[Any] = []
if system:
messages.append(("system", system))
messages.append(("human", "{input}\n\n{format_instructions}"))
template = ChatPromptTemplate.from_messages(messages)
def _run(model_obj: BaseChatModel) -> T:
t0 = time.time()
chain = template | model_obj | parser
out = chain.invoke({
"input": prompt,
"format_instructions": format_instructions,
})
log.info(f"LLM structured [{model}] {time.time() - t0:.2f}s Β· "
f"schema {schema.__name__} Β· prompt {len(prompt)} chars")
return out
try:
return _run(self._model(model))
except Exception as e:
fb = self._fb_model(model)
if fb is not None and _should_failover(e):
log.warning(f"Primary provider {self.provider} failed "
f"({type(e).__name__}); failing over to "
f"{self.fallback_provider}")
return _run(fb)
raise
|