Spaces:
Sleeping
Sleeping
| """ | |
| Groq LLM Client for RGB RAG Evaluation | |
| Supports multiple Groq models for experimentation | |
| """ | |
| import os | |
| import time | |
| from typing import Optional, Dict, Any, List | |
| from groq import Groq | |
| from dotenv import load_dotenv | |
| from collections import deque | |
| # Load environment variables | |
| load_dotenv() | |
| class GroqLLMClient: | |
| """ | |
| Client for interacting with Groq's free LLM API. | |
| Supports multiple models for comparative evaluation. | |
| Implements RPM (Requests Per Minute) rate limiting. | |
| """ | |
| # RPM limits (Groq free tier) | |
| RPM_LIMIT = 25 # Maximum 25 requests per minute (safe margin below 30) | |
| MIN_REQUEST_INTERVAL = 2.5 # Minimum 2.5 seconds between requests for safety | |
| # Available Groq models (free tier) | |
| AVAILABLE_MODELS = [ | |
| "meta-llama/llama-4-maverick-17b-128e-instruct", # Llama 4 Maverick 17B | |
| "meta-llama/llama-prompt-guard-2-86m", # Llama Prompt Guard 2 86M | |
| "llama-3.1-8b-instant", # Llama 3.1 8B - Fast | |
| "openai/gpt-oss-120b", # GPT OSS 120B | |
| "moonshotai/kimi-k2-instruct-0905", # Kimi K2 Instruct 0905 | |
| "moonshotai/kimi-k2-instruct", # Kimi K2 Instruct | |
| "llama-3.3-70b-versatile", # Llama 3.3 70B | |
| "meta-llama/llama-4-scout-17b-16e-instruct", # Llama 4 Scout 17B | |
| "qwen/qwen3-32b", # Qwen 3 32B | |
| ] | |
| def __init__(self, api_key: Optional[str] = None, model: str = "llama-3.3-70b-versatile"): | |
| """ | |
| Initialize the Groq client with RPM rate limiting. | |
| Args: | |
| api_key: Groq API key. If None, reads from GROQ_API_KEY env variable. | |
| model: The model to use for generation. | |
| """ | |
| self.api_key = api_key or os.getenv("GROQ_API_KEY") | |
| if not self.api_key: | |
| raise ValueError( | |
| "Groq API key is required. Set GROQ_API_KEY environment variable " | |
| "or pass api_key parameter. Get your free API key at https://console.groq.com/" | |
| ) | |
| self.client = Groq(api_key=self.api_key) | |
| self.model = model | |
| # RPM tracking using sliding window (timestamps of last 60 seconds) | |
| self.request_times = deque() | |
| self.last_request_time = 0 | |
| self.total_requests = 0 | |
| def set_model(self, model: str) -> None: | |
| """Switch to a different model.""" | |
| if model not in self.AVAILABLE_MODELS: | |
| print(f"Warning: {model} not in known models list. Proceeding anyway.") | |
| self.model = model | |
| def _check_rpm_limit(self) -> None: | |
| """ | |
| Check and enforce RPM (Requests Per Minute) limit using sliding window. | |
| Removes requests older than 60 seconds and waits if limit is exceeded. | |
| """ | |
| current_time = time.time() | |
| # Remove timestamps older than 60 seconds | |
| while self.request_times and self.request_times[0] < current_time - 60: | |
| self.request_times.popleft() | |
| # If we've hit the RPM limit, wait until oldest request is older than 60 seconds | |
| if len(self.request_times) >= self.RPM_LIMIT: | |
| wait_time = 60 - (current_time - self.request_times[0]) | |
| if wait_time > 0: | |
| print(f" [RPM limit reached ({self.RPM_LIMIT}/min). Waiting {wait_time:.1f}s...]") | |
| time.sleep(wait_time) | |
| # Retry after waiting | |
| self._check_rpm_limit() | |
| # Add current request time | |
| self.request_times.append(current_time) | |
| def _rate_limit(self) -> None: | |
| """ | |
| Enforce rate limiting: | |
| 1. Check RPM limit (30 requests per minute) | |
| 2. Enforce minimum interval between requests | |
| """ | |
| # Check RPM limit | |
| self._check_rpm_limit() | |
| # Ensure minimum interval between requests | |
| current_time = time.time() | |
| time_since_last = current_time - self.last_request_time | |
| if time_since_last < self.MIN_REQUEST_INTERVAL: | |
| sleep_time = self.MIN_REQUEST_INTERVAL - time_since_last | |
| time.sleep(sleep_time) | |
| self.last_request_time = time.time() | |
| self.total_requests += 1 | |
| # Log progress every 10 requests | |
| if self.total_requests % 10 == 0: | |
| print(f" [Processed {self.total_requests} requests ({len(self.request_times)}/25 in last minute)]") | |
| def generate( | |
| self, | |
| prompt: str, | |
| system_prompt: Optional[str] = None, | |
| temperature: float = 0.0, | |
| max_tokens: int = 1024, | |
| retry_count: int = 3 | |
| ) -> str: | |
| """ | |
| Generate a response from the LLM. | |
| Args: | |
| prompt: The user prompt/question. | |
| system_prompt: Optional system prompt for context. | |
| temperature: Sampling temperature (0.0 for deterministic). | |
| max_tokens: Maximum tokens in response. | |
| retry_count: Number of retries on failure. | |
| Returns: | |
| The generated text response. | |
| """ | |
| self._rate_limit() | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": prompt}) | |
| for attempt in range(retry_count): | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| if attempt < retry_count - 1: | |
| wait_time = (attempt + 1) * 5 # Exponential backoff | |
| print(f" [API error: {e}. Retrying in {wait_time}s...]") | |
| time.sleep(wait_time) | |
| else: | |
| print(f" [API error after {retry_count} attempts: {e}]") | |
| return f"ERROR: {str(e)}" | |
| return "ERROR: Failed to generate response" | |
| def generate_for_rag( | |
| self, | |
| question: str, | |
| documents: List[str], | |
| prompt_template: str | |
| ) -> str: | |
| """ | |
| Generate a response for RAG evaluation using the paper's prompt format. | |
| Args: | |
| question: The question to answer. | |
| documents: List of retrieved documents/passages. | |
| prompt_template: The prompt template from the research paper. | |
| Returns: | |
| The generated answer. | |
| """ | |
| # Format documents | |
| docs_text = "\n".join([f"Document [{i+1}]: {doc}" for i, doc in enumerate(documents)]) | |
| # Fill in the template | |
| prompt = prompt_template.format( | |
| question=question, | |
| documents=docs_text | |
| ) | |
| return self.generate(prompt, temperature=0.0) | |
| def get_available_models(cls) -> List[str]: | |
| """Return list of available Groq models.""" | |
| return cls.AVAILABLE_MODELS.copy() | |
| def test_client(): | |
| """Test the Groq client with a simple query.""" | |
| try: | |
| client = GroqLLMClient() | |
| print(f"Testing with model: {client.model}") | |
| response = client.generate("What is 2+2? Answer with just the number.") | |
| print(f"Response: {response}") | |
| return True | |
| except Exception as e: | |
| print(f"Error testing client: {e}") | |
| return False | |
| if __name__ == "__main__": | |
| test_client() | |