Spaces:
Sleeping
Sleeping
| """LLM client adapters for different providers.""" | |
| from typing import Dict, Any, List, Optional, Union | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| # LangChain imports | |
| from langchain_mistralai.chat_models import ChatMistralAI | |
| from langchain_openai.chat_models import ChatOpenAI | |
| from langchain_ollama import ChatOllama | |
| # Legacy client dependencies | |
| from huggingface_hub import InferenceClient | |
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
| from langchain_community.llms import HuggingFaceEndpoint | |
| from langchain_community.chat_models.huggingface import ChatHuggingFace | |
| # Configuration loader | |
| from ..config.loader import load_config | |
| # Load configuration once at module level | |
| _config = load_config() | |
| # Legacy client factory functions (inlined from auditqa_old.reader) | |
| def _create_inf_provider_client(): | |
| """Create INF_PROVIDERS client.""" | |
| reader_config = _config.get("reader", {}) | |
| inf_config = reader_config.get("INF_PROVIDERS", {}) | |
| api_key = inf_config.get("api_key") | |
| if not api_key: | |
| raise ValueError("INF_PROVIDERS api_key not found in configuration") | |
| provider = inf_config.get("provider") | |
| if not provider: | |
| raise ValueError("INF_PROVIDERS provider not found in configuration") | |
| return InferenceClient( | |
| provider=provider, | |
| api_key=api_key, | |
| bill_to="GIZ", | |
| ) | |
| def _create_nvidia_client(): | |
| """Create NVIDIA client.""" | |
| reader_config = _config.get("reader", {}) | |
| nvidia_config = reader_config.get("NVIDIA", {}) | |
| api_key = nvidia_config.get("api_key") | |
| if not api_key: | |
| raise ValueError("NVIDIA api_key not found in configuration") | |
| endpoint = nvidia_config.get("endpoint") | |
| if not endpoint: | |
| raise ValueError("NVIDIA endpoint not found in configuration") | |
| return InferenceClient( | |
| base_url=endpoint, | |
| api_key=api_key | |
| ) | |
| def _create_serverless_client(): | |
| """Create serverless API client.""" | |
| reader_config = _config.get("reader", {}) | |
| serverless_config = reader_config.get("SERVERLESS", {}) | |
| api_key = serverless_config.get("api_key") | |
| if not api_key: | |
| raise ValueError("SERVERLESS api_key not found in configuration") | |
| model_id = serverless_config.get("model", "meta-llama/Meta-Llama-3-8B-Instruct") | |
| return InferenceClient( | |
| model=model_id, | |
| api_key=api_key, | |
| ) | |
| def _create_dedicated_endpoint_client(): | |
| """Create dedicated endpoint client.""" | |
| reader_config = _config.get("reader", {}) | |
| dedicated_config = reader_config.get("DEDICATED", {}) | |
| api_key = dedicated_config.get("api_key") | |
| if not api_key: | |
| raise ValueError("DEDICATED api_key not found in configuration") | |
| endpoint = dedicated_config.get("endpoint") | |
| if not endpoint: | |
| raise ValueError("DEDICATED endpoint not found in configuration") | |
| max_tokens = dedicated_config.get("max_tokens", 768) | |
| # Set up the streaming callback handler | |
| callback = StreamingStdOutCallbackHandler() | |
| # Initialize the HuggingFaceEndpoint with streaming enabled | |
| llm_qa = HuggingFaceEndpoint( | |
| endpoint_url=endpoint, | |
| max_new_tokens=int(max_tokens), | |
| repetition_penalty=1.03, | |
| timeout=70, | |
| huggingfacehub_api_token=api_key, | |
| streaming=True, | |
| callbacks=[callback] | |
| ) | |
| # Create a ChatHuggingFace instance with the streaming-enabled endpoint | |
| return ChatHuggingFace(llm=llm_qa) | |
| class LLMResponse: | |
| """Standardized LLM response format.""" | |
| content: str | |
| model: str | |
| provider: str | |
| metadata: Dict[str, Any] = None | |
| class BaseLLMAdapter(ABC): | |
| """Base class for LLM adapters.""" | |
| def __init__(self, config: Dict[str, Any]): | |
| self.config = config | |
| def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: | |
| """Generate response from messages.""" | |
| pass | |
| def stream_generate(self, messages: List[Dict[str, str]], **kwargs): | |
| """Generate streaming response from messages.""" | |
| pass | |
| class MistralAdapter(BaseLLMAdapter): | |
| """Adapter for Mistral AI models.""" | |
| def __init__(self, config: Dict[str, Any]): | |
| super().__init__(config) | |
| self.model = ChatMistralAI( | |
| model=config.get("model", "mistral-medium-latest") | |
| ) | |
| def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: | |
| """Generate response using Mistral.""" | |
| response = self.model.invoke(messages) | |
| return LLMResponse( | |
| content=response.content, | |
| model=self.config.get("model", "mistral-medium-latest"), | |
| provider="mistral", | |
| metadata={"usage": getattr(response, 'usage_metadata', {})} | |
| ) | |
| def stream_generate(self, messages: List[Dict[str, str]], **kwargs): | |
| """Generate streaming response using Mistral.""" | |
| for chunk in self.model.stream(messages): | |
| if chunk.content: | |
| yield chunk.content | |
| class OpenAIAdapter(BaseLLMAdapter): | |
| """Adapter for OpenAI models.""" | |
| def __init__(self, config: Dict[str, Any]): | |
| super().__init__(config) | |
| self.model = ChatOpenAI( | |
| model=config.get("model", "gpt-4o-mini") | |
| ) | |
| def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: | |
| """Generate response using OpenAI.""" | |
| response = self.model.invoke(messages) | |
| return LLMResponse( | |
| content=response.content, | |
| model=self.config.get("model", "gpt-4o-mini"), | |
| provider="openai", | |
| metadata={"usage": getattr(response, 'usage_metadata', {})} | |
| ) | |
| def stream_generate(self, messages: List[Dict[str, str]], **kwargs): | |
| """Generate streaming response using OpenAI.""" | |
| for chunk in self.model.stream(messages): | |
| if chunk.content: | |
| yield chunk.content | |
| class OllamaAdapter(BaseLLMAdapter): | |
| """Adapter for Ollama models.""" | |
| def __init__(self, config: Dict[str, Any]): | |
| super().__init__(config) | |
| self.model = ChatOllama( | |
| model=config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"), | |
| base_url=config.get("base_url", "http://localhost:11434/"), | |
| temperature=config.get("temperature", 0.8), | |
| num_predict=config.get("num_predict", 256) | |
| ) | |
| def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: | |
| """Generate response using Ollama.""" | |
| response = self.model.invoke(messages) | |
| return LLMResponse( | |
| content=response.content, | |
| model=self.config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"), | |
| provider="ollama", | |
| metadata={} | |
| ) | |
| def stream_generate(self, messages: List[Dict[str, str]], **kwargs): | |
| """Generate streaming response using Ollama.""" | |
| for chunk in self.model.stream(messages): | |
| if chunk.content: | |
| yield chunk.content | |
| class OpenRouterAdapter(BaseLLMAdapter): | |
| """Adapter for OpenRouter models.""" | |
| def __init__(self, config: Dict[str, Any]): | |
| super().__init__(config) | |
| # Prepare custom headers for OpenRouter (optional) | |
| headers = {} | |
| if config.get("site_url"): | |
| headers["HTTP-Referer"] = config["site_url"] | |
| if config.get("site_name"): | |
| headers["X-Title"] = config["site_name"] | |
| # Initialize ChatOpenAI with OpenRouter configuration | |
| self.model = ChatOpenAI( | |
| model=config.get("model", "openai/gpt-3.5-turbo"), | |
| api_key=config.get("api_key"), | |
| base_url=config.get("base_url", "https://openrouter.ai/api/v1"), | |
| default_headers= headers if headers else {}, | |
| temperature=config.get("temperature", 0.7), | |
| max_tokens=config.get("max_tokens", 1000) | |
| ) | |
| def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: | |
| """Generate response using OpenRouter.""" | |
| response = self.model.invoke(messages) | |
| return LLMResponse( | |
| content=response.content, | |
| model=self.config.get("model", "openai/gpt-3.5-turbo"), | |
| provider="openrouter", | |
| metadata={"usage": getattr(response, 'usage_metadata', {})} | |
| ) | |
| def stream_generate(self, messages: List[Dict[str, str]], **kwargs): | |
| """Generate streaming response using OpenRouter.""" | |
| for chunk in self.model.stream(messages): | |
| if chunk.content: | |
| yield chunk.content | |
| class LegacyAdapter(BaseLLMAdapter): | |
| """Adapter for legacy LLM clients (INF_PROVIDERS, NVIDIA, etc.).""" | |
| def __init__(self, config: Dict[str, Any], client_type: str): | |
| super().__init__(config) | |
| self.client_type = client_type | |
| self.client = self._create_client() | |
| def _create_client(self): | |
| """Create legacy client based on type.""" | |
| if self.client_type == "INF_PROVIDERS": | |
| return _create_inf_provider_client() | |
| elif self.client_type == "NVIDIA": | |
| return _create_nvidia_client() | |
| elif self.client_type == "DEDICATED": | |
| return _create_dedicated_endpoint_client() | |
| else: # SERVERLESS | |
| return _create_serverless_client() | |
| def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: | |
| """Generate response using legacy client.""" | |
| max_tokens = kwargs.get('max_tokens', self.config.get('max_tokens', 768)) | |
| if self.client_type == "INF_PROVIDERS": | |
| response = self.client.chat.completions.create( | |
| model=self.config.get("model"), | |
| messages=messages, | |
| max_tokens=max_tokens | |
| ) | |
| content = response.choices[0].message.content | |
| elif self.client_type == "NVIDIA": | |
| response = self.client.chat_completion( | |
| model=self.config.get("model"), | |
| messages=messages, | |
| max_tokens=max_tokens | |
| ) | |
| content = response.choices[0].message.content | |
| else: # DEDICATED or SERVERLESS | |
| response = self.client.chat_completion( | |
| messages=messages, | |
| max_tokens=max_tokens | |
| ) | |
| content = response.choices[0].message.content | |
| return LLMResponse( | |
| content=content, | |
| model=self.config.get("model", "unknown"), | |
| provider=self.client_type.lower(), | |
| metadata={} | |
| ) | |
| def stream_generate(self, messages: List[Dict[str, str]], **kwargs): | |
| """Generate streaming response using legacy client.""" | |
| # Legacy clients may not support streaming in the same way | |
| # This is a simplified implementation | |
| response = self.generate(messages, **kwargs) | |
| words = response.content.split() | |
| for word in words: | |
| yield word + " " | |
| class LLMRegistry: | |
| """Registry for managing different LLM adapters.""" | |
| def __init__(self): | |
| self.adapters = {} | |
| self.adapter_configs = {} | |
| def register_adapter(self, name: str, adapter_class: type, config: Dict[str, Any]): | |
| """Register an LLM adapter (lazy instantiation).""" | |
| self.adapter_configs[name] = (adapter_class, config) | |
| def get_adapter(self, name: str) -> BaseLLMAdapter: | |
| """Get an LLM adapter by name (lazy instantiation).""" | |
| if name not in self.adapter_configs: | |
| raise ValueError(f"Unknown LLM adapter: {name}") | |
| # Lazy instantiation - only create when needed | |
| if name not in self.adapters: | |
| adapter_class, config = self.adapter_configs[name] | |
| self.adapters[name] = adapter_class(config) | |
| return self.adapters[name] | |
| def list_adapters(self) -> List[str]: | |
| """List available adapter names.""" | |
| return list(self.adapter_configs.keys()) | |
| def create_llm_registry(config: Dict[str, Any]) -> LLMRegistry: | |
| """ | |
| Create and populate LLM registry from configuration. | |
| Args: | |
| config: Configuration dictionary | |
| Returns: | |
| Populated LLMRegistry | |
| """ | |
| registry = LLMRegistry() | |
| reader_config = config.get("reader", {}) | |
| # Register simple adapters | |
| if "MISTRAL" in reader_config: | |
| registry.register_adapter("mistral", MistralAdapter, reader_config["MISTRAL"]) | |
| if "OPENAI" in reader_config: | |
| registry.register_adapter("openai", OpenAIAdapter, reader_config["OPENAI"]) | |
| if "OLLAMA" in reader_config: | |
| registry.register_adapter("ollama", OllamaAdapter, reader_config["OLLAMA"]) | |
| if "OPENROUTER" in reader_config: | |
| registry.register_adapter("openrouter", OpenRouterAdapter, reader_config["OPENROUTER"]) | |
| # Register legacy adapters | |
| # legacy_types = ["INF_PROVIDERS", "NVIDIA", "DEDICATED"] | |
| legacy_types = ["INF_PROVIDERS"] | |
| for legacy_type in legacy_types: | |
| if legacy_type in reader_config: | |
| registry.register_adapter( | |
| legacy_type.lower(), | |
| lambda cfg, lt=legacy_type: LegacyAdapter(cfg, lt), | |
| reader_config[legacy_type] | |
| ) | |
| return registry | |
| def get_llm_client(provider: str, config: Dict[str, Any]) -> BaseLLMAdapter: | |
| """ | |
| Get LLM client for specified provider. | |
| Args: | |
| provider: Provider name (mistral, openai, ollama, etc.) | |
| config: Configuration dictionary | |
| Returns: | |
| LLM adapter instance | |
| """ | |
| registry = create_llm_registry(config) | |
| return registry.get_adapter(provider) | |