File size: 5,275 Bytes
3e802a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import time
import json
from typing import Dict, List, Callable
import google.generativeai as genai
from groq import Groq, RateLimitError



from .config import APIKey

def no_op_callback(message: str):
    print(message)

class LLMHandler:
    def __init__(self, api_keys: List[APIKey], progress_callback: Callable[[str], None] = no_op_callback):
        self.clients = []
        self.progress_callback = progress_callback # NEW
        for key in api_keys:
            if key.provider == "groq":
                self.clients.append({
                    "provider": "groq",
                    "client": Groq(api_key=key.key),
                    "model": key.model,
                    "id": f"groq_{key.key[-4:]}"
                })
            elif key.provider == "gemini":
                genai.configure(api_key=key.key)
                self.clients.append({
                    "provider": "gemini",
                    "client": genai.GenerativeModel(key.model),
                    "model": key.model,
                    "id": f"gemini_{key.key[-4:]}"
                })
        
        self.cooldowns: Dict[str, float] = {}
        self.cooldown_period = 30  # 30 seconds

    def generate_documentation(self, prompt: str) -> Dict:
        """
        Tries to generate documentation using available clients, handling rate limits and failovers.
        """
        if not self.clients:
            raise ValueError("No LLM clients configured.")

        for client_info in self.clients:
            client_id = client_info["id"]

            # Check if the client is on cooldown
            if client_id in self.cooldowns:
                if time.time() - self.cooldowns[client_id] < self.cooldown_period:
                    self.progress_callback(f"Skipping {client_id} (on cooldown).")
                    continue
                else:
                    # Cooldown has expired
                    del self.cooldowns[client_id]
            
            try:
                self.progress_callback(f"Attempting to generate docs with {client_id} ({client_info['model']})...")
                if client_info["provider"] == "groq":
                    response = client_info["client"].chat.completions.create(
                        messages=[{"role": "user", "content": prompt}],
                        model=client_info["model"],
                        temperature=0.1,
                        response_format={"type": "json_object"},
                    )
                    content = response.choices[0].message.content
                
                elif client_info["provider"] == "gemini":
                    response = client_info["client"].generate_content(prompt)
                    # Gemini might wrap JSON in ```json ... ```
                    content = response.text.strip().replace("```json", "").replace("```", "").strip()

                return json.loads(content)

            except RateLimitError:
                self.progress_callback(f"Rate limit hit for {client_id}. Placing it on a {self.cooldown_period}s cooldown.")
                self.cooldowns[client_id] = time.time()
                continue
            except Exception as e:
                self.progress_callback(f"An error occurred with {client_id}: {e}. Trying next client.")
                continue

        raise RuntimeError("Failed to generate documentation from all available LLM providers.")
    
    
    def generate_text_response(self, prompt: str) -> str:
        """
        Generates a plain text response from LLMs, handling failovers.
        """
        if not self.clients:
            raise ValueError("No LLM clients configured.")

        for client_info in self.clients:
            client_id = client_info["id"]
            if client_id in self.cooldowns and time.time() - self.cooldowns[client_id] < self.cooldown_period:
                self.progress_callback(f"Skipping {client_id} (on cooldown).")
                continue
            elif client_id in self.cooldowns:
                 del self.cooldowns[client_id]

            try:
                self.progress_callback(f"Attempting to generate text with {client_id} ({client_info['model']})...")
                if client_info["provider"] == "groq":
                    response = client_info["client"].chat.completions.create(
                        messages=[{"role": "user", "content": prompt}],
                        model=client_info["model"],
                        temperature=0.2,
                    )
                    return response.choices[0].message.content
                
                elif client_info["provider"] == "gemini":
                    response = client_info["client"].generate_content(prompt)
                    return response.text.strip()

            except RateLimitError:
                self.progress_callback(f"Rate limit hit for {client_id}. Placing it on a {self.cooldown_period}s cooldown.")
                self.cooldowns[client_id] = time.time()
                continue
            except Exception as e:
                self.progress_callback(f"An error occurred with {client_id}: {e}. Trying next client.")
                continue

        raise RuntimeError("Failed to generate text response from all available LLM providers.")