Rahul-Samedavar commited on
Commit
d803316
·
1 Parent(s): faf02df

fixed rate limits

Browse files
Files changed (1) hide show
  1. codescribe/llm_handler.py +100 -78
codescribe/llm_handler.py CHANGED
@@ -1,123 +1,145 @@
1
  import time
2
  import json
3
- from typing import Dict, List, Callable
4
  import google.generativeai as genai
5
  from groq import Groq, RateLimitError
6
 
 
 
 
 
 
 
 
7
 
8
-
9
- from .config import APIKey
10
-
11
  def no_op_callback(message: str):
12
  print(message)
13
 
14
  class LLMHandler:
15
  def __init__(self, api_keys: List[APIKey], progress_callback: Callable[[str], None] = no_op_callback):
16
  self.clients = []
17
- self.progress_callback = progress_callback # NEW
18
  for key in api_keys:
19
- if key.provider == "groq":
20
- self.clients.append({
21
- "provider": "groq",
22
- "client": Groq(api_key=key.key),
23
- "model": key.model,
24
- "id": f"groq_{key.key[-4:]}"
25
- })
26
- elif key.provider == "gemini":
27
- genai.configure(api_key=key.key)
28
- self.clients.append({
29
- "provider": "gemini",
30
- "client": genai.GenerativeModel(key.model),
31
- "model": key.model,
32
- "id": f"gemini_{key.key[-4:]}"
33
- })
 
 
 
 
 
 
 
 
 
 
 
34
 
 
 
 
35
  self.cooldowns: Dict[str, float] = {}
36
  self.cooldown_period = 30 # 30 seconds
37
 
38
- def generate_documentation(self, prompt: str) -> Dict:
39
  """
40
- Tries to generate documentation using available clients, handling rate limits and failovers.
 
 
 
 
41
  """
42
  if not self.clients:
43
  raise ValueError("No LLM clients configured.")
44
 
 
45
  for client_info in self.clients:
46
  client_id = client_info["id"]
47
 
48
- # Check if the client is on cooldown
49
  if client_id in self.cooldowns:
50
  if time.time() - self.cooldowns[client_id] < self.cooldown_period:
51
  self.progress_callback(f"Skipping {client_id} (on cooldown).")
52
  continue
53
  else:
54
- # Cooldown has expired
55
  del self.cooldowns[client_id]
56
 
57
  try:
58
- self.progress_callback(f"Attempting to generate docs with {client_id} ({client_info['model']})...")
59
- if client_info["provider"] == "groq":
60
- response = client_info["client"].chat.completions.create(
61
- messages=[{"role": "user", "content": prompt}],
62
- model=client_info["model"],
63
- temperature=0.1,
64
- response_format={"type": "json_object"},
65
- )
66
- content = response.choices[0].message.content
67
-
68
- elif client_info["provider"] == "gemini":
69
- response = client_info["client"].generate_content(prompt)
70
- # Gemini might wrap JSON in ```json ... ```
71
- content = response.text.strip().replace("```json", "").replace("```", "").strip()
72
-
73
- return json.loads(content)
74
-
75
  except RateLimitError:
76
  self.progress_callback(f"Rate limit hit for {client_id}. Placing it on a {self.cooldown_period}s cooldown.")
77
  self.cooldowns[client_id] = time.time()
78
- continue
79
  except Exception as e:
80
- self.progress_callback(f"An error occurred with {client_id}: {e}. Trying next client.")
 
 
81
  continue
82
 
83
- raise RuntimeError("Failed to generate documentation from all available LLM providers.")
84
-
85
-
86
- def generate_text_response(self, prompt: str) -> str:
87
  """
88
- Generates a plain text response from LLMs, handling failovers.
89
  """
90
- if not self.clients:
91
- raise ValueError("No LLM clients configured.")
92
-
93
- for client_info in self.clients:
94
  client_id = client_info["id"]
95
- if client_id in self.cooldowns and time.time() - self.cooldowns[client_id] < self.cooldown_period:
96
- self.progress_callback(f"Skipping {client_id} (on cooldown).")
97
- continue
98
- elif client_id in self.cooldowns:
99
- del self.cooldowns[client_id]
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- try:
102
- self.progress_callback(f"Attempting to generate text with {client_id} ({client_info['model']})...")
103
- if client_info["provider"] == "groq":
104
- response = client_info["client"].chat.completions.create(
105
- messages=[{"role": "user", "content": prompt}],
106
- model=client_info["model"],
107
- temperature=0.2,
108
- )
109
- return response.choices[0].message.content
110
-
111
- elif client_info["provider"] == "gemini":
112
- response = client_info["client"].generate_content(prompt)
113
- return response.text.strip()
114
 
115
- except RateLimitError:
116
- self.progress_callback(f"Rate limit hit for {client_id}. Placing it on a {self.cooldown_period}s cooldown.")
117
- self.cooldowns[client_id] = time.time()
118
- continue
119
- except Exception as e:
120
- self.progress_callback(f"An error occurred with {client_id}: {e}. Trying next client.")
121
- continue
122
 
