Spaces:
Running
Running
| """ | |
| LLM Factory - Multi-provider LLM abstraction. | |
| Supports: | |
| - Google (Gemini) | |
| - OpenAI (GPT) | |
| - Anthropic (Claude) | |
| """ | |
| import os | |
| from typing import Literal | |
| from langchain_core.language_models import BaseChatModel | |
| from .exceptions import LLMInvalidModelError, LLMProviderError | |
| Provider = Literal["google", "openai", "anthropic"] | |
| MODEL_PROVIDERS: dict[Provider, list[str]] = { | |
| "google": [ | |
| "gemini-3-flash-preview", | |
| "gemini-3-pro-preview", | |
| "gemini-2.5-flash", | |
| "gemini-2.5-pro", | |
| "gemini-2.0-flash", | |
| "gemini-1.5-flash", | |
| "gemini-1.5-pro", | |
| ], | |
| "openai": [ | |
| "gpt-4o", | |
| "gpt-4o-mini", | |
| "gpt-4-turbo", | |
| "gpt-4", | |
| "gpt-3.5-turbo", | |
| ], | |
| "anthropic": [ | |
| "claude-sonnet-4-20250514", | |
| "claude-3-5-sonnet-20241022", | |
| "claude-3-5-haiku-20241022", | |
| "claude-3-opus-20240229", | |
| ], | |
| } | |
| # Flatten for quick lookup | |
| ALL_MODELS: set[str] = {model for models in MODEL_PROVIDERS.values() for model in models} | |
| def detect_provider(model: str) -> Provider: | |
| """ | |
| Detect the provider based on model name. | |
| Args: | |
| model: The model name (e.g., 'gemini-2.5-flash', 'gpt-4o') | |
| Returns: | |
| The provider name ('google', 'openai', 'anthropic') | |
| Raises: | |
| LLMInvalidModelError: If the model is not recognized | |
| """ | |
| model_lower = model.lower() | |
| # Check by prefix | |
| if model_lower.startswith("gemini"): | |
| return "google" | |
| if model_lower.startswith("gpt"): | |
| return "openai" | |
| if model_lower.startswith("claude"): | |
| return "anthropic" | |
| # Check in known models | |
| for provider, models in MODEL_PROVIDERS.items(): | |
| if model in models: | |
| return provider | |
| raise LLMInvalidModelError(model, list(ALL_MODELS)) | |
| class LLMFactory: | |
| """Factory for creating LLM instances across multiple providers.""" | |
| # Cache for LLM instances (singleton per model+config) | |
| _instances: dict[str, BaseChatModel] = {} | |
| def create( | |
| cls, | |
| model: str, | |
| temperature: float = 0.7, | |
| max_retries: int = 3, | |
| timeout: int = 60, | |
| api_key: str | None = None, | |
| use_cache: bool = True, | |
| **kwargs, | |
| ) -> BaseChatModel: | |
| """ | |
| Create an LLM instance for the specified model. | |
| Args: | |
| model: Model name (e.g., 'gemini-2.5-flash', 'gpt-4o', 'claude-sonnet-4-20250514') | |
| temperature: Sampling temperature (0.0 to 1.0) | |
| max_retries: Maximum number of retries on failure | |
| timeout: Request timeout in seconds | |
| api_key: Optional API key (defaults to environment variable) | |
| use_cache: Whether to use cached instances | |
| **kwargs: Additional provider-specific arguments | |
| Returns: | |
| BaseChatModel instance | |
| Raises: | |
| LLMInvalidModelError: If model is not recognized | |
| LLMProviderError: If provider initialization fails | |
| """ | |
| # Check cache | |
| cache_key = f"{model}:{temperature}:{timeout}" | |
| if use_cache and cache_key in cls._instances: | |
| return cls._instances[cache_key] | |
| provider = detect_provider(model) | |
| try: | |
| llm = cls._create_for_provider( | |
| provider=provider, | |
| model=model, | |
| temperature=temperature, | |
| max_retries=max_retries, | |
| timeout=timeout, | |
| api_key=api_key, | |
| **kwargs, | |
| ) | |
| if use_cache: | |
| cls._instances[cache_key] = llm | |
| return llm | |
| except ImportError as e: | |
| raise LLMProviderError( | |
| f"Provider '{provider}' dependencies not installed: {e}", | |
| provider=provider, | |
| model=model, | |
| ) | |
| except Exception as e: | |
| raise LLMProviderError( | |
| f"Failed to create LLM for '{model}': {e}", | |
| provider=provider, | |
| model=model, | |
| ) | |
| def _create_for_provider( | |
| cls, | |
| provider: Provider, | |
| model: str, | |
| temperature: float, | |
| max_retries: int, | |
| timeout: int, | |
| api_key: str | None, | |
| **kwargs, | |
| ) -> BaseChatModel: | |
| """Create LLM instance for a specific provider.""" | |
| match provider: | |
| case "google": | |
| return cls._create_google( | |
| model, temperature, max_retries, timeout, api_key, **kwargs | |
| ) | |
| case "openai": | |
| return cls._create_openai( | |
| model, temperature, max_retries, timeout, api_key, **kwargs | |
| ) | |
| case "anthropic": | |
| return cls._create_anthropic( | |
| model, temperature, max_retries, timeout, api_key, **kwargs | |
| ) | |
| def _create_google( | |
| model: str, | |
| temperature: float, | |
| max_retries: int, | |
| timeout: int, | |
| api_key: str | None, | |
| callbacks: list | None = None, | |
| **kwargs, | |
| ) -> BaseChatModel: | |
| """Create Google Gemini LLM instance.""" | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| return ChatGoogleGenerativeAI( | |
| model=model, | |
| temperature=temperature, | |
| max_retries=max_retries, | |
| timeout=timeout, | |
| google_api_key=api_key or os.getenv("GEMINI_API_KEY"), | |
| callbacks=callbacks, | |
| **kwargs, | |
| ) | |
| def _create_openai( | |
| model: str, | |
| temperature: float, | |
| max_retries: int, | |
| timeout: int, | |
| api_key: str | None, | |
| callbacks: list | None = None, | |
| **kwargs, | |
| ) -> BaseChatModel: | |
| """Create OpenAI LLM instance.""" | |
| from langchain_openai import ChatOpenAI | |
| return ChatOpenAI( | |
| model=model, | |
| temperature=temperature, | |
| max_retries=max_retries, | |
| timeout=timeout, | |
| api_key=api_key or os.getenv("OPENAI_API_KEY"), | |
| callbacks=callbacks, | |
| **kwargs, | |
| ) | |
| def _create_anthropic( | |
| model: str, | |
| temperature: float, | |
| max_retries: int, | |
| timeout: int, | |
| api_key: str | None, | |
| callbacks: list | None = None, | |
| **kwargs, | |
| ) -> BaseChatModel: | |
| """Create Anthropic Claude LLM instance.""" | |
| from langchain_anthropic import ChatAnthropic | |
| return ChatAnthropic( | |
| model=model, | |
| temperature=temperature, | |
| max_retries=max_retries, | |
| timeout=timeout, | |
| api_key=api_key or os.getenv("ANTHROPIC_API_KEY"), | |
| callbacks=callbacks, | |
| **kwargs, | |
| ) | |
| def list_models(cls, provider: Provider | None = None) -> list[str]: | |
| """ | |
| List available models. | |
| Args: | |
| provider: Optional provider to filter by | |
| Returns: | |
| List of model names | |
| """ | |
| if provider: | |
| return MODEL_PROVIDERS.get(provider, []) | |
| return list(ALL_MODELS) | |
| def list_providers(cls) -> list[Provider]: | |
| """List available providers.""" | |
| return list(MODEL_PROVIDERS.keys()) | |
| def clear_cache(cls) -> None: | |
| """Clear the LLM instance cache.""" | |
| cls._instances.clear() | |
| def get_default_model(cls, provider: Provider | None = None) -> str: | |
| """ | |
| Get the default model for a provider. | |
| Args: | |
| provider: Provider name (defaults to 'google') | |
| Returns: | |
| Default model name | |
| """ | |
| provider = provider or "google" | |
| models = MODEL_PROVIDERS.get(provider, []) | |
| if not models: | |
| raise LLMProviderError(f"No models available for provider: {provider}") | |
| return models[0] | |