zico-agent / src /llm /factory.py
github-actions[bot]
Deploy from GitHub Actions: b307fb4145d2a061408e2bf8587cb6dd8a90be79
171371a
"""
LLM Factory - Multi-provider LLM abstraction.
Supports:
- Google (Gemini)
- OpenAI (GPT)
- Anthropic (Claude)
"""
import os
from typing import Literal
from langchain_core.language_models import BaseChatModel
from .exceptions import LLMInvalidModelError, LLMProviderError
Provider = Literal["google", "openai", "anthropic"]
MODEL_PROVIDERS: dict[Provider, list[str]] = {
"google": [
"gemini-3-flash-preview",
"gemini-3-pro-preview",
"gemini-2.5-flash",
"gemini-2.5-pro",
"gemini-2.0-flash",
"gemini-1.5-flash",
"gemini-1.5-pro",
],
"openai": [
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-4",
"gpt-3.5-turbo",
],
"anthropic": [
"claude-sonnet-4-20250514",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"claude-3-opus-20240229",
],
}
# Flatten for quick lookup
ALL_MODELS: set[str] = {model for models in MODEL_PROVIDERS.values() for model in models}
def detect_provider(model: str) -> Provider:
"""
Detect the provider based on model name.
Args:
model: The model name (e.g., 'gemini-2.5-flash', 'gpt-4o')
Returns:
The provider name ('google', 'openai', 'anthropic')
Raises:
LLMInvalidModelError: If the model is not recognized
"""
model_lower = model.lower()
# Check by prefix
if model_lower.startswith("gemini"):
return "google"
if model_lower.startswith("gpt"):
return "openai"
if model_lower.startswith("claude"):
return "anthropic"
# Check in known models
for provider, models in MODEL_PROVIDERS.items():
if model in models:
return provider
raise LLMInvalidModelError(model, list(ALL_MODELS))
class LLMFactory:
"""Factory for creating LLM instances across multiple providers."""
# Cache for LLM instances (singleton per model+config)
_instances: dict[str, BaseChatModel] = {}
@classmethod
def create(
cls,
model: str,
temperature: float = 0.7,
max_retries: int = 3,
timeout: int = 60,
api_key: str | None = None,
use_cache: bool = True,
**kwargs,
) -> BaseChatModel:
"""
Create an LLM instance for the specified model.
Args:
model: Model name (e.g., 'gemini-2.5-flash', 'gpt-4o', 'claude-sonnet-4-20250514')
temperature: Sampling temperature (0.0 to 1.0)
max_retries: Maximum number of retries on failure
timeout: Request timeout in seconds
api_key: Optional API key (defaults to environment variable)
use_cache: Whether to use cached instances
**kwargs: Additional provider-specific arguments
Returns:
BaseChatModel instance
Raises:
LLMInvalidModelError: If model is not recognized
LLMProviderError: If provider initialization fails
"""
# Check cache
cache_key = f"{model}:{temperature}:{timeout}"
if use_cache and cache_key in cls._instances:
return cls._instances[cache_key]
provider = detect_provider(model)
try:
llm = cls._create_for_provider(
provider=provider,
model=model,
temperature=temperature,
max_retries=max_retries,
timeout=timeout,
api_key=api_key,
**kwargs,
)
if use_cache:
cls._instances[cache_key] = llm
return llm
except ImportError as e:
raise LLMProviderError(
f"Provider '{provider}' dependencies not installed: {e}",
provider=provider,
model=model,
)
except Exception as e:
raise LLMProviderError(
f"Failed to create LLM for '{model}': {e}",
provider=provider,
model=model,
)
@classmethod
def _create_for_provider(
cls,
provider: Provider,
model: str,
temperature: float,
max_retries: int,
timeout: int,
api_key: str | None,
**kwargs,
) -> BaseChatModel:
"""Create LLM instance for a specific provider."""
match provider:
case "google":
return cls._create_google(
model, temperature, max_retries, timeout, api_key, **kwargs
)
case "openai":
return cls._create_openai(
model, temperature, max_retries, timeout, api_key, **kwargs
)
case "anthropic":
return cls._create_anthropic(
model, temperature, max_retries, timeout, api_key, **kwargs
)
@staticmethod
def _create_google(
model: str,
temperature: float,
max_retries: int,
timeout: int,
api_key: str | None,
callbacks: list | None = None,
**kwargs,
) -> BaseChatModel:
"""Create Google Gemini LLM instance."""
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(
model=model,
temperature=temperature,
max_retries=max_retries,
timeout=timeout,
google_api_key=api_key or os.getenv("GEMINI_API_KEY"),
callbacks=callbacks,
**kwargs,
)
@staticmethod
def _create_openai(
model: str,
temperature: float,
max_retries: int,
timeout: int,
api_key: str | None,
callbacks: list | None = None,
**kwargs,
) -> BaseChatModel:
"""Create OpenAI LLM instance."""
from langchain_openai import ChatOpenAI
return ChatOpenAI(
model=model,
temperature=temperature,
max_retries=max_retries,
timeout=timeout,
api_key=api_key or os.getenv("OPENAI_API_KEY"),
callbacks=callbacks,
**kwargs,
)
@staticmethod
def _create_anthropic(
model: str,
temperature: float,
max_retries: int,
timeout: int,
api_key: str | None,
callbacks: list | None = None,
**kwargs,
) -> BaseChatModel:
"""Create Anthropic Claude LLM instance."""
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(
model=model,
temperature=temperature,
max_retries=max_retries,
timeout=timeout,
api_key=api_key or os.getenv("ANTHROPIC_API_KEY"),
callbacks=callbacks,
**kwargs,
)
@classmethod
def list_models(cls, provider: Provider | None = None) -> list[str]:
"""
List available models.
Args:
provider: Optional provider to filter by
Returns:
List of model names
"""
if provider:
return MODEL_PROVIDERS.get(provider, [])
return list(ALL_MODELS)
@classmethod
def list_providers(cls) -> list[Provider]:
"""List available providers."""
return list(MODEL_PROVIDERS.keys())
@classmethod
def clear_cache(cls) -> None:
"""Clear the LLM instance cache."""
cls._instances.clear()
@classmethod
def get_default_model(cls, provider: Provider | None = None) -> str:
"""
Get the default model for a provider.
Args:
provider: Provider name (defaults to 'google')
Returns:
Default model name
"""
provider = provider or "google"
models = MODEL_PROVIDERS.get(provider, [])
if not models:
raise LLMProviderError(f"No models available for provider: {provider}")
return models[0]