Spaces:
Sleeping
Sleeping
File size: 5,275 Bytes
3e802a5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | import time
import json
from typing import Dict, List, Callable
import google.generativeai as genai
from groq import Groq, RateLimitError
from .config import APIKey
def no_op_callback(message: str):
print(message)
class LLMHandler:
def __init__(self, api_keys: List[APIKey], progress_callback: Callable[[str], None] = no_op_callback):
self.clients = []
self.progress_callback = progress_callback # NEW
for key in api_keys:
if key.provider == "groq":
self.clients.append({
"provider": "groq",
"client": Groq(api_key=key.key),
"model": key.model,
"id": f"groq_{key.key[-4:]}"
})
elif key.provider == "gemini":
genai.configure(api_key=key.key)
self.clients.append({
"provider": "gemini",
"client": genai.GenerativeModel(key.model),
"model": key.model,
"id": f"gemini_{key.key[-4:]}"
})
self.cooldowns: Dict[str, float] = {}
self.cooldown_period = 30 # 30 seconds
def generate_documentation(self, prompt: str) -> Dict:
"""
Tries to generate documentation using available clients, handling rate limits and failovers.
"""
if not self.clients:
raise ValueError("No LLM clients configured.")
for client_info in self.clients:
client_id = client_info["id"]
# Check if the client is on cooldown
if client_id in self.cooldowns:
if time.time() - self.cooldowns[client_id] < self.cooldown_period:
self.progress_callback(f"Skipping {client_id} (on cooldown).")
continue
else:
# Cooldown has expired
del self.cooldowns[client_id]
try:
self.progress_callback(f"Attempting to generate docs with {client_id} ({client_info['model']})...")
if client_info["provider"] == "groq":
response = client_info["client"].chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=client_info["model"],
temperature=0.1,
response_format={"type": "json_object"},
)
content = response.choices[0].message.content
elif client_info["provider"] == "gemini":
response = client_info["client"].generate_content(prompt)
# Gemini might wrap JSON in ```json ... ```
content = response.text.strip().replace("```json", "").replace("```", "").strip()
return json.loads(content)
except RateLimitError:
self.progress_callback(f"Rate limit hit for {client_id}. Placing it on a {self.cooldown_period}s cooldown.")
self.cooldowns[client_id] = time.time()
continue
except Exception as e:
self.progress_callback(f"An error occurred with {client_id}: {e}. Trying next client.")
continue
raise RuntimeError("Failed to generate documentation from all available LLM providers.")
def generate_text_response(self, prompt: str) -> str:
"""
Generates a plain text response from LLMs, handling failovers.
"""
if not self.clients:
raise ValueError("No LLM clients configured.")
for client_info in self.clients:
client_id = client_info["id"]
if client_id in self.cooldowns and time.time() - self.cooldowns[client_id] < self.cooldown_period:
self.progress_callback(f"Skipping {client_id} (on cooldown).")
continue
elif client_id in self.cooldowns:
del self.cooldowns[client_id]
try:
self.progress_callback(f"Attempting to generate text with {client_id} ({client_info['model']})...")
if client_info["provider"] == "groq":
response = client_info["client"].chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=client_info["model"],
temperature=0.2,
)
return response.choices[0].message.content
elif client_info["provider"] == "gemini":
response = client_info["client"].generate_content(prompt)
return response.text.strip()
except RateLimitError:
self.progress_callback(f"Rate limit hit for {client_id}. Placing it on a {self.cooldown_period}s cooldown.")
self.cooldowns[client_id] = time.time()
continue
except Exception as e:
self.progress_callback(f"An error occurred with {client_id}: {e}. Trying next client.")
continue
raise RuntimeError("Failed to generate text response from all available LLM providers.") |