File size: 11,766 Bytes
ad45209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""LLM client β€” provider-agnostic wrapper for OpenAI and Gemini.

Why a wrapper:
  - Two-tier model selection (reasoning vs bulk) without scattering model names
  - Two-provider support (OpenAI / Gemini), switchable via LLM_PROVIDER env var
  - Built-in retry on transient errors
  - Pydantic-validated structured outputs
  - Single chokepoint for logging / token accounting

The provider is chosen at construction time from settings.llm_provider:
  - 'openai' (default) β†’ gpt-4o + gpt-4o-mini via langchain-openai
  - 'gemini'           β†’ gemini-2.5-flash + gemini-2.5-flash-lite via langchain-google-genai

Both providers share the same interface, so calling code never needs to
care which one is active.

Usage:
    llm = LLMClient()
    answer = llm.complete("Why is the sky blue?", model="bulk")
    parsed = llm.structured(prompt, ReviewOutput, model="reasoning")
"""
from __future__ import annotations

import logging
import time
from typing import Any, Type, TypeVar

from langchain_core.language_models import BaseChatModel
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel
from tenacity import (retry, stop_after_attempt, wait_exponential,
                      retry_if_exception)

from core.config import settings

log = logging.getLogger(__name__)

T = TypeVar("T", bound=BaseModel)


def _build_openai_models(temp_reasoning: float, temp_bulk: float) -> tuple[BaseChatModel, BaseChatModel]:
    """Construct OpenAI reasoning + bulk models."""
    from langchain_openai import ChatOpenAI
    if not settings.openai_api_key:
        raise RuntimeError(
            "LLM_PROVIDER=openai but OPENAI_API_KEY not set. "
            "Add it to .env or switch LLM_PROVIDER to 'gemini'."
        )
    reasoning = ChatOpenAI(
        model=settings.openai_reasoning_model,
        temperature=temp_reasoning,
        api_key=settings.openai_api_key,
    )
    bulk = ChatOpenAI(
        model=settings.openai_bulk_model,
        temperature=temp_bulk,
        api_key=settings.openai_api_key,
    )
    return reasoning, bulk


def _build_gemini_models(temp_reasoning: float, temp_bulk: float) -> tuple[BaseChatModel, BaseChatModel]:
    """Construct Gemini reasoning + bulk models."""
    try:
        from langchain_google_genai import ChatGoogleGenerativeAI
    except ImportError as e:
        raise ImportError(
            "LLM_PROVIDER=gemini but langchain-google-genai is not installed. "
            "Run: pip install langchain-google-genai"
        ) from e

    if not settings.gemini_api_key:
        raise RuntimeError(
            "LLM_PROVIDER=gemini but GEMINI_API_KEY not set. "
            "Get a key at https://aistudio.google.com/apikey and add it to .env."
        )
    reasoning = ChatGoogleGenerativeAI(
        model=settings.gemini_reasoning_model,
        temperature=temp_reasoning,
        google_api_key=settings.gemini_api_key,
    )
    bulk = ChatGoogleGenerativeAI(
        model=settings.gemini_bulk_model,
        temperature=temp_bulk,
        google_api_key=settings.gemini_api_key,
    )
    return reasoning, bulk


def _should_failover(exc: Exception) -> bool:
    """Decide whether an exception warrants trying the fallback provider.

    Triggers on quota / rate-limit errors AND on transient service errors
    (5xx, timeouts, connection failures) β€” i.e. any sign the primary
    provider is currently unable to serve the request. Does NOT trigger on
    clear client-side mistakes (bad request, malformed schema), which the
    fallback could not fix either.
    """
    text = f"{type(exc).__name__} {exc}".lower()
    quota = ("429", "quota", "rate limit", "ratelimit", "resource exhausted",
             "resource_exhausted", "exceeded", "too many requests")
    transient = ("500", "502", "503", "504", "overloaded", "unavailable",
                 "timeout", "timed out", "connection", "internal error",
                 "service")
    return any(s in text for s in quota + transient)


def _is_quota_error(exc: Exception) -> bool:
    """True only for rate-limit / quota-exhausted errors (used to skip the
    slow retry-backoff so failover happens fast on quota limits)."""
    text = f"{type(exc).__name__} {exc}".lower()
    signals = ("429", "quota", "rate limit", "ratelimit",
               "resource exhausted", "resource_exhausted",
               "exceeded", "too many requests")
    return any(s in text for s in signals)


class LLMClient:
    """Two-tier, two-provider LLM client with automatic failover.

    Tier 'reasoning' β†’ flagship model (gpt-4o / gemini-2.5-flash).
    Tier 'bulk' β†’ cheap/fast model (gpt-4o-mini / gemini-2.5-flash-lite).

    Failover: the primary provider is chosen from settings.llm_provider.
    If the other provider's API key is also present, it is built as a
    fallback. When a call to the primary fails with a quota / rate-limit
    error, the identical call is retried on the fallback provider β€” so a
    judge hitting the free Gemini tier's limit mid-demo never sees an
    error. If no fallback key is configured, the client behaves exactly
    as a single-provider client.
    """

    def __init__(self, temperature_reasoning: float = 0.7,
                 temperature_bulk: float = 0.3,
                 provider: str | None = None):
        self.provider = (provider or settings.llm_provider).lower()
        log.info(f"LLMClient initializing with primary provider={self.provider!r}")

        if self.provider == "openai":
            self._reasoning, self._bulk = _build_openai_models(
                temperature_reasoning, temperature_bulk)
        elif self.provider == "gemini":
            self._reasoning, self._bulk = _build_gemini_models(
                temperature_reasoning, temperature_bulk)
        else:
            raise ValueError(
                f"Unknown LLM_PROVIDER={self.provider!r}; expected 'openai' or 'gemini'")

        # Build the OTHER provider as a fallback, if its key is available.
        self.fallback_provider: str | None = None
        self._fb_reasoning: BaseChatModel | None = None
        self._fb_bulk: BaseChatModel | None = None
        try:
            if self.provider == "gemini" and settings.openai_api_key:
                self._fb_reasoning, self._fb_bulk = _build_openai_models(
                    temperature_reasoning, temperature_bulk)
                self.fallback_provider = "openai"
            elif self.provider == "openai" and settings.gemini_api_key:
                self._fb_reasoning, self._fb_bulk = _build_gemini_models(
                    temperature_reasoning, temperature_bulk)
                self.fallback_provider = "gemini"
        except Exception as e:  # fallback is best-effort; never block startup
            log.warning(f"Fallback provider unavailable, continuing without it: {e}")
            self.fallback_provider = None

        if self.fallback_provider:
            log.info(f"Failover enabled: {self.provider} β†’ {self.fallback_provider} "
                     f"on quota errors")
        else:
            log.info("No fallback provider configured; running single-provider")

    def _model(self, tier: str) -> BaseChatModel:
        if tier == "reasoning":
            return self._reasoning
        if tier == "bulk":
            return self._bulk
        raise ValueError(f"Unknown tier {tier!r}; expected 'reasoning' or 'bulk'")

    def _fb_model(self, tier: str) -> BaseChatModel | None:
        if tier == "reasoning":
            return self._fb_reasoning
        if tier == "bulk":
            return self._fb_bulk
        return None

    # ──────────────────────────────────────────────────────────────────
    # Free-form completion
    # ──────────────────────────────────────────────────────────────────
    @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10),
           retry=retry_if_exception(lambda e: not _is_quota_error(e)))
    def complete(self, prompt: str, model: str = "bulk",
                 system: str | None = None) -> str:
        messages: list[Any] = []
        if system:
            messages.append(("system", system))
        messages.append(("human", "{input}"))
        template = ChatPromptTemplate.from_messages(messages)

        def _run(model_obj: BaseChatModel) -> str:
            t0 = time.time()
            result = (template | model_obj).invoke({"input": prompt})
            content = result.content
            if isinstance(content, list):
                content = "".join(
                    p.get("text", "") if isinstance(p, dict) else str(p)
                    for p in content)
            log.info(f"LLM complete [{model}] {time.time() - t0:.2f}s Β· "
                     f"prompt {len(prompt)} chars Β· output {len(content)} chars")
            return content

        try:
            return _run(self._model(model))
        except Exception as e:
            fb = self._fb_model(model)
            if fb is not None and _should_failover(e):
                log.warning(f"Primary provider {self.provider} failed "
                            f"({type(e).__name__}); failing over to "
                            f"{self.fallback_provider}")
                return _run(fb)
            raise

    # ──────────────────────────────────────────────────────────────────
    # Structured output β€” pydantic-validated
    # ──────────────────────────────────────────────────────────────────
    @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10),
           retry=retry_if_exception(lambda e: not _is_quota_error(e)))
    def structured(self, prompt: str, schema: Type[T], model: str = "reasoning",
                   system: str | None = None) -> T:
        """Run prompt, parse output into the given Pydantic schema.

        Uses LangChain's PydanticOutputParser. On a quota / rate-limit error
        from the primary provider, the same call is retried on the fallback.
        """
        parser = PydanticOutputParser(pydantic_object=schema)
        format_instructions = parser.get_format_instructions()

        messages: list[Any] = []
        if system:
            messages.append(("system", system))
        messages.append(("human", "{input}\n\n{format_instructions}"))
        template = ChatPromptTemplate.from_messages(messages)

        def _run(model_obj: BaseChatModel) -> T:
            t0 = time.time()
            chain = template | model_obj | parser
            out = chain.invoke({
                "input": prompt,
                "format_instructions": format_instructions,
            })
            log.info(f"LLM structured [{model}] {time.time() - t0:.2f}s Β· "
                     f"schema {schema.__name__} Β· prompt {len(prompt)} chars")
            return out

        try:
            return _run(self._model(model))
        except Exception as e:
            fb = self._fb_model(model)
            if fb is not None and _should_failover(e):
                log.warning(f"Primary provider {self.provider} failed "
                            f"({type(e).__name__}); failing over to "
                            f"{self.fallback_provider}")
                return _run(fb)
            raise