Spaces:
Running
Running
| """Factory functions for creating LLM and embedding instances. | |
| All provider-specific imports are isolated here. The rest of the codebase | |
| interacts only with LangChain abstract interfaces returned by these factories. | |
| """ | |
| import logging | |
| from dataclasses import replace | |
| from langchain_core.embeddings import Embeddings | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| from src.config import Settings | |
| logger = logging.getLogger(__name__) | |
| _SUPPORTED_LLM_PROVIDERS = ["ollama", "azure_openai", "openai", "groq", "anthropic", "google_genai", "bedrock"] | |
| _SUPPORTED_EMBEDDING_PROVIDERS = ["local", "azure_openai", "openai", "google_genai", "bedrock"] | |
| def create_llm(settings: Settings) -> BaseChatModel: | |
| """Create an LLM instance based on the configured provider. | |
| Args: | |
| settings: Application settings with provider configuration. | |
| Returns: | |
| A LangChain BaseChatModel instance. | |
| Raises: | |
| ValueError: If the provider is not supported. | |
| """ | |
| provider = settings.llm_provider.lower() | |
| logger.info("Creating LLM with provider: %s", provider) | |
| match provider: | |
| case "ollama": | |
| from langchain_ollama import ChatOllama | |
| return ChatOllama( | |
| base_url=settings.ollama_base_url, | |
| model=settings.ollama_model, | |
| temperature=0.0, | |
| ) | |
| case "azure_openai": | |
| from langchain_openai import AzureChatOpenAI | |
| return AzureChatOpenAI( | |
| azure_endpoint=settings.azure_openai_endpoint, | |
| api_key=settings.azure_openai_api_key, | |
| api_version=settings.azure_openai_api_version, | |
| azure_deployment=settings.azure_openai_deployment, | |
| temperature=0.0, | |
| ) | |
| case "openai": | |
| from langchain_openai import ChatOpenAI | |
| kwargs: dict = { | |
| "model": settings.openai_model, | |
| "api_key": settings.openai_api_key, | |
| "temperature": 0.0, | |
| } | |
| if settings.openai_base_url: | |
| kwargs["base_url"] = settings.openai_base_url | |
| return ChatOpenAI(**kwargs) | |
| case "groq": | |
| from langchain_openai import ChatOpenAI | |
| return ChatOpenAI( | |
| model=settings.groq_model, | |
| api_key=settings.groq_api_key, | |
| base_url="https://api.groq.com/openai/v1", | |
| temperature=0.0, | |
| ) | |
| case "anthropic": | |
| from langchain_anthropic import ChatAnthropic | |
| return ChatAnthropic( | |
| model=settings.anthropic_model, | |
| api_key=settings.anthropic_api_key, | |
| temperature=0.0, | |
| ) | |
| case "google_genai": | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| return ChatGoogleGenerativeAI( | |
| model=settings.google_model, | |
| google_api_key=settings.google_api_key, | |
| temperature=0.0, | |
| ) | |
| case "bedrock": | |
| from langchain_aws import ChatBedrockConverse | |
| return ChatBedrockConverse( | |
| model=settings.aws_bedrock_model, | |
| region_name=settings.aws_region, | |
| temperature=0.0, | |
| ) | |
| case _: | |
| raise ValueError( | |
| f"Unknown LLM provider: '{provider}'. " | |
| f"Supported providers: {_SUPPORTED_LLM_PROVIDERS}" | |
| ) | |
| # Exceptions that engage the fallback chain. Set to the broad ``Exception`` | |
| # because real-world LLM SDK errors (openai.RateLimitError, | |
| # openai.APIConnectionError, httpx.ConnectError, anthropic.APIError, ...) | |
| # do NOT inherit from stdlib ``ConnectionError`` / ``TimeoutError`` / ``OSError``. | |
| # A narrower set would silently let the most common transient failures bypass | |
| # the fallback. Safety relies on three layers instead: | |
| # 1. The whole feature is opt-in via ``LLM_FALLBACK_ENABLED`` (default off). | |
| # 2. Every fallback activation logs a WARNING naming the destination provider. | |
| # 3. Startup logs the full chain at WARNING with cost / privacy reminders. | |
| _FALLBACK_EXCEPTIONS: tuple[type[BaseException], ...] = (Exception,) | |
| def _wrap_with_fallback_logging(llm: BaseChatModel, provider: str) -> BaseChatModel: | |
| """Wrap ``llm`` so every invocation logs a WARNING naming the provider. | |
| The wrapper only fires when the underlying Runnable is actually invoked, | |
| which for a fallback entry means the primary (and any earlier fallbacks) | |
| already failed. This gives operators a clear trail showing when data | |
| leaves the primary provider — critical for the privacy-aware default of | |
| this project. | |
| Args: | |
| llm: The chat model to wrap. | |
| provider: Provider label shown in the log message. | |
| Returns: | |
| A Runnable that transparently delegates to ``llm``. | |
| """ | |
| def _on_start(_run_obj, _config=None) -> None: # noqa: ANN001 | |
| logger.warning( | |
| "LLM fallback activated: routing request to provider '%s'. " | |
| "Check cost / privacy implications.", | |
| provider, | |
| ) | |
| return llm.with_listeners(on_start=_on_start) | |
| def create_llm_with_fallback(settings: Settings) -> BaseChatModel: | |
| """Create the generation LLM, optionally wrapping it in a fallback chain. | |
| When ``settings.llm_fallback_enabled`` is False OR the fallback list is | |
| empty, this is a drop-in equivalent of :func:`create_llm`. Otherwise the | |
| primary LLM is wrapped via LangChain's ``with_fallbacks`` so that when | |
| the primary raises a transient failure (network / timeout / connection), | |
| each fallback provider is tried in order. | |
| Args: | |
| settings: Application settings. | |
| Returns: | |
| A BaseChatModel (primary on its own, or primary-with-fallbacks). | |
| """ | |
| primary = create_llm(settings) | |
| if not settings.llm_fallback_enabled or not settings.llm_fallback_providers: | |
| return primary | |
| fallbacks: list[BaseChatModel] = [] | |
| for provider in settings.llm_fallback_providers: | |
| try: | |
| fallback_settings = replace(settings, llm_provider=provider) | |
| raw = create_llm(fallback_settings) | |
| except Exception as exc: # noqa: BLE001 — log and skip broken fallbacks | |
| logger.error( | |
| "Skipping LLM fallback provider '%s' due to construction error: %s", | |
| provider, exc, | |
| ) | |
| continue | |
| fallbacks.append(_wrap_with_fallback_logging(raw, provider)) | |
| if not fallbacks: | |
| logger.warning( | |
| "LLM_FALLBACK_ENABLED is true but no fallback providers could be " | |
| "constructed; running without fallback." | |
| ) | |
| return primary | |
| chain_repr = " -> ".join([settings.llm_provider, *settings.llm_fallback_providers]) | |
| logger.warning( | |
| "LLM fallback chain is ACTIVE: %s. " | |
| "On transient failure of the primary, requests will be routed to the " | |
| "next provider. This may incur API costs and send data to third-party " | |
| "providers.", | |
| chain_repr, | |
| ) | |
| return primary.with_fallbacks( | |
| fallbacks, exceptions_to_handle=_FALLBACK_EXCEPTIONS | |
| ) | |
| _EVALUATOR_MODEL_FIELD: dict[str, str] = { | |
| "groq": "groq_model", | |
| "openai": "openai_model", | |
| "anthropic": "anthropic_model", | |
| "google_genai": "google_model", | |
| "azure_openai": "azure_openai_deployment", | |
| "bedrock": "aws_bedrock_model", | |
| "ollama": "ollama_model", | |
| } | |
| def create_evaluator_llm(settings: Settings) -> BaseChatModel: | |
| """Create the LLM used as a RAGAS judge. | |
| The judge LLM is independent of the generation LLM so a strong cloud | |
| model (e.g. Qwen3-32B via Groq) can score outputs produced by a small | |
| local generation model. If ``EVALUATOR_LLM_PROVIDER`` is unset, falls | |
| back to ``create_llm(settings)`` which reuses the generation LLM. | |
| Args: | |
| settings: Application settings with provider configuration. | |
| Returns: | |
| A LangChain BaseChatModel instance to use as the RAGAS judge. | |
| Raises: | |
| ValueError: If ``EVALUATOR_LLM_PROVIDER`` is set to an unknown value. | |
| """ | |
| provider = settings.evaluator_llm_provider.lower().strip() | |
| if not provider: | |
| logger.info("EVALUATOR_LLM_PROVIDER unset; reusing generation LLM as judge") | |
| return create_llm(settings) | |
| overrides: dict[str, str] = {"llm_provider": provider} | |
| if settings.evaluator_llm_model: | |
| model_field = _EVALUATOR_MODEL_FIELD.get(provider) | |
| if model_field is None: | |
| raise ValueError( | |
| f"Cannot override evaluator model for unknown provider: '{provider}'" | |
| ) | |
| overrides[model_field] = settings.evaluator_llm_model | |
| overridden = replace(settings, **overrides) | |
| logger.info( | |
| "Creating evaluator (judge) LLM with provider: %s | model override: %s", | |
| provider, | |
| settings.evaluator_llm_model or "(provider default)", | |
| ) | |
| return create_llm(overridden) | |
| def create_embeddings(settings: Settings) -> Embeddings: | |
| """Create an embeddings instance based on the configured provider. | |
| Args: | |
| settings: Application settings with provider configuration. | |
| Returns: | |
| A LangChain Embeddings instance. | |
| Raises: | |
| ValueError: If the provider is not supported. | |
| """ | |
| provider = settings.embedding_provider.lower() | |
| logger.info("Creating embeddings with provider: %s", provider) | |
| match provider: | |
| case "local": | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| return HuggingFaceEmbeddings( | |
| model_name=settings.local_embedding_model, | |
| ) | |
| case "azure_openai": | |
| from langchain_openai import AzureOpenAIEmbeddings | |
| return AzureOpenAIEmbeddings( | |
| azure_endpoint=settings.azure_openai_endpoint, | |
| api_key=settings.azure_openai_api_key, | |
| api_version=settings.azure_openai_api_version, | |
| azure_deployment=settings.azure_openai_embedding_deployment, | |
| ) | |
| case "openai": | |
| from langchain_openai import OpenAIEmbeddings | |
| return OpenAIEmbeddings( | |
| model=settings.openai_embedding_model, | |
| api_key=settings.openai_api_key, | |
| ) | |
| case "google_genai": | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
| return GoogleGenerativeAIEmbeddings( | |
| model=settings.google_embedding_model, | |
| google_api_key=settings.google_api_key, | |
| ) | |
| case "bedrock": | |
| from langchain_aws import BedrockEmbeddings | |
| return BedrockEmbeddings( | |
| model_id=settings.aws_bedrock_embedding_model, | |
| region_name=settings.aws_region, | |
| ) | |
| case _: | |
| raise ValueError( | |
| f"Unknown embedding provider: '{provider}'. " | |
| f"Supported providers: {_SUPPORTED_EMBEDDING_PROVIDERS}" | |
| ) | |
| def create_reranker(model_name: str) -> object: | |
| """Create a cross-encoder reranker model instance. | |
| Args: | |
| model_name: HuggingFace model name for the cross-encoder. | |
| Returns: | |
| A CrossEncoder model instance. | |
| """ | |
| from sentence_transformers import CrossEncoder | |
| logger.info("Creating cross-encoder reranker: %s", model_name) | |
| return CrossEncoder(model_name) | |