#!/usr/bin/env python3 """ SUPRA RAG System with CPU/MPS/CUDA Optimizations Optimized for CPU (HF Spaces), MPS (Apple Silicon), and CUDA with efficient memory management """ import json import chromadb import torch import os from sentence_transformers import SentenceTransformer from pathlib import Path from typing import List, Dict, Any import streamlit as st import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class SupraRAG: def __init__(self, rag_data_path: str = None): # Default RAG data path (for HF Spaces deployment) if rag_data_path is None: # Try multiple possible locations possible_paths = [ Path("data/processed/rag_seeds/rag_seeds.jsonl"), Path(__file__).parent.parent / "data/processed/rag_seeds/rag_seeds.jsonl", Path("rag_seeds.jsonl"), ] for path in possible_paths: if path.exists(): rag_data_path = str(path) break else: # Default fallback rag_data_path = "data/processed/rag_seeds/rag_seeds.jsonl" self.rag_data_path = Path(rag_data_path) # Device-specific optimizations self._setup_device_optimizations() # Initialize ChromaDB with device optimizations self.client = chromadb.Client() self.collection_name = "supra_knowledge" # Use efficient embedding model (CPU for HF Spaces free tier) # CPU is optimal for sentence-transformers on CPU-only deployments embedding_device = 'cpu' if self.device == 'cpu' else self.device self.embedding_model = SentenceTransformer( 'all-MiniLM-L6-v2', device=embedding_device ) # Initialize or load collection try: self.collection = self.client.get_collection(self.collection_name) # Check if collection needs to be reloaded (count doesn't match JSONL file) current_count = len(self.collection.get()['ids']) if hasattr(self.collection, 'get') else 0 # Count expected documents from JSONL expected_count = sum(1 for _ in open(self.rag_data_path, 'r', encoding='utf-8') if _.strip()) if self.rag_data_path.exists() else 0 if current_count != expected_count: logger.info(f"🔄 Reloading RAG documents (current: {current_count}, expected: {expected_count})") # Delete and recreate collection to reload self.client.delete_collection(self.collection_name) self.collection = self.client.create_collection(self.collection_name) self._load_rag_documents() else: logger.info(f"✅ RAG knowledge base loaded ({current_count} documents)") # Removed UI success message - shown in sidebar instead except: self.collection = self.client.create_collection(self.collection_name) self._load_rag_documents() def _setup_device_optimizations(self): """Configure optimizations for CPU/MPS/CUDA.""" logger.info("🔧 Setting up device optimizations...") # Environment variables os.environ["TOKENIZERS_PARALLELISM"] = "false" # Detect device: MPS > CUDA > CPU if torch.backends.mps.is_available(): logger.info("✅ MPS (Metal Performance Shaders) available - using MPS") self.device = "mps" os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" torch.backends.mps.is_built() elif torch.cuda.is_available(): logger.info("✅ CUDA available - using GPU") self.device = "cuda" else: logger.info("💻 CPU detected - using CPU optimizations") self.device = "cpu" logger.info(f"🔧 Using device: {self.device}") def _load_rag_documents(self): """Load RAG documents from JSONL file with device optimizations.""" if not self.rag_data_path.exists(): logger.warning("⚠️ RAG data file not found") if st: st.warning("⚠️ RAG data file not found") return documents = [] metadatas = [] ids = [] logger.info(f"📚 Loading RAG documents from {self.rag_data_path}") with open(self.rag_data_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): if line.strip(): try: doc = json.loads(line) if 'content' in doc and 'id' in doc: # Truncate content for memory efficiency content = doc['content'] if len(content) > 2000: # Limit content length content = content[:2000] + "..." documents.append(content) metadatas.append({ 'title': doc.get('title', ''), 'type': doc.get('type', ''), 'source': doc.get('source', ''), 'word_count': len(content.split()) }) ids.append(doc['id']) else: logger.warning(f"⚠️ Skipping line {line_num}: missing required fields") except json.JSONDecodeError as e: logger.warning(f"⚠️ Skipping line {line_num}: JSON decode error - {e}") if documents: # Add to ChromaDB with batch processing batch_size = 50 # Smaller batches for memory efficiency for i in range(0, len(documents), batch_size): batch_docs = documents[i:i+batch_size] batch_metadatas = metadatas[i:i+batch_size] batch_ids = ids[i:i+batch_size] self.collection.add( documents=batch_docs, metadatas=batch_metadatas, ids=batch_ids ) logger.info(f"📊 Processed batch {i//batch_size + 1}/{(len(documents)-1)//batch_size + 1}") logger.info(f"✅ Loaded {len(documents)} RAG documents") # Removed UI success message - shown in sidebar instead else: logger.warning("⚠️ No valid documents found in RAG data file") if st: st.warning("⚠️ No valid documents found in RAG data file") def retrieve_context(self, query: str, n_results: int = 3) -> List[Dict[str, Any]]: """Retrieve relevant context for a query with device optimizations.""" try: # Limit query length for efficiency if len(query) > 500: query = query[:500] results = self.collection.query( query_texts=[query], n_results=min(n_results, 5) # Limit results for efficiency ) context_docs = [] for i, doc in enumerate(results['documents'][0]): # Truncate retrieved content for memory efficiency content = doc if len(content) > 1500: content = content[:1500] + "..." context_docs.append({ 'content': content, 'metadata': results['metadatas'][0][i], 'distance': results['distances'][0][i] }) logger.info(f"🔍 Retrieved {len(context_docs)} context documents") return context_docs except Exception as e: logger.error(f"RAG retrieval error: {e}") if st: st.error(f"RAG retrieval error: {e}") return [] def build_enhanced_prompt(self, user_query: str, context_docs: List[Dict[str, Any]]) -> str: """Build enhanced prompt with RAG context and SUPRA facts with device optimizations.""" # Import SUPRA facts system from .supra_facts import build_supra_prompt, inject_facts_for_query # Extract RAG context chunks rag_context = None if context_docs: # Limit context length for memory efficiency max_context_length = 2000 # Reduced for memory efficiency context_text = "" for doc in context_docs: doc_text = f"{doc['content'][:800]}" if len(context_text + doc_text) > max_context_length: break context_text += doc_text + "\n\n" rag_context = [context_text] if context_text else None # Auto-detect relevant facts from query facts = inject_facts_for_query(user_query) # Get model name from model_loader to detect chat template from .model_loader import get_model_info try: model_info = get_model_info() # Get base model name to detect Llama vs Mistral base_model = model_info.get('base_model', '') if 'llama' in base_model.lower() or 'meta-llama' in base_model.lower(): model_name = 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit' else: model_name = model_info.get('model_name', 'unsloth/mistral-7b-instruct-v0.3-bnb-4bit') except: # Default to Llama since latest models use Llama model_name = 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit' # Build complete SUPRA prompt with system prompt, facts, and RAG context enhanced_prompt = build_supra_prompt( user_query=user_query, facts=facts, rag_context=rag_context, model_name=model_name ) return enhanced_prompt def generate_response(self, query: str, model, tokenizer, max_new_tokens: int = 800) -> str: """Generate response using the enhanced model with RAG context.""" try: logger.info(f"🤖 Generating response for query: {query[:50]}...") # Get RAG context context_docs = self.retrieve_context(query, n_results=3) enhanced_prompt = self.build_enhanced_prompt(query, context_docs) # Import the generation function from .model_loader import generate_response_optimized # Generate with enhanced model - tighter parameters for better quality response = generate_response_optimized( model=model, tokenizer=tokenizer, prompt=enhanced_prompt, max_new_tokens=max_new_tokens, temperature=0.6, # Lower temperature for more focused responses top_p=0.85 # Tighter sampling ) logger.info(f"✅ Generated response ({len(response)} characters)") return response except Exception as e: logger.error(f"Error generating response: {e}") if st: st.error(f"Error generating response: {e}") return f"I apologize, but I encountered an error while generating a response: {e}" # Global RAG instance with device-specific optimizations @st.cache_resource def get_supra_rag(): """Get cached SUPRA RAG instance optimized for CPU/MPS/CUDA.""" return SupraRAG() # Backward compatibility (kept for compatibility with old imports) def get_supra_rag_m2max(): """Backward compatible function that returns device-optimized RAG.""" return get_supra_rag()