RAG10 / llm_client.py
Vivek Kadamati
Initial commit
ee444c0
"""Groq LLM integration with rate limiting."""
from typing import List, Dict, Optional, AsyncIterator
import time
from groq import Groq
import asyncio
from datetime import datetime, timedelta
from collections import deque
import os
class RateLimiter:
"""Rate limiter for API calls."""
def __init__(self, max_requests_per_minute: int = 30):
"""Initialize rate limiter.
Args:
max_requests_per_minute: Maximum requests allowed per minute
"""
self.max_requests = max_requests_per_minute
self.request_times = deque()
self.lock = asyncio.Lock()
async def acquire(self):
"""Acquire permission to make a request."""
async with self.lock:
now = datetime.now()
# Remove requests older than 1 minute
while self.request_times and (now - self.request_times[0]) > timedelta(minutes=1):
self.request_times.popleft()
# If at limit, wait
if len(self.request_times) >= self.max_requests:
# Calculate how long to wait
oldest_request = self.request_times[0]
wait_time = 60 - (now - oldest_request).total_seconds()
if wait_time > 0:
print(f"Rate limit reached. Waiting {wait_time:.2f} seconds...")
await asyncio.sleep(wait_time)
# Recursive call after waiting
return await self.acquire()
# Record this request
self.request_times.append(now)
def acquire_sync(self):
"""Synchronous version of acquire."""
now = datetime.now()
# Remove requests older than 1 minute
while self.request_times and (now - self.request_times[0]) > timedelta(minutes=1):
self.request_times.popleft()
# If at limit, wait
if len(self.request_times) >= self.max_requests:
oldest_request = self.request_times[0]
wait_time = 60 - (now - oldest_request).total_seconds()
if wait_time > 0:
print(f"Rate limit reached. Waiting {wait_time:.2f} seconds...")
time.sleep(wait_time)
return self.acquire_sync()
# Record this request
self.request_times.append(now)
class GroqLLMClient:
"""Client for Groq LLM API with rate limiting."""
def __init__(
self,
api_key: str,
model_name: str = "llama-3.1-8b-instant",
max_rpm: int = 30,
rate_limit_delay: float = 2.0
):
"""Initialize Groq client.
Args:
api_key: Groq API key
model_name: Name of the LLM model
max_rpm: Maximum requests per minute
rate_limit_delay: Additional delay between requests (seconds)
"""
self.client = Groq(api_key=api_key)
self.model_name = model_name
self.rate_limiter = RateLimiter(max_rpm)
self.rate_limit_delay = rate_limit_delay
# Available models
self.available_models = [
"meta-llama/llama-4-maverick-17b-128e-instruct",
"llama-3.1-8b-instant",
"openai/gpt-oss-120b"
]
def set_model(self, model_name: str):
"""Set the LLM model.
Args:
model_name: Name of the model
"""
if model_name not in self.available_models:
print(f"Warning: {model_name} not in available models. Using anyway...")
self.model_name = model_name
def generate(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 0.7,
system_prompt: Optional[str] = None
) -> str:
"""Generate text using Groq LLM.
Args:
prompt: Input prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
system_prompt: System prompt
Returns:
Generated text
"""
# Apply rate limiting
self.rate_limiter.acquire_sync()
# Prepare messages
messages = []
if system_prompt:
messages.append({
"role": "system",
"content": system_prompt
})
messages.append({
"role": "user",
"content": prompt
})
try:
# Make API call
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=max_tokens,
temperature=temperature
)
# Add delay
time.sleep(self.rate_limit_delay)
return response.choices[0].message.content
except Exception as e:
print(f"Error generating response: {str(e)}")
return f"Error: {str(e)}"
async def generate_async(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 0.7,
system_prompt: Optional[str] = None
) -> str:
"""Asynchronous version of generate.
Args:
prompt: Input prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
system_prompt: System prompt
Returns:
Generated text
"""
# Apply rate limiting
await self.rate_limiter.acquire()
# Prepare messages
messages = []
if system_prompt:
messages.append({
"role": "system",
"content": system_prompt
})
messages.append({
"role": "user",
"content": prompt
})
try:
# Make API call (synchronous client used in async context)
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=max_tokens,
temperature=temperature
)
# Add delay
await asyncio.sleep(self.rate_limit_delay)
return response.choices[0].message.content
except Exception as e:
print(f"Error generating response: {str(e)}")
return f"Error: {str(e)}"
def generate_with_context(
self,
query: str,
context_documents: List[str],
max_tokens: int = 1024,
temperature: float = 0.7
) -> str:
"""Generate response with retrieved context.
Args:
query: User query
context_documents: List of retrieved documents
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
Returns:
Generated response
"""
# Build context
context = "\n\n".join([
f"Document {i+1}: {doc}"
for i, doc in enumerate(context_documents)
])
# Build prompt
prompt = f"""Answer the following question based on the provided context.
Context:
{context}
Question: {query}
Answer:"""
system_prompt = "You are a helpful AI assistant. Answer questions based on the provided context. If the answer is not in the context, say so."
return self.generate(prompt, max_tokens, temperature, system_prompt)
def batch_generate(
self,
prompts: List[str],
max_tokens: int = 1024,
temperature: float = 0.7,
system_prompt: Optional[str] = None
) -> List[str]:
"""Generate responses for multiple prompts.
Args:
prompts: List of prompts
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
system_prompt: System prompt
Returns:
List of generated responses
"""
responses = []
for i, prompt in enumerate(prompts):
print(f"Processing prompt {i+1}/{len(prompts)}")
response = self.generate(prompt, max_tokens, temperature, system_prompt)
responses.append(response)
return responses
class RAGPipeline:
"""Complete RAG pipeline with LLM and vector store."""
def __init__(
self,
llm_client: GroqLLMClient,
vector_store_manager
):
"""Initialize RAG pipeline.
Args:
llm_client: Groq LLM client
vector_store_manager: ChromaDB manager
"""
self.llm = llm_client
self.vector_store = vector_store_manager
self.chat_history = []
def query(
self,
query: str,
n_results: int = 5,
max_tokens: int = 1024,
temperature: float = 0.7
) -> Dict:
"""Query the RAG system.
Args:
query: User query
n_results: Number of documents to retrieve
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
Returns:
Dictionary with response and retrieved documents
"""
# Retrieve documents
retrieved_docs = self.vector_store.get_retrieved_documents(query, n_results)
# Extract document texts
doc_texts = [doc["document"] for doc in retrieved_docs]
# Generate response
response = self.llm.generate_with_context(
query,
doc_texts,
max_tokens,
temperature
)
# Store in chat history
self.chat_history.append({
"query": query,
"response": response,
"retrieved_docs": retrieved_docs,
"timestamp": datetime.now().isoformat()
})
return {
"query": query,
"response": response,
"retrieved_documents": retrieved_docs
}
def get_chat_history(self) -> List[Dict]:
"""Get chat history.
Returns:
List of chat history entries
"""
return self.chat_history
def clear_history(self):
"""Clear chat history."""
self.chat_history = []