Spaces:
Sleeping
Sleeping
File size: 6,674 Bytes
3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 3e802a5 d803316 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import time
import json
from typing import Dict, List, Callable, Any, Union
import google.generativeai as genai
from groq import Groq, RateLimitError
# Assuming your config.py looks something like this for the example to be runnable
from dataclasses import dataclass
@dataclass
class APIKey:
provider: str
key: str
model: str
# A simple callback for demonstration
def no_op_callback(message: str):
print(message)
class LLMHandler:
def __init__(self, api_keys: List[APIKey], progress_callback: Callable[[str], None] = no_op_callback):
self.clients = []
self.progress_callback = progress_callback
for key in api_keys:
try:
if key.provider == "groq":
# --- SOLUTION ---
# Disable the library's internal retries. Let our handler manage failovers.
# This gives us immediate control when a rate limit is hit.
client = Groq(api_key=key.key, max_retries=0)
self.clients.append({
"provider": "groq",
"client": client,
"model": key.model,
"id": f"groq_{key.key[-4:]}"
})
elif key.provider == "gemini":
# Note: Gemini's library is less explicit about HTTP retries in its
# standard configuration, but the principle remains the same. The main
# offender is usually HTTP-based libraries like Groq's or OpenAI's.
genai.configure(api_key=key.key)
self.clients.append({
"provider": "gemini",
"client": genai.GenerativeModel(key.model),
"model": key.model,
"id": f"gemini_{key.key[-4:]}"
})
self.progress_callback(f"Successfully configured client: {self.clients[-1]['id']}")
except Exception as e:
self.progress_callback(f"Failed to configure client for key ending in {key.key[-4:]}: {e}")
if not self.clients:
self.progress_callback("Warning: No LLM clients were successfully configured.")
self.cooldowns: Dict[str, float] = {}
self.cooldown_period = 30 # 30 seconds
def _attempt_generation(self, generation_logic: Callable[[Dict], Any]) -> Any:
"""
A private generic method to handle the client iteration, cooldown, and error handling logic.
Args:
generation_logic: A function that takes a client_info dictionary and executes
the specific LLM call, returning the processed content.
"""
if not self.clients:
raise ValueError("No LLM clients configured.")
# Iterate through a copy of the clients list to allow for potential future modifications
for client_info in self.clients:
client_id = client_info["id"]
# Check and manage cooldown
if client_id in self.cooldowns:
if time.time() - self.cooldowns[client_id] < self.cooldown_period:
self.progress_callback(f"Skipping {client_id} (on cooldown).")
continue
else:
self.progress_callback(f"Cooldown expired for {client_id}.")
del self.cooldowns[client_id]
try:
# Execute the specific generation logic passed to this method
return generation_logic(client_info)
except RateLimitError:
self.progress_callback(f"Rate limit hit for {client_id}. Placing it on a {self.cooldown_period}s cooldown.")
self.cooldowns[client_id] = time.time()
continue # Try the next client
except Exception as e:
# This catches other errors like API key issues, parsing errors, etc.
self.progress_callback(f"An error occurred with {client_id}: {e}. Placing on cooldown and trying next client.")
self.cooldowns[client_id] = time.time() # Put faulty clients on cooldown too
continue
# If the loop completes without returning, all clients have failed.
raise RuntimeError("Failed to get a response from any available LLM provider.")
def generate_documentation(self, prompt: str) -> Dict:
"""
Generates structured JSON documentation using available clients.
"""
def _generate(client_info: Dict) -> Dict:
client_id = client_info["id"]
self.progress_callback(f"Attempting to generate JSON docs with {client_id} ({client_info['model']})...")
if client_info["provider"] == "groq":
response = client_info["client"].chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=client_info["model"],
temperature=0.1,
response_format={"type": "json_object"},
)
content = response.choices[0].message.content
elif client_info["provider"] == "gemini":
# For Gemini, you must explicitly ask for JSON in the prompt
# e.g., prompt = "Generate JSON... " + original_prompt
response = client_info["client"].generate_content(prompt)
content = response.text.strip().lstrip("```json").rstrip("```").strip()
return json.loads(content)
return self._attempt_generation(_generate)
def generate_text_response(self, prompt: str) -> str:
"""
Generates a plain text response using available clients.
"""
def _generate(client_info: Dict) -> str:
client_id = client_info["id"]
self.progress_callback(f"Attempting to generate text with {client_id} ({client_info['model']})...")
if client_info["provider"] == "groq":
response = client_info["client"].chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=client_info["model"],
temperature=0.2,
)
return response.choices[0].message.content
elif client_info["provider"] == "gemini":
response = client_info["client"].generate_content(prompt)
return response.text.strip()
return self._attempt_generation(_generate) |