""" Groq LLM Client for RGB RAG Evaluation Supports multiple Groq models for experimentation """ import os import time from typing import Optional, Dict, Any, List from groq import Groq from dotenv import load_dotenv from collections import deque # Load environment variables load_dotenv() class GroqLLMClient: """ Client for interacting with Groq's free LLM API. Supports multiple models for comparative evaluation. Implements RPM (Requests Per Minute) rate limiting. """ # RPM limits (Groq free tier) RPM_LIMIT = 25 # Maximum 25 requests per minute (safe margin below 30) MIN_REQUEST_INTERVAL = 2.5 # Minimum 2.5 seconds between requests for safety # Available Groq models (free tier) AVAILABLE_MODELS = [ "meta-llama/llama-4-maverick-17b-128e-instruct", # Llama 4 Maverick 17B "meta-llama/llama-prompt-guard-2-86m", # Llama Prompt Guard 2 86M "llama-3.1-8b-instant", # Llama 3.1 8B - Fast "openai/gpt-oss-120b", # GPT OSS 120B "moonshotai/kimi-k2-instruct-0905", # Kimi K2 Instruct 0905 "moonshotai/kimi-k2-instruct", # Kimi K2 Instruct "llama-3.3-70b-versatile", # Llama 3.3 70B "meta-llama/llama-4-scout-17b-16e-instruct", # Llama 4 Scout 17B "qwen/qwen3-32b", # Qwen 3 32B ] def __init__(self, api_key: Optional[str] = None, model: str = "llama-3.3-70b-versatile"): """ Initialize the Groq client with RPM rate limiting. Args: api_key: Groq API key. If None, reads from GROQ_API_KEY env variable. model: The model to use for generation. """ self.api_key = api_key or os.getenv("GROQ_API_KEY") if not self.api_key: raise ValueError( "Groq API key is required. Set GROQ_API_KEY environment variable " "or pass api_key parameter. Get your free API key at https://console.groq.com/" ) self.client = Groq(api_key=self.api_key) self.model = model # RPM tracking using sliding window (timestamps of last 60 seconds) self.request_times = deque() self.last_request_time = 0 self.total_requests = 0 def set_model(self, model: str) -> None: """Switch to a different model.""" if model not in self.AVAILABLE_MODELS: print(f"Warning: {model} not in known models list. Proceeding anyway.") self.model = model def _check_rpm_limit(self) -> None: """ Check and enforce RPM (Requests Per Minute) limit using sliding window. Removes requests older than 60 seconds and waits if limit is exceeded. """ current_time = time.time() # Remove timestamps older than 60 seconds while self.request_times and self.request_times[0] < current_time - 60: self.request_times.popleft() # If we've hit the RPM limit, wait until oldest request is older than 60 seconds if len(self.request_times) >= self.RPM_LIMIT: wait_time = 60 - (current_time - self.request_times[0]) if wait_time > 0: print(f" [RPM limit reached ({self.RPM_LIMIT}/min). Waiting {wait_time:.1f}s...]") time.sleep(wait_time) # Retry after waiting self._check_rpm_limit() # Add current request time self.request_times.append(current_time) def _rate_limit(self) -> None: """ Enforce rate limiting: 1. Check RPM limit (30 requests per minute) 2. Enforce minimum interval between requests """ # Check RPM limit self._check_rpm_limit() # Ensure minimum interval between requests current_time = time.time() time_since_last = current_time - self.last_request_time if time_since_last < self.MIN_REQUEST_INTERVAL: sleep_time = self.MIN_REQUEST_INTERVAL - time_since_last time.sleep(sleep_time) self.last_request_time = time.time() self.total_requests += 1 # Log progress every 10 requests if self.total_requests % 10 == 0: print(f" [Processed {self.total_requests} requests ({len(self.request_times)}/25 in last minute)]") def generate( self, prompt: str, system_prompt: Optional[str] = None, temperature: float = 0.0, max_tokens: int = 1024, retry_count: int = 3 ) -> str: """ Generate a response from the LLM. Args: prompt: The user prompt/question. system_prompt: Optional system prompt for context. temperature: Sampling temperature (0.0 for deterministic). max_tokens: Maximum tokens in response. retry_count: Number of retries on failure. Returns: The generated text response. """ self._rate_limit() messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) for attempt in range(retry_count): try: response = self.client.chat.completions.create( model=self.model, messages=messages, temperature=temperature, max_tokens=max_tokens, ) return response.choices[0].message.content.strip() except Exception as e: if attempt < retry_count - 1: wait_time = (attempt + 1) * 5 # Exponential backoff print(f" [API error: {e}. Retrying in {wait_time}s...]") time.sleep(wait_time) else: print(f" [API error after {retry_count} attempts: {e}]") return f"ERROR: {str(e)}" return "ERROR: Failed to generate response" def generate_for_rag( self, question: str, documents: List[str], prompt_template: str ) -> str: """ Generate a response for RAG evaluation using the paper's prompt format. Args: question: The question to answer. documents: List of retrieved documents/passages. prompt_template: The prompt template from the research paper. Returns: The generated answer. """ # Format documents docs_text = "\n".join([f"Document [{i+1}]: {doc}" for i, doc in enumerate(documents)]) # Fill in the template prompt = prompt_template.format( question=question, documents=docs_text ) return self.generate(prompt, temperature=0.0) @classmethod def get_available_models(cls) -> List[str]: """Return list of available Groq models.""" return cls.AVAILABLE_MODELS.copy() def test_client(): """Test the Groq client with a simple query.""" try: client = GroqLLMClient() print(f"Testing with model: {client.model}") response = client.generate("What is 2+2? Answer with just the number.") print(f"Response: {response}") return True except Exception as e: print(f"Error testing client: {e}") return False if __name__ == "__main__": test_client()