File size: 12,041 Bytes
ea2a063
 
34fc1eb
 
ea2a063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34fc1eb
ea2a063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34fc1eb
 
ea2a063
34fc1eb
ea2a063
 
 
34fc1eb
 
 
ea2a063
 
34fc1eb
ea2a063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34fc1eb
 
 
ea2a063
34fc1eb
ea2a063
 
34fc1eb
ea2a063
34fc1eb
ea2a063
34fc1eb
 
 
 
 
ea2a063
34fc1eb
ea2a063
 
 
 
 
34fc1eb
ea2a063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34fc1eb
ea2a063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34fc1eb
 
ea2a063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3379400
ea2a063
3379400
ea2a063
 
 
 
 
3379400
ea2a063
 
 
 
3379400
ea2a063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3379400
ea2a063
 
 
 
 
 
3379400
 
ea2a063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34fc1eb
ea2a063
3379400
34fc1eb
3379400
ea2a063
3379400
 
34fc1eb
3379400
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
#!/usr/bin/env python3
"""
SUPRA RAG System with CPU/MPS/CUDA Optimizations
Optimized for CPU (HF Spaces), MPS (Apple Silicon), and CUDA with efficient memory management
"""

import json
import chromadb
import torch
import os
from sentence_transformers import SentenceTransformer
from pathlib import Path
from typing import List, Dict, Any
import streamlit as st
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class SupraRAG:
    def __init__(self, rag_data_path: str = None):
        # Default RAG data path (for HF Spaces deployment)
        if rag_data_path is None:
            # Try multiple possible locations
            possible_paths = [
                Path("data/processed/rag_seeds/rag_seeds.jsonl"),
                Path(__file__).parent.parent / "data/processed/rag_seeds/rag_seeds.jsonl",
                Path("rag_seeds.jsonl"),
            ]
            for path in possible_paths:
                if path.exists():
                    rag_data_path = str(path)
                    break
            else:
                # Default fallback
                rag_data_path = "data/processed/rag_seeds/rag_seeds.jsonl"
        self.rag_data_path = Path(rag_data_path)
        
        # Device-specific optimizations
        self._setup_device_optimizations()
        
        # Initialize ChromaDB with device optimizations
        self.client = chromadb.Client()
        self.collection_name = "supra_knowledge"
        
        # Use efficient embedding model (CPU for HF Spaces free tier)
        # CPU is optimal for sentence-transformers on CPU-only deployments
        embedding_device = 'cpu' if self.device == 'cpu' else self.device
        self.embedding_model = SentenceTransformer(
            'all-MiniLM-L6-v2',
            device=embedding_device
        )
        
        # Initialize or load collection
        try:
            self.collection = self.client.get_collection(self.collection_name)
            # Check if collection needs to be reloaded (count doesn't match JSONL file)
            current_count = len(self.collection.get()['ids']) if hasattr(self.collection, 'get') else 0
            # Count expected documents from JSONL
            expected_count = sum(1 for _ in open(self.rag_data_path, 'r', encoding='utf-8') if _.strip()) if self.rag_data_path.exists() else 0
            
            if current_count != expected_count:
                logger.info(f"πŸ”„ Reloading RAG documents (current: {current_count}, expected: {expected_count})")
                # Delete and recreate collection to reload
                self.client.delete_collection(self.collection_name)
                self.collection = self.client.create_collection(self.collection_name)
                self._load_rag_documents()
            else:
                logger.info(f"βœ… RAG knowledge base loaded ({current_count} documents)")
                # Removed UI success message - shown in sidebar instead
        except:
            self.collection = self.client.create_collection(self.collection_name)
            self._load_rag_documents()
    
    def _setup_device_optimizations(self):
        """Configure optimizations for CPU/MPS/CUDA."""
        logger.info("πŸ”§ Setting up device optimizations...")
        
        # Environment variables
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        
        # Detect device: MPS > CUDA > CPU
        if torch.backends.mps.is_available():
            logger.info("βœ… MPS (Metal Performance Shaders) available - using MPS")
            self.device = "mps"
            os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
            torch.backends.mps.is_built()
        elif torch.cuda.is_available():
            logger.info("βœ… CUDA available - using GPU")
            self.device = "cuda"
        else:
            logger.info("πŸ’» CPU detected - using CPU optimizations")
            self.device = "cpu"
        
        logger.info(f"πŸ”§ Using device: {self.device}")
    
    def _load_rag_documents(self):
        """Load RAG documents from JSONL file with device optimizations."""
        if not self.rag_data_path.exists():
            logger.warning("⚠️ RAG data file not found")
            if st:
                st.warning("⚠️ RAG data file not found")
            return
        
        documents = []
        metadatas = []
        ids = []
        
        logger.info(f"πŸ“š Loading RAG documents from {self.rag_data_path}")
        
        with open(self.rag_data_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                if line.strip():
                    try:
                        doc = json.loads(line)
                        if 'content' in doc and 'id' in doc:
                            # Truncate content for memory efficiency
                            content = doc['content']
                            if len(content) > 2000:  # Limit content length
                                content = content[:2000] + "..."
                            
                            documents.append(content)
                            metadatas.append({
                                'title': doc.get('title', ''),
                                'type': doc.get('type', ''),
                                'source': doc.get('source', ''),
                                'word_count': len(content.split())
                            })
                            ids.append(doc['id'])
                        else:
                            logger.warning(f"⚠️ Skipping line {line_num}: missing required fields")
                    except json.JSONDecodeError as e:
                        logger.warning(f"⚠️ Skipping line {line_num}: JSON decode error - {e}")
        
        if documents:
            # Add to ChromaDB with batch processing
            batch_size = 50  # Smaller batches for memory efficiency
            for i in range(0, len(documents), batch_size):
                batch_docs = documents[i:i+batch_size]
                batch_metadatas = metadatas[i:i+batch_size]
                batch_ids = ids[i:i+batch_size]
                
                self.collection.add(
                    documents=batch_docs,
                    metadatas=batch_metadatas,
                    ids=batch_ids
                )
                
                logger.info(f"πŸ“Š Processed batch {i//batch_size + 1}/{(len(documents)-1)//batch_size + 1}")
            
            logger.info(f"βœ… Loaded {len(documents)} RAG documents")
            # Removed UI success message - shown in sidebar instead
        else:
            logger.warning("⚠️ No valid documents found in RAG data file")
            if st:
                st.warning("⚠️ No valid documents found in RAG data file")
    
    def retrieve_context(self, query: str, n_results: int = 3) -> List[Dict[str, Any]]:
        """Retrieve relevant context for a query with device optimizations."""
        try:
            # Limit query length for efficiency
            if len(query) > 500:
                query = query[:500]
            
            results = self.collection.query(
                query_texts=[query],
                n_results=min(n_results, 5)  # Limit results for efficiency
            )
            
            context_docs = []
            for i, doc in enumerate(results['documents'][0]):
                # Truncate retrieved content for memory efficiency
                content = doc
                if len(content) > 1500:
                    content = content[:1500] + "..."
                
                context_docs.append({
                    'content': content,
                    'metadata': results['metadatas'][0][i],
                    'distance': results['distances'][0][i]
                })
            
            logger.info(f"πŸ” Retrieved {len(context_docs)} context documents")
            return context_docs
            
        except Exception as e:
            logger.error(f"RAG retrieval error: {e}")
            if st:
                st.error(f"RAG retrieval error: {e}")
            return []
    
    def build_enhanced_prompt(self, user_query: str, context_docs: List[Dict[str, Any]]) -> str:
        """Build enhanced prompt with RAG context and SUPRA facts with device optimizations."""
        # Import SUPRA facts system
        from .supra_facts import build_supra_prompt, inject_facts_for_query
        
        # Extract RAG context chunks
        rag_context = None
        if context_docs:
            # Limit context length for memory efficiency
            max_context_length = 2000  # Reduced for memory efficiency
            context_text = ""
            
            for doc in context_docs:
                doc_text = f"{doc['content'][:800]}"
                if len(context_text + doc_text) > max_context_length:
                    break
                context_text += doc_text + "\n\n"
            
            rag_context = [context_text] if context_text else None
        
        # Auto-detect relevant facts from query
        facts = inject_facts_for_query(user_query)
        
        # Get model name from model_loader to detect chat template
        from .model_loader import get_model_info
        try:
            model_info = get_model_info()
            # Get base model name to detect Llama vs Mistral
            base_model = model_info.get('base_model', '')
            if 'llama' in base_model.lower() or 'meta-llama' in base_model.lower():
                model_name = 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'
            else:
                model_name = model_info.get('model_name', 'unsloth/mistral-7b-instruct-v0.3-bnb-4bit')
        except:
            # Default to Llama since latest models use Llama
            model_name = 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'
        
        # Build complete SUPRA prompt with system prompt, facts, and RAG context
        enhanced_prompt = build_supra_prompt(
            user_query=user_query,
            facts=facts,
            rag_context=rag_context,
            model_name=model_name
        )
        
        return enhanced_prompt
    
    def generate_response(self, query: str, model, tokenizer, max_new_tokens: int = 800) -> str:
        """Generate response using the enhanced model with RAG context."""
        try:
            logger.info(f"πŸ€– Generating response for query: {query[:50]}...")
            
            # Get RAG context
            context_docs = self.retrieve_context(query, n_results=3)
            enhanced_prompt = self.build_enhanced_prompt(query, context_docs)
            
            # Import the generation function
            from .model_loader import generate_response_optimized
            
            # Generate with enhanced model - tighter parameters for better quality
            response = generate_response_optimized(
                model=model,
                tokenizer=tokenizer,
                prompt=enhanced_prompt,
                max_new_tokens=max_new_tokens,
                temperature=0.6,  # Lower temperature for more focused responses
                top_p=0.85  # Tighter sampling
            )
            
            logger.info(f"βœ… Generated response ({len(response)} characters)")
            return response
            
        except Exception as e:
            logger.error(f"Error generating response: {e}")
            if st:
                st.error(f"Error generating response: {e}")
            return f"I apologize, but I encountered an error while generating a response: {e}"

# Global RAG instance with device-specific optimizations
@st.cache_resource
def get_supra_rag():
    """Get cached SUPRA RAG instance optimized for CPU/MPS/CUDA."""
    return SupraRAG()

# Backward compatibility (kept for compatibility with old imports)
def get_supra_rag_m2max():
    """Backward compatible function that returns device-optimized RAG."""
    return get_supra_rag()