Spaces:
Sleeping
Sleeping
| """LLM integration with support for Groq API and local Ollama.""" | |
| from typing import List, Dict, Optional, AsyncIterator | |
| import time | |
| from groq import Groq | |
| import asyncio | |
| from datetime import datetime, timedelta | |
| from collections import deque | |
| import os | |
| import requests | |
| import json | |
| class RateLimiter: | |
| """Rate limiter for API calls to respect RPM limits.""" | |
| def __init__(self, max_requests_per_minute: int = 30): | |
| """Initialize rate limiter. | |
| Args: | |
| max_requests_per_minute: Maximum requests allowed per minute (default: 30) | |
| """ | |
| self.max_requests = max_requests_per_minute | |
| self.request_times = deque() | |
| self.lock = asyncio.Lock() | |
| # Calculate minimum delay between requests to avoid hitting limit | |
| # E.g., 30 RPM = 2.0 seconds minimum delay, but we add safety margin | |
| self.min_interval = 60.0 / max_requests_per_minute | |
| async def acquire(self): | |
| """Acquire permission to make a request (async version).""" | |
| async with self.lock: | |
| now = datetime.now() | |
| # Remove requests older than 1 minute | |
| while self.request_times and (now - self.request_times[0]) > timedelta(minutes=1): | |
| self.request_times.popleft() | |
| # If at limit, wait | |
| if len(self.request_times) >= self.max_requests: | |
| # Calculate how long to wait | |
| oldest_request = self.request_times[0] | |
| wait_time = 60 - (now - oldest_request).total_seconds() | |
| if wait_time > 0: | |
| print(f"[RATE LIMIT] At {self.max_requests} RPM limit. Waiting {wait_time:.2f}s...") | |
| await asyncio.sleep(wait_time) | |
| # Recursive call after waiting | |
| return await self.acquire() | |
| # Record this request | |
| self.request_times.append(now) | |
| def acquire_sync(self): | |
| """Synchronous version of acquire (for blocking code).""" | |
| now = datetime.now() | |
| # Remove requests older than 1 minute | |
| while self.request_times and (now - self.request_times[0]) > timedelta(minutes=1): | |
| self.request_times.popleft() | |
| # Calculate current request rate | |
| current_rpm = len(self.request_times) | |
| # If at limit, wait | |
| if len(self.request_times) >= self.max_requests: | |
| oldest_request = self.request_times[0] | |
| wait_time = 60 - (now - oldest_request).total_seconds() | |
| if wait_time > 0: | |
| print(f"[RATE LIMIT] At {self.max_requests} RPM limit. Waiting {wait_time:.2f}s before next request...") | |
| time.sleep(wait_time) | |
| return self.acquire_sync() | |
| # Record this request | |
| self.request_times.append(now) | |
| # Log current rate | |
| if current_rpm > 0: | |
| print(f"[RATE LIMIT] Current: {current_rpm} requests in last minute (Limit: {self.max_requests} RPM)") | |
| class GroqLLMClient: | |
| """Client for Groq LLM API with rate limiting and API key rotation.""" | |
| def __init__( | |
| self, | |
| api_key: str, | |
| model_name: str = "llama-3.1-8b-instant", | |
| max_rpm: int = 30, | |
| rate_limit_delay: float = 2.0, | |
| api_keys: list = None, | |
| max_retries: int = 3, | |
| retry_delay: float = 60.0 | |
| ): | |
| """Initialize Groq client with optional API key rotation. | |
| Args: | |
| api_key: Primary Groq API key | |
| model_name: Name of the LLM model | |
| max_rpm: Maximum requests per minute | |
| rate_limit_delay: Additional delay between requests (seconds) | |
| api_keys: List of API keys for rotation (optional) | |
| max_retries: Maximum retries on rate limit errors | |
| retry_delay: Delay before retry on rate limit error | |
| """ | |
| # Setup API key rotation | |
| self.api_keys = api_keys if api_keys else [api_key] | |
| self.current_key_index = 0 | |
| self.api_key = self.api_keys[self.current_key_index] | |
| self.client = Groq(api_key=self.api_key) | |
| self.model_name = model_name | |
| self.rate_limiter = RateLimiter(max_rpm) | |
| self.rate_limit_delay = rate_limit_delay | |
| self.max_retries = max_retries | |
| self.retry_delay = retry_delay | |
| # Track requests per key for smart rotation | |
| self.requests_per_key = {key: 0 for key in self.api_keys} | |
| # Available models | |
| self.available_models = [ | |
| "meta-llama/llama-4-maverick-17b-128e-instruct", | |
| "llama-3.1-8b-instant", | |
| "openai/gpt-oss-120b" | |
| ] | |
| if len(self.api_keys) > 1: | |
| print(f"[API KEYS] Initialized with {len(self.api_keys)} API keys for rotation") | |
| def rotate_api_key(self): | |
| """Rotate to the next API key.""" | |
| if len(self.api_keys) <= 1: | |
| return False | |
| self.current_key_index = (self.current_key_index + 1) % len(self.api_keys) | |
| self.api_key = self.api_keys[self.current_key_index] | |
| self.client = Groq(api_key=self.api_key) | |
| print(f"[API KEYS] Rotated to API key {self.current_key_index + 1}/{len(self.api_keys)}") | |
| return True | |
| def set_model(self, model_name: str): | |
| """Set the LLM model. | |
| Args: | |
| model_name: Name of the model | |
| """ | |
| if model_name not in self.available_models: | |
| print(f"Warning: {model_name} not in available models. Using anyway...") | |
| self.model_name = model_name | |
| def generate( | |
| self, | |
| prompt: str, | |
| max_tokens: int = 1024, | |
| temperature: float = 0.7, | |
| system_prompt: Optional[str] = None | |
| ) -> str: | |
| """Generate text using Groq LLM with rate limiting and retry logic. | |
| Args: | |
| prompt: Input prompt | |
| max_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature | |
| system_prompt: System prompt | |
| Returns: | |
| Generated text | |
| """ | |
| # Apply rate limiting to respect 30 RPM limit | |
| print(f"[RATE LIMIT] Applying rate limiting (RPM limit: {self.rate_limiter.max_requests}, delay: {self.rate_limit_delay}s)") | |
| self.rate_limiter.acquire_sync() | |
| # Prepare messages | |
| messages = [] | |
| if system_prompt: | |
| messages.append({ | |
| "role": "system", | |
| "content": system_prompt | |
| }) | |
| messages.append({ | |
| "role": "user", | |
| "content": prompt | |
| }) | |
| # Retry logic with API key rotation | |
| last_error = None | |
| for attempt in range(self.max_retries): | |
| try: | |
| # Make API call | |
| response = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature | |
| ) | |
| # Track successful request | |
| self.requests_per_key[self.api_key] = self.requests_per_key.get(self.api_key, 0) + 1 | |
| # Add additional delay for safety margin below RPM limit | |
| print(f"[RATE LIMIT] Adding safety delay: {self.rate_limit_delay}s") | |
| time.sleep(self.rate_limit_delay) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| last_error = e | |
| error_str = str(e).lower() | |
| # Check if it's a rate limit error | |
| if "rate" in error_str or "limit" in error_str or "429" in error_str or "quota" in error_str: | |
| print(f"[RATE LIMIT ERROR] Hit rate limit on attempt {attempt + 1}/{self.max_retries}") | |
| # Try rotating to another API key | |
| if self.rotate_api_key(): | |
| print(f"[API KEYS] Trying with different API key...") | |
| continue | |
| # If no more keys or rotation failed, wait and retry | |
| if attempt < self.max_retries - 1: | |
| print(f"[RATE LIMIT] Waiting {self.retry_delay}s before retry...") | |
| time.sleep(self.retry_delay) | |
| continue | |
| else: | |
| # Non-rate-limit error | |
| print(f"[ERROR] API error: {str(e)}") | |
| break | |
| print(f"[ERROR] Failed after {self.max_retries} attempts: {str(last_error)}") | |
| return f"Error: {str(last_error)}" | |
| async def generate_async( | |
| self, | |
| prompt: str, | |
| max_tokens: int = 1024, | |
| temperature: float = 0.7, | |
| system_prompt: Optional[str] = None | |
| ) -> str: | |
| """Asynchronous version of generate. | |
| Args: | |
| prompt: Input prompt | |
| max_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature | |
| system_prompt: System prompt | |
| Returns: | |
| Generated text | |
| """ | |
| # Apply rate limiting | |
| await self.rate_limiter.acquire() | |
| # Prepare messages | |
| messages = [] | |
| if system_prompt: | |
| messages.append({ | |
| "role": "system", | |
| "content": system_prompt | |
| }) | |
| messages.append({ | |
| "role": "user", | |
| "content": prompt | |
| }) | |
| try: | |
| # Make API call (synchronous client used in async context) | |
| response = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature | |
| ) | |
| # Add delay | |
| await asyncio.sleep(self.rate_limit_delay) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| print(f"Error generating response: {str(e)}") | |
| return f"Error: {str(e)}" | |
| def generate_with_context( | |
| self, | |
| query: str, | |
| context_documents: List[str], | |
| max_tokens: int = 1024, | |
| temperature: float = 0.7 | |
| ) -> str: | |
| """Generate response with retrieved context. | |
| Args: | |
| query: User query | |
| context_documents: List of retrieved documents | |
| max_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature | |
| Returns: | |
| Generated response | |
| """ | |
| # Build context | |
| context = "\n\n".join([ | |
| f"Document {i+1}: {doc}" | |
| for i, doc in enumerate(context_documents) | |
| ]) | |
| # Build prompt | |
| prompt = f"""Answer the following question based on the provided context. | |
| Context: | |
| {context} | |
| Question: {query} | |
| Answer:""" | |
| system_prompt = "You are a helpful AI assistant. Answer questions based on the provided context. If the answer is not in the context, say so." | |
| return self.generate(prompt, max_tokens, temperature, system_prompt) | |
| def batch_generate( | |
| self, | |
| prompts: List[str], | |
| max_tokens: int = 1024, | |
| temperature: float = 0.7, | |
| system_prompt: Optional[str] = None | |
| ) -> List[str]: | |
| """Generate responses for multiple prompts. | |
| Args: | |
| prompts: List of prompts | |
| max_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature | |
| system_prompt: System prompt | |
| Returns: | |
| List of generated responses | |
| """ | |
| responses = [] | |
| for i, prompt in enumerate(prompts): | |
| print(f"Processing prompt {i+1}/{len(prompts)}") | |
| response = self.generate(prompt, max_tokens, temperature, system_prompt) | |
| responses.append(response) | |
| return responses | |
| class OllamaLLMClient: | |
| """Client for local Ollama LLM - no rate limits, unlimited usage.""" | |
| def __init__( | |
| self, | |
| host: str = "http://localhost:11434", | |
| model_name: str = "gemma3:12b" | |
| ): | |
| """Initialize Ollama client. | |
| Args: | |
| host: Ollama server URL (default: http://localhost:11434) | |
| model_name: Name of the model (e.g., gemma3:12b, llama3.3) | |
| """ | |
| self.host = host.rstrip("/") | |
| self.model_name = model_name | |
| # Available models for Ollama | |
| self.available_models = [ | |
| "gemma3:12b", | |
| "llama3.3" | |
| ] | |
| print(f"[OLLAMA] Initialized with model: {model_name} at {host}") | |
| def check_connection(self) -> bool: | |
| """Check if Ollama server is running. | |
| Returns: | |
| True if connected, False otherwise | |
| """ | |
| try: | |
| response = requests.get(f"{self.host}/api/tags", timeout=5) | |
| return response.status_code == 200 | |
| except Exception as e: | |
| print(f"[OLLAMA] Connection error: {e}") | |
| return False | |
| def list_models(self) -> List[str]: | |
| """List available models on Ollama server. | |
| Returns: | |
| List of model names | |
| """ | |
| try: | |
| response = requests.get(f"{self.host}/api/tags", timeout=10) | |
| if response.status_code == 200: | |
| data = response.json() | |
| return [model["name"] for model in data.get("models", [])] | |
| return [] | |
| except Exception as e: | |
| print(f"[OLLAMA] Error listing models: {e}") | |
| return [] | |
| def set_model(self, model_name: str): | |
| """Set the LLM model. | |
| Args: | |
| model_name: Name of the model | |
| """ | |
| self.model_name = model_name | |
| print(f"[OLLAMA] Model set to: {model_name}") | |
| def generate( | |
| self, | |
| prompt: str, | |
| max_tokens: int = 1024, | |
| temperature: float = 0.7, | |
| system_prompt: Optional[str] = None | |
| ) -> str: | |
| """Generate text using local Ollama LLM. | |
| Args: | |
| prompt: Input prompt | |
| max_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature | |
| system_prompt: System prompt | |
| Returns: | |
| Generated text | |
| """ | |
| print(f"[OLLAMA] Generating with model: {self.model_name} (no rate limits)") | |
| # Build full prompt with system prompt | |
| full_prompt = prompt | |
| if system_prompt: | |
| full_prompt = f"{system_prompt}\n\n{prompt}" | |
| try: | |
| # Make API call to Ollama | |
| response = requests.post( | |
| f"{self.host}/api/generate", | |
| json={ | |
| "model": self.model_name, | |
| "prompt": full_prompt, | |
| "options": { | |
| "num_predict": max_tokens, | |
| "temperature": temperature | |
| }, | |
| "stream": False | |
| }, | |
| timeout=600 # Longer timeout for local inference | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| return data.get("response", "") | |
| else: | |
| error_msg = f"Ollama API error: {response.status_code} - {response.text}" | |
| print(f"[OLLAMA ERROR] {error_msg}") | |
| return f"Error: {error_msg}" | |
| except requests.exceptions.ConnectionError: | |
| error_msg = f"Cannot connect to Ollama at {self.host}. Is Ollama running?" | |
| print(f"[OLLAMA ERROR] {error_msg}") | |
| return f"Error: {error_msg}" | |
| except Exception as e: | |
| error_msg = f"Ollama error: {str(e)}" | |
| print(f"[OLLAMA ERROR] {error_msg}") | |
| return f"Error: {error_msg}" | |
| def generate_with_context( | |
| self, | |
| query: str, | |
| context_documents: List[str], | |
| max_tokens: int = 1024, | |
| temperature: float = 0.7 | |
| ) -> str: | |
| """Generate response with retrieved context. | |
| Args: | |
| query: User query | |
| context_documents: List of retrieved documents | |
| max_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature | |
| Returns: | |
| Generated response | |
| """ | |
| # Build context | |
| context = "\n\n".join([ | |
| f"Document {i+1}: {doc}" | |
| for i, doc in enumerate(context_documents) | |
| ]) | |
| # Build prompt | |
| prompt = f"""Answer the following question based on the provided context. | |
| Context: | |
| {context} | |
| Question: {query} | |
| Answer:""" | |
| system_prompt = "You are a helpful AI assistant. Answer questions based on the provided context. If the answer is not in the context, say so." | |
| return self.generate(prompt, max_tokens, temperature, system_prompt) | |
| def batch_generate( | |
| self, | |
| prompts: List[str], | |
| max_tokens: int = 1024, | |
| temperature: float = 0.7, | |
| system_prompt: Optional[str] = None | |
| ) -> List[str]: | |
| """Generate responses for multiple prompts. | |
| Args: | |
| prompts: List of prompts | |
| max_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature | |
| system_prompt: System prompt | |
| Returns: | |
| List of generated responses | |
| """ | |
| responses = [] | |
| for i, prompt in enumerate(prompts): | |
| print(f"[OLLAMA] Processing prompt {i+1}/{len(prompts)}") | |
| response = self.generate(prompt, max_tokens, temperature, system_prompt) | |
| responses.append(response) | |
| return responses | |
| def create_llm_client( | |
| provider: str = "groq", | |
| api_key: str = "", | |
| api_keys: list = None, | |
| model_name: str = None, | |
| ollama_host: str = "http://localhost:11434", | |
| max_rpm: int = 30, | |
| rate_limit_delay: float = 2.0, | |
| max_retries: int = 3, | |
| retry_delay: float = 60.0 | |
| ): | |
| """Factory function to create LLM client based on provider. | |
| Args: | |
| provider: "groq" or "ollama" | |
| api_key: Groq API key (for groq provider) | |
| api_keys: List of Groq API keys for rotation | |
| model_name: Model name (auto-detected if not provided) | |
| ollama_host: Ollama server URL | |
| max_rpm: Maximum requests per minute (for groq) | |
| rate_limit_delay: Delay between requests (for groq) | |
| max_retries: Max retries on error (for groq) | |
| retry_delay: Delay before retry (for groq) | |
| Returns: | |
| LLM client instance (GroqLLMClient or OllamaLLMClient) | |
| """ | |
| if provider.lower() == "ollama": | |
| if model_name is None: | |
| model_name = "gemma3:12b" | |
| print(f"[LLM FACTORY] Creating Ollama client with model: {model_name}") | |
| return OllamaLLMClient(host=ollama_host, model_name=model_name) | |
| else: | |
| if model_name is None: | |
| model_name = "llama-3.1-8b-instant" | |
| print(f"[LLM FACTORY] Creating Groq client with model: {model_name}") | |
| return GroqLLMClient( | |
| api_key=api_key, | |
| model_name=model_name, | |
| max_rpm=max_rpm, | |
| rate_limit_delay=rate_limit_delay, | |
| api_keys=api_keys, | |
| max_retries=max_retries, | |
| retry_delay=retry_delay | |
| ) | |
| class RAGPipeline: | |
| """Complete RAG pipeline with LLM and vector store.""" | |
| def __init__( | |
| self, | |
| llm_client, | |
| vector_store_manager | |
| ): | |
| """Initialize RAG pipeline. | |
| Args: | |
| llm_client: LLM client (GroqLLMClient or OllamaLLMClient) | |
| vector_store_manager: ChromaDB manager | |
| """ | |
| self.llm = llm_client | |
| self.vector_store = vector_store_manager | |
| self.chat_history = [] | |
| def query( | |
| self, | |
| query: str, | |
| n_results: int = 5, | |
| max_tokens: int = 1024, | |
| temperature: float = 0.7 | |
| ) -> Dict: | |
| """Query the RAG system. | |
| Args: | |
| query: User query | |
| n_results: Number of documents to retrieve | |
| max_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature | |
| Returns: | |
| Dictionary with response and retrieved documents | |
| """ | |
| # Retrieve documents | |
| retrieved_docs = self.vector_store.get_retrieved_documents(query, n_results) | |
| # Extract document texts | |
| doc_texts = [doc["document"] for doc in retrieved_docs] | |
| # Generate response | |
| response = self.llm.generate_with_context( | |
| query, | |
| doc_texts, | |
| max_tokens, | |
| temperature | |
| ) | |
| # Store in chat history | |
| self.chat_history.append({ | |
| "query": query, | |
| "response": response, | |
| "retrieved_docs": retrieved_docs, | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| return { | |
| "query": query, | |
| "response": response, | |
| "retrieved_documents": retrieved_docs | |
| } | |
| def get_chat_history(self) -> List[Dict]: | |
| """Get chat history. | |
| Returns: | |
| List of chat history entries | |
| """ | |
| return self.chat_history | |
| def clear_history(self): | |
| """Clear chat history.""" | |
| self.chat_history = [] | |