hpmor / src /rag_engine.py
deenaik's picture
Initial commit
6ef4823
"""RAG query engine for HPMOR Q&A system."""
from typing import Optional, List, Dict, Any
import json
from pathlib import Path
from llama_index.core import Document
from src.document_processor import HPMORProcessor
from src.vector_store import VectorStoreManager
from src.model_chain import ModelChain, ModelType
from src.config import config
class RAGEngine:
"""Main RAG engine combining retrieval and generation."""
def __init__(self, force_recreate: bool = False):
"""Initialize RAG engine components."""
print("Initializing RAG Engine...")
# Initialize components
self.processor = HPMORProcessor()
self.vector_store = VectorStoreManager()
self.model_chain = ModelChain()
# Process and index documents
self._initialize_index(force_recreate)
# Cache for responses
self.response_cache = {}
def _initialize_index(self, force_recreate: bool = False):
"""Initialize or load the vector index."""
# Process documents
documents = self.processor.process(force_reprocess=force_recreate)
# Create or load index
self.index = self.vector_store.get_or_create_index(
documents=documents,
force_recreate=force_recreate
)
print(f"Index ready with {len(documents)} documents")
def retrieve_context(self, query: str, top_k: Optional[int] = None) -> tuple[str, List[Dict]]:
"""Retrieve relevant context for a query."""
if top_k is None:
top_k = config.top_k_retrieval
# Query vector store
nodes = self.vector_store.query(query, top_k=top_k)
# Format context
context_parts = []
source_info = []
for i, node in enumerate(nodes, 1):
# Add to context
context_parts.append(f"[Excerpt {i}]\n{node.text}")
# Collect source info
source_info.append({
"chunk_id": node.metadata.get("chunk_id", "unknown"),
"chapter_number": node.metadata.get("chapter_number", 0),
"chapter_title": node.metadata.get("chapter_title", "Unknown"),
"score": float(node.score) if node.score else 0.0,
"text_preview": node.text[:200] + "..." if len(node.text) > 200 else node.text
})
context = "\n\n".join(context_parts)
return context, source_info
def query(
self,
question: str,
top_k: Optional[int] = None,
force_model: Optional[ModelType] = None,
return_sources: bool = True,
use_cache: bool = True,
stream: bool = False
) -> Dict[str, Any]:
"""Execute RAG query with retrieval and generation."""
# Check cache
cache_key = f"{question}_{top_k}_{force_model}"
if use_cache and cache_key in self.response_cache and not stream:
print("Returning cached response")
return self.response_cache[cache_key]
# Retrieve context
print(f"Retrieving context for: {question[:100]}...")
context, sources = self.retrieve_context(question, top_k)
# Generate response
print("Generating response...")
try:
result = self.model_chain.generate_response(
query=question,
context=context,
force_model=force_model,
stream=stream
)
# Prepare full response
full_response = {
"question": question,
"answer": result.get("response"),
"model_used": result.get("model_used"),
"sources": sources if return_sources else None,
"context_size": len(context),
"streaming": stream,
"fallback_used": result.get("fallback", False)
}
# Cache if not streaming
if use_cache and not stream:
self.response_cache[cache_key] = full_response
return full_response
except Exception as e:
print(f"Error generating response: {e}")
return {
"question": question,
"answer": f"Error generating response: {str(e)}",
"model_used": None,
"sources": sources if return_sources else None,
"error": str(e)
}
def chat(
self,
messages: List[Dict[str, str]],
stream: bool = False
) -> Dict[str, Any]:
"""Handle chat conversation with context."""
# Get the latest user message
if not messages or messages[-1]["role"] != "user":
return {"error": "No user message found"}
current_question = messages[-1]["content"]
# Build conversation context if multiple messages
conversation_context = ""
if len(messages) > 1:
prev_messages = messages[:-1][-4:] # Keep last 4 messages for context
for msg in prev_messages:
role = "Human" if msg["role"] == "user" else "Assistant"
conversation_context += f"{role}: {msg['content']}\n\n"
# Modify question to include conversation context
if conversation_context:
full_query = f"""Previous conversation:
{conversation_context}
Current question: {current_question}"""
else:
full_query = current_question
# Execute RAG query
response = self.query(
question=full_query,
return_sources=True,
stream=stream
)
return response
def get_stats(self) -> Dict[str, Any]:
"""Get statistics about the RAG engine."""
vector_stats = self.vector_store.get_stats()
stats = {
"vector_store": vector_stats,
"cache_size": len(self.response_cache),
"models_available": {
"ollama": self.model_chain.check_ollama_available(),
"groq": self.model_chain.groq_available
}
}
return stats
def clear_cache(self):
"""Clear response cache."""
self.response_cache = {}
print("Response cache cleared")
def main():
"""Test RAG engine."""
# Initialize engine
print("Initializing RAG engine...")
engine = RAGEngine(force_recreate=False)
# Test queries
test_questions = [
"What is Harry Potter's approach to understanding magic?",
"How does Harry react when he first learns about magic?",
"What are Harry's thoughts on the scientific method?",
]
for question in test_questions:
print(f"\n{'='*80}")
print(f"Question: {question}")
print(f"{'='*80}")
response = engine.query(question, top_k=3)
print(f"\nModel used: {response['model_used']}")
print(f"Context size: {response['context_size']} characters")
if response.get("fallback_used"):
print("(Fallback to Groq was used)")
print(f"\nAnswer:\n{response['answer']}")
if response.get("sources"):
print(f"\nSources ({len(response['sources'])} chunks):")
for i, source in enumerate(response['sources'], 1):
print(f" {i}. Chapter {source['chapter_number']}: {source['chapter_title']}")
print(f" Score: {source['score']:.4f}")
# Show stats
print(f"\n{'='*80}")
print("Engine Statistics:")
stats = engine.get_stats()
print(json.dumps(stats, indent=2))
if __name__ == "__main__":
main()