Dokumentassistent / src /provider.py
XQ
Add LLM Provider Fallback
4d2a2da
raw
history blame
11.5 kB
"""Factory functions for creating LLM and embedding instances.
All provider-specific imports are isolated here. The rest of the codebase
interacts only with LangChain abstract interfaces returned by these factories.
"""
import logging
from dataclasses import replace
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from src.config import Settings
logger = logging.getLogger(__name__)
_SUPPORTED_LLM_PROVIDERS = ["ollama", "azure_openai", "openai", "groq", "anthropic", "google_genai", "bedrock"]
_SUPPORTED_EMBEDDING_PROVIDERS = ["local", "azure_openai", "openai", "google_genai", "bedrock"]
def create_llm(settings: Settings) -> BaseChatModel:
"""Create an LLM instance based on the configured provider.
Args:
settings: Application settings with provider configuration.
Returns:
A LangChain BaseChatModel instance.
Raises:
ValueError: If the provider is not supported.
"""
provider = settings.llm_provider.lower()
logger.info("Creating LLM with provider: %s", provider)
match provider:
case "ollama":
from langchain_ollama import ChatOllama
return ChatOllama(
base_url=settings.ollama_base_url,
model=settings.ollama_model,
temperature=0.0,
)
case "azure_openai":
from langchain_openai import AzureChatOpenAI
return AzureChatOpenAI(
azure_endpoint=settings.azure_openai_endpoint,
api_key=settings.azure_openai_api_key,
api_version=settings.azure_openai_api_version,
azure_deployment=settings.azure_openai_deployment,
temperature=0.0,
)
case "openai":
from langchain_openai import ChatOpenAI
kwargs: dict = {
"model": settings.openai_model,
"api_key": settings.openai_api_key,
"temperature": 0.0,
}
if settings.openai_base_url:
kwargs["base_url"] = settings.openai_base_url
return ChatOpenAI(**kwargs)
case "groq":
from langchain_openai import ChatOpenAI
return ChatOpenAI(
model=settings.groq_model,
api_key=settings.groq_api_key,
base_url="https://api.groq.com/openai/v1",
temperature=0.0,
)
case "anthropic":
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(
model=settings.anthropic_model,
api_key=settings.anthropic_api_key,
temperature=0.0,
)
case "google_genai":
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(
model=settings.google_model,
google_api_key=settings.google_api_key,
temperature=0.0,
)
case "bedrock":
from langchain_aws import ChatBedrockConverse
return ChatBedrockConverse(
model=settings.aws_bedrock_model,
region_name=settings.aws_region,
temperature=0.0,
)
case _:
raise ValueError(
f"Unknown LLM provider: '{provider}'. "
f"Supported providers: {_SUPPORTED_LLM_PROVIDERS}"
)
# Exceptions that engage the fallback chain. Set to the broad ``Exception``
# because real-world LLM SDK errors (openai.RateLimitError,
# openai.APIConnectionError, httpx.ConnectError, anthropic.APIError, ...)
# do NOT inherit from stdlib ``ConnectionError`` / ``TimeoutError`` / ``OSError``.
# A narrower set would silently let the most common transient failures bypass
# the fallback. Safety relies on three layers instead:
# 1. The whole feature is opt-in via ``LLM_FALLBACK_ENABLED`` (default off).
# 2. Every fallback activation logs a WARNING naming the destination provider.
# 3. Startup logs the full chain at WARNING with cost / privacy reminders.
_FALLBACK_EXCEPTIONS: tuple[type[BaseException], ...] = (Exception,)
def _wrap_with_fallback_logging(llm: BaseChatModel, provider: str) -> BaseChatModel:
"""Wrap ``llm`` so every invocation logs a WARNING naming the provider.
The wrapper only fires when the underlying Runnable is actually invoked,
which for a fallback entry means the primary (and any earlier fallbacks)
already failed. This gives operators a clear trail showing when data
leaves the primary provider — critical for the privacy-aware default of
this project.
Args:
llm: The chat model to wrap.
provider: Provider label shown in the log message.
Returns:
A Runnable that transparently delegates to ``llm``.
"""
def _on_start(_run_obj, _config=None) -> None: # noqa: ANN001
logger.warning(
"LLM fallback activated: routing request to provider '%s'. "
"Check cost / privacy implications.",
provider,
)
return llm.with_listeners(on_start=_on_start)
def create_llm_with_fallback(settings: Settings) -> BaseChatModel:
"""Create the generation LLM, optionally wrapping it in a fallback chain.
When ``settings.llm_fallback_enabled`` is False OR the fallback list is
empty, this is a drop-in equivalent of :func:`create_llm`. Otherwise the
primary LLM is wrapped via LangChain's ``with_fallbacks`` so that when
the primary raises a transient failure (network / timeout / connection),
each fallback provider is tried in order.
Args:
settings: Application settings.
Returns:
A BaseChatModel (primary on its own, or primary-with-fallbacks).
"""
primary = create_llm(settings)
if not settings.llm_fallback_enabled or not settings.llm_fallback_providers:
return primary
fallbacks: list[BaseChatModel] = []
for provider in settings.llm_fallback_providers:
try:
fallback_settings = replace(settings, llm_provider=provider)
raw = create_llm(fallback_settings)
except Exception as exc: # noqa: BLE001 — log and skip broken fallbacks
logger.error(
"Skipping LLM fallback provider '%s' due to construction error: %s",
provider, exc,
)
continue
fallbacks.append(_wrap_with_fallback_logging(raw, provider))
if not fallbacks:
logger.warning(
"LLM_FALLBACK_ENABLED is true but no fallback providers could be "
"constructed; running without fallback."
)
return primary
chain_repr = " -> ".join([settings.llm_provider, *settings.llm_fallback_providers])
logger.warning(
"LLM fallback chain is ACTIVE: %s. "
"On transient failure of the primary, requests will be routed to the "
"next provider. This may incur API costs and send data to third-party "
"providers.",
chain_repr,
)
return primary.with_fallbacks(
fallbacks, exceptions_to_handle=_FALLBACK_EXCEPTIONS
)
_EVALUATOR_MODEL_FIELD: dict[str, str] = {
"groq": "groq_model",
"openai": "openai_model",
"anthropic": "anthropic_model",
"google_genai": "google_model",
"azure_openai": "azure_openai_deployment",
"bedrock": "aws_bedrock_model",
"ollama": "ollama_model",
}
def create_evaluator_llm(settings: Settings) -> BaseChatModel:
"""Create the LLM used as a RAGAS judge.
The judge LLM is independent of the generation LLM so a strong cloud
model (e.g. Qwen3-32B via Groq) can score outputs produced by a small
local generation model. If ``EVALUATOR_LLM_PROVIDER`` is unset, falls
back to ``create_llm(settings)`` which reuses the generation LLM.
Args:
settings: Application settings with provider configuration.
Returns:
A LangChain BaseChatModel instance to use as the RAGAS judge.
Raises:
ValueError: If ``EVALUATOR_LLM_PROVIDER`` is set to an unknown value.
"""
provider = settings.evaluator_llm_provider.lower().strip()
if not provider:
logger.info("EVALUATOR_LLM_PROVIDER unset; reusing generation LLM as judge")
return create_llm(settings)
overrides: dict[str, str] = {"llm_provider": provider}
if settings.evaluator_llm_model:
model_field = _EVALUATOR_MODEL_FIELD.get(provider)
if model_field is None:
raise ValueError(
f"Cannot override evaluator model for unknown provider: '{provider}'"
)
overrides[model_field] = settings.evaluator_llm_model
overridden = replace(settings, **overrides)
logger.info(
"Creating evaluator (judge) LLM with provider: %s | model override: %s",
provider,
settings.evaluator_llm_model or "(provider default)",
)
return create_llm(overridden)
def create_embeddings(settings: Settings) -> Embeddings:
"""Create an embeddings instance based on the configured provider.
Args:
settings: Application settings with provider configuration.
Returns:
A LangChain Embeddings instance.
Raises:
ValueError: If the provider is not supported.
"""
provider = settings.embedding_provider.lower()
logger.info("Creating embeddings with provider: %s", provider)
match provider:
case "local":
from langchain_huggingface import HuggingFaceEmbeddings
return HuggingFaceEmbeddings(
model_name=settings.local_embedding_model,
)
case "azure_openai":
from langchain_openai import AzureOpenAIEmbeddings
return AzureOpenAIEmbeddings(
azure_endpoint=settings.azure_openai_endpoint,
api_key=settings.azure_openai_api_key,
api_version=settings.azure_openai_api_version,
azure_deployment=settings.azure_openai_embedding_deployment,
)
case "openai":
from langchain_openai import OpenAIEmbeddings
return OpenAIEmbeddings(
model=settings.openai_embedding_model,
api_key=settings.openai_api_key,
)
case "google_genai":
from langchain_google_genai import GoogleGenerativeAIEmbeddings
return GoogleGenerativeAIEmbeddings(
model=settings.google_embedding_model,
google_api_key=settings.google_api_key,
)
case "bedrock":
from langchain_aws import BedrockEmbeddings
return BedrockEmbeddings(
model_id=settings.aws_bedrock_embedding_model,
region_name=settings.aws_region,
)
case _:
raise ValueError(
f"Unknown embedding provider: '{provider}'. "
f"Supported providers: {_SUPPORTED_EMBEDDING_PROVIDERS}"
)
def create_reranker(model_name: str) -> object:
"""Create a cross-encoder reranker model instance.
Args:
model_name: HuggingFace model name for the cross-encoder.
Returns:
A CrossEncoder model instance.
"""
from sentence_transformers import CrossEncoder
logger.info("Creating cross-encoder reranker: %s", model_name)
return CrossEncoder(model_name)