Spaces:
Running
Running
| # DEPENDENCIES | |
| import os | |
| import json | |
| import time | |
| import aiohttp | |
| import requests | |
| from typing import List | |
| from typing import Dict | |
| from typing import Optional | |
| from typing import AsyncGenerator | |
| from config.models import LLMProvider | |
| from config.settings import get_settings | |
| from config.logging_config import get_logger | |
| from utils.error_handler import handle_errors | |
| from utils.error_handler import LLMClientError | |
| # Setup Settings and Logging | |
| settings = get_settings() | |
| logger = get_logger(__name__) | |
| class LLMClient: | |
| """ | |
| Unified LLM client supporting multiple providers (Ollama, OpenAI): Provides consistent interface for text generation across different LLM services | |
| """ | |
| def __init__(self, provider: LLMProvider = None, model_name: str = None, api_key: str = None, base_url: str = None): | |
| """ | |
| Initialize LLM client | |
| Arguments: | |
| ---------- | |
| provider { LLMProvider } : LLM provider to use | |
| model_name { str } : Model name to use | |
| api_key { str } : API key (for OpenAI) | |
| base_url { str } : Base URL for API (for Ollama) | |
| """ | |
| self.logger = logger | |
| self.settings = get_settings() | |
| self.provider = provider or LLMProvider.OLLAMA | |
| self.model_name = model_name or self.settings.OLLAMA_MODEL | |
| self.api_key = api_key | |
| self.base_url = base_url or self.settings.OLLAMA_BASE_URL | |
| self.timeout = self.settings.OLLAMA_TIMEOUT | |
| # Initialize provider-specific configurations | |
| self._initialize_provider() | |
| def _initialize_provider(self): | |
| """ | |
| Initialize provider-specific configurations | |
| """ | |
| # Auto-detect provider if not explicitly set | |
| if self.settings.IS_HF_SPACE: | |
| # In Hugging Face Space, we can't use Ollama | |
| if (self.provider == LLMProvider.OLLAMA): | |
| self.logger.warning("Ollama not available in Hugging Face Space. Switching provider.") | |
| # Try to use OpenAI if API key is available | |
| if (self.settings.USE_OPENAI and self.settings.OPENAI_API_KEY): | |
| self.provider = LLMProvider.OPENAI | |
| self.logger.info("HF Space detected: Using OpenAI API") | |
| else: | |
| # No OpenAI key either - create a dummy client or raise error | |
| raise LLMClientError("Running in Hugging Face Space without Ollama or OpenAI. Please set OPENAI_API_KEY in Space secrets or use a different provider.") | |
| # Provider initialization | |
| if (self.provider == LLMProvider.OLLAMA): | |
| if not self.base_url: | |
| raise LLMClientError("Ollama base URL is required") | |
| self.logger.info(f"Initialized Ollama client: {self.base_url}, model: {self.model_name}") | |
| elif (self.provider == LLMProvider.OPENAI): | |
| if not self.api_key: | |
| # Try to get from environment | |
| self.api_key = os.getenv('OPENAI_API_KEY') | |
| if not self.api_key: | |
| raise LLMClientError("OpenAI API key is required") | |
| self.base_url = "https://api.openai.com/v1" | |
| self.logger.info(f"Initialized OpenAI client, model: {self.model_name}") | |
| else: | |
| raise LLMClientError(f"Unsupported provider: {self.provider}") | |
| async def generate(self, messages: List[Dict], **generation_params) -> Dict: | |
| """ | |
| Generate text completion (async) | |
| Arguments: | |
| ---------- | |
| messages { list } : List of message dictionaries | |
| **generation_params : Generation parameters (temperature, max_tokens, etc.) | |
| Returns: | |
| -------- | |
| { dict } : Generation response | |
| """ | |
| try: | |
| if (self.provider == LLMProvider.OLLAMA): | |
| return await self._generate_ollama(messages, **generation_params) | |
| elif (self.provider == LLMProvider.OPENAI): | |
| return await self._generate_openai(messages, **generation_params) | |
| else: | |
| raise LLMClientError(f"Unsupported provider: {self.provider}") | |
| except Exception as e: | |
| self.logger.error(f"Generation failed: {repr(e)}") | |
| raise LLMClientError(f"Generation failed: {repr(e)}") | |
| async def generate_stream(self, messages: List[Dict], **generation_params) -> AsyncGenerator[str, None]: | |
| """ | |
| Generate text completion with streaming (async) | |
| Arguments: | |
| ---------- | |
| messages { list } : List of message dictionaries | |
| **generation_params : Generation parameters | |
| Returns: | |
| -------- | |
| { AsyncGenerator } : Async generator yielding response chunks | |
| """ | |
| try: | |
| if (self.provider == LLMProvider.OLLAMA): | |
| async for chunk in self._generate_ollama_stream(messages, **generation_params): | |
| yield chunk | |
| elif (self.provider == LLMProvider.OPENAI): | |
| async for chunk in self._generate_openai_stream(messages, **generation_params): | |
| yield chunk | |
| else: | |
| raise LLMClientError(f"Unsupported provider: {self.provider}") | |
| except Exception as e: | |
| self.logger.error(f"Stream generation failed: {repr(e)}") | |
| raise LLMClientError(f"Stream generation failed: {repr(e)}") | |
| async def _generate_ollama(self, messages: List[Dict], **generation_params) -> Dict: | |
| """ | |
| Generate using Ollama API | |
| """ | |
| url = f"{self.base_url}/api/chat" | |
| # Prepare request payload | |
| payload = {"model" : self.model_name, | |
| "messages" : messages, | |
| "stream" : False, | |
| "options" : {"temperature" : generation_params.get("temperature", 0.1), | |
| "top_p" : generation_params.get("top_p", 0.9), | |
| "num_predict" : generation_params.get("max_tokens", 1000), | |
| } | |
| } | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post(url, json = payload, timeout = self.timeout) as response: | |
| if (response.status != 200): | |
| error_text = await response.text() | |
| raise LLMClientError(f"Ollama API error: {response.status} - {error_text}") | |
| result = await response.json() | |
| return {"content" : result["message"]["content"], | |
| "usage" : {"prompt_tokens" : result.get("prompt_eval_count", 0), | |
| "completion_tokens" : result.get("eval_count", 0), | |
| "total_tokens" : result.get("prompt_eval_count", 0) + result.get("eval_count", 0), | |
| }, | |
| "finish_reason" : result.get("done_reason", "stop"), | |
| } | |
| async def _generate_ollama_stream(self, messages: List[Dict], **generation_params) -> AsyncGenerator[str, None]: | |
| """ | |
| Generate stream using Ollama API - FIXED VERSION | |
| """ | |
| url = f"{self.base_url}/api/chat" | |
| payload = {"model" : self.model_name, | |
| "messages" : messages, | |
| "stream" : True, | |
| "options" : {"temperature" : generation_params.get("temperature", 0.1), | |
| "top_p" : generation_params.get("top_p", 0.9), | |
| "num_predict" : generation_params.get("max_tokens", 1000), | |
| } | |
| } | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post(url, json = payload, timeout = self.timeout) as response: | |
| if (response.status != 200): | |
| error_text = await response.text() | |
| raise LLMClientError(f"Ollama API error: {response.status} - {error_text}") | |
| async for line in response.content: | |
| line_str = line.decode('utf-8').strip() | |
| # Skip empty lines | |
| if not line_str: | |
| continue | |
| try: | |
| chunk_data = json.loads(line_str) | |
| # Check if this is the final chunk | |
| if (chunk_data.get("done", False)): | |
| break | |
| # Extract content regardless of whether it's empty: Ollama sends incremental content in each chunk | |
| if ("message" in chunk_data): | |
| content = chunk_data["message"].get("content", "") | |
| # Only yield non-empty content | |
| if content: | |
| yield content | |
| except json.JSONDecodeError as e: | |
| self.logger.warning(f"Failed to parse streaming chunk: {line_str[:100]}") | |
| continue | |
| async def _generate_openai(self, messages: List[Dict], **generation_params) -> Dict: | |
| """ | |
| Generate using OpenAI API | |
| """ | |
| url = f"{self.base_url}/chat/completions" | |
| headers = {"Authorization" : f"Bearer {self.api_key}", | |
| "Content-Type" : "application/json", | |
| } | |
| payload = {"model" : self.model_name, | |
| "messages" : messages, | |
| "temperature" : generation_params.get("temperature", 0.1), | |
| "top_p" : generation_params.get("top_p", 0.9), | |
| "max_tokens" : generation_params.get("max_tokens", 1000), | |
| "stream" : False, | |
| } | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post(url, headers = headers, json = payload, timeout = self.timeout) as response: | |
| if (response.status != 200): | |
| error_text = await response.text() | |
| raise LLMClientError(f"OpenAI API error: {response.status} - {error_text}") | |
| result = await response.json() | |
| return {"content" : result["choices"][0]["message"]["content"], | |
| "usage" : result["usage"], | |
| "finish_reason" : result["choices"][0]["finish_reason"], | |
| } | |
| async def _generate_openai_stream(self, messages: List[Dict], **generation_params) -> AsyncGenerator[str, None]: | |
| """ | |
| Generate stream using OpenAI API | |
| """ | |
| url = f"{self.base_url}/chat/completions" | |
| headers = {"Authorization" : f"Bearer {self.api_key}", | |
| "Content-Type" : "application/json", | |
| } | |
| payload = {"model" : self.model_name, | |
| "messages" : messages, | |
| "temperature" : generation_params.get("temperature", 0.1), | |
| "top_p" : generation_params.get("top_p", 0.9), | |
| "max_tokens" : generation_params.get("max_tokens", 1000), | |
| "stream" : True, | |
| } | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post(url, headers = headers, json = payload, timeout = self.timeout) as response: | |
| if (response.status != 200): | |
| error_text = await response.text() | |
| raise LLMClientError(f"OpenAI API error: {response.status} - {error_text}") | |
| async for line in response.content: | |
| line = line.decode('utf-8').strip() | |
| if (line.startswith('data: ')): | |
| # Remove 'data: ' prefix | |
| data = line[6:] | |
| if (data == '[DONE]'): | |
| break | |
| try: | |
| chunk_data = json.loads(data) | |
| if ("choices" in chunk_data) and (chunk_data["choices"]): | |
| delta = chunk_data["choices"][0].get("delta", {}) | |
| if ("content" in delta): | |
| yield delta["content"] | |
| except json.JSONDecodeError: | |
| continue | |
| def check_health(self) -> bool: | |
| """ | |
| Check if LLM provider is healthy and accessible | |
| Returns: | |
| -------- | |
| { bool } : True if healthy | |
| """ | |
| try: | |
| if (self.provider == LLMProvider.OLLAMA): | |
| response = requests.get(f"{self.base_url}/api/tags", timeout = 30) | |
| return (response.status_code == 200) | |
| elif (self.provider == LLMProvider.OPENAI): | |
| # Simple models list check | |
| headers = {"Authorization": f"Bearer {self.api_key}"} | |
| response = requests.get(f"{self.base_url}/models", headers=headers, timeout=10) | |
| return (response.status_code == 200) | |
| return False | |
| except Exception as e: | |
| self.logger.warning(f"Health check failed: {repr(e)}") | |
| return False | |
| def get_provider_info(self) -> Dict: | |
| """ | |
| Get provider information | |
| Returns: | |
| -------- | |
| { dict } : Provider information | |
| """ | |
| return {"provider" : self.provider.value, | |
| "model" : self.model_name, | |
| "base_url" : self.base_url, | |
| "healthy" : self.check_health(), | |
| "timeout" : self.timeout, | |
| } | |
| # Global LLM client instance | |
| _llm_client = None | |
| def get_llm_client(provider: LLMProvider = None, **kwargs) -> LLMClient: | |
| """ | |
| Get global LLM client instance (singleton) | |
| Arguments: | |
| ---------- | |
| provider { LLMProvider } : LLM provider to use | |
| **kwargs : Additional client configuration | |
| Returns: | |
| -------- | |
| { LLMClient } : LLMClient instance | |
| """ | |
| global _llm_client | |
| if _llm_client is None or (provider and _llm_client.provider != provider): | |
| _llm_client = LLMClient(provider, **kwargs) | |
| return _llm_client | |
| async def generate_text(messages: List[Dict], provider: LLMProvider = LLMProvider.OLLAMA, **kwargs) -> str: | |
| """ | |
| Convenience function for text generation | |
| Arguments: | |
| ---------- | |
| messages { list } : List of message dictionaries | |
| provider { LLMProvider } : LLM provider to use | |
| **kwargs : Generation parameters | |
| Returns: | |
| -------- | |
| { str } : Generated text | |
| """ | |
| client = get_llm_client(provider, **kwargs) | |
| response = await client.generate(messages, **kwargs) | |
| return response["content"] | |