123
- raise RuntimeError("Failed to generate text response from all available LLM providers.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import time
2
  import json
3
+ from typing import Dict, List, Callable, Any, Union
4
  import google.generativeai as genai
5
  from groq import Groq, RateLimitError
6
 
7
+ # Assuming your config.py looks something like this for the example to be runnable
8
+ from dataclasses import dataclass
9
+ @dataclass
10
+ class APIKey:
11
+ provider: str
12
+ key: str
13
+ model: str
14
 
15
+ # A simple callback for demonstration
 
 
16
  def no_op_callback(message: str):
17
  print(message)
18
 
19
  class LLMHandler:
20
  def __init__(self, api_keys: List[APIKey], progress_callback: Callable[[str], None] = no_op_callback):
21
  self.clients = []
22
+ self.progress_callback = progress_callback
23
  for key in api_keys:
24
+ try:
25
+ if key.provider == "groq":
26
+ # --- SOLUTION ---
27
+ # Disable the library's internal retries. Let our handler manage failovers.
28
+ # This gives us immediate control when a rate limit is hit.
29
+ client = Groq(api_key=key.key, max_retries=0)
30
+ self.clients.append({
31
+ "provider": "groq",
32
+ "client": client,
33
+ "model": key.model,
34
+ "id": f"groq_{key.key[-4:]}"
35
+ })
36
+ elif key.provider == "gemini":
37
+ # Note: Gemini's library is less explicit about HTTP retries in its
38
+ # standard configuration, but the principle remains the same. The main
39
+ # offender is usually HTTP-based libraries like Groq's or OpenAI's.
40
+ genai.configure(api_key=key.key)
41
+ self.clients.append({
42
+ "provider": "gemini",
43
+ "client": genai.GenerativeModel(key.model),
44
+ "model": key.model,
45
+ "id": f"gemini_{key.key[-4:]}"
46
+ })
47
+ self.progress_callback(f"Successfully configured client: {self.clients[-1]['id']}")
48
+ except Exception as e:
49
+ self.progress_callback(f"Failed to configure client for key ending in {key.key[-4:]}: {e}")
50
 
51
+ if not self.clients:
52
+ self.progress_callback("Warning: No LLM clients were successfully configured.")
53
+
54
  self.cooldowns: Dict[str, float] = {}
55
  self.cooldown_period = 30 # 30 seconds
56
 
57
+ def _attempt_generation(self, generation_logic: Callable[[Dict], Any]) -> Any:
58
  """
59
+ A private generic method to handle the client iteration, cooldown, and error handling logic.
60
+
61
+ Args:
62
+ generation_logic: A function that takes a client_info dictionary and executes
63
+ the specific LLM call, returning the processed content.
64
  """
65
  if not self.clients:
66
  raise ValueError("No LLM clients configured.")
67
 
68
+ # Iterate through a copy of the clients list to allow for potential future modifications
69
  for client_info in self.clients:
70
  client_id = client_info["id"]
71
 
72
+ # Check and manage cooldown
73
  if client_id in self.cooldowns:
74
  if time.time() - self.cooldowns[client_id] < self.cooldown_period:
75
  self.progress_callback(f"Skipping {client_id} (on cooldown).")
76
  continue
77
  else:
78
+ self.progress_callback(f"Cooldown expired for {client_id}.")
79
  del self.cooldowns[client_id]
80
 
81
  try:
82
+ # Execute the specific generation logic passed to this method
83
+ return generation_logic(client_info)
84
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  except RateLimitError:
86
  self.progress_callback(f"Rate limit hit for {client_id}. Placing it on a {self.cooldown_period}s cooldown.")
87
  self.cooldowns[client_id] = time.time()
88
+ continue # Try the next client
89
  except Exception as e:
90
+ # This catches other errors like API key issues, parsing errors, etc.
91
+ self.progress_callback(f"An error occurred with {client_id}: {e}. Placing on cooldown and trying next client.")
92
+ self.cooldowns[client_id] = time.time() # Put faulty clients on cooldown too
93
  continue
94
 
95
+ # If the loop completes without returning, all clients have failed.
96
+ raise RuntimeError("Failed to get a response from any available LLM provider.")
97
+
98
+ def generate_documentation(self, prompt: str) -> Dict:
99
  """
100
+ Generates structured JSON documentation using available clients.
101
  """
102
+ def _generate(client_info: Dict) -> Dict:
 
 
 
103
  client_id = client_info["id"]
104
+ self.progress_callback(f"Attempting to generate JSON docs with {client_id} ({client_info['model']})...")
105
+
106
+ if client_info["provider"] == "groq":
107
+ response = client_info["client"].chat.completions.create(
108
+ messages=[{"role": "user", "content": prompt}],
109
+ model=client_info["model"],
110
+ temperature=0.1,
111
+ response_format={"type": "json_object"},
112
+ )
113
+ content = response.choices[0].message.content
114
+
115
+ elif client_info["provider"] == "gemini":
116
+ # For Gemini, you must explicitly ask for JSON in the prompt
117
+ # e.g., prompt = "Generate JSON... " + original_prompt
118
+ response = client_info["client"].generate_content(prompt)
119
+ content = response.text.strip().lstrip("```json").rstrip("```").strip()
120
 
121
+ return json.loads(content)
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ return self._attempt_generation(_generate)
 
 
 
 
 
 
124
 
125
+ def generate_text_response(self, prompt: str) -> str:
126
+ """
127
+ Generates a plain text response using available clients.
128
+ """
129
+ def _generate(client_info: Dict) -> str:
130
+ client_id = client_info["id"]
131
+ self.progress_callback(f"Attempting to generate text with {client_id} ({client_info['model']})...")
132
+
133
+ if client_info["provider"] == "groq":
134
+ response = client_info["client"].chat.completions.create(
135
+ messages=[{"role": "user", "content": prompt}],
136
+ model=client_info["model"],
137
+ temperature=0.2,
138
+ )
139
+ return response.choices[0].message.content
140
+
141
+ elif client_info["provider"] == "gemini":
142
+ response = client_info["client"].generate_content(prompt)
143
+ return response.text.strip()
144
+
145
+ return self._attempt_generation(_generate)