"""RAG query engine for HPMOR Q&A system.""" from typing import Optional, List, Dict, Any import json from pathlib import Path from llama_index.core import Document from src.document_processor import HPMORProcessor from src.vector_store import VectorStoreManager from src.model_chain import ModelChain, ModelType from src.config import config class RAGEngine: """Main RAG engine combining retrieval and generation.""" def __init__(self, force_recreate: bool = False): """Initialize RAG engine components.""" print("Initializing RAG Engine...") # Initialize components self.processor = HPMORProcessor() self.vector_store = VectorStoreManager() self.model_chain = ModelChain() # Process and index documents self._initialize_index(force_recreate) # Cache for responses self.response_cache = {} def _initialize_index(self, force_recreate: bool = False): """Initialize or load the vector index.""" # Process documents documents = self.processor.process(force_reprocess=force_recreate) # Create or load index self.index = self.vector_store.get_or_create_index( documents=documents, force_recreate=force_recreate ) print(f"Index ready with {len(documents)} documents") def retrieve_context(self, query: str, top_k: Optional[int] = None) -> tuple[str, List[Dict]]: """Retrieve relevant context for a query.""" if top_k is None: top_k = config.top_k_retrieval # Query vector store nodes = self.vector_store.query(query, top_k=top_k) # Format context context_parts = [] source_info = [] for i, node in enumerate(nodes, 1): # Add to context context_parts.append(f"[Excerpt {i}]\n{node.text}") # Collect source info source_info.append({ "chunk_id": node.metadata.get("chunk_id", "unknown"), "chapter_number": node.metadata.get("chapter_number", 0), "chapter_title": node.metadata.get("chapter_title", "Unknown"), "score": float(node.score) if node.score else 0.0, "text_preview": node.text[:200] + "..." if len(node.text) > 200 else node.text }) context = "\n\n".join(context_parts) return context, source_info def query( self, question: str, top_k: Optional[int] = None, force_model: Optional[ModelType] = None, return_sources: bool = True, use_cache: bool = True, stream: bool = False ) -> Dict[str, Any]: """Execute RAG query with retrieval and generation.""" # Check cache cache_key = f"{question}_{top_k}_{force_model}" if use_cache and cache_key in self.response_cache and not stream: print("Returning cached response") return self.response_cache[cache_key] # Retrieve context print(f"Retrieving context for: {question[:100]}...") context, sources = self.retrieve_context(question, top_k) # Generate response print("Generating response...") try: result = self.model_chain.generate_response( query=question, context=context, force_model=force_model, stream=stream ) # Prepare full response full_response = { "question": question, "answer": result.get("response"), "model_used": result.get("model_used"), "sources": sources if return_sources else None, "context_size": len(context), "streaming": stream, "fallback_used": result.get("fallback", False) } # Cache if not streaming if use_cache and not stream: self.response_cache[cache_key] = full_response return full_response except Exception as e: print(f"Error generating response: {e}") return { "question": question, "answer": f"Error generating response: {str(e)}", "model_used": None, "sources": sources if return_sources else None, "error": str(e) } def chat( self, messages: List[Dict[str, str]], stream: bool = False ) -> Dict[str, Any]: """Handle chat conversation with context.""" # Get the latest user message if not messages or messages[-1]["role"] != "user": return {"error": "No user message found"} current_question = messages[-1]["content"] # Build conversation context if multiple messages conversation_context = "" if len(messages) > 1: prev_messages = messages[:-1][-4:] # Keep last 4 messages for context for msg in prev_messages: role = "Human" if msg["role"] == "user" else "Assistant" conversation_context += f"{role}: {msg['content']}\n\n" # Modify question to include conversation context if conversation_context: full_query = f"""Previous conversation: {conversation_context} Current question: {current_question}""" else: full_query = current_question # Execute RAG query response = self.query( question=full_query, return_sources=True, stream=stream ) return response def get_stats(self) -> Dict[str, Any]: """Get statistics about the RAG engine.""" vector_stats = self.vector_store.get_stats() stats = { "vector_store": vector_stats, "cache_size": len(self.response_cache), "models_available": { "ollama": self.model_chain.check_ollama_available(), "groq": self.model_chain.groq_available } } return stats def clear_cache(self): """Clear response cache.""" self.response_cache = {} print("Response cache cleared") def main(): """Test RAG engine.""" # Initialize engine print("Initializing RAG engine...") engine = RAGEngine(force_recreate=False) # Test queries test_questions = [ "What is Harry Potter's approach to understanding magic?", "How does Harry react when he first learns about magic?", "What are Harry's thoughts on the scientific method?", ] for question in test_questions: print(f"\n{'='*80}") print(f"Question: {question}") print(f"{'='*80}") response = engine.query(question, top_k=3) print(f"\nModel used: {response['model_used']}") print(f"Context size: {response['context_size']} characters") if response.get("fallback_used"): print("(Fallback to Groq was used)") print(f"\nAnswer:\n{response['answer']}") if response.get("sources"): print(f"\nSources ({len(response['sources'])} chunks):") for i, source in enumerate(response['sources'], 1): print(f" {i}. Chapter {source['chapter_number']}: {source['chapter_title']}") print(f" Score: {source['score']:.4f}") # Show stats print(f"\n{'='*80}") print("Engine Statistics:") stats = engine.get_stats() print(json.dumps(stats, indent=2)) if __name__ == "__main__": main()