File size: 7,752 Bytes
af25c62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""
Groq LLM Client for RGB RAG Evaluation
Supports multiple Groq models for experimentation
"""

import os
import time
from typing import Optional, Dict, Any, List
from groq import Groq
from dotenv import load_dotenv
from collections import deque

# Load environment variables
load_dotenv()


class GroqLLMClient:
    """
    Client for interacting with Groq's free LLM API.
    Supports multiple models for comparative evaluation.
    Implements RPM (Requests Per Minute) rate limiting.
    """
    
    # RPM limits (Groq free tier)
    RPM_LIMIT = 25  # Maximum 25 requests per minute (safe margin below 30)
    MIN_REQUEST_INTERVAL = 2.5  # Minimum 2.5 seconds between requests for safety
    
    # Available Groq models (free tier)
    AVAILABLE_MODELS = [
        "meta-llama/llama-4-maverick-17b-128e-instruct",  # Llama 4 Maverick 17B
        "meta-llama/llama-prompt-guard-2-86m",            # Llama Prompt Guard 2 86M
        "llama-3.1-8b-instant",                           # Llama 3.1 8B - Fast
        "openai/gpt-oss-120b",                            # GPT OSS 120B
        "moonshotai/kimi-k2-instruct-0905",               # Kimi K2 Instruct 0905
        "moonshotai/kimi-k2-instruct",                    # Kimi K2 Instruct
        "llama-3.3-70b-versatile",                        # Llama 3.3 70B
        "meta-llama/llama-4-scout-17b-16e-instruct",     # Llama 4 Scout 17B
        "qwen/qwen3-32b",                                 # Qwen 3 32B
    ]
    
    def __init__(self, api_key: Optional[str] = None, model: str = "llama-3.3-70b-versatile"):
        """
        Initialize the Groq client with RPM rate limiting.
        
        Args:
            api_key: Groq API key. If None, reads from GROQ_API_KEY env variable.
            model: The model to use for generation.
        """
        self.api_key = api_key or os.getenv("GROQ_API_KEY")
        if not self.api_key:
            raise ValueError(
                "Groq API key is required. Set GROQ_API_KEY environment variable "
                "or pass api_key parameter. Get your free API key at https://console.groq.com/"
            )
        
        self.client = Groq(api_key=self.api_key)
        self.model = model
        
        # RPM tracking using sliding window (timestamps of last 60 seconds)
        self.request_times = deque()
        self.last_request_time = 0
        self.total_requests = 0
        
    def set_model(self, model: str) -> None:
        """Switch to a different model."""
        if model not in self.AVAILABLE_MODELS:
            print(f"Warning: {model} not in known models list. Proceeding anyway.")
        self.model = model
    
    def _check_rpm_limit(self) -> None:
        """
        Check and enforce RPM (Requests Per Minute) limit using sliding window.
        Removes requests older than 60 seconds and waits if limit is exceeded.
        """
        current_time = time.time()
        
        # Remove timestamps older than 60 seconds
        while self.request_times and self.request_times[0] < current_time - 60:
            self.request_times.popleft()
        
        # If we've hit the RPM limit, wait until oldest request is older than 60 seconds
        if len(self.request_times) >= self.RPM_LIMIT:
            wait_time = 60 - (current_time - self.request_times[0])
            if wait_time > 0:
                print(f"  [RPM limit reached ({self.RPM_LIMIT}/min). Waiting {wait_time:.1f}s...]")
                time.sleep(wait_time)
                # Retry after waiting
                self._check_rpm_limit()
        
        # Add current request time
        self.request_times.append(current_time)
    
    def _rate_limit(self) -> None:
        """
        Enforce rate limiting:
        1. Check RPM limit (30 requests per minute)
        2. Enforce minimum interval between requests
        """
        # Check RPM limit
        self._check_rpm_limit()
        
        # Ensure minimum interval between requests
        current_time = time.time()
        time_since_last = current_time - self.last_request_time
        
        if time_since_last < self.MIN_REQUEST_INTERVAL:
            sleep_time = self.MIN_REQUEST_INTERVAL - time_since_last
            time.sleep(sleep_time)
        
        self.last_request_time = time.time()
        self.total_requests += 1
        
        # Log progress every 10 requests
        if self.total_requests % 10 == 0:
            print(f"  [Processed {self.total_requests} requests ({len(self.request_times)}/25 in last minute)]")
    
    def generate(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        temperature: float = 0.0,
        max_tokens: int = 1024,
        retry_count: int = 3
    ) -> str:
        """
        Generate a response from the LLM.
        
        Args:
            prompt: The user prompt/question.
            system_prompt: Optional system prompt for context.
            temperature: Sampling temperature (0.0 for deterministic).
            max_tokens: Maximum tokens in response.
            retry_count: Number of retries on failure.
            
        Returns:
            The generated text response.
        """
        self._rate_limit()
        
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})
        
        for attempt in range(retry_count):
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=messages,
                    temperature=temperature,
                    max_tokens=max_tokens,
                )
                return response.choices[0].message.content.strip()
                
            except Exception as e:
                if attempt < retry_count - 1:
                    wait_time = (attempt + 1) * 5  # Exponential backoff
                    print(f"  [API error: {e}. Retrying in {wait_time}s...]")
                    time.sleep(wait_time)
                else:
                    print(f"  [API error after {retry_count} attempts: {e}]")
                    return f"ERROR: {str(e)}"
        
        return "ERROR: Failed to generate response"
    
    def generate_for_rag(
        self,
        question: str,
        documents: List[str],
        prompt_template: str
    ) -> str:
        """
        Generate a response for RAG evaluation using the paper's prompt format.
        
        Args:
            question: The question to answer.
            documents: List of retrieved documents/passages.
            prompt_template: The prompt template from the research paper.
            
        Returns:
            The generated answer.
        """
        # Format documents
        docs_text = "\n".join([f"Document [{i+1}]: {doc}" for i, doc in enumerate(documents)])
        
        # Fill in the template
        prompt = prompt_template.format(
            question=question,
            documents=docs_text
        )
        
        return self.generate(prompt, temperature=0.0)
    
    @classmethod
    def get_available_models(cls) -> List[str]:
        """Return list of available Groq models."""
        return cls.AVAILABLE_MODELS.copy()


def test_client():
    """Test the Groq client with a simple query."""
    try:
        client = GroqLLMClient()
        print(f"Testing with model: {client.model}")
        response = client.generate("What is 2+2? Answer with just the number.")
        print(f"Response: {response}")
        return True
    except Exception as e:
        print(f"Error testing client: {e}")
        return False


if __name__ == "__main__":
    test_client()