"""LLM client adapters for different providers.""" from typing import Dict, Any, List, Optional, Union from abc import ABC, abstractmethod from dataclasses import dataclass # LangChain imports from langchain_mistralai.chat_models import ChatMistralAI from langchain_openai.chat_models import ChatOpenAI from langchain_ollama import ChatOllama # Legacy client dependencies from huggingface_hub import InferenceClient from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain_community.llms import HuggingFaceEndpoint from langchain_community.chat_models.huggingface import ChatHuggingFace # Configuration loader from ..config.loader import load_config # Load configuration once at module level _config = load_config() # Legacy client factory functions (inlined from auditqa_old.reader) def _create_inf_provider_client(): """Create INF_PROVIDERS client.""" reader_config = _config.get("reader", {}) inf_config = reader_config.get("INF_PROVIDERS", {}) api_key = inf_config.get("api_key") if not api_key: raise ValueError("INF_PROVIDERS api_key not found in configuration") provider = inf_config.get("provider") if not provider: raise ValueError("INF_PROVIDERS provider not found in configuration") return InferenceClient( provider=provider, api_key=api_key, bill_to="GIZ", ) def _create_nvidia_client(): """Create NVIDIA client.""" reader_config = _config.get("reader", {}) nvidia_config = reader_config.get("NVIDIA", {}) api_key = nvidia_config.get("api_key") if not api_key: raise ValueError("NVIDIA api_key not found in configuration") endpoint = nvidia_config.get("endpoint") if not endpoint: raise ValueError("NVIDIA endpoint not found in configuration") return InferenceClient( base_url=endpoint, api_key=api_key ) def _create_serverless_client(): """Create serverless API client.""" reader_config = _config.get("reader", {}) serverless_config = reader_config.get("SERVERLESS", {}) api_key = serverless_config.get("api_key") if not api_key: raise ValueError("SERVERLESS api_key not found in configuration") model_id = serverless_config.get("model", "meta-llama/Meta-Llama-3-8B-Instruct") return InferenceClient( model=model_id, api_key=api_key, ) def _create_dedicated_endpoint_client(): """Create dedicated endpoint client.""" reader_config = _config.get("reader", {}) dedicated_config = reader_config.get("DEDICATED", {}) api_key = dedicated_config.get("api_key") if not api_key: raise ValueError("DEDICATED api_key not found in configuration") endpoint = dedicated_config.get("endpoint") if not endpoint: raise ValueError("DEDICATED endpoint not found in configuration") max_tokens = dedicated_config.get("max_tokens", 768) # Set up the streaming callback handler callback = StreamingStdOutCallbackHandler() # Initialize the HuggingFaceEndpoint with streaming enabled llm_qa = HuggingFaceEndpoint( endpoint_url=endpoint, max_new_tokens=int(max_tokens), repetition_penalty=1.03, timeout=70, huggingfacehub_api_token=api_key, streaming=True, callbacks=[callback] ) # Create a ChatHuggingFace instance with the streaming-enabled endpoint return ChatHuggingFace(llm=llm_qa) @dataclass class LLMResponse: """Standardized LLM response format.""" content: str model: str provider: str metadata: Dict[str, Any] = None class BaseLLMAdapter(ABC): """Base class for LLM adapters.""" def __init__(self, config: Dict[str, Any]): self.config = config @abstractmethod def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: """Generate response from messages.""" pass @abstractmethod def stream_generate(self, messages: List[Dict[str, str]], **kwargs): """Generate streaming response from messages.""" pass class MistralAdapter(BaseLLMAdapter): """Adapter for Mistral AI models.""" def __init__(self, config: Dict[str, Any]): super().__init__(config) self.model = ChatMistralAI( model=config.get("model", "mistral-medium-latest") ) def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: """Generate response using Mistral.""" response = self.model.invoke(messages) return LLMResponse( content=response.content, model=self.config.get("model", "mistral-medium-latest"), provider="mistral", metadata={"usage": getattr(response, 'usage_metadata', {})} ) def stream_generate(self, messages: List[Dict[str, str]], **kwargs): """Generate streaming response using Mistral.""" for chunk in self.model.stream(messages): if chunk.content: yield chunk.content class OpenAIAdapter(BaseLLMAdapter): """Adapter for OpenAI models.""" def __init__(self, config: Dict[str, Any]): super().__init__(config) self.model = ChatOpenAI( model=config.get("model", "gpt-4o-mini") ) def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: """Generate response using OpenAI.""" response = self.model.invoke(messages) return LLMResponse( content=response.content, model=self.config.get("model", "gpt-4o-mini"), provider="openai", metadata={"usage": getattr(response, 'usage_metadata', {})} ) def stream_generate(self, messages: List[Dict[str, str]], **kwargs): """Generate streaming response using OpenAI.""" for chunk in self.model.stream(messages): if chunk.content: yield chunk.content class OllamaAdapter(BaseLLMAdapter): """Adapter for Ollama models.""" def __init__(self, config: Dict[str, Any]): super().__init__(config) self.model = ChatOllama( model=config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"), base_url=config.get("base_url", "http://localhost:11434/"), temperature=config.get("temperature", 0.8), num_predict=config.get("num_predict", 256) ) def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: """Generate response using Ollama.""" response = self.model.invoke(messages) return LLMResponse( content=response.content, model=self.config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"), provider="ollama", metadata={} ) def stream_generate(self, messages: List[Dict[str, str]], **kwargs): """Generate streaming response using Ollama.""" for chunk in self.model.stream(messages): if chunk.content: yield chunk.content class OpenRouterAdapter(BaseLLMAdapter): """Adapter for OpenRouter models.""" def __init__(self, config: Dict[str, Any]): super().__init__(config) # Prepare custom headers for OpenRouter (optional) headers = {} if config.get("site_url"): headers["HTTP-Referer"] = config["site_url"] if config.get("site_name"): headers["X-Title"] = config["site_name"] # Initialize ChatOpenAI with OpenRouter configuration self.model = ChatOpenAI( model=config.get("model", "openai/gpt-3.5-turbo"), api_key=config.get("api_key"), base_url=config.get("base_url", "https://openrouter.ai/api/v1"), default_headers= headers if headers else {}, temperature=config.get("temperature", 0.7), max_tokens=config.get("max_tokens", 1000) ) def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: """Generate response using OpenRouter.""" response = self.model.invoke(messages) return LLMResponse( content=response.content, model=self.config.get("model", "openai/gpt-3.5-turbo"), provider="openrouter", metadata={"usage": getattr(response, 'usage_metadata', {})} ) def stream_generate(self, messages: List[Dict[str, str]], **kwargs): """Generate streaming response using OpenRouter.""" for chunk in self.model.stream(messages): if chunk.content: yield chunk.content class LegacyAdapter(BaseLLMAdapter): """Adapter for legacy LLM clients (INF_PROVIDERS, NVIDIA, etc.).""" def __init__(self, config: Dict[str, Any], client_type: str): super().__init__(config) self.client_type = client_type self.client = self._create_client() def _create_client(self): """Create legacy client based on type.""" if self.client_type == "INF_PROVIDERS": return _create_inf_provider_client() elif self.client_type == "NVIDIA": return _create_nvidia_client() elif self.client_type == "DEDICATED": return _create_dedicated_endpoint_client() else: # SERVERLESS return _create_serverless_client() def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: """Generate response using legacy client.""" max_tokens = kwargs.get('max_tokens', self.config.get('max_tokens', 768)) if self.client_type == "INF_PROVIDERS": response = self.client.chat.completions.create( model=self.config.get("model"), messages=messages, max_tokens=max_tokens ) content = response.choices[0].message.content elif self.client_type == "NVIDIA": response = self.client.chat_completion( model=self.config.get("model"), messages=messages, max_tokens=max_tokens ) content = response.choices[0].message.content else: # DEDICATED or SERVERLESS response = self.client.chat_completion( messages=messages, max_tokens=max_tokens ) content = response.choices[0].message.content return LLMResponse( content=content, model=self.config.get("model", "unknown"), provider=self.client_type.lower(), metadata={} ) def stream_generate(self, messages: List[Dict[str, str]], **kwargs): """Generate streaming response using legacy client.""" # Legacy clients may not support streaming in the same way # This is a simplified implementation response = self.generate(messages, **kwargs) words = response.content.split() for word in words: yield word + " " class LLMRegistry: """Registry for managing different LLM adapters.""" def __init__(self): self.adapters = {} self.adapter_configs = {} def register_adapter(self, name: str, adapter_class: type, config: Dict[str, Any]): """Register an LLM adapter (lazy instantiation).""" self.adapter_configs[name] = (adapter_class, config) def get_adapter(self, name: str) -> BaseLLMAdapter: """Get an LLM adapter by name (lazy instantiation).""" if name not in self.adapter_configs: raise ValueError(f"Unknown LLM adapter: {name}") # Lazy instantiation - only create when needed if name not in self.adapters: adapter_class, config = self.adapter_configs[name] self.adapters[name] = adapter_class(config) return self.adapters[name] def list_adapters(self) -> List[str]: """List available adapter names.""" return list(self.adapter_configs.keys()) def create_llm_registry(config: Dict[str, Any]) -> LLMRegistry: """ Create and populate LLM registry from configuration. Args: config: Configuration dictionary Returns: Populated LLMRegistry """ registry = LLMRegistry() reader_config = config.get("reader", {}) # Register simple adapters if "MISTRAL" in reader_config: registry.register_adapter("mistral", MistralAdapter, reader_config["MISTRAL"]) if "OPENAI" in reader_config: registry.register_adapter("openai", OpenAIAdapter, reader_config["OPENAI"]) if "OLLAMA" in reader_config: registry.register_adapter("ollama", OllamaAdapter, reader_config["OLLAMA"]) if "OPENROUTER" in reader_config: registry.register_adapter("openrouter", OpenRouterAdapter, reader_config["OPENROUTER"]) # Register legacy adapters # legacy_types = ["INF_PROVIDERS", "NVIDIA", "DEDICATED"] legacy_types = ["INF_PROVIDERS"] for legacy_type in legacy_types: if legacy_type in reader_config: registry.register_adapter( legacy_type.lower(), lambda cfg, lt=legacy_type: LegacyAdapter(cfg, lt), reader_config[legacy_type] ) return registry def get_llm_client(provider: str, config: Dict[str, Any]) -> BaseLLMAdapter: """ Get LLM client for specified provider. Args: provider: Provider name (mistral, openai, ollama, etc.) config: Configuration dictionary Returns: LLM adapter instance """ registry = create_llm_registry(config) return registry.get_adapter(provider)