audit_assistant / src /llm /adapters.py
akryldigital's picture
remove extra adapters
8255eb2 verified
raw
history blame
13.3 kB
"""LLM client adapters for different providers."""
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional, Union
# from langchain_ollama import ChatOllama
from langchain_openai.chat_models import ChatOpenAI
# from langchain_mistralai.chat_models import ChatMistralAI
# Legacy dependencies
from huggingface_hub import InferenceClient
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.chat_models.huggingface import ChatHuggingFace
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from ..config.loader import load_config
_config = load_config()
# Legacy functions
def _create_nvidia_client():
"""Create NVIDIA client."""
reader_config = _config.get("reader", {})
nvidia_config = reader_config.get("NVIDIA", {})
api_key = nvidia_config.get("api_key")
if not api_key:
raise ValueError("NVIDIA api_key not found in configuration")
endpoint = nvidia_config.get("endpoint")
if not endpoint:
raise ValueError("NVIDIA endpoint not found in configuration")
return InferenceClient(
base_url=endpoint,
api_key=api_key
)
def _create_serverless_client():
"""Create serverless API client."""
reader_config = _config.get("reader", {})
serverless_config = reader_config.get("SERVERLESS", {})
api_key = serverless_config.get("api_key")
if not api_key:
raise ValueError("SERVERLESS api_key not found in configuration")
model_id = serverless_config.get("model", "meta-llama/Meta-Llama-3-8B-Instruct")
return InferenceClient(
model=model_id,
api_key=api_key,
)
def _create_dedicated_endpoint_client():
"""Create dedicated endpoint client."""
reader_config = _config.get("reader", {})
dedicated_config = reader_config.get("DEDICATED", {})
api_key = dedicated_config.get("api_key")
if not api_key:
raise ValueError("DEDICATED api_key not found in configuration")
endpoint = dedicated_config.get("endpoint")
if not endpoint:
raise ValueError("DEDICATED endpoint not found in configuration")
max_tokens = dedicated_config.get("max_tokens", 768)
# Set up the streaming callback handler
# callback = StreamingStdOutCallbackHandler()
# Initialize the HuggingFaceEndpoint with streaming enabled
llm_qa = HuggingFaceEndpoint(
endpoint_url=endpoint,
max_new_tokens=int(max_tokens),
repetition_penalty=1.03,
timeout=70,
huggingfacehub_api_token=api_key,
streaming=True
# callbacks=[callback]
)
return ChatHuggingFace(llm=llm_qa)
@dataclass
class LLMResponse:
"""Standardized LLM response format."""
content: str
model: str
provider: str
metadata: Dict[str, Any] = None
class BaseLLMAdapter(ABC):
"""Base class for LLM adapters."""
def __init__(self, config: Dict[str, Any]):
self.config = config
@abstractmethod
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
"""Generate response from messages."""
pass
@abstractmethod
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
"""Generate streaming response from messages."""
pass
# class MistralAdapter(BaseLLMAdapter):
# """Adapter for Mistral AI models."""
# def __init__(self, config: Dict[str, Any]):
# super().__init__(config)
# self.model = ChatMistralAI(
# model=config.get("model", "mistral-medium-latest")
# )
# def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
# """Generate response using Mistral."""
# response = self.model.invoke(messages)
# return LLMResponse(
# content=response.content,
# model=self.config.get("model", "mistral-medium-latest"),
# provider="mistral",
# metadata={"usage": getattr(response, 'usage_metadata', {})}
# )
# def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
# """Generate streaming response using Mistral."""
# for chunk in self.model.stream(messages):
# if chunk.content:
# yield chunk.content
class OpenAIAdapter(BaseLLMAdapter):
"""Adapter for OpenAI models."""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.model = ChatOpenAI(
model=config.get("model", "gpt-4o-mini")
)
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
"""Generate response using OpenAI."""
response = self.model.invoke(messages)
return LLMResponse(
content=response.content,
model=self.config.get("model", "gpt-4o-mini"),
provider="openai",
metadata={"usage": getattr(response, 'usage_metadata', {})}
)
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
"""Generate streaming response using OpenAI."""
for chunk in self.model.stream(messages):
if chunk.content:
yield chunk.content
# class OllamaAdapter(BaseLLMAdapter):
# """Adapter for Ollama models."""
# def __init__(self, config: Dict[str, Any]):
# super().__init__(config)
# self.model = ChatOllama(
# model=config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
# base_url=config.get("base_url", "http://localhost:11434/"),
# temperature=config.get("temperature", 0.8),
# num_predict=config.get("num_predict", 256)
# )
# def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
# """Generate response using Ollama."""
# response = self.model.invoke(messages)
# return LLMResponse(
# content=response.content,
# model=self.config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
# provider="ollama",
# metadata={}
# )
# def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
# """Generate streaming response using Ollama."""
# for chunk in self.model.stream(messages):
# if chunk.content:
# yield chunk.content
class OpenRouterAdapter(BaseLLMAdapter):
"""Adapter for OpenRouter models."""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
# Prepare custom headers for OpenRouter (optional)
headers = {}
if config.get("site_url"):
headers["HTTP-Referer"] = config["site_url"]
if config.get("site_name"):
headers["X-Title"] = config["site_name"]
# Initialize ChatOpenAI with OpenRouter configuration
self.model = ChatOpenAI(
model=config.get("model", "openai/gpt-3.5-turbo"),
api_key=config.get("api_key"),
base_url=config.get("base_url", "https://openrouter.ai/api/v1"),
default_headers= headers if headers else {},
temperature=config.get("temperature", 0.7),
max_tokens=config.get("max_tokens", 1000)
)
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
"""Generate response using OpenRouter."""
response = self.model.invoke(messages)
return LLMResponse(
content=response.content,
model=self.config.get("model", "openai/gpt-3.5-turbo"),
provider="openrouter",
metadata={"usage": getattr(response, 'usage_metadata', {})}
)
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
"""Generate streaming response using OpenRouter."""
for chunk in self.model.stream(messages):
if chunk.content:
yield chunk.content
class LegacyAdapter(BaseLLMAdapter):
"""Adapter for legacy LLM clients (INF_PROVIDERS, NVIDIA, etc.)."""
def __init__(self, config: Dict[str, Any], client_type: str):
super().__init__(config)
self.client_type = client_type
self.client = self._create_client()
def _create_client(self):
"""Create legacy client based on type."""
if self.client_type == "INF_PROVIDERS":
return _create_inf_provider_client()
elif self.client_type == "NVIDIA":
return _create_nvidia_client()
elif self.client_type == "DEDICATED":
return _create_dedicated_endpoint_client()
else: # SERVERLESS
return _create_serverless_client()
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
"""Generate response using legacy client."""
max_tokens = kwargs.get('max_tokens', self.config.get('max_tokens', 768))
if self.client_type == "INF_PROVIDERS":
response = self.client.chat.completions.create(
model=self.config.get("model"),
messages=messages,
max_tokens=max_tokens
)
content = response.choices[0].message.content
elif self.client_type == "NVIDIA":
response = self.client.chat_completion(
model=self.config.get("model"),
messages=messages,
max_tokens=max_tokens
)
content = response.choices[0].message.content
else: # DEDICATED or SERVERLESS
response = self.client.chat_completion(
messages=messages,
max_tokens=max_tokens
)
content = response.choices[0].message.content
return LLMResponse(
content=content,
model=self.config.get("model", "unknown"),
provider=self.client_type.lower(),
metadata={}
)
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
"""Generate streaming response using legacy client."""
# Legacy clients may not support streaming in the same way
# This is a simplified implementation
response = self.generate(messages, **kwargs)
words = response.content.split()
for word in words:
yield word + " "
class LLMRegistry:
"""Registry for managing different LLM adapters."""
def __init__(self):
self.adapters = {}
self.adapter_configs = {}
def register_adapter(self, name: str, adapter_class: type, config: Dict[str, Any]):
"""Register an LLM adapter (lazy instantiation)."""
self.adapter_configs[name] = (adapter_class, config)
def get_adapter(self, name: str) -> BaseLLMAdapter:
"""Get an LLM adapter by name (lazy instantiation)."""
if name not in self.adapter_configs:
raise ValueError(f"Unknown LLM adapter: {name}")
# Lazy instantiation - only create when needed
if name not in self.adapters:
adapter_class, config = self.adapter_configs[name]
self.adapters[name] = adapter_class(config)
return self.adapters[name]
def list_adapters(self) -> List[str]:
"""List available adapter names."""
return list(self.adapter_configs.keys())
def create_llm_registry(config: Dict[str, Any]) -> LLMRegistry:
"""
Create and populate LLM registry from configuration.
Args:
config: Configuration dictionary
Returns:
Populated LLMRegistry
"""
registry = LLMRegistry()
reader_config = config.get("reader", {})
# Register simple adapters
# if "MISTRAL" in reader_config:
# registry.register_adapter("mistral", MistralAdapter, reader_config["MISTRAL"])
if "OPENAI" in reader_config:
registry.register_adapter("openai", OpenAIAdapter, reader_config["OPENAI"])
# if "OLLAMA" in reader_config:
# registry.register_adapter("ollama", OllamaAdapter, reader_config["OLLAMA"])
# if "OPENROUTER" in reader_config:
# registry.register_adapter("openrouter", OpenRouterAdapter, reader_config["OPENROUTER"])
# Register legacy adapters
# legacy_types = ["INF_PROVIDERS", "NVIDIA", "DEDICATED"]
legacy_types = ["INF_PROVIDERS"]
for legacy_type in legacy_types:
if legacy_type in reader_config:
registry.register_adapter(
legacy_type.lower(),
lambda cfg, lt=legacy_type: LegacyAdapter(cfg, lt),
reader_config[legacy_type]
)
return registry
def get_llm_client(provider: str, config: Dict[str, Any]) -> BaseLLMAdapter:
"""
Get LLM client for specified provider.
Args:
provider: Provider name (mistral, openai, ollama, etc.)
config: Configuration dictionary
Returns:
LLM adapter instance
"""
registry = create_llm_registry(config)
return registry.get_adapter(provider)