Spaces:
Sleeping
Sleeping
| """ | |
| Multi-LLM Handler with failover support | |
| Uses Groq, Gemini, and OpenAI with automatic failover for reliability | |
| """ | |
| import asyncio | |
| import re | |
| import time | |
| from typing import Optional, Dict, Any, List | |
| import os | |
| import requests | |
| import google.generativeai as genai | |
| import openai | |
| from dotenv import load_dotenv | |
| from config.config import get_provider_configs | |
| load_dotenv() | |
| class MultiLLMHandler: | |
| """Multi-LLM handler with automatic failover across providers.""" | |
| def __init__(self): | |
| """Initialize the multi-LLM handler with all available providers.""" | |
| self.providers = get_provider_configs() | |
| self.current_provider = None | |
| self.current_config = None | |
| # Initialize the first available provider (prefer Gemini/OpenAI for general RAG) | |
| self._initialize_provider() | |
| print(f"✅ Initialized Multi-LLM Handler with {self.provider.upper()}: {self.model_name}") | |
| def _initialize_provider(self): | |
| """Initialize the first available provider.""" | |
| # Prefer Gemini first for general text tasks | |
| if self.providers["gemini"]: | |
| self.current_provider = "gemini" | |
| self.current_config = self.providers["gemini"][0] | |
| genai.configure(api_key=self.current_config["api_key"]) | |
| # Then OpenAI | |
| elif self.providers["openai"]: | |
| self.current_provider = "openai" | |
| self.current_config = self.providers["openai"][0] | |
| openai.api_key = self.current_config["api_key"] | |
| # Finally Groq | |
| elif self.providers["groq"]: | |
| self.current_provider = "groq" | |
| self.current_config = self.providers["groq"][0] | |
| else: | |
| raise ValueError("No LLM providers available with valid API keys") | |
| def provider(self): | |
| """Get current provider name.""" | |
| return self.current_provider | |
| def model_name(self): | |
| """Get current model name.""" | |
| return self.current_config["model"] if self.current_config else "unknown" | |
| async def _call_groq(self, prompt: str, temperature: float, max_tokens: int) -> str: | |
| """Call Groq API.""" | |
| headers = { | |
| "Authorization": f"Bearer {self.current_config['api_key']}", | |
| "Content-Type": "application/json" | |
| } | |
| data = { | |
| "model": self.current_config["model"], | |
| "messages": [{"role": "user", "content": prompt}], | |
| "temperature": temperature, | |
| "max_tokens": max_tokens | |
| } | |
| # Hide reasoning tokens (e.g., <think>) for Qwen reasoning models | |
| try: | |
| model_name = (self.current_config.get("model") or "").lower() | |
| if "qwen" in model_name: | |
| # Per request, use the chat completion parameter to hide reasoning content | |
| data["reasoning_effort"] = "hidden" | |
| except Exception: | |
| # Be resilient if config shape changes | |
| pass | |
| response = requests.post( | |
| "https://api.groq.com/openai/v1/chat/completions", | |
| headers=headers, | |
| json=data, | |
| timeout=30 | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| text = result["choices"][0]["message"]["content"].strip() | |
| # Safety net: strip any <think>...</think> blocks if present | |
| try: | |
| model_name = (self.current_config.get("model") or "").lower() | |
| if "qwen" in model_name and "<think>" in text.lower(): | |
| text = re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE).strip() | |
| except Exception: | |
| pass | |
| return text | |
| async def _call_gemini(self, prompt: str, temperature: float, max_tokens: int) -> str: | |
| """Call Gemini API.""" | |
| model = genai.GenerativeModel(self.current_config["model"]) | |
| generation_config = genai.types.GenerationConfig( | |
| temperature=temperature, | |
| max_output_tokens=max_tokens | |
| ) | |
| response = await asyncio.to_thread( | |
| model.generate_content, | |
| prompt, | |
| generation_config=generation_config | |
| ) | |
| return response.text.strip() | |
| async def _call_openai(self, prompt: str, temperature: float, max_tokens: int) -> str: | |
| """Call OpenAI API.""" | |
| response = await asyncio.to_thread( | |
| openai.ChatCompletion.create, | |
| model=self.current_config["model"], | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=temperature, | |
| max_tokens=max_tokens | |
| ) | |
| return response.choices[0].message.content.strip() | |
| async def _try_with_failover(self, prompt: str, temperature: float, max_tokens: int) -> str: | |
| """Try to generate text with automatic failover.""" | |
| # Get all available providers in order | |
| provider_order = [] | |
| # Prefer Gemini -> OpenAI -> Groq for general text | |
| if self.providers["gemini"]: | |
| provider_order.extend([("gemini", config) for config in self.providers["gemini"]]) | |
| if self.providers["openai"]: | |
| provider_order.extend([("openai", config) for config in self.providers["openai"]]) | |
| if self.providers["groq"]: | |
| provider_order.extend([("groq", config) for config in self.providers["groq"]]) | |
| last_error = None | |
| for provider_name, config in provider_order: | |
| try: | |
| # Set current provider | |
| old_provider = self.current_provider | |
| old_config = self.current_config | |
| self.current_provider = provider_name | |
| self.current_config = config | |
| # Configure API if needed | |
| if provider_name == "gemini": | |
| genai.configure(api_key=config["api_key"]) | |
| elif provider_name == "openai": | |
| openai.api_key = config["api_key"] | |
| # Try the API call | |
| if provider_name == "groq": | |
| return await self._call_groq(prompt, temperature, max_tokens) | |
| elif provider_name == "gemini": | |
| return await self._call_gemini(prompt, temperature, max_tokens) | |
| elif provider_name == "openai": | |
| return await self._call_openai(prompt, temperature, max_tokens) | |
| except Exception as e: | |
| print(f"⚠️ {provider_name.upper()} ({config['name']}) failed: {str(e)}") | |
| last_error = e | |
| # Restore previous provider | |
| self.current_provider = old_provider | |
| self.current_config = old_config | |
| continue | |
| # If all providers failed | |
| raise RuntimeError(f"All LLM providers failed. Last error: {last_error}") | |
| async def generate_text(self, | |
| prompt: Optional[str] = None, | |
| system_prompt: Optional[str] = None, | |
| user_prompt: Optional[str] = None, | |
| temperature: Optional[float] = 0.4, | |
| max_tokens: Optional[int] = 1200) -> str: | |
| """Generate text using multi-LLM with failover.""" | |
| # Handle both single prompt and system/user prompt formats | |
| if prompt: | |
| final_prompt = prompt | |
| elif system_prompt and user_prompt: | |
| final_prompt = f"{system_prompt}\n\n{user_prompt}" | |
| elif user_prompt: | |
| final_prompt = user_prompt | |
| else: | |
| raise ValueError("Must provide either 'prompt' or 'user_prompt'") | |
| return await self._try_with_failover( | |
| final_prompt, | |
| temperature or 0.4, | |
| max_tokens or 1200 | |
| ) | |
| async def generate_simple(self, | |
| prompt: str, | |
| temperature: Optional[float] = 0.4, | |
| max_tokens: Optional[int] = 1200) -> str: | |
| """Simple text generation (alias for generate_text for compatibility).""" | |
| return await self.generate_text(prompt=prompt, temperature=temperature, max_tokens=max_tokens) | |
| def get_provider_info(self) -> Dict[str, Any]: | |
| """Get information about the current provider.""" | |
| return { | |
| "provider": self.current_provider, | |
| "model": self.model_name, | |
| "config_name": self.current_config["name"] if self.current_config else "none", | |
| "available_providers": { | |
| "groq": len(self.providers["groq"]), | |
| "gemini": len(self.providers["gemini"]), | |
| "openai": len(self.providers["openai"]) | |
| } | |
| } | |
| async def test_connection(self) -> bool: | |
| """Test the connection to the current LLM provider.""" | |
| try: | |
| test_prompt = "Say 'Hello' if you can read this." | |
| response = await self.generate_simple(test_prompt, temperature=0.1, max_tokens=10) | |
| return "hello" in response.lower() | |
| except Exception as e: | |
| print(f"❌ Connection test failed: {str(e)}") | |
| return False | |
| # Create a global instance | |
| llm_handler = MultiLLMHandler() | |