File size: 11,488 Bytes
31a2688
 
 
 
 
 
 
a493f04
31a2688
 
 
 
 
 
 
 
3f19c23
 
31a2688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a120767
 
 
 
 
 
 
 
31a2688
b205d63
 
 
 
 
 
 
 
 
 
31a2688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f19c23
 
 
 
 
 
 
 
 
31a2688
 
 
 
 
 
 
4d2a2da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a493f04
 
 
 
 
 
3f19c23
a493f04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31a2688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f19c23
 
 
 
 
 
 
 
31a2688
 
 
 
 
9612292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
"""Factory functions for creating LLM and embedding instances.

All provider-specific imports are isolated here. The rest of the codebase
interacts only with LangChain abstract interfaces returned by these factories.
"""

import logging
from dataclasses import replace

from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel

from src.config import Settings

logger = logging.getLogger(__name__)

_SUPPORTED_LLM_PROVIDERS = ["ollama", "azure_openai", "openai", "groq", "anthropic", "google_genai", "bedrock"]
_SUPPORTED_EMBEDDING_PROVIDERS = ["local", "azure_openai", "openai", "google_genai", "bedrock"]


def create_llm(settings: Settings) -> BaseChatModel:
    """Create an LLM instance based on the configured provider.

    Args:
        settings: Application settings with provider configuration.

    Returns:
        A LangChain BaseChatModel instance.

    Raises:
        ValueError: If the provider is not supported.
    """
    provider = settings.llm_provider.lower()
    logger.info("Creating LLM with provider: %s", provider)

    match provider:
        case "ollama":
            from langchain_ollama import ChatOllama

            return ChatOllama(
                base_url=settings.ollama_base_url,
                model=settings.ollama_model,
                temperature=0.0,
            )

        case "azure_openai":
            from langchain_openai import AzureChatOpenAI

            return AzureChatOpenAI(
                azure_endpoint=settings.azure_openai_endpoint,
                api_key=settings.azure_openai_api_key,
                api_version=settings.azure_openai_api_version,
                azure_deployment=settings.azure_openai_deployment,
                temperature=0.0,
            )

        case "openai":
            from langchain_openai import ChatOpenAI

            kwargs: dict = {
                "model": settings.openai_model,
                "api_key": settings.openai_api_key,
                "temperature": 0.0,
            }
            if settings.openai_base_url:
                kwargs["base_url"] = settings.openai_base_url
            return ChatOpenAI(**kwargs)

        case "groq":
            from langchain_openai import ChatOpenAI

            return ChatOpenAI(
                model=settings.groq_model,
                api_key=settings.groq_api_key,
                base_url="https://api.groq.com/openai/v1",
                temperature=0.0,
            )

        case "anthropic":
            from langchain_anthropic import ChatAnthropic

            return ChatAnthropic(
                model=settings.anthropic_model,
                api_key=settings.anthropic_api_key,
                temperature=0.0,
            )

        case "google_genai":
            from langchain_google_genai import ChatGoogleGenerativeAI

            return ChatGoogleGenerativeAI(
                model=settings.google_model,
                google_api_key=settings.google_api_key,
                temperature=0.0,
            )

        case "bedrock":
            from langchain_aws import ChatBedrockConverse

            return ChatBedrockConverse(
                model=settings.aws_bedrock_model,
                region_name=settings.aws_region,
                temperature=0.0,
            )

        case _:
            raise ValueError(
                f"Unknown LLM provider: '{provider}'. "
                f"Supported providers: {_SUPPORTED_LLM_PROVIDERS}"
            )


# Exceptions that engage the fallback chain. Set to the broad ``Exception``
# because real-world LLM SDK errors (openai.RateLimitError,
# openai.APIConnectionError, httpx.ConnectError, anthropic.APIError, ...)
# do NOT inherit from stdlib ``ConnectionError`` / ``TimeoutError`` / ``OSError``.
# A narrower set would silently let the most common transient failures bypass
# the fallback. Safety relies on three layers instead:
#   1. The whole feature is opt-in via ``LLM_FALLBACK_ENABLED`` (default off).
#   2. Every fallback activation logs a WARNING naming the destination provider.
#   3. Startup logs the full chain at WARNING with cost / privacy reminders.
_FALLBACK_EXCEPTIONS: tuple[type[BaseException], ...] = (Exception,)


def _wrap_with_fallback_logging(llm: BaseChatModel, provider: str) -> BaseChatModel:
    """Wrap ``llm`` so every invocation logs a WARNING naming the provider.

    The wrapper only fires when the underlying Runnable is actually invoked,
    which for a fallback entry means the primary (and any earlier fallbacks)
    already failed. This gives operators a clear trail showing when data
    leaves the primary provider — critical for the privacy-aware default of
    this project.

    Args:
        llm: The chat model to wrap.
        provider: Provider label shown in the log message.

    Returns:
        A Runnable that transparently delegates to ``llm``.
    """

    def _on_start(_run_obj, _config=None) -> None:  # noqa: ANN001
        logger.warning(
            "LLM fallback activated: routing request to provider '%s'. "
            "Check cost / privacy implications.",
            provider,
        )

    return llm.with_listeners(on_start=_on_start)


