Spaces:
Sleeping
Sleeping
File size: 8,415 Bytes
0c591a7 494ea7a de66f3b 0c591a7 3e0da39 0c591a7 3e0da39 0c591a7 3e0da39 0c591a7 3e0da39 0c591a7 3e0da39 0c591a7 3e0da39 87e16be 0c591a7 3e0da39 0c591a7 3e0da39 6a70bab 0c591a7 494ea7a 0c591a7 494ea7a de66f3b 0c591a7 494ea7a 0c591a7 494ea7a 0c591a7 494ea7a 0c591a7 494ea7a 0c591a7 3e0da39 0c591a7 3e0da39 0c591a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
"""
Multi-provider LLM client with cascading fallback.
Adopts pattern from Enterprise-AI-Gateway for resilient LLM access.
"""
import os
import time
import requests
from typing import Optional, Tuple
# Retry configuration - rotate through providers instead of consecutive retries
MAX_ROUNDS = 3 # Number of times to cycle through all providers
PROVIDER_DELAY = 10 # seconds between provider attempts
class LLMClient:
"""LLM client with automatic provider fallback."""
def __init__(self, override_keys: dict = None):
"""Initialize client with available providers based on API keys.
Args:
override_keys: Optional dict with user-provided API keys.
Keys: "groq", "gemini", "openrouter"
"""
self.providers = []
override_keys = override_keys or {}
# Build providers list - use override keys if provided, else env vars
groq_key = override_keys.get("groq") or os.getenv("GROQ_API_KEY")
gemini_key = override_keys.get("gemini") or os.getenv("GEMINI_API_KEY")
openrouter_key = override_keys.get("openrouter") or os.getenv("OPENROUTER_API_KEY")
if groq_key:
self.providers.append({
"name": "groq",
"key": groq_key,
"model": os.getenv("GROQ_MODEL", "llama-3.1-8b-instant"),
"url": "https://api.groq.com/openai/v1/chat/completions"
})
if gemini_key:
self.providers.append({
"name": "gemini",
"key": gemini_key,
"model": os.getenv("GEMINI_MODEL", "gemini-2.0-flash")
})
if openrouter_key:
self.providers.append({
"name": "openrouter",
"key": openrouter_key,
"model": os.getenv("OPENROUTER_MODEL", "google/gemma-3n-e4b-it:free"),
"url": "https://openrouter.ai/api/v1/chat/completions"
})
if not self.providers:
raise ValueError("No LLM API keys configured. Set at least one of: GROQ_API_KEY, GEMINI_API_KEY, OPENROUTER_API_KEY")
def query(self, prompt: str, temperature: float = 0, max_tokens: int = 2048) -> Tuple[Optional[str], Optional[str], Optional[str], list]:
"""
Query LLM with rotating fallback across providers.
Instead of retrying same provider consecutively, rotates:
Groq β Gemini β OpenRouter β Groq β Gemini β OpenRouter β ...
Returns:
Tuple of (response_content, provider_used, error_message, providers_failed)
providers_failed is a list of dicts: [{"name": "gemini", "error": "..."}]
"""
errors = []
providers_failed = []
is_first_attempt = True
# Rotate through providers for MAX_ROUNDS cycles
for round_num in range(MAX_ROUNDS):
for provider in self.providers:
# Add delay between attempts (skip first attempt)
if not is_first_attempt:
print(f"Waiting {PROVIDER_DELAY}s before trying {provider['name']} (round {round_num + 1})...")
time.sleep(PROVIDER_DELAY)
is_first_attempt = False
print(f"Attempting LLM call with {provider['name']} (round {round_num + 1}/{MAX_ROUNDS})...")
start_time = time.perf_counter()
try:
content, error = self._call_provider(
provider=provider,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens
)
latency_ms = int((time.perf_counter() - start_time) * 1000)
if content:
print(f"Success with {provider['name']} ({latency_ms}ms)")
provider_info = f"{provider['name']}:{provider['model']}"
return content, provider_info, None, providers_failed
else:
errors.append(f"{provider['name']}: {error}")
providers_failed.append({"name": provider['name'], "error": error})
print(f"Provider {provider['name']} failed: {error}")
except Exception as e:
errors.append(f"{provider['name']}: {str(e)}")
providers_failed.append({"name": provider['name'], "error": str(e)})
print(f"Provider {provider['name']} exception: {e}")
return None, None, f"All LLM providers failed after {MAX_ROUNDS} rounds: {'; '.join(errors)}", providers_failed
def _make_request(self, url: str, headers: dict, payload: dict, provider_name: str) -> requests.Response:
"""Make HTTP request to provider (no internal retry - rotation handles retries)."""
response = requests.post(url, headers=headers, json=payload, timeout=30)
response.raise_for_status()
return response
def _call_provider(self, provider: dict, prompt: str, temperature: float, max_tokens: int) -> Tuple[Optional[str], Optional[str]]:
"""Call a specific LLM provider."""
headers = {"Content-Type": "application/json"}
if provider["name"] == "groq":
headers["Authorization"] = f"Bearer {provider['key']}"
payload = {
"model": provider["model"],
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": temperature,
}
response = self._make_request(provider["url"], headers, payload, provider["name"])
data = response.json()
if data and "choices" in data and data["choices"]:
return data["choices"][0]["message"]["content"], None
return None, "No content in Groq response"
elif provider["name"] == "gemini":
url = f"https://generativelanguage.googleapis.com/v1beta/models/{provider['model']}:generateContent?key={provider['key']}"
payload = {
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"temperature": temperature,
"maxOutputTokens": max_tokens,
}
}
response = self._make_request(url, headers, payload, provider["name"])
data = response.json()
if data and "candidates" in data and data["candidates"]:
first_candidate = data["candidates"][0]
if "content" in first_candidate and "parts" in first_candidate["content"]:
for part in first_candidate["content"]["parts"]:
if "text" in part:
return part["text"], None
return None, "No text content in Gemini response"
elif provider["name"] == "openrouter":
headers["Authorization"] = f"Bearer {provider['key']}"
headers["HTTP-Referer"] = "https://huggingface.co/spaces"
headers["X-Title"] = "Instant SWOT Agent"
payload = {
"model": provider["model"],
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": temperature,
}
response = self._make_request(provider["url"], headers, payload, provider["name"])
data = response.json()
if data and "choices" in data and data["choices"]:
return data["choices"][0]["message"]["content"], None
return None, "No content in OpenRouter response"
return None, f"Unknown provider: {provider['name']}"
# Singleton instance for default (env-based) client
_client = None
def get_llm_client(override_keys: dict = None) -> LLMClient:
"""Get or create an LLM client instance.
Args:
override_keys: If provided, creates a new client with these keys.
If None/empty, returns the singleton instance.
"""
# If user provided override keys, create a fresh client for this request
if override_keys:
return LLMClient(override_keys)
# Otherwise use singleton for default env-based keys
global _client
if _client is None:
_client = LLMClient()
return _client
|