ShastraDocs2 / LLM /llm_handler.py
Rahul-Samedavar's picture
made onseshotter faster
8882944
"""
Unified LLM Handler with Multiple Provider Instances
Supports multiple Groq, Gemini, and OpenAI instances with different API keys.
"""
import asyncio
import time
from typing import Optional, Dict, Any, List
from dataclasses import dataclass
from enum import Enum
import openai
import google.generativeai as genai
from groq import Groq
from config.config import get_provider_configs, MAX_TOKENS, TEMPERATURE
class ProviderType(Enum):
"""Enum for LLM provider types."""
GROQ = "groq"
GEMINI = "gemini"
OPENAI = "openai"
@dataclass
class ProviderInstance:
"""Represents a single provider instance."""
provider_type: ProviderType
instance_name: str
client: Any
model: str
api_key: str
@dataclass
class ProviderStatus:
"""Tracks provider instance status and cooldown."""
is_available: bool = True
cooldown_until: float = 0.0
error_count: int = 0
last_success: float = 0.0
class UnifiedLLMHandler:
"""Unified handler supporting multiple instances of each LLM provider."""
def __init__(self):
"""Initialize the LLM handler with all configured provider instances."""
self.provider_instances: Dict[str, ProviderInstance] = {}
self.provider_status: Dict[str, ProviderStatus] = {}
self.cooldown_duration = 60.0 # 1 minute cooldown
# Priority order: Groq instances first, then Gemini, then OpenAI
self.provider_priority: List[str] = []
# Initialize all available providers
self._init_providers()
if not self.provider_instances:
raise ValueError("No LLM providers could be initialized. Check your configuration.")
print(f"✅ Initialized {len(self.provider_instances)} LLM provider instance(s)")
self._print_provider_summary()
def _init_providers(self):
"""Initialize all configured LLM provider instances."""
provider_configs = get_provider_configs()
# Initialize Groq instances
for groq_config in provider_configs.get('groq', []):
self._init_groq_instance(groq_config)
# Initialize Gemini instances
for gemini_config in provider_configs.get('gemini', []):
self._init_gemini_instance(gemini_config)
# Initialize OpenAI instances
for openai_config in provider_configs.get('openai', []):
self._init_openai_instance(openai_config)
def _init_groq_instance(self, config: Dict[str, Any]):
"""Initialize a Groq instance."""
try:
instance_name = f"groq_{config['name']}"
client = Groq(api_key=config['api_key'])
instance = ProviderInstance(
provider_type=ProviderType.GROQ,
instance_name=instance_name,
client=client,
model=config['model'],
api_key=config['api_key'][:8] + "..." # Store truncated key for logging
)
self.provider_instances[instance_name] = instance
self.provider_status[instance_name] = ProviderStatus()
self.provider_priority.append(instance_name)
print(f"✅ Initialized {instance_name}: {config['model']}")
except Exception as e:
print(f"❌ Failed to initialize Groq instance '{config['name']}': {e}")
def _init_gemini_instance(self, config: Dict[str, Any]):
"""Initialize a Gemini instance."""
try:
instance_name = f"gemini_{config['name']}"
# Create a separate genai configuration for this instance
# Note: genai.configure is global, so we'll handle this carefully
client = genai.GenerativeModel(config['model'])
instance = ProviderInstance(
provider_type=ProviderType.GEMINI,
instance_name=instance_name,
client=client,
model=config['model'],
api_key=config['api_key'][:8] + "..."
)
# Store the API key for later use
instance.full_api_key = config['api_key']
self.provider_instances[instance_name] = instance
self.provider_status[instance_name] = ProviderStatus()
self.provider_priority.append(instance_name)
print(f"✅ Initialized {instance_name}: {config['model']}")
except Exception as e:
print(f"❌ Failed to initialize Gemini instance '{config['name']}': {e}")
def _init_openai_instance(self, config: Dict[str, Any]):
"""Initialize an OpenAI instance."""
try:
instance_name = f"openai_{config['name']}"
# Create OpenAI client instance
client = openai.OpenAI(api_key=config['api_key'])
instance = ProviderInstance(
provider_type=ProviderType.OPENAI,
instance_name=instance_name,
client=client,
model=config['model'],
api_key=config['api_key'][:8] + "..."
)
self.provider_instances[instance_name] = instance
self.provider_status[instance_name] = ProviderStatus()
self.provider_priority.append(instance_name)
print(f"✅ Initialized {instance_name}: {config['model']}")
except Exception as e:
print(f"❌ Failed to initialize OpenAI instance '{config['name']}': {e}")
def _print_provider_summary(self):
"""Print a summary of initialized providers."""
provider_counts = {}
for instance in self.provider_instances.values():
provider_type = instance.provider_type.value
provider_counts[provider_type] = provider_counts.get(provider_type, 0) + 1
print("📊 Provider Summary:")
for provider_type, count in provider_counts.items():
print(f" {provider_type.upper()}: {count} instance(s)")
def _get_available_provider(self) -> Optional[str]:
"""Get the next available provider instance based on priority and cooldowns."""
current_time = time.time()
# Check each provider in priority order
for instance_name in self.provider_priority:
if instance_name not in self.provider_instances:
continue
status = self.provider_status[instance_name]
# Check if cooldown has expired
if not status.is_available and current_time >= status.cooldown_until:
status.is_available = True
status.error_count = 0
print(f"🔄 {instance_name} cooldown expired, marking as available")
# Return first available provider
if status.is_available:
return instance_name
return None
def _handle_rate_limit(self, instance_name: str, error: Exception):
"""Handle rate limit by putting provider instance on cooldown."""
current_time = time.time()
status = self.provider_status[instance_name]
# Check if this is a rate limit error
error_str = str(error).lower()
is_rate_limit = any(keyword in error_str for keyword in [
'rate limit', 'quota', 'too many requests', '429', 'ratelimit'
])
if is_rate_limit:
status.is_available = False
status.cooldown_until = current_time + self.cooldown_duration
status.error_count += 1
print(f"⏳ {instance_name} hit rate limit. Cooldown until {time.strftime('%H:%M:%S', time.localtime(status.cooldown_until))}")
return True
return False
async def generate_text(self,
system_prompt: str,
user_prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
reasoning_format: str = "hidden") -> Dict[str, Any]:
"""
Generate text using available LLM provider instances with automatic fallback.
Args:
system_prompt: System instruction for the LLM
user_prompt: User query/prompt
temperature: Sampling temperature (uses config default if None)
max_tokens: Maximum tokens to generate (uses config default if None)
reasoning_format: For reasoning models - "hidden" (default), "raw", or "parsed"
Returns:
Dictionary with 'text', 'provider', 'instance', and 'model' keys
"""
temp = temperature if temperature is not None else TEMPERATURE
max_tok = max_tokens if max_tokens is not None else MAX_TOKENS
last_error = None
# Try each available provider instance
for attempt in range(len(self.provider_instances)):
instance_name = self._get_available_provider()
if instance_name is None:
# All providers are on cooldown
available_times = [
status.cooldown_until
for status in self.provider_status.values()
if not status.is_available
]
if available_times:
min_cooldown = min(available_times) - time.time()
if min_cooldown > 0:
print(f"⏳ All providers on cooldown. Waiting {min_cooldown:.1f}s for next available...")
await asyncio.sleep(min(min_cooldown, 5)) # Wait max 5 seconds
continue
break
instance = self.provider_instances[instance_name]
try:
print(f"🚀 Attempting generation with {instance_name} ({instance.model})")
if instance.provider_type == ProviderType.GROQ:
result = await self._generate_groq(instance, system_prompt, user_prompt, temp, max_tok, reasoning_format)
elif instance.provider_type == ProviderType.GEMINI:
result = await self._generate_gemini(instance, system_prompt, user_prompt, temp, max_tok)
elif instance.provider_type == ProviderType.OPENAI:
result = await self._generate_openai(instance, system_prompt, user_prompt, temp, max_tok)
# Mark success
self.provider_status[instance_name].last_success = time.time()
return result, instance.provider_type.value, instance_name
except Exception as e:
print(f"❌ {instance_name} error: {e}")
last_error = e
# Handle rate limiting
if self._handle_rate_limit(instance_name, e):
continue # Try next provider
else:
# Non-rate-limit error, still try next provider but with a short delay
await asyncio.sleep(1)
continue
# If we get here, all providers failed
raise Exception(f"All LLM provider instances failed. Last error: {last_error}")
async def _generate_groq(self, instance: ProviderInstance, system_prompt: str, user_prompt: str,
temperature: float, max_tokens: int, reasoning_format: str = "hidden") -> str:
"""Generate text using Groq with reasoning format support."""
# Prepare the request parameters
request_params = {
"model": instance.model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"temperature": temperature,
"max_tokens": max_tokens
}
# Add reasoning_format for reasoning models
reasoning_models = [
"qwen/qwen3-32b"
]
if any(model in instance.model.lower() for model in reasoning_models):
request_params["reasoning_format"] = reasoning_format
print(f"🧠 Using reasoning format: {reasoning_format}")
response = await asyncio.to_thread(
instance.client.chat.completions.create,
**request_params
)
# Handle different response formats
if hasattr(response.choices[0].message, 'reasoning') and reasoning_format == "parsed":
# For parsed format, reasoning is in a separate field
reasoning = response.choices[0].message.reasoning or ""
content = response.choices[0].message.content or ""
if reasoning and reasoning_format != "hidden":
return f"Reasoning: {reasoning}\n\nAnswer: {content}"
else:
return content
else:
# For raw and hidden formats, content contains everything or just the answer
return response.choices[0].message.content
async def _generate_gemini(self, instance: ProviderInstance, system_prompt: str, user_prompt: str,
temperature: float, max_tokens: int) -> str:
"""Generate text using Gemini."""
# Configure API key for this specific request
genai.configure(api_key=instance.full_api_key)
# Combine system and user prompts for Gemini
combined_prompt = f"{system_prompt}\n\nUser Query: {user_prompt}"
# Configure generation parameters
generation_config = genai.types.GenerationConfig(
temperature=temperature,
max_output_tokens=max_tokens,
)
# Generate response
response = await asyncio.to_thread(
instance.client.generate_content,
combined_prompt,
generation_config=generation_config
)
return response.text
async def _generate_openai(self, instance: ProviderInstance, system_prompt: str, user_prompt: str,
temperature: float, max_tokens: int) -> str:
"""Generate text using OpenAI."""
response = await asyncio.to_thread(
instance.client.chat.completions.create,
model=instance.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=temperature,
max_tokens=max_tokens
)
return response.choices[0].message.content
async def generate_simple(self, prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
reasoning_format: str = "hidden") -> Dict[str, Any]:
"""
Generate text with a simple prompt (no system message).
Args:
prompt: The prompt text
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
reasoning_format: For reasoning models - "hidden" (default), "raw", or "parsed"
Returns:
Dictionary with 'text', 'provider', 'instance', and 'model' keys
"""
return await self.generate_text("", prompt, temperature, max_tokens, reasoning_format)
def get_provider_status(self) -> Dict[str, Dict[str, Any]]:
"""Get status information for all provider instances."""
current_time = time.time()
status_info = {}
for instance_name, instance in self.provider_instances.items():
status = self.provider_status[instance_name]
cooldown_remaining = max(0, status.cooldown_until - current_time)
status_info[instance_name] = {
"provider_type": instance.provider_type.value,
"model": instance.model,
"available": status.is_available,
"cooldown_remaining_seconds": cooldown_remaining,
"error_count": status.error_count,
"last_success": status.last_success,
"last_success_ago": current_time - status.last_success if status.last_success > 0 else None,
"api_key": instance.api_key # Truncated version
}
return status_info
def reset_cooldowns(self):
"""Reset all provider instance cooldowns."""
for status in self.provider_status.values():
status.is_available = True
status.cooldown_until = 0.0
status.error_count = 0
print("🔄 All provider instance cooldowns reset")
def get_provider_info(self) -> Dict[str, Any]:
"""Get information about all configured provider instances."""
provider_summary = {}
for instance_name, instance in self.provider_instances.items():
provider_type = instance.provider_type.value
if provider_type not in provider_summary:
provider_summary[provider_type] = []
provider_summary[provider_type].append({
"instance_name": instance_name,
"model": instance.model,
"api_key": instance.api_key
})
return {
"total_instances": len(self.provider_instances),
"provider_summary": provider_summary,
"provider_priority": self.provider_priority,
"cooldown_duration_seconds": self.cooldown_duration,
"max_tokens": MAX_TOKENS,
"temperature": TEMPERATURE,
"provider_status": self.get_provider_status()
}
# Global instance
llm_handler = UnifiedLLMHandler()