def create_llm_with_fallback(settings: Settings) -> BaseChatModel:
    """Create the generation LLM, optionally wrapping it in a fallback chain.

    When ``settings.llm_fallback_enabled`` is False OR the fallback list is
    empty, this is a drop-in equivalent of :func:`create_llm`. Otherwise the
    primary LLM is wrapped via LangChain's ``with_fallbacks`` so that when
    the primary raises a transient failure (network / timeout / connection),
    each fallback provider is tried in order.

    Args:
        settings: Application settings.

    Returns:
        A BaseChatModel (primary on its own, or primary-with-fallbacks).
    """
    primary = create_llm(settings)
    if not settings.llm_fallback_enabled or not settings.llm_fallback_providers:
        return primary

    fallbacks: list[BaseChatModel] = []
    for provider in settings.llm_fallback_providers:
        try:
            fallback_settings = replace(settings, llm_provider=provider)
            raw = create_llm(fallback_settings)
        except Exception as exc:  # noqa: BLE001 — log and skip broken fallbacks
            logger.error(
                "Skipping LLM fallback provider '%s' due to construction error: %s",
                provider, exc,
            )
            continue
        fallbacks.append(_wrap_with_fallback_logging(raw, provider))

    if not fallbacks:
        logger.warning(
            "LLM_FALLBACK_ENABLED is true but no fallback providers could be "
            "constructed; running without fallback."
        )
        return primary

    chain_repr = " -> ".join([settings.llm_provider, *settings.llm_fallback_providers])
    logger.warning(
        "LLM fallback chain is ACTIVE: %s. "
        "On transient failure of the primary, requests will be routed to the "
        "next provider. This may incur API costs and send data to third-party "
        "providers.",
        chain_repr,
    )

    return primary.with_fallbacks(
        fallbacks, exceptions_to_handle=_FALLBACK_EXCEPTIONS
    )


_EVALUATOR_MODEL_FIELD: dict[str, str] = {
    "groq": "groq_model",
    "openai": "openai_model",
    "anthropic": "anthropic_model",
    "google_genai": "google_model",
    "azure_openai": "azure_openai_deployment",
    "bedrock": "aws_bedrock_model",
    "ollama": "ollama_model",
}


def create_evaluator_llm(settings: Settings) -> BaseChatModel:
    """Create the LLM used as a RAGAS judge.

    The judge LLM is independent of the generation LLM so a strong cloud
    model (e.g. Qwen3-32B via Groq) can score outputs produced by a small
    local generation model. If ``EVALUATOR_LLM_PROVIDER`` is unset, falls
    back to ``create_llm(settings)`` which reuses the generation LLM.

    Args:
        settings: Application settings with provider configuration.

    Returns:
        A LangChain BaseChatModel instance to use as the RAGAS judge.

    Raises:
        ValueError: If ``EVALUATOR_LLM_PROVIDER`` is set to an unknown value.
    """
    provider = settings.evaluator_llm_provider.lower().strip()
    if not provider:
        logger.info("EVALUATOR_LLM_PROVIDER unset; reusing generation LLM as judge")
        return create_llm(settings)

    overrides: dict[str, str] = {"llm_provider": provider}
    if settings.evaluator_llm_model:
        model_field = _EVALUATOR_MODEL_FIELD.get(provider)
        if model_field is None:
            raise ValueError(
                f"Cannot override evaluator model for unknown provider: '{provider}'"
            )
        overrides[model_field] = settings.evaluator_llm_model

    overridden = replace(settings, **overrides)
    logger.info(
        "Creating evaluator (judge) LLM with provider: %s | model override: %s",
        provider,
        settings.evaluator_llm_model or "(provider default)",
    )
    return create_llm(overridden)


def create_embeddings(settings: Settings) -> Embeddings:
    """Create an embeddings instance based on the configured provider.

    Args:
        settings: Application settings with provider configuration.

    Returns:
        A LangChain Embeddings instance.

    Raises:
        ValueError: If the provider is not supported.
    """
    provider = settings.embedding_provider.lower()
    logger.info("Creating embeddings with provider: %s", provider)

    match provider:
        case "local":
            from langchain_huggingface import HuggingFaceEmbeddings

            return HuggingFaceEmbeddings(
                model_name=settings.local_embedding_model,
            )

        case "azure_openai":
            from langchain_openai import AzureOpenAIEmbeddings

            return AzureOpenAIEmbeddings(
                azure_endpoint=settings.azure_openai_endpoint,
                api_key=settings.azure_openai_api_key,
                api_version=settings.azure_openai_api_version,
                azure_deployment=settings.azure_openai_embedding_deployment,
            )

        case "openai":
            from langchain_openai import OpenAIEmbeddings

            return OpenAIEmbeddings(
                model=settings.openai_embedding_model,
                api_key=settings.openai_api_key,
            )

        case "google_genai":
            from langchain_google_genai import GoogleGenerativeAIEmbeddings

            return GoogleGenerativeAIEmbeddings(
                model=settings.google_embedding_model,
                google_api_key=settings.google_api_key,
            )

        case "bedrock":
            from langchain_aws import BedrockEmbeddings

            return BedrockEmbeddings(
                model_id=settings.aws_bedrock_embedding_model,
                region_name=settings.aws_region,
            )

        case _:
            raise ValueError(
                f"Unknown embedding provider: '{provider}'. "
                f"Supported providers: {_SUPPORTED_EMBEDDING_PROVIDERS}"
            )


def create_reranker(model_name: str) -> object:
    """Create a cross-encoder reranker model instance.

    Args:
        model_name: HuggingFace model name for the cross-encoder.

    Returns:
        A CrossEncoder model instance.
    """
    from sentence_transformers import CrossEncoder

    logger.info("Creating cross-encoder reranker: %s", model_name)
    return CrossEncoder(model_name)