audit_assistant / src /llm /adapters.py
akryldigital's picture
Pilot (#2)
92633a7 verified
raw
history blame
13.9 kB
"""LLM client adapters for different providers."""
from typing import Dict, Any, List, Optional, Union
from abc import ABC, abstractmethod
from dataclasses import dataclass
# LangChain imports
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_openai.chat_models import ChatOpenAI
from langchain_ollama import ChatOllama
# Legacy client dependencies
from huggingface_hub import InferenceClient
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.chat_models.huggingface import ChatHuggingFace
# Configuration loader
from ..config.loader import load_config
# Load configuration once at module level
_config = load_config()
# Legacy client factory functions (inlined from auditqa_old.reader)
def _create_inf_provider_client():
"""Create INF_PROVIDERS client."""
reader_config = _config.get("reader", {})
inf_config = reader_config.get("INF_PROVIDERS", {})
api_key = inf_config.get("api_key")
if not api_key:
raise ValueError("INF_PROVIDERS api_key not found in configuration")
provider = inf_config.get("provider")
if not provider:
raise ValueError("INF_PROVIDERS provider not found in configuration")
return InferenceClient(
provider=provider,
api_key=api_key,
bill_to="GIZ",
)
def _create_nvidia_client():
"""Create NVIDIA client."""
reader_config = _config.get("reader", {})
nvidia_config = reader_config.get("NVIDIA", {})
api_key = nvidia_config.get("api_key")
if not api_key:
raise ValueError("NVIDIA api_key not found in configuration")
endpoint = nvidia_config.get("endpoint")
if not endpoint:
raise ValueError("NVIDIA endpoint not found in configuration")
return InferenceClient(
base_url=endpoint,
api_key=api_key
)
def _create_serverless_client():
"""Create serverless API client."""
reader_config = _config.get("reader", {})
serverless_config = reader_config.get("SERVERLESS", {})
api_key = serverless_config.get("api_key")
if not api_key:
raise ValueError("SERVERLESS api_key not found in configuration")
model_id = serverless_config.get("model", "meta-llama/Meta-Llama-3-8B-Instruct")
return InferenceClient(
model=model_id,
api_key=api_key,
)
def _create_dedicated_endpoint_client():
"""Create dedicated endpoint client."""
reader_config = _config.get("reader", {})
dedicated_config = reader_config.get("DEDICATED", {})
api_key = dedicated_config.get("api_key")
if not api_key:
raise ValueError("DEDICATED api_key not found in configuration")
endpoint = dedicated_config.get("endpoint")
if not endpoint:
raise ValueError("DEDICATED endpoint not found in configuration")
max_tokens = dedicated_config.get("max_tokens", 768)
# Set up the streaming callback handler
callback = StreamingStdOutCallbackHandler()
# Initialize the HuggingFaceEndpoint with streaming enabled
llm_qa = HuggingFaceEndpoint(
endpoint_url=endpoint,
max_new_tokens=int(max_tokens),
repetition_penalty=1.03,
timeout=70,
huggingfacehub_api_token=api_key,
streaming=True,
callbacks=[callback]
)
# Create a ChatHuggingFace instance with the streaming-enabled endpoint
return ChatHuggingFace(llm=llm_qa)
@dataclass
class LLMResponse:
"""Standardized LLM response format."""
content: str
model: str
provider: str
metadata: Dict[str, Any] = None
class BaseLLMAdapter(ABC):
"""Base class for LLM adapters."""
def __init__(self, config: Dict[str, Any]):
self.config = config
@abstractmethod
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
"""Generate response from messages."""
pass
@abstractmethod
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
"""Generate streaming response from messages."""
pass
class MistralAdapter(BaseLLMAdapter):
"""Adapter for Mistral AI models."""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.model = ChatMistralAI(
model=config.get("model", "mistral-medium-latest")
)
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
"""Generate response using Mistral."""
response = self.model.invoke(messages)
return LLMResponse(
content=response.content,
model=self.config.get("model", "mistral-medium-latest"),
provider="mistral",
metadata={"usage": getattr(response, 'usage_metadata', {})}
)
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
"""Generate streaming response using Mistral."""
for chunk in self.model.stream(messages):
if chunk.content:
yield chunk.content
class OpenAIAdapter(BaseLLMAdapter):
"""Adapter for OpenAI models."""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.model = ChatOpenAI(
model=config.get("model", "gpt-4o-mini")
)
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
"""Generate response using OpenAI."""
response = self.model.invoke(messages)
return LLMResponse(
content=response.content,
model=self.config.get("model", "gpt-4o-mini"),
provider="openai",
metadata={"usage": getattr(response, 'usage_metadata', {})}
)
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
"""Generate streaming response using OpenAI."""
for chunk in self.model.stream(messages):
if chunk.content:
yield chunk.content
class OllamaAdapter(BaseLLMAdapter):
"""Adapter for Ollama models."""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.model = ChatOllama(
model=config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
base_url=config.get("base_url", "http://localhost:11434/"),
temperature=config.get("temperature", 0.8),
num_predict=config.get("num_predict", 256)
)
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
"""Generate response using Ollama."""
response = self.model.invoke(messages)
return LLMResponse(
content=response.content,
model=self.config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
provider="ollama",
metadata={}
)
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
"""Generate streaming response using Ollama."""
for chunk in self.model.stream(messages):
if chunk.content:
yield chunk.content
class OpenRouterAdapter(BaseLLMAdapter):
"""Adapter for OpenRouter models."""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
# Prepare custom headers for OpenRouter (optional)
headers = {}
if config.get("site_url"):
headers["HTTP-Referer"] = config["site_url"]
if config.get("site_name"):
headers["X-Title"] = config["site_name"]
# Initialize ChatOpenAI with OpenRouter configuration
self.model = ChatOpenAI(
model=config.get("model", "openai/gpt-3.5-turbo"),
api_key=config.get("api_key"),
base_url=config.get("base_url", "https://openrouter.ai/api/v1"),
default_headers= headers if headers else {},
temperature=config.get("temperature", 0.7),
max_tokens=config.get("max_tokens", 1000)
)
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
"""Generate response using OpenRouter."""
response = self.model.invoke(messages)
return LLMResponse(
content=response.content,
model=self.config.get("model", "openai/gpt-3.5-turbo"),
provider="openrouter",
metadata={"usage": getattr(response, 'usage_metadata', {})}
)
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
"""Generate streaming response using OpenRouter."""
for chunk in self.model.stream(messages):
if chunk.content:
yield chunk.content
class LegacyAdapter(BaseLLMAdapter):
"""Adapter for legacy LLM clients (INF_PROVIDERS, NVIDIA, etc.)."""
def __init__(self, config: Dict[str, Any], client_type: str):
super().__init__(config)
self.client_type = client_type
self.client = self._create_client()
def _create_client(self):
"""Create legacy client based on type."""
if self.client_type == "INF_PROVIDERS":
return _create_inf_provider_client()
elif self.client_type == "NVIDIA":
return _create_nvidia_client()
elif self.client_type == "DEDICATED":
return _create_dedicated_endpoint_client()
else: # SERVERLESS
return _create_serverless_client()
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
"""Generate response using legacy client."""
max_tokens = kwargs.get('max_tokens', self.config.get('max_tokens', 768))
if self.client_type == "INF_PROVIDERS":
response = self.client.chat.completions.create(
model=self.config.get("model"),
messages=messages,
max_tokens=max_tokens
)
content = response.choices[0].message.content
elif self.client_type == "NVIDIA":
response = self.client.chat_completion(
model=self.config.get("model"),
messages=messages,
max_tokens=max_tokens
)
content = response.choices[0].message.content
else: # DEDICATED or SERVERLESS
response = self.client.chat_completion(
messages=messages,
max_tokens=max_tokens
)
content = response.choices[0].message.content
return LLMResponse(
content=content,
model=self.config.get("model", "unknown"),
provider=self.client_type.lower(),
metadata={}
)
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
"""Generate streaming response using legacy client."""
# Legacy clients may not support streaming in the same way
# This is a simplified implementation
response = self.generate(messages, **kwargs)
words = response.content.split()
for word in words:
yield word + " "
class LLMRegistry:
"""Registry for managing different LLM adapters."""
def __init__(self):
self.adapters = {}
self.adapter_configs = {}
def register_adapter(self, name: str, adapter_class: type, config: Dict[str, Any]):
"""Register an LLM adapter (lazy instantiation)."""
self.adapter_configs[name] = (adapter_class, config)
def get_adapter(self, name: str) -> BaseLLMAdapter:
"""Get an LLM adapter by name (lazy instantiation)."""
if name not in self.adapter_configs:
raise ValueError(f"Unknown LLM adapter: {name}")
# Lazy instantiation - only create when needed
if name not in self.adapters:
adapter_class, config = self.adapter_configs[name]
self.adapters[name] = adapter_class(config)
return self.adapters[name]
def list_adapters(self) -> List[str]:
"""List available adapter names."""
return list(self.adapter_configs.keys())
def create_llm_registry(config: Dict[str, Any]) -> LLMRegistry:
"""
Create and populate LLM registry from configuration.
Args:
config: Configuration dictionary
Returns:
Populated LLMRegistry
"""
registry = LLMRegistry()
reader_config = config.get("reader", {})
# Register simple adapters
if "MISTRAL" in reader_config:
registry.register_adapter("mistral", MistralAdapter, reader_config["MISTRAL"])
if "OPENAI" in reader_config:
registry.register_adapter("openai", OpenAIAdapter, reader_config["OPENAI"])
if "OLLAMA" in reader_config:
registry.register_adapter("ollama", OllamaAdapter, reader_config["OLLAMA"])
if "OPENROUTER" in reader_config:
registry.register_adapter("openrouter", OpenRouterAdapter, reader_config["OPENROUTER"])
# Register legacy adapters
# legacy_types = ["INF_PROVIDERS", "NVIDIA", "DEDICATED"]
legacy_types = ["INF_PROVIDERS"]
for legacy_type in legacy_types:
if legacy_type in reader_config:
registry.register_adapter(
legacy_type.lower(),
lambda cfg, lt=legacy_type: LegacyAdapter(cfg, lt),
reader_config[legacy_type]
)
return registry
def get_llm_client(provider: str, config: Dict[str, Any]) -> BaseLLMAdapter:
"""
Get LLM client for specified provider.
Args:
provider: Provider name (mistral, openai, ollama, etc.)
config: Configuration dictionary
Returns:
LLM adapter instance
"""
registry = create_llm_registry(config)
return registry.get_adapter(provider)