Jan Biermeyer
cpu
3379400
#!/usr/bin/env python3
"""
SUPRA RAG System with CPU/MPS/CUDA Optimizations
Optimized for CPU (HF Spaces), MPS (Apple Silicon), and CUDA with efficient memory management
"""
import json
import chromadb
import torch
import os
from sentence_transformers import SentenceTransformer
from pathlib import Path
from typing import List, Dict, Any
import streamlit as st
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SupraRAG:
def __init__(self, rag_data_path: str = None):
# Default RAG data path (for HF Spaces deployment)
if rag_data_path is None:
# Try multiple possible locations
possible_paths = [
Path("data/processed/rag_seeds/rag_seeds.jsonl"),
Path(__file__).parent.parent / "data/processed/rag_seeds/rag_seeds.jsonl",
Path("rag_seeds.jsonl"),
]
for path in possible_paths:
if path.exists():
rag_data_path = str(path)
break
else:
# Default fallback
rag_data_path = "data/processed/rag_seeds/rag_seeds.jsonl"
self.rag_data_path = Path(rag_data_path)
# Device-specific optimizations
self._setup_device_optimizations()
# Initialize ChromaDB with device optimizations
self.client = chromadb.Client()
self.collection_name = "supra_knowledge"
# Use efficient embedding model (CPU for HF Spaces free tier)
# CPU is optimal for sentence-transformers on CPU-only deployments
embedding_device = 'cpu' if self.device == 'cpu' else self.device
self.embedding_model = SentenceTransformer(
'all-MiniLM-L6-v2',
device=embedding_device
)
# Initialize or load collection
try:
self.collection = self.client.get_collection(self.collection_name)
# Check if collection needs to be reloaded (count doesn't match JSONL file)
current_count = len(self.collection.get()['ids']) if hasattr(self.collection, 'get') else 0
# Count expected documents from JSONL
expected_count = sum(1 for _ in open(self.rag_data_path, 'r', encoding='utf-8') if _.strip()) if self.rag_data_path.exists() else 0
if current_count != expected_count:
logger.info(f"🔄 Reloading RAG documents (current: {current_count}, expected: {expected_count})")
# Delete and recreate collection to reload
self.client.delete_collection(self.collection_name)
self.collection = self.client.create_collection(self.collection_name)
self._load_rag_documents()
else:
logger.info(f"✅ RAG knowledge base loaded ({current_count} documents)")
# Removed UI success message - shown in sidebar instead
except:
self.collection = self.client.create_collection(self.collection_name)
self._load_rag_documents()
def _setup_device_optimizations(self):
"""Configure optimizations for CPU/MPS/CUDA."""
logger.info("🔧 Setting up device optimizations...")
# Environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Detect device: MPS > CUDA > CPU
if torch.backends.mps.is_available():
logger.info("✅ MPS (Metal Performance Shaders) available - using MPS")
self.device = "mps"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
torch.backends.mps.is_built()
elif torch.cuda.is_available():
logger.info("✅ CUDA available - using GPU")
self.device = "cuda"
else:
logger.info("💻 CPU detected - using CPU optimizations")
self.device = "cpu"
logger.info(f"🔧 Using device: {self.device}")
def _load_rag_documents(self):
"""Load RAG documents from JSONL file with device optimizations."""
if not self.rag_data_path.exists():
logger.warning("⚠️ RAG data file not found")
if st:
st.warning("⚠️ RAG data file not found")
return
documents = []
metadatas = []
ids = []
logger.info(f"📚 Loading RAG documents from {self.rag_data_path}")
with open(self.rag_data_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
if line.strip():
try:
doc = json.loads(line)
if 'content' in doc and 'id' in doc:
# Truncate content for memory efficiency
content = doc['content']
if len(content) > 2000: # Limit content length
content = content[:2000] + "..."
documents.append(content)
metadatas.append({
'title': doc.get('title', ''),
'type': doc.get('type', ''),
'source': doc.get('source', ''),
'word_count': len(content.split())
})
ids.append(doc['id'])
else:
logger.warning(f"⚠️ Skipping line {line_num}: missing required fields")
except json.JSONDecodeError as e:
logger.warning(f"⚠️ Skipping line {line_num}: JSON decode error - {e}")
if documents:
# Add to ChromaDB with batch processing
batch_size = 50 # Smaller batches for memory efficiency
for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i+batch_size]
batch_metadatas = metadatas[i:i+batch_size]
batch_ids = ids[i:i+batch_size]
self.collection.add(
documents=batch_docs,
metadatas=batch_metadatas,
ids=batch_ids
)
logger.info(f"📊 Processed batch {i//batch_size + 1}/{(len(documents)-1)//batch_size + 1}")
logger.info(f"✅ Loaded {len(documents)} RAG documents")
# Removed UI success message - shown in sidebar instead
else:
logger.warning("⚠️ No valid documents found in RAG data file")
if st:
st.warning("⚠️ No valid documents found in RAG data file")
def retrieve_context(self, query: str, n_results: int = 3) -> List[Dict[str, Any]]:
"""Retrieve relevant context for a query with device optimizations."""
try:
# Limit query length for efficiency
if len(query) > 500:
query = query[:500]
results = self.collection.query(
query_texts=[query],
n_results=min(n_results, 5) # Limit results for efficiency
)
context_docs = []
for i, doc in enumerate(results['documents'][0]):
# Truncate retrieved content for memory efficiency
content = doc
if len(content) > 1500:
content = content[:1500] + "..."
context_docs.append({
'content': content,
'metadata': results['metadatas'][0][i],
'distance': results['distances'][0][i]
})
logger.info(f"🔍 Retrieved {len(context_docs)} context documents")
return context_docs
except Exception as e:
logger.error(f"RAG retrieval error: {e}")
if st:
st.error(f"RAG retrieval error: {e}")
return []
def build_enhanced_prompt(self, user_query: str, context_docs: List[Dict[str, Any]]) -> str:
"""Build enhanced prompt with RAG context and SUPRA facts with device optimizations."""
# Import SUPRA facts system
from .supra_facts import build_supra_prompt, inject_facts_for_query
# Extract RAG context chunks
rag_context = None
if context_docs:
# Limit context length for memory efficiency
max_context_length = 2000 # Reduced for memory efficiency
context_text = ""
for doc in context_docs:
doc_text = f"{doc['content'][:800]}"
if len(context_text + doc_text) > max_context_length:
break
context_text += doc_text + "\n\n"
rag_context = [context_text] if context_text else None
# Auto-detect relevant facts from query
facts = inject_facts_for_query(user_query)
# Get model name from model_loader to detect chat template
from .model_loader import get_model_info
try:
model_info = get_model_info()
# Get base model name to detect Llama vs Mistral
base_model = model_info.get('base_model', '')
if 'llama' in base_model.lower() or 'meta-llama' in base_model.lower():
model_name = 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'
else:
model_name = model_info.get('model_name', 'unsloth/mistral-7b-instruct-v0.3-bnb-4bit')
except:
# Default to Llama since latest models use Llama
model_name = 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'
# Build complete SUPRA prompt with system prompt, facts, and RAG context
enhanced_prompt = build_supra_prompt(
user_query=user_query,
facts=facts,
rag_context=rag_context,
model_name=model_name
)
return enhanced_prompt
def generate_response(self, query: str, model, tokenizer, max_new_tokens: int = 800) -> str:
"""Generate response using the enhanced model with RAG context."""
try:
logger.info(f"🤖 Generating response for query: {query[:50]}...")
# Get RAG context
context_docs = self.retrieve_context(query, n_results=3)
enhanced_prompt = self.build_enhanced_prompt(query, context_docs)
# Import the generation function
from .model_loader import generate_response_optimized
# Generate with enhanced model - tighter parameters for better quality
response = generate_response_optimized(
model=model,
tokenizer=tokenizer,
prompt=enhanced_prompt,
max_new_tokens=max_new_tokens,
temperature=0.6, # Lower temperature for more focused responses
top_p=0.85 # Tighter sampling
)
logger.info(f"✅ Generated response ({len(response)} characters)")
return response
except Exception as e:
logger.error(f"Error generating response: {e}")
if st:
st.error(f"Error generating response: {e}")
return f"I apologize, but I encountered an error while generating a response: {e}"
# Global RAG instance with device-specific optimizations
@st.cache_resource
def get_supra_rag():
"""Get cached SUPRA RAG instance optimized for CPU/MPS/CUDA."""
return SupraRAG()
# Backward compatibility (kept for compatibility with old imports)
def get_supra_rag_m2max():
"""Backward compatible function that returns device-optimized RAG."""
return get_supra_rag()