Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| SUPRA RAG System with CPU/MPS/CUDA Optimizations | |
| Optimized for CPU (HF Spaces), MPS (Apple Silicon), and CUDA with efficient memory management | |
| """ | |
| import json | |
| import chromadb | |
| import torch | |
| import os | |
| from sentence_transformers import SentenceTransformer | |
| from pathlib import Path | |
| from typing import List, Dict, Any | |
| import streamlit as st | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class SupraRAG: | |
| def __init__(self, rag_data_path: str = None): | |
| # Default RAG data path (for HF Spaces deployment) | |
| if rag_data_path is None: | |
| # Try multiple possible locations | |
| possible_paths = [ | |
| Path("data/processed/rag_seeds/rag_seeds.jsonl"), | |
| Path(__file__).parent.parent / "data/processed/rag_seeds/rag_seeds.jsonl", | |
| Path("rag_seeds.jsonl"), | |
| ] | |
| for path in possible_paths: | |
| if path.exists(): | |
| rag_data_path = str(path) | |
| break | |
| else: | |
| # Default fallback | |
| rag_data_path = "data/processed/rag_seeds/rag_seeds.jsonl" | |
| self.rag_data_path = Path(rag_data_path) | |
| # Device-specific optimizations | |
| self._setup_device_optimizations() | |
| # Initialize ChromaDB with device optimizations | |
| self.client = chromadb.Client() | |
| self.collection_name = "supra_knowledge" | |
| # Use efficient embedding model (CPU for HF Spaces free tier) | |
| # CPU is optimal for sentence-transformers on CPU-only deployments | |
| embedding_device = 'cpu' if self.device == 'cpu' else self.device | |
| self.embedding_model = SentenceTransformer( | |
| 'all-MiniLM-L6-v2', | |
| device=embedding_device | |
| ) | |
| # Initialize or load collection | |
| try: | |
| self.collection = self.client.get_collection(self.collection_name) | |
| # Check if collection needs to be reloaded (count doesn't match JSONL file) | |
| current_count = len(self.collection.get()['ids']) if hasattr(self.collection, 'get') else 0 | |
| # Count expected documents from JSONL | |
| expected_count = sum(1 for _ in open(self.rag_data_path, 'r', encoding='utf-8') if _.strip()) if self.rag_data_path.exists() else 0 | |
| if current_count != expected_count: | |
| logger.info(f"🔄 Reloading RAG documents (current: {current_count}, expected: {expected_count})") | |
| # Delete and recreate collection to reload | |
| self.client.delete_collection(self.collection_name) | |
| self.collection = self.client.create_collection(self.collection_name) | |
| self._load_rag_documents() | |
| else: | |
| logger.info(f"✅ RAG knowledge base loaded ({current_count} documents)") | |
| # Removed UI success message - shown in sidebar instead | |
| except: | |
| self.collection = self.client.create_collection(self.collection_name) | |
| self._load_rag_documents() | |
| def _setup_device_optimizations(self): | |
| """Configure optimizations for CPU/MPS/CUDA.""" | |
| logger.info("🔧 Setting up device optimizations...") | |
| # Environment variables | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # Detect device: MPS > CUDA > CPU | |
| if torch.backends.mps.is_available(): | |
| logger.info("✅ MPS (Metal Performance Shaders) available - using MPS") | |
| self.device = "mps" | |
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
| torch.backends.mps.is_built() | |
| elif torch.cuda.is_available(): | |
| logger.info("✅ CUDA available - using GPU") | |
| self.device = "cuda" | |
| else: | |
| logger.info("💻 CPU detected - using CPU optimizations") | |
| self.device = "cpu" | |
| logger.info(f"🔧 Using device: {self.device}") | |
| def _load_rag_documents(self): | |
| """Load RAG documents from JSONL file with device optimizations.""" | |
| if not self.rag_data_path.exists(): | |
| logger.warning("⚠️ RAG data file not found") | |
| if st: | |
| st.warning("⚠️ RAG data file not found") | |
| return | |
| documents = [] | |
| metadatas = [] | |
| ids = [] | |
| logger.info(f"📚 Loading RAG documents from {self.rag_data_path}") | |
| with open(self.rag_data_path, 'r', encoding='utf-8') as f: | |
| for line_num, line in enumerate(f, 1): | |
| if line.strip(): | |
| try: | |
| doc = json.loads(line) | |
| if 'content' in doc and 'id' in doc: | |
| # Truncate content for memory efficiency | |
| content = doc['content'] | |
| if len(content) > 2000: # Limit content length | |
| content = content[:2000] + "..." | |
| documents.append(content) | |
| metadatas.append({ | |
| 'title': doc.get('title', ''), | |
| 'type': doc.get('type', ''), | |
| 'source': doc.get('source', ''), | |
| 'word_count': len(content.split()) | |
| }) | |
| ids.append(doc['id']) | |
| else: | |
| logger.warning(f"⚠️ Skipping line {line_num}: missing required fields") | |
| except json.JSONDecodeError as e: | |
| logger.warning(f"⚠️ Skipping line {line_num}: JSON decode error - {e}") | |
| if documents: | |
| # Add to ChromaDB with batch processing | |
| batch_size = 50 # Smaller batches for memory efficiency | |
| for i in range(0, len(documents), batch_size): | |
| batch_docs = documents[i:i+batch_size] | |
| batch_metadatas = metadatas[i:i+batch_size] | |
| batch_ids = ids[i:i+batch_size] | |
| self.collection.add( | |
| documents=batch_docs, | |
| metadatas=batch_metadatas, | |
| ids=batch_ids | |
| ) | |
| logger.info(f"📊 Processed batch {i//batch_size + 1}/{(len(documents)-1)//batch_size + 1}") | |
| logger.info(f"✅ Loaded {len(documents)} RAG documents") | |
| # Removed UI success message - shown in sidebar instead | |
| else: | |
| logger.warning("⚠️ No valid documents found in RAG data file") | |
| if st: | |
| st.warning("⚠️ No valid documents found in RAG data file") | |
| def retrieve_context(self, query: str, n_results: int = 3) -> List[Dict[str, Any]]: | |
| """Retrieve relevant context for a query with device optimizations.""" | |
| try: | |
| # Limit query length for efficiency | |
| if len(query) > 500: | |
| query = query[:500] | |
| results = self.collection.query( | |
| query_texts=[query], | |
| n_results=min(n_results, 5) # Limit results for efficiency | |
| ) | |
| context_docs = [] | |
| for i, doc in enumerate(results['documents'][0]): | |
| # Truncate retrieved content for memory efficiency | |
| content = doc | |
| if len(content) > 1500: | |
| content = content[:1500] + "..." | |
| context_docs.append({ | |
| 'content': content, | |
| 'metadata': results['metadatas'][0][i], | |
| 'distance': results['distances'][0][i] | |
| }) | |
| logger.info(f"🔍 Retrieved {len(context_docs)} context documents") | |
| return context_docs | |
| except Exception as e: | |
| logger.error(f"RAG retrieval error: {e}") | |
| if st: | |
| st.error(f"RAG retrieval error: {e}") | |
| return [] | |
| def build_enhanced_prompt(self, user_query: str, context_docs: List[Dict[str, Any]]) -> str: | |
| """Build enhanced prompt with RAG context and SUPRA facts with device optimizations.""" | |
| # Import SUPRA facts system | |
| from .supra_facts import build_supra_prompt, inject_facts_for_query | |
| # Extract RAG context chunks | |
| rag_context = None | |
| if context_docs: | |
| # Limit context length for memory efficiency | |
| max_context_length = 2000 # Reduced for memory efficiency | |
| context_text = "" | |
| for doc in context_docs: | |
| doc_text = f"{doc['content'][:800]}" | |
| if len(context_text + doc_text) > max_context_length: | |
| break | |
| context_text += doc_text + "\n\n" | |
| rag_context = [context_text] if context_text else None | |
| # Auto-detect relevant facts from query | |
| facts = inject_facts_for_query(user_query) | |
| # Get model name from model_loader to detect chat template | |
| from .model_loader import get_model_info | |
| try: | |
| model_info = get_model_info() | |
| # Get base model name to detect Llama vs Mistral | |
| base_model = model_info.get('base_model', '') | |
| if 'llama' in base_model.lower() or 'meta-llama' in base_model.lower(): | |
| model_name = 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit' | |
| else: | |
| model_name = model_info.get('model_name', 'unsloth/mistral-7b-instruct-v0.3-bnb-4bit') | |
| except: | |
| # Default to Llama since latest models use Llama | |
| model_name = 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit' | |
| # Build complete SUPRA prompt with system prompt, facts, and RAG context | |
| enhanced_prompt = build_supra_prompt( | |
| user_query=user_query, | |
| facts=facts, | |
| rag_context=rag_context, | |
| model_name=model_name | |
| ) | |
| return enhanced_prompt | |
| def generate_response(self, query: str, model, tokenizer, max_new_tokens: int = 800) -> str: | |
| """Generate response using the enhanced model with RAG context.""" | |
| try: | |
| logger.info(f"🤖 Generating response for query: {query[:50]}...") | |
| # Get RAG context | |
| context_docs = self.retrieve_context(query, n_results=3) | |
| enhanced_prompt = self.build_enhanced_prompt(query, context_docs) | |
| # Import the generation function | |
| from .model_loader import generate_response_optimized | |
| # Generate with enhanced model - tighter parameters for better quality | |
| response = generate_response_optimized( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt=enhanced_prompt, | |
| max_new_tokens=max_new_tokens, | |
| temperature=0.6, # Lower temperature for more focused responses | |
| top_p=0.85 # Tighter sampling | |
| ) | |
| logger.info(f"✅ Generated response ({len(response)} characters)") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error generating response: {e}") | |
| if st: | |
| st.error(f"Error generating response: {e}") | |
| return f"I apologize, but I encountered an error while generating a response: {e}" | |
| # Global RAG instance with device-specific optimizations | |
| def get_supra_rag(): | |
| """Get cached SUPRA RAG instance optimized for CPU/MPS/CUDA.""" | |
| return SupraRAG() | |
| # Backward compatibility (kept for compatibility with old imports) | |
| def get_supra_rag_m2max(): | |
| """Backward compatible function that returns device-optimized RAG.""" | |
| return get_supra_rag() | |