VivekanandaAI / query_rag.py
jyotirmoy05's picture
Upload 18 files
3c15254 verified
"""
Query RAG System - Test Retrieval Quality
Pure LangChain, config-driven, no hardcoding
"""
import sys
from pathlib import Path
from typing import List, Dict, Any, Tuple
import json
# LangChain imports
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document
import faiss
import json
import numpy as np
# Defer SentenceTransformer import to runtime to avoid hard torch dependency in some envs
# Add parent directory to path for imports
sys.path.append(str(Path(__file__).resolve().parents[1]))
# Local imports
from utils import get_utils
# ============================================================================
# RAG RETRIEVER
# ============================================================================
class RAGRetriever:
"""Retrieve relevant context using RAG"""
def __init__(self, utils):
self.utils = utils
self.config = utils.config
self.logger = utils.logger
self.device = utils.device_manager.device
self.vectorstore = None
self.embeddings = None
def load_vectorstore(self):
"""Load pre-created vector store"""
vectorstore_root = self.config.get_path('paths', 'vectorstore', 'root')
vectorstore_path = vectorstore_root / self.config.get('paths.vectorstore.db_name')
if not vectorstore_path.exists():
self.logger.warning(f"Vector store directory missing: {vectorstore_path}. Will attempt fallback index.")
if vectorstore_path.exists():
self.logger.info(f"Loading vector store from: {vectorstore_path}")
# Create embeddings model (must match the one used for creation)
use_hf = self.config.get('embeddings.use_hf', True)
model_name = self.config.get('embeddings.model_name')
normalize = self.config.get('embeddings.normalize', True)
self.embeddings = None
if use_hf and vectorstore_path.exists():
try:
self.embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={'device': self.device},
encode_kwargs={'normalize_embeddings': normalize}
)
except Exception as e:
self.logger.warning(f"HuggingFaceEmbeddings unavailable: {e}. Will use fallback index.")
# Load vector store with graceful fallback
allow_dangerous = self.config.get('vectorstore.persistence.allow_dangerous_deserialization', True)
try:
if self.embeddings is None:
raise RuntimeError("Embeddings unavailable; skipping LangChain FAISS load.")
self.vectorstore = FAISS.load_local(
str(vectorstore_path),
self.embeddings,
allow_dangerous_deserialization=allow_dangerous
)
total_vectors = self.vectorstore.index.ntotal
self.logger.info(f"Loaded {total_vectors} vectors (LangChain FAISS)")
return self.vectorstore
except Exception as e:
self.logger.warning(f"Primary FAISS load failed: {e}. Attempting simple FAISS fallback.")
fallback_dir = vectorstore_root / 'faiss_index'
index_file = fallback_dir / 'index.faiss'
texts_file = fallback_dir / 'texts.json'
if not (fallback_dir.exists() and texts_file.exists()):
raise RuntimeError(
f"FAISS deserialization failed and no fallback texts found at: {fallback_dir}"
)
texts = json.loads(texts_file.read_text(encoding='utf-8'))
try:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
class SimpleTFIDFStore:
def __init__(self, texts):
self.texts = texts
self.vectorizer = TfidfVectorizer(stop_words='english')
self.matrix = self.vectorizer.fit_transform(texts)
def similarity_search_with_score(self, query, k=5):
q_vec = self.vectorizer.transform([query])
sims = cosine_similarity(q_vec, self.matrix).ravel()
top_idx = np.argsort(-sims)[:k]
results = []
for idx in top_idx:
doc = Document(page_content=self.texts[idx], metadata={})
results.append((doc, float(1 - sims[idx])))
return results
def max_marginal_relevance_search(self, query, k=5, lambda_mult=0.5):
base = self.similarity_search_with_score(query, k * 2)
selected = []
if base:
selected.append(base[0][0])
while len(selected) < k and len(selected) < len(base):
remaining = [doc for doc, _ in base if doc not in selected]
if not remaining:
break
selected.append(remaining[0])
return selected
self.vectorstore = SimpleTFIDFStore(texts)
self.logger.info(f"Loaded {len(texts)} texts (TF-IDF fallback)")
return self.vectorstore
except Exception as e_sklearn:
self.logger.warning(f"TF-IDF fallback unavailable: {e_sklearn}. Using naive bag-of-words.")
def tokenize(s):
return [t for t in ''.join(c.lower() if c.isalnum() else ' ' for c in s).split() if t]
docs_tokens = []
for t in texts:
toks = tokenize(t)
docs_tokens.append(toks)
class SimpleBoWStore:
def __init__(self, texts, docs_tokens):
self.texts = texts
self.docs_tokens = docs_tokens
def similarity_search_with_score(self, query, k=5):
q_toks = set(tokenize(query))
scores = []
for idx, toks in enumerate(self.docs_tokens):
overlap = len(q_toks.intersection(set(toks)))
scores.append((idx, overlap))
top = sorted(scores, key=lambda x: -x[1])[:k]
results = []
for idx, score in top:
doc = Document(page_content=self.texts[idx], metadata={})
dist = float(1.0 / (1 + score))
results.append((doc, dist))
return results
def max_marginal_relevance_search(self, query, k=5, lambda_mult=0.5):
base = self.similarity_search_with_score(query, k * 2)
selected = []
if base:
selected.append(base[0][0])
while len(selected) < k and len(selected) < len(base):
remaining = [doc for doc, _ in base if doc not in selected]
if not remaining:
break
selected.append(remaining[0])
return selected
self.vectorstore = SimpleBoWStore(texts, docs_tokens)
self.logger.info(f"Loaded {len(texts)} texts (Naive BoW fallback)")
return self.vectorstore
def retrieve(self, query: str, **kwargs) -> List[Tuple[Document, float]]:
"""
Retrieve relevant documents for a query
Args:
query: User's question
**kwargs: Override config parameters (top_k, etc.)
"""
if not self.vectorstore:
raise RuntimeError("Vector store not loaded. Call load_vectorstore() first.")
# Get retrieval config
top_k = kwargs.get('top_k', self.config.get('rag.retrieval.top_k', 5))
search_type = kwargs.get('search_type', self.config.get('rag.retrieval.search_type', 'similarity'))
self.logger.info(f"\nQuery: '{query}'")
self.logger.info(f"Retrieving top {top_k} results using {search_type} search")
# Perform search
if search_type == 'similarity':
results = self.vectorstore.similarity_search_with_score(query, k=top_k)
elif search_type == 'mmr':
# Maximum Marginal Relevance for diversity
diversity_score = self.config.get('rag.retrieval.mmr_diversity_score', 0.3)
docs = self.vectorstore.max_marginal_relevance_search(query, k=top_k, lambda_mult=diversity_score)
results = [(doc, 0.0) for doc in docs] # MMR doesn't return scores
else:
results = self.vectorstore.similarity_search_with_score(query, k=top_k)
# Filter by threshold
threshold = kwargs.get('threshold', self.config.get('rag.retrieval.similarity_threshold', 0.5))
filtered_results = [(doc, score) for doc, score in results if score <= threshold]
if not filtered_results:
self.logger.warning(f"No results below threshold {threshold}")
# Return top 3 anyway
filtered_results = results[:3]
self.logger.info(f"Retrieved {len(filtered_results)} relevant documents")
return filtered_results
def format_context(self, results: List[Tuple[Document, float]], **kwargs) -> str:
"""
Format retrieved documents into context string
Args:
results: List of (document, score) tuples
**kwargs: Override config parameters
"""
max_tokens = kwargs.get('max_tokens', self.config.get('rag.context.max_tokens', 2000))
max_chunks = kwargs.get('max_chunks', self.config.get('rag.context.max_chunks', 5))
include_metadata = kwargs.get('include_metadata', self.config.get('rag.context.include_metadata', True))
metadata_fields = kwargs.get('metadata_fields', self.config.get('rag.context.metadata_fields', ['source', 'work_type', 'topic']))
# Add synthesis guidance to discourage verbatim copying
guidance = (
"Use the following passages as references. Speak directly to the reader as Swami Vivekananda—"
"fearless, compassionate, and practical. Summarize the core idea in 1–2 lines, then weave insights "
"into a cohesive response with 3–5 practical steps. Do not copy long passages verbatim; paraphrase and "
"synthesize. Include at most one short quote if essential, avoid bracketed numeric citations, and "
"conclude with an uplifting benediction (e.g., Om Shanti)."
)
context_parts = [f"[Guidance] {guidance}"]
current_length = 0
max_chars = max_tokens * 4 # Rough estimate: 1 token ≈ 4 chars
for idx, (doc, score) in enumerate(results[:max_chunks], 1):
chunk_text = doc.page_content.strip()
# Build metadata string
metadata_str = ""
if include_metadata:
meta_parts = []
for field in metadata_fields:
value = doc.metadata.get(field)
if value:
meta_parts.append(f"{field.title()}: {value}")
if meta_parts:
metadata_str = f"[{', '.join(meta_parts)}]"
# Format chunk
chunk_formatted = f"--- Passage {idx} ---\n"
if metadata_str:
chunk_formatted += f"{metadata_str}\n"
chunk_formatted += f"{chunk_text}\n"
# Check length
if current_length + len(chunk_formatted) > max_chars:
break
context_parts.append(chunk_formatted)
current_length += len(chunk_formatted)
return "\n".join(context_parts)
# ============================================================================
# RESULT DISPLAY
# ============================================================================
class ResultDisplay:
"""Display retrieval results"""
def __init__(self, utils):
self.utils = utils
self.logger = utils.logger
self.config = utils.config
def display_results(self, query: str, results: List[Tuple[Document, float]], context: str):
"""Display formatted results"""
print("\n" + "="*80)
print("RAG RETRIEVAL RESULTS")
print("="*80)
print(f"\nQuery: {query}")
print(f"Results: {len(results)}")
print("\n" + "-"*80)
print("RETRIEVED DOCUMENTS")
print("-"*80)
for idx, (doc, score) in enumerate(results, 1):
print(f"\n[{idx}] Similarity Score: {score:.4f}")
# Metadata
print("Metadata:")
for key, value in doc.metadata.items():
if value and key not in ['chunk_index', 'char_count', 'word_count', 'file_path']:
print(f" - {key}: {value}")
# Content preview
print("\nContent:")
content = doc.page_content.strip()
if len(content) > 300:
print(f" {content[:300]}...")
else:
print(f" {content}")
print("-"*80)
print("\n" + "="*80)
print("FORMATTED CONTEXT (for LLM)")
print("="*80)
print(context)
print("="*80)
def display_prompt(self, query: str, context: str):
"""Display complete prompt for LLM"""
prompt = self.utils.prompt_builder.get_full_prompt(query, context)
print("\n" + "="*80)
print("COMPLETE PROMPT FOR LLM")
print("="*80)
print(prompt)
print("="*80)
# ============================================================================
# QUERY ANALYZER
# ============================================================================
class QueryAnalyzer:
"""Analyze query quality and results"""
def __init__(self, utils):
self.utils = utils
self.logger = utils.logger
def analyze_results(self, query: str, results: List[Tuple[Document, float]]) -> Dict[str, Any]:
"""Analyze retrieval quality"""
analysis = {
'query': query,
'num_results': len(results),
'scores': [score for _, score in results],
'avg_score': sum(score for _, score in results) / len(results) if results else 0,
'sources': list(set(doc.metadata.get('source', 'Unknown') for doc, _ in results)),
'work_types': list(set(doc.metadata.get('work_type', 'Unknown') for doc, _ in results)),
'topics': list(set(doc.metadata.get('topic', 'Unknown') for doc, _ in results if doc.metadata.get('topic')))
}
return analysis
def display_analysis(self, analysis: Dict[str, Any]):
"""Display analysis results"""
print("\n" + "="*80)
print("RETRIEVAL ANALYSIS")
print("="*80)
print(f"\nQuery: {analysis['query']}")
print(f"Results found: {analysis['num_results']}")
print(f"Average similarity: {analysis['avg_score']:.4f}")
print(f"\nSources: {', '.join(analysis['sources'])}")
print(f"Work types: {', '.join(analysis['work_types'])}")
if analysis['topics']:
print(f"Topics: {', '.join(analysis['topics'])}")
print("\nScore distribution:")
for idx, score in enumerate(analysis['scores'], 1):
print(f" {idx}. {score:.4f}")
# ============================================================================
# BATCH QUERY
# ============================================================================
class BatchQueryProcessor:
"""Process multiple queries"""
def __init__(self, utils, retriever, analyzer):
self.utils = utils
self.logger = utils.logger
self.retriever = retriever
self.analyzer = analyzer
def process_batch(self, queries: List[str]) -> List[Dict[str, Any]]:
"""Process multiple queries"""
self.logger.info(f"\nProcessing {len(queries)} queries...")
all_results = []
for query in queries:
print("\n" + "="*80)
self.logger.info(f"Processing: {query}")
# Retrieve
results = self.retriever.retrieve(query)
# Analyze
analysis = self.analyzer.analyze_results(query, results)
# Store
all_results.append({
'query': query,
'results': results,
'analysis': analysis
})
# Display summary
self.analyzer.display_analysis(analysis)
return all_results
def save_results(self, all_results: List[Dict[str, Any]], output_file: str):
"""Save batch results to JSON"""
# Convert to serializable format
serializable_results = []
for item in all_results:
serializable_results.append({
'query': item['query'],
'analysis': item['analysis'],
'results': [
{
'content': doc.page_content,
'metadata': doc.metadata,
'score': float(score)
}
for doc, score in item['results']
]
})
output_path = self.utils.config.get_path('paths', 'outputs', 'results') / output_file
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(serializable_results, f, indent=2, ensure_ascii=False)
self.logger.info(f"\nResults saved to: {output_path}")
# ============================================================================
# INTERACTIVE MODE
# ============================================================================
def interactive_mode(retriever, display, analyzer):
"""Interactive query mode"""
print("\n" + "="*80)
print("🕉️ VIVEKANANDA AI - INTERACTIVE RAG QUERY")
print("="*80)
print("Ask questions to test retrieval quality")
print("Commands:")
print(" - Type your question")
print(" - 'quit' or 'exit' to stop")
print(" - 'config' to see current settings")
print("="*80)
while True:
try:
query = input("\n[YOU] ").strip()
if not query:
continue
if query.lower() in ['quit', 'exit', 'q']:
print("\n[INFO] Om Shanti Shanti Shanti 🕉️")
break
if query.lower() == 'config':
print("\nCurrent RAG Configuration:")
print(f" Top-K: {retriever.config.get('rag.retrieval.top_k')}")
print(f" Threshold: {retriever.config.get('rag.retrieval.similarity_threshold')}")
print(f" Search Type: {retriever.config.get('rag.retrieval.search_type')}")
print(f" Max Context Tokens: {retriever.config.get('rag.context.max_tokens')}")
continue
# Retrieve
results = retriever.retrieve(query)
# Format context
context = retriever.format_context(results)
# Display
display.display_results(query, results, context)
# Analyze
analysis = analyzer.analyze_results(query, results)
analyzer.display_analysis(analysis)
# Show prompt
display.display_prompt(query, context)
except KeyboardInterrupt:
print("\n\n[INFO] Interrupted by user")
break
except Exception as e:
print(f"\n[ERROR] {e}")
import traceback
traceback.print_exc()
# ============================================================================
# MAIN
# ============================================================================
def main():
"""Main execution"""
# Parse arguments
import argparse
parser = argparse.ArgumentParser(description="Query RAG System")
parser.add_argument('query', nargs='*', help='Query to search')
parser.add_argument('--batch', action='store_true', help='Batch mode with test queries')
parser.add_argument('--save', type=str, help='Save results to file')
parser.add_argument('--top-k', type=int, help='Override top-k')
parser.add_argument('--threshold', type=float, help='Override similarity threshold')
args = parser.parse_args()
# Initialize
utils = get_utils()
logger = utils.logger
logger.info("="*80)
logger.info("VIVEKANANDA AI - RAG QUERY SYSTEM")
logger.info("="*80)
try:
# Initialize components
retriever = RAGRetriever(utils)
display = ResultDisplay(utils)
analyzer = QueryAnalyzer(utils)
# Load vector store
retriever.load_vectorstore()
# Prepare kwargs
kwargs = {}
if args.top_k:
kwargs['top_k'] = args.top_k
if args.threshold:
kwargs['threshold'] = args.threshold
if args.batch:
# Batch mode
test_queries = utils.config.get('evaluation.test_queries', [
"What is Karma Yoga?",
"How can I overcome fear?",
"What is the purpose of meditation?",
"What is true knowledge?",
"How to develop spiritual strength?"
])
batch_processor = BatchQueryProcessor(utils, retriever, analyzer)
results = batch_processor.process_batch(test_queries)
if args.save:
batch_processor.save_results(results, args.save)
elif args.query:
# Single query mode
query = ' '.join(args.query)
# Retrieve
results = retriever.retrieve(query, **kwargs)
# Format context
context = retriever.format_context(results)
# Display
display.display_results(query, results, context)
# Analyze
analysis = analyzer.analyze_results(query, results)
analyzer.display_analysis(analysis)
# Show prompt
display.display_prompt(query, context)
else:
# Interactive mode
interactive_mode(retriever, display, analyzer)
return 0
except Exception as e:
logger.error(f"\nFATAL ERROR: {e}", exc_info=True)
return 1
if __name__ == "__main__":
sys.exit(main())