""" Query RAG System - Test Retrieval Quality Pure LangChain, config-driven, no hardcoding """ import sys from pathlib import Path from typing import List, Dict, Any, Tuple import json # LangChain imports from langchain_community.vectorstores import FAISS from langchain_huggingface import HuggingFaceEmbeddings from langchain_core.documents import Document import faiss import json import numpy as np # Defer SentenceTransformer import to runtime to avoid hard torch dependency in some envs # Add parent directory to path for imports sys.path.append(str(Path(__file__).resolve().parents[1])) # Local imports from utils import get_utils # ============================================================================ # RAG RETRIEVER # ============================================================================ class RAGRetriever: """Retrieve relevant context using RAG""" def __init__(self, utils): self.utils = utils self.config = utils.config self.logger = utils.logger self.device = utils.device_manager.device self.vectorstore = None self.embeddings = None def load_vectorstore(self): """Load pre-created vector store""" vectorstore_root = self.config.get_path('paths', 'vectorstore', 'root') vectorstore_path = vectorstore_root / self.config.get('paths.vectorstore.db_name') if not vectorstore_path.exists(): self.logger.warning(f"Vector store directory missing: {vectorstore_path}. Will attempt fallback index.") if vectorstore_path.exists(): self.logger.info(f"Loading vector store from: {vectorstore_path}") # Create embeddings model (must match the one used for creation) use_hf = self.config.get('embeddings.use_hf', True) model_name = self.config.get('embeddings.model_name') normalize = self.config.get('embeddings.normalize', True) self.embeddings = None if use_hf and vectorstore_path.exists(): try: self.embeddings = HuggingFaceEmbeddings( model_name=model_name, model_kwargs={'device': self.device}, encode_kwargs={'normalize_embeddings': normalize} ) except Exception as e: self.logger.warning(f"HuggingFaceEmbeddings unavailable: {e}. Will use fallback index.") # Load vector store with graceful fallback allow_dangerous = self.config.get('vectorstore.persistence.allow_dangerous_deserialization', True) try: if self.embeddings is None: raise RuntimeError("Embeddings unavailable; skipping LangChain FAISS load.") self.vectorstore = FAISS.load_local( str(vectorstore_path), self.embeddings, allow_dangerous_deserialization=allow_dangerous ) total_vectors = self.vectorstore.index.ntotal self.logger.info(f"Loaded {total_vectors} vectors (LangChain FAISS)") return self.vectorstore except Exception as e: self.logger.warning(f"Primary FAISS load failed: {e}. Attempting simple FAISS fallback.") fallback_dir = vectorstore_root / 'faiss_index' index_file = fallback_dir / 'index.faiss' texts_file = fallback_dir / 'texts.json' if not (fallback_dir.exists() and texts_file.exists()): raise RuntimeError( f"FAISS deserialization failed and no fallback texts found at: {fallback_dir}" ) texts = json.loads(texts_file.read_text(encoding='utf-8')) try: from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity class SimpleTFIDFStore: def __init__(self, texts): self.texts = texts self.vectorizer = TfidfVectorizer(stop_words='english') self.matrix = self.vectorizer.fit_transform(texts) def similarity_search_with_score(self, query, k=5): q_vec = self.vectorizer.transform([query]) sims = cosine_similarity(q_vec, self.matrix).ravel() top_idx = np.argsort(-sims)[:k] results = [] for idx in top_idx: doc = Document(page_content=self.texts[idx], metadata={}) results.append((doc, float(1 - sims[idx]))) return results def max_marginal_relevance_search(self, query, k=5, lambda_mult=0.5): base = self.similarity_search_with_score(query, k * 2) selected = [] if base: selected.append(base[0][0]) while len(selected) < k and len(selected) < len(base): remaining = [doc for doc, _ in base if doc not in selected] if not remaining: break selected.append(remaining[0]) return selected self.vectorstore = SimpleTFIDFStore(texts) self.logger.info(f"Loaded {len(texts)} texts (TF-IDF fallback)") return self.vectorstore except Exception as e_sklearn: self.logger.warning(f"TF-IDF fallback unavailable: {e_sklearn}. Using naive bag-of-words.") def tokenize(s): return [t for t in ''.join(c.lower() if c.isalnum() else ' ' for c in s).split() if t] docs_tokens = [] for t in texts: toks = tokenize(t) docs_tokens.append(toks) class SimpleBoWStore: def __init__(self, texts, docs_tokens): self.texts = texts self.docs_tokens = docs_tokens def similarity_search_with_score(self, query, k=5): q_toks = set(tokenize(query)) scores = [] for idx, toks in enumerate(self.docs_tokens): overlap = len(q_toks.intersection(set(toks))) scores.append((idx, overlap)) top = sorted(scores, key=lambda x: -x[1])[:k] results = [] for idx, score in top: doc = Document(page_content=self.texts[idx], metadata={}) dist = float(1.0 / (1 + score)) results.append((doc, dist)) return results def max_marginal_relevance_search(self, query, k=5, lambda_mult=0.5): base = self.similarity_search_with_score(query, k * 2) selected = [] if base: selected.append(base[0][0]) while len(selected) < k and len(selected) < len(base): remaining = [doc for doc, _ in base if doc not in selected] if not remaining: break selected.append(remaining[0]) return selected self.vectorstore = SimpleBoWStore(texts, docs_tokens) self.logger.info(f"Loaded {len(texts)} texts (Naive BoW fallback)") return self.vectorstore def retrieve(self, query: str, **kwargs) -> List[Tuple[Document, float]]: """ Retrieve relevant documents for a query Args: query: User's question **kwargs: Override config parameters (top_k, etc.) """ if not self.vectorstore: raise RuntimeError("Vector store not loaded. Call load_vectorstore() first.") # Get retrieval config top_k = kwargs.get('top_k', self.config.get('rag.retrieval.top_k', 5)) search_type = kwargs.get('search_type', self.config.get('rag.retrieval.search_type', 'similarity')) self.logger.info(f"\nQuery: '{query}'") self.logger.info(f"Retrieving top {top_k} results using {search_type} search") # Perform search if search_type == 'similarity': results = self.vectorstore.similarity_search_with_score(query, k=top_k) elif search_type == 'mmr': # Maximum Marginal Relevance for diversity diversity_score = self.config.get('rag.retrieval.mmr_diversity_score', 0.3) docs = self.vectorstore.max_marginal_relevance_search(query, k=top_k, lambda_mult=diversity_score) results = [(doc, 0.0) for doc in docs] # MMR doesn't return scores else: results = self.vectorstore.similarity_search_with_score(query, k=top_k) # Filter by threshold threshold = kwargs.get('threshold', self.config.get('rag.retrieval.similarity_threshold', 0.5)) filtered_results = [(doc, score) for doc, score in results if score <= threshold] if not filtered_results: self.logger.warning(f"No results below threshold {threshold}") # Return top 3 anyway filtered_results = results[:3] self.logger.info(f"Retrieved {len(filtered_results)} relevant documents") return filtered_results def format_context(self, results: List[Tuple[Document, float]], **kwargs) -> str: """ Format retrieved documents into context string Args: results: List of (document, score) tuples **kwargs: Override config parameters """ max_tokens = kwargs.get('max_tokens', self.config.get('rag.context.max_tokens', 2000)) max_chunks = kwargs.get('max_chunks', self.config.get('rag.context.max_chunks', 5)) include_metadata = kwargs.get('include_metadata', self.config.get('rag.context.include_metadata', True)) metadata_fields = kwargs.get('metadata_fields', self.config.get('rag.context.metadata_fields', ['source', 'work_type', 'topic'])) # Add synthesis guidance to discourage verbatim copying guidance = ( "Use the following passages as references. Speak directly to the reader as Swami Vivekananda—" "fearless, compassionate, and practical. Summarize the core idea in 1–2 lines, then weave insights " "into a cohesive response with 3–5 practical steps. Do not copy long passages verbatim; paraphrase and " "synthesize. Include at most one short quote if essential, avoid bracketed numeric citations, and " "conclude with an uplifting benediction (e.g., Om Shanti)." ) context_parts = [f"[Guidance] {guidance}"] current_length = 0 max_chars = max_tokens * 4 # Rough estimate: 1 token ≈ 4 chars for idx, (doc, score) in enumerate(results[:max_chunks], 1): chunk_text = doc.page_content.strip() # Build metadata string metadata_str = "" if include_metadata: meta_parts = [] for field in metadata_fields: value = doc.metadata.get(field) if value: meta_parts.append(f"{field.title()}: {value}") if meta_parts: metadata_str = f"[{', '.join(meta_parts)}]" # Format chunk chunk_formatted = f"--- Passage {idx} ---\n" if metadata_str: chunk_formatted += f"{metadata_str}\n" chunk_formatted += f"{chunk_text}\n" # Check length if current_length + len(chunk_formatted) > max_chars: break context_parts.append(chunk_formatted) current_length += len(chunk_formatted) return "\n".join(context_parts) # ============================================================================ # RESULT DISPLAY # ============================================================================ class ResultDisplay: """Display retrieval results""" def __init__(self, utils): self.utils = utils self.logger = utils.logger self.config = utils.config def display_results(self, query: str, results: List[Tuple[Document, float]], context: str): """Display formatted results""" print("\n" + "="*80) print("RAG RETRIEVAL RESULTS") print("="*80) print(f"\nQuery: {query}") print(f"Results: {len(results)}") print("\n" + "-"*80) print("RETRIEVED DOCUMENTS") print("-"*80) for idx, (doc, score) in enumerate(results, 1): print(f"\n[{idx}] Similarity Score: {score:.4f}") # Metadata print("Metadata:") for key, value in doc.metadata.items(): if value and key not in ['chunk_index', 'char_count', 'word_count', 'file_path']: print(f" - {key}: {value}") # Content preview print("\nContent:") content = doc.page_content.strip() if len(content) > 300: print(f" {content[:300]}...") else: print(f" {content}") print("-"*80) print("\n" + "="*80) print("FORMATTED CONTEXT (for LLM)") print("="*80) print(context) print("="*80) def display_prompt(self, query: str, context: str): """Display complete prompt for LLM""" prompt = self.utils.prompt_builder.get_full_prompt(query, context) print("\n" + "="*80) print("COMPLETE PROMPT FOR LLM") print("="*80) print(prompt) print("="*80) # ============================================================================ # QUERY ANALYZER # ============================================================================ class QueryAnalyzer: """Analyze query quality and results""" def __init__(self, utils): self.utils = utils self.logger = utils.logger def analyze_results(self, query: str, results: List[Tuple[Document, float]]) -> Dict[str, Any]: """Analyze retrieval quality""" analysis = { 'query': query, 'num_results': len(results), 'scores': [score for _, score in results], 'avg_score': sum(score for _, score in results) / len(results) if results else 0, 'sources': list(set(doc.metadata.get('source', 'Unknown') for doc, _ in results)), 'work_types': list(set(doc.metadata.get('work_type', 'Unknown') for doc, _ in results)), 'topics': list(set(doc.metadata.get('topic', 'Unknown') for doc, _ in results if doc.metadata.get('topic'))) } return analysis def display_analysis(self, analysis: Dict[str, Any]): """Display analysis results""" print("\n" + "="*80) print("RETRIEVAL ANALYSIS") print("="*80) print(f"\nQuery: {analysis['query']}") print(f"Results found: {analysis['num_results']}") print(f"Average similarity: {analysis['avg_score']:.4f}") print(f"\nSources: {', '.join(analysis['sources'])}") print(f"Work types: {', '.join(analysis['work_types'])}") if analysis['topics']: print(f"Topics: {', '.join(analysis['topics'])}") print("\nScore distribution:") for idx, score in enumerate(analysis['scores'], 1): print(f" {idx}. {score:.4f}") # ============================================================================ # BATCH QUERY # ============================================================================ class BatchQueryProcessor: """Process multiple queries""" def __init__(self, utils, retriever, analyzer): self.utils = utils self.logger = utils.logger self.retriever = retriever self.analyzer = analyzer def process_batch(self, queries: List[str]) -> List[Dict[str, Any]]: """Process multiple queries""" self.logger.info(f"\nProcessing {len(queries)} queries...") all_results = [] for query in queries: print("\n" + "="*80) self.logger.info(f"Processing: {query}") # Retrieve results = self.retriever.retrieve(query) # Analyze analysis = self.analyzer.analyze_results(query, results) # Store all_results.append({ 'query': query, 'results': results, 'analysis': analysis }) # Display summary self.analyzer.display_analysis(analysis) return all_results def save_results(self, all_results: List[Dict[str, Any]], output_file: str): """Save batch results to JSON""" # Convert to serializable format serializable_results = [] for item in all_results: serializable_results.append({ 'query': item['query'], 'analysis': item['analysis'], 'results': [ { 'content': doc.page_content, 'metadata': doc.metadata, 'score': float(score) } for doc, score in item['results'] ] }) output_path = self.utils.config.get_path('paths', 'outputs', 'results') / output_file output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: json.dump(serializable_results, f, indent=2, ensure_ascii=False) self.logger.info(f"\nResults saved to: {output_path}") # ============================================================================ # INTERACTIVE MODE # ============================================================================ def interactive_mode(retriever, display, analyzer): """Interactive query mode""" print("\n" + "="*80) print("🕉️ VIVEKANANDA AI - INTERACTIVE RAG QUERY") print("="*80) print("Ask questions to test retrieval quality") print("Commands:") print(" - Type your question") print(" - 'quit' or 'exit' to stop") print(" - 'config' to see current settings") print("="*80) while True: try: query = input("\n[YOU] ").strip() if not query: continue if query.lower() in ['quit', 'exit', 'q']: print("\n[INFO] Om Shanti Shanti Shanti 🕉️") break if query.lower() == 'config': print("\nCurrent RAG Configuration:") print(f" Top-K: {retriever.config.get('rag.retrieval.top_k')}") print(f" Threshold: {retriever.config.get('rag.retrieval.similarity_threshold')}") print(f" Search Type: {retriever.config.get('rag.retrieval.search_type')}") print(f" Max Context Tokens: {retriever.config.get('rag.context.max_tokens')}") continue # Retrieve results = retriever.retrieve(query) # Format context context = retriever.format_context(results) # Display display.display_results(query, results, context) # Analyze analysis = analyzer.analyze_results(query, results) analyzer.display_analysis(analysis) # Show prompt display.display_prompt(query, context) except KeyboardInterrupt: print("\n\n[INFO] Interrupted by user") break except Exception as e: print(f"\n[ERROR] {e}") import traceback traceback.print_exc() # ============================================================================ # MAIN # ============================================================================ def main(): """Main execution""" # Parse arguments import argparse parser = argparse.ArgumentParser(description="Query RAG System") parser.add_argument('query', nargs='*', help='Query to search') parser.add_argument('--batch', action='store_true', help='Batch mode with test queries') parser.add_argument('--save', type=str, help='Save results to file') parser.add_argument('--top-k', type=int, help='Override top-k') parser.add_argument('--threshold', type=float, help='Override similarity threshold') args = parser.parse_args() # Initialize utils = get_utils() logger = utils.logger logger.info("="*80) logger.info("VIVEKANANDA AI - RAG QUERY SYSTEM") logger.info("="*80) try: # Initialize components retriever = RAGRetriever(utils) display = ResultDisplay(utils) analyzer = QueryAnalyzer(utils) # Load vector store retriever.load_vectorstore() # Prepare kwargs kwargs = {} if args.top_k: kwargs['top_k'] = args.top_k if args.threshold: kwargs['threshold'] = args.threshold if args.batch: # Batch mode test_queries = utils.config.get('evaluation.test_queries', [ "What is Karma Yoga?", "How can I overcome fear?", "What is the purpose of meditation?", "What is true knowledge?", "How to develop spiritual strength?" ]) batch_processor = BatchQueryProcessor(utils, retriever, analyzer) results = batch_processor.process_batch(test_queries) if args.save: batch_processor.save_results(results, args.save) elif args.query: # Single query mode query = ' '.join(args.query) # Retrieve results = retriever.retrieve(query, **kwargs) # Format context context = retriever.format_context(results) # Display display.display_results(query, results, context) # Analyze analysis = analyzer.analyze_results(query, results) analyzer.display_analysis(analysis) # Show prompt display.display_prompt(query, context) else: # Interactive mode interactive_mode(retriever, display, analyzer) return 0 except Exception as e: logger.error(f"\nFATAL ERROR: {e}", exc_info=True) return 1 if __name__ == "__main__": sys.exit(main())