| """Groq LLM integration with rate limiting.""" |
| from typing import List, Dict, Optional, AsyncIterator |
| import time |
| from groq import Groq |
| import asyncio |
| from datetime import datetime, timedelta |
| from collections import deque |
| import os |
|
|
|
|
| class RateLimiter: |
| """Rate limiter for API calls.""" |
| |
| def __init__(self, max_requests_per_minute: int = 30): |
| """Initialize rate limiter. |
| |
| Args: |
| max_requests_per_minute: Maximum requests allowed per minute |
| """ |
| self.max_requests = max_requests_per_minute |
| self.request_times = deque() |
| self.lock = asyncio.Lock() |
| |
| async def acquire(self): |
| """Acquire permission to make a request.""" |
| async with self.lock: |
| now = datetime.now() |
| |
| |
| while self.request_times and (now - self.request_times[0]) > timedelta(minutes=1): |
| self.request_times.popleft() |
| |
| |
| if len(self.request_times) >= self.max_requests: |
| |
| oldest_request = self.request_times[0] |
| wait_time = 60 - (now - oldest_request).total_seconds() |
| |
| if wait_time > 0: |
| print(f"Rate limit reached. Waiting {wait_time:.2f} seconds...") |
| await asyncio.sleep(wait_time) |
| |
| return await self.acquire() |
| |
| |
| self.request_times.append(now) |
| |
| def acquire_sync(self): |
| """Synchronous version of acquire.""" |
| now = datetime.now() |
| |
| |
| while self.request_times and (now - self.request_times[0]) > timedelta(minutes=1): |
| self.request_times.popleft() |
| |
| |
| if len(self.request_times) >= self.max_requests: |
| oldest_request = self.request_times[0] |
| wait_time = 60 - (now - oldest_request).total_seconds() |
| |
| if wait_time > 0: |
| print(f"Rate limit reached. Waiting {wait_time:.2f} seconds...") |
| time.sleep(wait_time) |
| return self.acquire_sync() |
| |
| |
| self.request_times.append(now) |
|
|
|
|
| class GroqLLMClient: |
| """Client for Groq LLM API with rate limiting.""" |
| |
| def __init__( |
| self, |
| api_key: str, |
| model_name: str = "llama-3.1-8b-instant", |
| max_rpm: int = 30, |
| rate_limit_delay: float = 2.0 |
| ): |
| """Initialize Groq client. |
| |
| Args: |
| api_key: Groq API key |
| model_name: Name of the LLM model |
| max_rpm: Maximum requests per minute |
| rate_limit_delay: Additional delay between requests (seconds) |
| """ |
| self.client = Groq(api_key=api_key) |
| self.model_name = model_name |
| self.rate_limiter = RateLimiter(max_rpm) |
| self.rate_limit_delay = rate_limit_delay |
| |
| |
| self.available_models = [ |
| "meta-llama/llama-4-maverick-17b-128e-instruct", |
| "llama-3.1-8b-instant", |
| "openai/gpt-oss-120b" |
| ] |
| |
| def set_model(self, model_name: str): |
| """Set the LLM model. |
| |
| Args: |
| model_name: Name of the model |
| """ |
| if model_name not in self.available_models: |
| print(f"Warning: {model_name} not in available models. Using anyway...") |
| self.model_name = model_name |
| |
| def generate( |
| self, |
| prompt: str, |
| max_tokens: int = 1024, |
| temperature: float = 0.7, |
| system_prompt: Optional[str] = None |
| ) -> str: |
| """Generate text using Groq LLM. |
| |
| Args: |
| prompt: Input prompt |
| max_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| system_prompt: System prompt |
| |
| Returns: |
| Generated text |
| """ |
| |
| self.rate_limiter.acquire_sync() |
| |
| |
| messages = [] |
| if system_prompt: |
| messages.append({ |
| "role": "system", |
| "content": system_prompt |
| }) |
| messages.append({ |
| "role": "user", |
| "content": prompt |
| }) |
| |
| try: |
| |
| response = self.client.chat.completions.create( |
| model=self.model_name, |
| messages=messages, |
| max_tokens=max_tokens, |
| temperature=temperature |
| ) |
| |
| |
| time.sleep(self.rate_limit_delay) |
| |
| return response.choices[0].message.content |
| |
| except Exception as e: |
| print(f"Error generating response: {str(e)}") |
| return f"Error: {str(e)}" |
| |
| async def generate_async( |
| self, |
| prompt: str, |
| max_tokens: int = 1024, |
| temperature: float = 0.7, |
| system_prompt: Optional[str] = None |
| ) -> str: |
| """Asynchronous version of generate. |
| |
| Args: |
| prompt: Input prompt |
| max_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| system_prompt: System prompt |
| |
| Returns: |
| Generated text |
| """ |
| |
| await self.rate_limiter.acquire() |
| |
| |
| messages = [] |
| if system_prompt: |
| messages.append({ |
| "role": "system", |
| "content": system_prompt |
| }) |
| messages.append({ |
| "role": "user", |
| "content": prompt |
| }) |
| |
| try: |
| |
| response = self.client.chat.completions.create( |
| model=self.model_name, |
| messages=messages, |
| max_tokens=max_tokens, |
| temperature=temperature |
| ) |
| |
| |
| await asyncio.sleep(self.rate_limit_delay) |
| |
| return response.choices[0].message.content |
| |
| except Exception as e: |
| print(f"Error generating response: {str(e)}") |
| return f"Error: {str(e)}" |
| |
| def generate_with_context( |
| self, |
| query: str, |
| context_documents: List[str], |
| max_tokens: int = 1024, |
| temperature: float = 0.7 |
| ) -> str: |
| """Generate response with retrieved context. |
| |
| Args: |
| query: User query |
| context_documents: List of retrieved documents |
| max_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| |
| Returns: |
| Generated response |
| """ |
| |
| context = "\n\n".join([ |
| f"Document {i+1}: {doc}" |
| for i, doc in enumerate(context_documents) |
| ]) |
| |
| |
| prompt = f"""Answer the following question based on the provided context. |
| |
| Context: |
| {context} |
| |
| Question: {query} |
| |
| Answer:""" |
| |
| system_prompt = "You are a helpful AI assistant. Answer questions based on the provided context. If the answer is not in the context, say so." |
| |
| return self.generate(prompt, max_tokens, temperature, system_prompt) |
| |
| def batch_generate( |
| self, |
| prompts: List[str], |
| max_tokens: int = 1024, |
| temperature: float = 0.7, |
| system_prompt: Optional[str] = None |
| ) -> List[str]: |
| """Generate responses for multiple prompts. |
| |
| Args: |
| prompts: List of prompts |
| max_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| system_prompt: System prompt |
| |
| Returns: |
| List of generated responses |
| """ |
| responses = [] |
| for i, prompt in enumerate(prompts): |
| print(f"Processing prompt {i+1}/{len(prompts)}") |
| response = self.generate(prompt, max_tokens, temperature, system_prompt) |
| responses.append(response) |
| |
| return responses |
|
|
|
|
| class RAGPipeline: |
| """Complete RAG pipeline with LLM and vector store.""" |
| |
| def __init__( |
| self, |
| llm_client: GroqLLMClient, |
| vector_store_manager |
| ): |
| """Initialize RAG pipeline. |
| |
| Args: |
| llm_client: Groq LLM client |
| vector_store_manager: ChromaDB manager |
| """ |
| self.llm = llm_client |
| self.vector_store = vector_store_manager |
| self.chat_history = [] |
| |
| def query( |
| self, |
| query: str, |
| n_results: int = 5, |
| max_tokens: int = 1024, |
| temperature: float = 0.7 |
| ) -> Dict: |
| """Query the RAG system. |
| |
| Args: |
| query: User query |
| n_results: Number of documents to retrieve |
| max_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| |
| Returns: |
| Dictionary with response and retrieved documents |
| """ |
| |
| retrieved_docs = self.vector_store.get_retrieved_documents(query, n_results) |
| |
| |
| doc_texts = [doc["document"] for doc in retrieved_docs] |
| |
| |
| response = self.llm.generate_with_context( |
| query, |
| doc_texts, |
| max_tokens, |
| temperature |
| ) |
| |
| |
| self.chat_history.append({ |
| "query": query, |
| "response": response, |
| "retrieved_docs": retrieved_docs, |
| "timestamp": datetime.now().isoformat() |
| }) |
| |
| return { |
| "query": query, |
| "response": response, |
| "retrieved_documents": retrieved_docs |
| } |
| |
| def get_chat_history(self) -> List[Dict]: |
| """Get chat history. |
| |
| Returns: |
| List of chat history entries |
| """ |
| return self.chat_history |
| |
| def clear_history(self): |
| """Clear chat history.""" |
| self.chat_history = [] |
|
|