RGBMetrics / src /llm_client.py
RGB Evaluation
feat: Add separate grid layout for 4 RAG abilities in Streamlit UI
af25c62
"""
Groq LLM Client for RGB RAG Evaluation
Supports multiple Groq models for experimentation
"""
import os
import time
from typing import Optional, Dict, Any, List
from groq import Groq
from dotenv import load_dotenv
from collections import deque
# Load environment variables
load_dotenv()
class GroqLLMClient:
"""
Client for interacting with Groq's free LLM API.
Supports multiple models for comparative evaluation.
Implements RPM (Requests Per Minute) rate limiting.
"""
# RPM limits (Groq free tier)
RPM_LIMIT = 25 # Maximum 25 requests per minute (safe margin below 30)
MIN_REQUEST_INTERVAL = 2.5 # Minimum 2.5 seconds between requests for safety
# Available Groq models (free tier)
AVAILABLE_MODELS = [
"meta-llama/llama-4-maverick-17b-128e-instruct", # Llama 4 Maverick 17B
"meta-llama/llama-prompt-guard-2-86m", # Llama Prompt Guard 2 86M
"llama-3.1-8b-instant", # Llama 3.1 8B - Fast
"openai/gpt-oss-120b", # GPT OSS 120B
"moonshotai/kimi-k2-instruct-0905", # Kimi K2 Instruct 0905
"moonshotai/kimi-k2-instruct", # Kimi K2 Instruct
"llama-3.3-70b-versatile", # Llama 3.3 70B
"meta-llama/llama-4-scout-17b-16e-instruct", # Llama 4 Scout 17B
"qwen/qwen3-32b", # Qwen 3 32B
]
def __init__(self, api_key: Optional[str] = None, model: str = "llama-3.3-70b-versatile"):
"""
Initialize the Groq client with RPM rate limiting.
Args:
api_key: Groq API key. If None, reads from GROQ_API_KEY env variable.
model: The model to use for generation.
"""
self.api_key = api_key or os.getenv("GROQ_API_KEY")
if not self.api_key:
raise ValueError(
"Groq API key is required. Set GROQ_API_KEY environment variable "
"or pass api_key parameter. Get your free API key at https://console.groq.com/"
)
self.client = Groq(api_key=self.api_key)
self.model = model
# RPM tracking using sliding window (timestamps of last 60 seconds)
self.request_times = deque()
self.last_request_time = 0
self.total_requests = 0
def set_model(self, model: str) -> None:
"""Switch to a different model."""
if model not in self.AVAILABLE_MODELS:
print(f"Warning: {model} not in known models list. Proceeding anyway.")
self.model = model
def _check_rpm_limit(self) -> None:
"""
Check and enforce RPM (Requests Per Minute) limit using sliding window.
Removes requests older than 60 seconds and waits if limit is exceeded.
"""
current_time = time.time()
# Remove timestamps older than 60 seconds
while self.request_times and self.request_times[0] < current_time - 60:
self.request_times.popleft()
# If we've hit the RPM limit, wait until oldest request is older than 60 seconds
if len(self.request_times) >= self.RPM_LIMIT:
wait_time = 60 - (current_time - self.request_times[0])
if wait_time > 0:
print(f" [RPM limit reached ({self.RPM_LIMIT}/min). Waiting {wait_time:.1f}s...]")
time.sleep(wait_time)
# Retry after waiting
self._check_rpm_limit()
# Add current request time
self.request_times.append(current_time)
def _rate_limit(self) -> None:
"""
Enforce rate limiting:
1. Check RPM limit (30 requests per minute)
2. Enforce minimum interval between requests
"""
# Check RPM limit
self._check_rpm_limit()
# Ensure minimum interval between requests
current_time = time.time()
time_since_last = current_time - self.last_request_time
if time_since_last < self.MIN_REQUEST_INTERVAL:
sleep_time = self.MIN_REQUEST_INTERVAL - time_since_last
time.sleep(sleep_time)
self.last_request_time = time.time()
self.total_requests += 1
# Log progress every 10 requests
if self.total_requests % 10 == 0:
print(f" [Processed {self.total_requests} requests ({len(self.request_times)}/25 in last minute)]")
def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: float = 0.0,
max_tokens: int = 1024,
retry_count: int = 3
) -> str:
"""
Generate a response from the LLM.
Args:
prompt: The user prompt/question.
system_prompt: Optional system prompt for context.
temperature: Sampling temperature (0.0 for deterministic).
max_tokens: Maximum tokens in response.
retry_count: Number of retries on failure.
Returns:
The generated text response.
"""
self._rate_limit()
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
for attempt in range(retry_count):
try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return response.choices[0].message.content.strip()
except Exception as e:
if attempt < retry_count - 1:
wait_time = (attempt + 1) * 5 # Exponential backoff
print(f" [API error: {e}. Retrying in {wait_time}s...]")
time.sleep(wait_time)
else:
print(f" [API error after {retry_count} attempts: {e}]")
return f"ERROR: {str(e)}"
return "ERROR: Failed to generate response"
def generate_for_rag(
self,
question: str,
documents: List[str],
prompt_template: str
) -> str:
"""
Generate a response for RAG evaluation using the paper's prompt format.
Args:
question: The question to answer.
documents: List of retrieved documents/passages.
prompt_template: The prompt template from the research paper.
Returns:
The generated answer.
"""
# Format documents
docs_text = "\n".join([f"Document [{i+1}]: {doc}" for i, doc in enumerate(documents)])
# Fill in the template
prompt = prompt_template.format(
question=question,
documents=docs_text
)
return self.generate(prompt, temperature=0.0)
@classmethod
def get_available_models(cls) -> List[str]:
"""Return list of available Groq models."""
return cls.AVAILABLE_MODELS.copy()
def test_client():
"""Test the Groq client with a simple query."""
try:
client = GroqLLMClient()
print(f"Testing with model: {client.model}")
response = client.generate("What is 2+2? Answer with just the number.")
print(f"Response: {response}")
return True
except Exception as e:
print(f"Error testing client: {e}")
return False
if __name__ == "__main__":
test_client()