File size: 6,674 Bytes
3e802a5
 
d803316
3e802a5
 
 
d803316
 
 
 
 
 
 
3e802a5
d803316
3e802a5
 
 
 
 
 
d803316
3e802a5
d803316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e802a5
d803316
 
 
3e802a5
 
 
d803316
3e802a5
d803316
 
 
 
 
3e802a5
 
 
 
d803316
3e802a5
 
 
d803316
3e802a5
 
 
 
 
d803316
3e802a5
 
 
d803316
 
 
3e802a5
 
 
d803316
3e802a5
d803316
 
 
3e802a5
 
d803316
 
 
 
3e802a5
d803316
3e802a5
d803316
3e802a5
d803316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e802a5
d803316
3e802a5
d803316
3e802a5
d803316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import time
import json
from typing import Dict, List, Callable, Any, Union
import google.generativeai as genai
from groq import Groq, RateLimitError

# Assuming your config.py looks something like this for the example to be runnable
from dataclasses import dataclass
@dataclass
class APIKey:
    provider: str
    key: str
    model: str

# A simple callback for demonstration
def no_op_callback(message: str):
    print(message)

class LLMHandler:
    def __init__(self, api_keys: List[APIKey], progress_callback: Callable[[str], None] = no_op_callback):
        self.clients = []
        self.progress_callback = progress_callback
        for key in api_keys:
            try:
                if key.provider == "groq":
                    # --- SOLUTION ---
                    # Disable the library's internal retries. Let our handler manage failovers.
                    # This gives us immediate control when a rate limit is hit.
                    client = Groq(api_key=key.key, max_retries=0) 
                    self.clients.append({
                        "provider": "groq",
                        "client": client,
                        "model": key.model,
                        "id": f"groq_{key.key[-4:]}"
                    })
                elif key.provider == "gemini":
                    # Note: Gemini's library is less explicit about HTTP retries in its
                    # standard configuration, but the principle remains the same. The main
                    # offender is usually HTTP-based libraries like Groq's or OpenAI's.
                    genai.configure(api_key=key.key)
                    self.clients.append({
                        "provider": "gemini",
                        "client": genai.GenerativeModel(key.model),
                        "model": key.model,
                        "id": f"gemini_{key.key[-4:]}"
                    })
                self.progress_callback(f"Successfully configured client: {self.clients[-1]['id']}")
            except Exception as e:
                self.progress_callback(f"Failed to configure client for key ending in {key.key[-4:]}: {e}")
        
        if not self.clients:
            self.progress_callback("Warning: No LLM clients were successfully configured.")

        self.cooldowns: Dict[str, float] = {}
        self.cooldown_period = 30  # 30 seconds

    def _attempt_generation(self, generation_logic: Callable[[Dict], Any]) -> Any:
        """
        A private generic method to handle the client iteration, cooldown, and error handling logic.
        
        Args:
            generation_logic: A function that takes a client_info dictionary and executes
                              the specific LLM call, returning the processed content.
        """
        if not self.clients:
            raise ValueError("No LLM clients configured.")

        # Iterate through a copy of the clients list to allow for potential future modifications
        for client_info in self.clients:
            client_id = client_info["id"]

            # Check and manage cooldown
            if client_id in self.cooldowns:
                if time.time() - self.cooldowns[client_id] < self.cooldown_period:
                    self.progress_callback(f"Skipping {client_id} (on cooldown).")
                    continue
                else:
                    self.progress_callback(f"Cooldown expired for {client_id}.")
                    del self.cooldowns[client_id]
            
            try:
                # Execute the specific generation logic passed to this method
                return generation_logic(client_info)
            
            except RateLimitError:
                self.progress_callback(f"Rate limit hit for {client_id}. Placing it on a {self.cooldown_period}s cooldown.")
                self.cooldowns[client_id] = time.time()
                continue # Try the next client
            except Exception as e:
                # This catches other errors like API key issues, parsing errors, etc.
                self.progress_callback(f"An error occurred with {client_id}: {e}. Placing on cooldown and trying next client.")
                self.cooldowns[client_id] = time.time() # Put faulty clients on cooldown too
                continue

        # If the loop completes without returning, all clients have failed.
        raise RuntimeError("Failed to get a response from any available LLM provider.")

    def generate_documentation(self, prompt: str) -> Dict:
        """
        Generates structured JSON documentation using available clients.
        """
        def _generate(client_info: Dict) -> Dict:
            client_id = client_info["id"]
            self.progress_callback(f"Attempting to generate JSON docs with {client_id} ({client_info['model']})...")
            
            if client_info["provider"] == "groq":
                response = client_info["client"].chat.completions.create(
                    messages=[{"role": "user", "content": prompt}],
                    model=client_info["model"],
                    temperature=0.1,
                    response_format={"type": "json_object"},
                )
                content = response.choices[0].message.content
            
            elif client_info["provider"] == "gemini":
                # For Gemini, you must explicitly ask for JSON in the prompt
                # e.g., prompt = "Generate JSON... " + original_prompt
                response = client_info["client"].generate_content(prompt)
                content = response.text.strip().lstrip("```json").rstrip("```").strip()

            return json.loads(content)

        return self._attempt_generation(_generate)

    def generate_text_response(self, prompt: str) -> str:
        """
        Generates a plain text response using available clients.
        """
        def _generate(client_info: Dict) -> str:
            client_id = client_info["id"]
            self.progress_callback(f"Attempting to generate text with {client_id} ({client_info['model']})...")
            
            if client_info["provider"] == "groq":
                response = client_info["client"].chat.completions.create(
                    messages=[{"role": "user", "content": prompt}],
                    model=client_info["model"],
                    temperature=0.2,
                )
                return response.choices[0].message.content
            
            elif client_info["provider"] == "gemini":
                response = client_info["client"].generate_content(prompt)
                return response.text.strip()
        
        return self._attempt_generation(_generate)