| """ |
| Query RAG System - Test Retrieval Quality |
| Pure LangChain, config-driven, no hardcoding |
| """ |
|
|
| import sys |
| from pathlib import Path |
| from typing import List, Dict, Any, Tuple |
| import json |
|
|
| |
| from langchain_community.vectorstores import FAISS |
| from langchain_huggingface import HuggingFaceEmbeddings |
| from langchain_core.documents import Document |
| import faiss |
| import json |
| import numpy as np |
| |
|
|
| |
| sys.path.append(str(Path(__file__).resolve().parents[1])) |
|
|
| |
| from utils import get_utils |
|
|
| |
| |
| |
|
|
| class RAGRetriever: |
| """Retrieve relevant context using RAG""" |
| |
| def __init__(self, utils): |
| self.utils = utils |
| self.config = utils.config |
| self.logger = utils.logger |
| self.device = utils.device_manager.device |
| self.vectorstore = None |
| self.embeddings = None |
| |
| def load_vectorstore(self): |
| """Load pre-created vector store""" |
| vectorstore_root = self.config.get_path('paths', 'vectorstore', 'root') |
| vectorstore_path = vectorstore_root / self.config.get('paths.vectorstore.db_name') |
|
|
| if not vectorstore_path.exists(): |
| self.logger.warning(f"Vector store directory missing: {vectorstore_path}. Will attempt fallback index.") |
|
|
| if vectorstore_path.exists(): |
| self.logger.info(f"Loading vector store from: {vectorstore_path}") |
| |
| |
| use_hf = self.config.get('embeddings.use_hf', True) |
| model_name = self.config.get('embeddings.model_name') |
| normalize = self.config.get('embeddings.normalize', True) |
|
|
| self.embeddings = None |
| if use_hf and vectorstore_path.exists(): |
| try: |
| self.embeddings = HuggingFaceEmbeddings( |
| model_name=model_name, |
| model_kwargs={'device': self.device}, |
| encode_kwargs={'normalize_embeddings': normalize} |
| ) |
| except Exception as e: |
| self.logger.warning(f"HuggingFaceEmbeddings unavailable: {e}. Will use fallback index.") |
|
|
| |
| allow_dangerous = self.config.get('vectorstore.persistence.allow_dangerous_deserialization', True) |
|
|
| try: |
| if self.embeddings is None: |
| raise RuntimeError("Embeddings unavailable; skipping LangChain FAISS load.") |
| self.vectorstore = FAISS.load_local( |
| str(vectorstore_path), |
| self.embeddings, |
| allow_dangerous_deserialization=allow_dangerous |
| ) |
| total_vectors = self.vectorstore.index.ntotal |
| self.logger.info(f"Loaded {total_vectors} vectors (LangChain FAISS)") |
| return self.vectorstore |
| except Exception as e: |
| self.logger.warning(f"Primary FAISS load failed: {e}. Attempting simple FAISS fallback.") |
|
|
| fallback_dir = vectorstore_root / 'faiss_index' |
| index_file = fallback_dir / 'index.faiss' |
| texts_file = fallback_dir / 'texts.json' |
| if not (fallback_dir.exists() and texts_file.exists()): |
| raise RuntimeError( |
| f"FAISS deserialization failed and no fallback texts found at: {fallback_dir}" |
| ) |
|
|
| texts = json.loads(texts_file.read_text(encoding='utf-8')) |
|
|
| try: |
| from sklearn.feature_extraction.text import TfidfVectorizer |
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
| class SimpleTFIDFStore: |
| def __init__(self, texts): |
| self.texts = texts |
| self.vectorizer = TfidfVectorizer(stop_words='english') |
| self.matrix = self.vectorizer.fit_transform(texts) |
|
|
| def similarity_search_with_score(self, query, k=5): |
| q_vec = self.vectorizer.transform([query]) |
| sims = cosine_similarity(q_vec, self.matrix).ravel() |
| top_idx = np.argsort(-sims)[:k] |
| results = [] |
| for idx in top_idx: |
| doc = Document(page_content=self.texts[idx], metadata={}) |
| results.append((doc, float(1 - sims[idx]))) |
| return results |
|
|
| def max_marginal_relevance_search(self, query, k=5, lambda_mult=0.5): |
| base = self.similarity_search_with_score(query, k * 2) |
| selected = [] |
| if base: |
| selected.append(base[0][0]) |
| while len(selected) < k and len(selected) < len(base): |
| remaining = [doc for doc, _ in base if doc not in selected] |
| if not remaining: |
| break |
| selected.append(remaining[0]) |
| return selected |
|
|
| self.vectorstore = SimpleTFIDFStore(texts) |
| self.logger.info(f"Loaded {len(texts)} texts (TF-IDF fallback)") |
| return self.vectorstore |
| except Exception as e_sklearn: |
| self.logger.warning(f"TF-IDF fallback unavailable: {e_sklearn}. Using naive bag-of-words.") |
|
|
| def tokenize(s): |
| return [t for t in ''.join(c.lower() if c.isalnum() else ' ' for c in s).split() if t] |
|
|
| docs_tokens = [] |
| for t in texts: |
| toks = tokenize(t) |
| docs_tokens.append(toks) |
|
|
| class SimpleBoWStore: |
| def __init__(self, texts, docs_tokens): |
| self.texts = texts |
| self.docs_tokens = docs_tokens |
|
|
| def similarity_search_with_score(self, query, k=5): |
| q_toks = set(tokenize(query)) |
| scores = [] |
| for idx, toks in enumerate(self.docs_tokens): |
| overlap = len(q_toks.intersection(set(toks))) |
| scores.append((idx, overlap)) |
| top = sorted(scores, key=lambda x: -x[1])[:k] |
| results = [] |
| for idx, score in top: |
| doc = Document(page_content=self.texts[idx], metadata={}) |
| dist = float(1.0 / (1 + score)) |
| results.append((doc, dist)) |
| return results |
|
|
| def max_marginal_relevance_search(self, query, k=5, lambda_mult=0.5): |
| base = self.similarity_search_with_score(query, k * 2) |
| selected = [] |
| if base: |
| selected.append(base[0][0]) |
| while len(selected) < k and len(selected) < len(base): |
| remaining = [doc for doc, _ in base if doc not in selected] |
| if not remaining: |
| break |
| selected.append(remaining[0]) |
| return selected |
|
|
| self.vectorstore = SimpleBoWStore(texts, docs_tokens) |
| self.logger.info(f"Loaded {len(texts)} texts (Naive BoW fallback)") |
| return self.vectorstore |
| |
| def retrieve(self, query: str, **kwargs) -> List[Tuple[Document, float]]: |
| """ |
| Retrieve relevant documents for a query |
| |
| Args: |
| query: User's question |
| **kwargs: Override config parameters (top_k, etc.) |
| """ |
| if not self.vectorstore: |
| raise RuntimeError("Vector store not loaded. Call load_vectorstore() first.") |
| |
| |
| top_k = kwargs.get('top_k', self.config.get('rag.retrieval.top_k', 5)) |
| search_type = kwargs.get('search_type', self.config.get('rag.retrieval.search_type', 'similarity')) |
| |
| self.logger.info(f"\nQuery: '{query}'") |
| self.logger.info(f"Retrieving top {top_k} results using {search_type} search") |
| |
| |
| if search_type == 'similarity': |
| results = self.vectorstore.similarity_search_with_score(query, k=top_k) |
| elif search_type == 'mmr': |
| |
| diversity_score = self.config.get('rag.retrieval.mmr_diversity_score', 0.3) |
| docs = self.vectorstore.max_marginal_relevance_search(query, k=top_k, lambda_mult=diversity_score) |
| results = [(doc, 0.0) for doc in docs] |
| else: |
| results = self.vectorstore.similarity_search_with_score(query, k=top_k) |
| |
| |
| threshold = kwargs.get('threshold', self.config.get('rag.retrieval.similarity_threshold', 0.5)) |
| filtered_results = [(doc, score) for doc, score in results if score <= threshold] |
| |
| if not filtered_results: |
| self.logger.warning(f"No results below threshold {threshold}") |
| |
| filtered_results = results[:3] |
| |
| self.logger.info(f"Retrieved {len(filtered_results)} relevant documents") |
| |
| return filtered_results |
| |
| def format_context(self, results: List[Tuple[Document, float]], **kwargs) -> str: |
| """ |
| Format retrieved documents into context string |
| |
| Args: |
| results: List of (document, score) tuples |
| **kwargs: Override config parameters |
| """ |
| max_tokens = kwargs.get('max_tokens', self.config.get('rag.context.max_tokens', 2000)) |
| max_chunks = kwargs.get('max_chunks', self.config.get('rag.context.max_chunks', 5)) |
| include_metadata = kwargs.get('include_metadata', self.config.get('rag.context.include_metadata', True)) |
| metadata_fields = kwargs.get('metadata_fields', self.config.get('rag.context.metadata_fields', ['source', 'work_type', 'topic'])) |
| |
| |
| guidance = ( |
| "Use the following passages as references. Speak directly to the reader as Swami Vivekananda—" |
| "fearless, compassionate, and practical. Summarize the core idea in 1–2 lines, then weave insights " |
| "into a cohesive response with 3–5 practical steps. Do not copy long passages verbatim; paraphrase and " |
| "synthesize. Include at most one short quote if essential, avoid bracketed numeric citations, and " |
| "conclude with an uplifting benediction (e.g., Om Shanti)." |
| ) |
|
|
| context_parts = [f"[Guidance] {guidance}"] |
| current_length = 0 |
| max_chars = max_tokens * 4 |
| |
| for idx, (doc, score) in enumerate(results[:max_chunks], 1): |
| chunk_text = doc.page_content.strip() |
| |
| |
| metadata_str = "" |
| if include_metadata: |
| meta_parts = [] |
| for field in metadata_fields: |
| value = doc.metadata.get(field) |
| if value: |
| meta_parts.append(f"{field.title()}: {value}") |
| |
| if meta_parts: |
| metadata_str = f"[{', '.join(meta_parts)}]" |
| |
| |
| chunk_formatted = f"--- Passage {idx} ---\n" |
| if metadata_str: |
| chunk_formatted += f"{metadata_str}\n" |
| chunk_formatted += f"{chunk_text}\n" |
| |
| |
| if current_length + len(chunk_formatted) > max_chars: |
| break |
| |
| context_parts.append(chunk_formatted) |
| current_length += len(chunk_formatted) |
| |
| return "\n".join(context_parts) |
|
|
| |
| |
| |
|
|
| class ResultDisplay: |
| """Display retrieval results""" |
| |
| def __init__(self, utils): |
| self.utils = utils |
| self.logger = utils.logger |
| self.config = utils.config |
| |
| def display_results(self, query: str, results: List[Tuple[Document, float]], context: str): |
| """Display formatted results""" |
| |
| print("\n" + "="*80) |
| print("RAG RETRIEVAL RESULTS") |
| print("="*80) |
| print(f"\nQuery: {query}") |
| print(f"Results: {len(results)}") |
| |
| print("\n" + "-"*80) |
| print("RETRIEVED DOCUMENTS") |
| print("-"*80) |
| |
| for idx, (doc, score) in enumerate(results, 1): |
| print(f"\n[{idx}] Similarity Score: {score:.4f}") |
| |
| |
| print("Metadata:") |
| for key, value in doc.metadata.items(): |
| if value and key not in ['chunk_index', 'char_count', 'word_count', 'file_path']: |
| print(f" - {key}: {value}") |
| |
| |
| print("\nContent:") |
| content = doc.page_content.strip() |
| if len(content) > 300: |
| print(f" {content[:300]}...") |
| else: |
| print(f" {content}") |
| |
| print("-"*80) |
| |
| print("\n" + "="*80) |
| print("FORMATTED CONTEXT (for LLM)") |
| print("="*80) |
| print(context) |
| print("="*80) |
| |
| def display_prompt(self, query: str, context: str): |
| """Display complete prompt for LLM""" |
| prompt = self.utils.prompt_builder.get_full_prompt(query, context) |
| |
| print("\n" + "="*80) |
| print("COMPLETE PROMPT FOR LLM") |
| print("="*80) |
| print(prompt) |
| print("="*80) |
|
|
| |
| |
| |
|
|
| class QueryAnalyzer: |
| """Analyze query quality and results""" |
| |
| def __init__(self, utils): |
| self.utils = utils |
| self.logger = utils.logger |
| |
| def analyze_results(self, query: str, results: List[Tuple[Document, float]]) -> Dict[str, Any]: |
| """Analyze retrieval quality""" |
| |
| analysis = { |
| 'query': query, |
| 'num_results': len(results), |
| 'scores': [score for _, score in results], |
| 'avg_score': sum(score for _, score in results) / len(results) if results else 0, |
| 'sources': list(set(doc.metadata.get('source', 'Unknown') for doc, _ in results)), |
| 'work_types': list(set(doc.metadata.get('work_type', 'Unknown') for doc, _ in results)), |
| 'topics': list(set(doc.metadata.get('topic', 'Unknown') for doc, _ in results if doc.metadata.get('topic'))) |
| } |
| |
| return analysis |
| |
| def display_analysis(self, analysis: Dict[str, Any]): |
| """Display analysis results""" |
| |
| print("\n" + "="*80) |
| print("RETRIEVAL ANALYSIS") |
| print("="*80) |
| |
| print(f"\nQuery: {analysis['query']}") |
| print(f"Results found: {analysis['num_results']}") |
| print(f"Average similarity: {analysis['avg_score']:.4f}") |
| |
| print(f"\nSources: {', '.join(analysis['sources'])}") |
| print(f"Work types: {', '.join(analysis['work_types'])}") |
| if analysis['topics']: |
| print(f"Topics: {', '.join(analysis['topics'])}") |
| |
| print("\nScore distribution:") |
| for idx, score in enumerate(analysis['scores'], 1): |
| print(f" {idx}. {score:.4f}") |
|
|
| |
| |
| |
|
|
| class BatchQueryProcessor: |
| """Process multiple queries""" |
| |
| def __init__(self, utils, retriever, analyzer): |
| self.utils = utils |
| self.logger = utils.logger |
| self.retriever = retriever |
| self.analyzer = analyzer |
| |
| def process_batch(self, queries: List[str]) -> List[Dict[str, Any]]: |
| """Process multiple queries""" |
| |
| self.logger.info(f"\nProcessing {len(queries)} queries...") |
| |
| all_results = [] |
| |
| for query in queries: |
| print("\n" + "="*80) |
| self.logger.info(f"Processing: {query}") |
| |
| |
| results = self.retriever.retrieve(query) |
| |
| |
| analysis = self.analyzer.analyze_results(query, results) |
| |
| |
| all_results.append({ |
| 'query': query, |
| 'results': results, |
| 'analysis': analysis |
| }) |
| |
| |
| self.analyzer.display_analysis(analysis) |
| |
| return all_results |
| |
| def save_results(self, all_results: List[Dict[str, Any]], output_file: str): |
| """Save batch results to JSON""" |
| |
| |
| serializable_results = [] |
| |
| for item in all_results: |
| serializable_results.append({ |
| 'query': item['query'], |
| 'analysis': item['analysis'], |
| 'results': [ |
| { |
| 'content': doc.page_content, |
| 'metadata': doc.metadata, |
| 'score': float(score) |
| } |
| for doc, score in item['results'] |
| ] |
| }) |
| |
| output_path = self.utils.config.get_path('paths', 'outputs', 'results') / output_file |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| with open(output_path, 'w', encoding='utf-8') as f: |
| json.dump(serializable_results, f, indent=2, ensure_ascii=False) |
| |
| self.logger.info(f"\nResults saved to: {output_path}") |
|
|
| |
| |
| |
|
|
| def interactive_mode(retriever, display, analyzer): |
| """Interactive query mode""" |
| |
| print("\n" + "="*80) |
| print("🕉️ VIVEKANANDA AI - INTERACTIVE RAG QUERY") |
| print("="*80) |
| print("Ask questions to test retrieval quality") |
| print("Commands:") |
| print(" - Type your question") |
| print(" - 'quit' or 'exit' to stop") |
| print(" - 'config' to see current settings") |
| print("="*80) |
| |
| while True: |
| try: |
| query = input("\n[YOU] ").strip() |
| |
| if not query: |
| continue |
| |
| if query.lower() in ['quit', 'exit', 'q']: |
| print("\n[INFO] Om Shanti Shanti Shanti 🕉️") |
| break |
| |
| if query.lower() == 'config': |
| print("\nCurrent RAG Configuration:") |
| print(f" Top-K: {retriever.config.get('rag.retrieval.top_k')}") |
| print(f" Threshold: {retriever.config.get('rag.retrieval.similarity_threshold')}") |
| print(f" Search Type: {retriever.config.get('rag.retrieval.search_type')}") |
| print(f" Max Context Tokens: {retriever.config.get('rag.context.max_tokens')}") |
| continue |
| |
| |
| results = retriever.retrieve(query) |
| |
| |
| context = retriever.format_context(results) |
| |
| |
| display.display_results(query, results, context) |
| |
| |
| analysis = analyzer.analyze_results(query, results) |
| analyzer.display_analysis(analysis) |
| |
| |
| display.display_prompt(query, context) |
| |
| except KeyboardInterrupt: |
| print("\n\n[INFO] Interrupted by user") |
| break |
| except Exception as e: |
| print(f"\n[ERROR] {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| |
| |
| |
|
|
| def main(): |
| """Main execution""" |
| |
| |
| import argparse |
| parser = argparse.ArgumentParser(description="Query RAG System") |
| parser.add_argument('query', nargs='*', help='Query to search') |
| parser.add_argument('--batch', action='store_true', help='Batch mode with test queries') |
| parser.add_argument('--save', type=str, help='Save results to file') |
| parser.add_argument('--top-k', type=int, help='Override top-k') |
| parser.add_argument('--threshold', type=float, help='Override similarity threshold') |
| |
| args = parser.parse_args() |
| |
| |
| utils = get_utils() |
| logger = utils.logger |
| |
| logger.info("="*80) |
| logger.info("VIVEKANANDA AI - RAG QUERY SYSTEM") |
| logger.info("="*80) |
| |
| try: |
| |
| retriever = RAGRetriever(utils) |
| display = ResultDisplay(utils) |
| analyzer = QueryAnalyzer(utils) |
| |
| |
| retriever.load_vectorstore() |
| |
| |
| kwargs = {} |
| if args.top_k: |
| kwargs['top_k'] = args.top_k |
| if args.threshold: |
| kwargs['threshold'] = args.threshold |
| |
| if args.batch: |
| |
| test_queries = utils.config.get('evaluation.test_queries', [ |
| "What is Karma Yoga?", |
| "How can I overcome fear?", |
| "What is the purpose of meditation?", |
| "What is true knowledge?", |
| "How to develop spiritual strength?" |
| ]) |
| |
| batch_processor = BatchQueryProcessor(utils, retriever, analyzer) |
| results = batch_processor.process_batch(test_queries) |
| |
| if args.save: |
| batch_processor.save_results(results, args.save) |
| |
| elif args.query: |
| |
| query = ' '.join(args.query) |
| |
| |
| results = retriever.retrieve(query, **kwargs) |
| |
| |
| context = retriever.format_context(results) |
| |
| |
| display.display_results(query, results, context) |
| |
| |
| analysis = analyzer.analyze_results(query, results) |
| analyzer.display_analysis(analysis) |
| |
| |
| display.display_prompt(query, context) |
| |
| else: |
| |
| interactive_mode(retriever, display, analyzer) |
| |
| return 0 |
| |
| except Exception as e: |
| logger.error(f"\nFATAL ERROR: {e}", exc_info=True) |
| return 1 |
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |