"""LLM integration with support for Groq API and local Ollama.""" from typing import List, Dict, Optional, AsyncIterator import time from groq import Groq import asyncio from datetime import datetime, timedelta from collections import deque import os import requests import json class RateLimiter: """Rate limiter for API calls to respect RPM limits.""" def __init__(self, max_requests_per_minute: int = 30): """Initialize rate limiter. Args: max_requests_per_minute: Maximum requests allowed per minute (default: 30) """ self.max_requests = max_requests_per_minute self.request_times = deque() self.lock = asyncio.Lock() # Calculate minimum delay between requests to avoid hitting limit # E.g., 30 RPM = 2.0 seconds minimum delay, but we add safety margin self.min_interval = 60.0 / max_requests_per_minute async def acquire(self): """Acquire permission to make a request (async version).""" async with self.lock: now = datetime.now() # Remove requests older than 1 minute while self.request_times and (now - self.request_times[0]) > timedelta(minutes=1): self.request_times.popleft() # If at limit, wait if len(self.request_times) >= self.max_requests: # Calculate how long to wait oldest_request = self.request_times[0] wait_time = 60 - (now - oldest_request).total_seconds() if wait_time > 0: print(f"[RATE LIMIT] At {self.max_requests} RPM limit. Waiting {wait_time:.2f}s...") await asyncio.sleep(wait_time) # Recursive call after waiting return await self.acquire() # Record this request self.request_times.append(now) def acquire_sync(self): """Synchronous version of acquire (for blocking code).""" now = datetime.now() # Remove requests older than 1 minute while self.request_times and (now - self.request_times[0]) > timedelta(minutes=1): self.request_times.popleft() # Calculate current request rate current_rpm = len(self.request_times) # If at limit, wait if len(self.request_times) >= self.max_requests: oldest_request = self.request_times[0] wait_time = 60 - (now - oldest_request).total_seconds() if wait_time > 0: print(f"[RATE LIMIT] At {self.max_requests} RPM limit. Waiting {wait_time:.2f}s before next request...") time.sleep(wait_time) return self.acquire_sync() # Record this request self.request_times.append(now) # Log current rate if current_rpm > 0: print(f"[RATE LIMIT] Current: {current_rpm} requests in last minute (Limit: {self.max_requests} RPM)") class GroqLLMClient: """Client for Groq LLM API with rate limiting and API key rotation.""" def __init__( self, api_key: str, model_name: str = "llama-3.1-8b-instant", max_rpm: int = 30, rate_limit_delay: float = 2.0, api_keys: list = None, max_retries: int = 3, retry_delay: float = 60.0 ): """Initialize Groq client with optional API key rotation. Args: api_key: Primary Groq API key model_name: Name of the LLM model max_rpm: Maximum requests per minute rate_limit_delay: Additional delay between requests (seconds) api_keys: List of API keys for rotation (optional) max_retries: Maximum retries on rate limit errors retry_delay: Delay before retry on rate limit error """ # Setup API key rotation self.api_keys = api_keys if api_keys else [api_key] self.current_key_index = 0 self.api_key = self.api_keys[self.current_key_index] self.client = Groq(api_key=self.api_key) self.model_name = model_name self.rate_limiter = RateLimiter(max_rpm) self.rate_limit_delay = rate_limit_delay self.max_retries = max_retries self.retry_delay = retry_delay # Track requests per key for smart rotation self.requests_per_key = {key: 0 for key in self.api_keys} # Available models self.available_models = [ "meta-llama/llama-4-maverick-17b-128e-instruct", "llama-3.1-8b-instant", "openai/gpt-oss-120b" ] if len(self.api_keys) > 1: print(f"[API KEYS] Initialized with {len(self.api_keys)} API keys for rotation") def rotate_api_key(self): """Rotate to the next API key.""" if len(self.api_keys) <= 1: return False self.current_key_index = (self.current_key_index + 1) % len(self.api_keys) self.api_key = self.api_keys[self.current_key_index] self.client = Groq(api_key=self.api_key) print(f"[API KEYS] Rotated to API key {self.current_key_index + 1}/{len(self.api_keys)}") return True def set_model(self, model_name: str): """Set the LLM model. Args: model_name: Name of the model """ if model_name not in self.available_models: print(f"Warning: {model_name} not in available models. Using anyway...") self.model_name = model_name def generate( self, prompt: str, max_tokens: int = 1024, temperature: float = 0.7, system_prompt: Optional[str] = None ) -> str: """Generate text using Groq LLM with rate limiting and retry logic. Args: prompt: Input prompt max_tokens: Maximum tokens to generate temperature: Sampling temperature system_prompt: System prompt Returns: Generated text """ # Apply rate limiting to respect 30 RPM limit print(f"[RATE LIMIT] Applying rate limiting (RPM limit: {self.rate_limiter.max_requests}, delay: {self.rate_limit_delay}s)") self.rate_limiter.acquire_sync() # Prepare messages messages = [] if system_prompt: messages.append({ "role": "system", "content": system_prompt }) messages.append({ "role": "user", "content": prompt }) # Retry logic with API key rotation last_error = None for attempt in range(self.max_retries): try: # Make API call response = self.client.chat.completions.create( model=self.model_name, messages=messages, max_tokens=max_tokens, temperature=temperature ) # Track successful request self.requests_per_key[self.api_key] = self.requests_per_key.get(self.api_key, 0) + 1 # Add additional delay for safety margin below RPM limit print(f"[RATE LIMIT] Adding safety delay: {self.rate_limit_delay}s") time.sleep(self.rate_limit_delay) return response.choices[0].message.content except Exception as e: last_error = e error_str = str(e).lower() # Check if it's a rate limit error if "rate" in error_str or "limit" in error_str or "429" in error_str or "quota" in error_str: print(f"[RATE LIMIT ERROR] Hit rate limit on attempt {attempt + 1}/{self.max_retries}") # Try rotating to another API key if self.rotate_api_key(): print(f"[API KEYS] Trying with different API key...") continue # If no more keys or rotation failed, wait and retry if attempt < self.max_retries - 1: print(f"[RATE LIMIT] Waiting {self.retry_delay}s before retry...") time.sleep(self.retry_delay) continue else: # Non-rate-limit error print(f"[ERROR] API error: {str(e)}") break print(f"[ERROR] Failed after {self.max_retries} attempts: {str(last_error)}") return f"Error: {str(last_error)}" async def generate_async( self, prompt: str, max_tokens: int = 1024, temperature: float = 0.7, system_prompt: Optional[str] = None ) -> str: """Asynchronous version of generate. Args: prompt: Input prompt max_tokens: Maximum tokens to generate temperature: Sampling temperature system_prompt: System prompt Returns: Generated text """ # Apply rate limiting await self.rate_limiter.acquire() # Prepare messages messages = [] if system_prompt: messages.append({ "role": "system", "content": system_prompt }) messages.append({ "role": "user", "content": prompt }) try: # Make API call (synchronous client used in async context) response = self.client.chat.completions.create( model=self.model_name, messages=messages, max_tokens=max_tokens, temperature=temperature ) # Add delay await asyncio.sleep(self.rate_limit_delay) return response.choices[0].message.content except Exception as e: print(f"Error generating response: {str(e)}") return f"Error: {str(e)}" def generate_with_context( self, query: str, context_documents: List[str], max_tokens: int = 1024, temperature: float = 0.7 ) -> str: """Generate response with retrieved context. Args: query: User query context_documents: List of retrieved documents max_tokens: Maximum tokens to generate temperature: Sampling temperature Returns: Generated response """ # Build context context = "\n\n".join([ f"Document {i+1}: {doc}" for i, doc in enumerate(context_documents) ]) # Build prompt prompt = f"""Answer the following question based on the provided context. Context: {context} Question: {query} Answer:""" system_prompt = "You are a helpful AI assistant. Answer questions based on the provided context. If the answer is not in the context, say so." return self.generate(prompt, max_tokens, temperature, system_prompt) def batch_generate( self, prompts: List[str], max_tokens: int = 1024, temperature: float = 0.7, system_prompt: Optional[str] = None ) -> List[str]: """Generate responses for multiple prompts. Args: prompts: List of prompts max_tokens: Maximum tokens to generate temperature: Sampling temperature system_prompt: System prompt Returns: List of generated responses """ responses = [] for i, prompt in enumerate(prompts): print(f"Processing prompt {i+1}/{len(prompts)}") response = self.generate(prompt, max_tokens, temperature, system_prompt) responses.append(response) return responses class OllamaLLMClient: """Client for local Ollama LLM - no rate limits, unlimited usage.""" def __init__( self, host: str = "http://localhost:11434", model_name: str = "gemma3:12b" ): """Initialize Ollama client. Args: host: Ollama server URL (default: http://localhost:11434) model_name: Name of the model (e.g., gemma3:12b, llama3.3) """ self.host = host.rstrip("/") self.model_name = model_name # Available models for Ollama self.available_models = [ "gemma3:12b", "llama3.3" ] print(f"[OLLAMA] Initialized with model: {model_name} at {host}") def check_connection(self) -> bool: """Check if Ollama server is running. Returns: True if connected, False otherwise """ try: response = requests.get(f"{self.host}/api/tags", timeout=5) return response.status_code == 200 except Exception as e: print(f"[OLLAMA] Connection error: {e}") return False def list_models(self) -> List[str]: """List available models on Ollama server. Returns: List of model names """ try: response = requests.get(f"{self.host}/api/tags", timeout=10) if response.status_code == 200: data = response.json() return [model["name"] for model in data.get("models", [])] return [] except Exception as e: print(f"[OLLAMA] Error listing models: {e}") return [] def set_model(self, model_name: str): """Set the LLM model. Args: model_name: Name of the model """ self.model_name = model_name print(f"[OLLAMA] Model set to: {model_name}") def generate( self, prompt: str, max_tokens: int = 1024, temperature: float = 0.7, system_prompt: Optional[str] = None ) -> str: """Generate text using local Ollama LLM. Args: prompt: Input prompt max_tokens: Maximum tokens to generate temperature: Sampling temperature system_prompt: System prompt Returns: Generated text """ print(f"[OLLAMA] Generating with model: {self.model_name} (no rate limits)") # Build full prompt with system prompt full_prompt = prompt if system_prompt: full_prompt = f"{system_prompt}\n\n{prompt}" try: # Make API call to Ollama response = requests.post( f"{self.host}/api/generate", json={ "model": self.model_name, "prompt": full_prompt, "options": { "num_predict": max_tokens, "temperature": temperature }, "stream": False }, timeout=600 # Longer timeout for local inference ) if response.status_code == 200: data = response.json() return data.get("response", "") else: error_msg = f"Ollama API error: {response.status_code} - {response.text}" print(f"[OLLAMA ERROR] {error_msg}") return f"Error: {error_msg}" except requests.exceptions.ConnectionError: error_msg = f"Cannot connect to Ollama at {self.host}. Is Ollama running?" print(f"[OLLAMA ERROR] {error_msg}") return f"Error: {error_msg}" except Exception as e: error_msg = f"Ollama error: {str(e)}" print(f"[OLLAMA ERROR] {error_msg}") return f"Error: {error_msg}" def generate_with_context( self, query: str, context_documents: List[str], max_tokens: int = 1024, temperature: float = 0.7 ) -> str: """Generate response with retrieved context. Args: query: User query context_documents: List of retrieved documents max_tokens: Maximum tokens to generate temperature: Sampling temperature Returns: Generated response """ # Build context context = "\n\n".join([ f"Document {i+1}: {doc}" for i, doc in enumerate(context_documents) ]) # Build prompt prompt = f"""Answer the following question based on the provided context. Context: {context} Question: {query} Answer:""" system_prompt = "You are a helpful AI assistant. Answer questions based on the provided context. If the answer is not in the context, say so." return self.generate(prompt, max_tokens, temperature, system_prompt) def batch_generate( self, prompts: List[str], max_tokens: int = 1024, temperature: float = 0.7, system_prompt: Optional[str] = None ) -> List[str]: """Generate responses for multiple prompts. Args: prompts: List of prompts max_tokens: Maximum tokens to generate temperature: Sampling temperature system_prompt: System prompt Returns: List of generated responses """ responses = [] for i, prompt in enumerate(prompts): print(f"[OLLAMA] Processing prompt {i+1}/{len(prompts)}") response = self.generate(prompt, max_tokens, temperature, system_prompt) responses.append(response) return responses def create_llm_client( provider: str = "groq", api_key: str = "", api_keys: list = None, model_name: str = None, ollama_host: str = "http://localhost:11434", max_rpm: int = 30, rate_limit_delay: float = 2.0, max_retries: int = 3, retry_delay: float = 60.0 ): """Factory function to create LLM client based on provider. Args: provider: "groq" or "ollama" api_key: Groq API key (for groq provider) api_keys: List of Groq API keys for rotation model_name: Model name (auto-detected if not provided) ollama_host: Ollama server URL max_rpm: Maximum requests per minute (for groq) rate_limit_delay: Delay between requests (for groq) max_retries: Max retries on error (for groq) retry_delay: Delay before retry (for groq) Returns: LLM client instance (GroqLLMClient or OllamaLLMClient) """ if provider.lower() == "ollama": if model_name is None: model_name = "gemma3:12b" print(f"[LLM FACTORY] Creating Ollama client with model: {model_name}") return OllamaLLMClient(host=ollama_host, model_name=model_name) else: if model_name is None: model_name = "llama-3.1-8b-instant" print(f"[LLM FACTORY] Creating Groq client with model: {model_name}") return GroqLLMClient( api_key=api_key, model_name=model_name, max_rpm=max_rpm, rate_limit_delay=rate_limit_delay, api_keys=api_keys, max_retries=max_retries, retry_delay=retry_delay ) class RAGPipeline: """Complete RAG pipeline with LLM and vector store.""" def __init__( self, llm_client, vector_store_manager ): """Initialize RAG pipeline. Args: llm_client: LLM client (GroqLLMClient or OllamaLLMClient) vector_store_manager: ChromaDB manager """ self.llm = llm_client self.vector_store = vector_store_manager self.chat_history = [] def query( self, query: str, n_results: int = 5, max_tokens: int = 1024, temperature: float = 0.7 ) -> Dict: """Query the RAG system. Args: query: User query n_results: Number of documents to retrieve max_tokens: Maximum tokens to generate temperature: Sampling temperature Returns: Dictionary with response and retrieved documents """ # Retrieve documents retrieved_docs = self.vector_store.get_retrieved_documents(query, n_results) # Extract document texts doc_texts = [doc["document"] for doc in retrieved_docs] # Generate response response = self.llm.generate_with_context( query, doc_texts, max_tokens, temperature ) # Store in chat history self.chat_history.append({ "query": query, "response": response, "retrieved_docs": retrieved_docs, "timestamp": datetime.now().isoformat() }) return { "query": query, "response": response, "retrieved_documents": retrieved_docs } def get_chat_history(self) -> List[Dict]: """Get chat history. Returns: List of chat history entries """ return self.chat_history def clear_history(self): """Clear chat history.""" self.chat_history = []