|
|
"""RAG query engine for HPMOR Q&A system.""" |
|
|
|
|
|
from typing import Optional, List, Dict, Any |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
from llama_index.core import Document |
|
|
from src.document_processor import HPMORProcessor |
|
|
from src.vector_store import VectorStoreManager |
|
|
from src.model_chain import ModelChain, ModelType |
|
|
from src.config import config |
|
|
|
|
|
|
|
|
class RAGEngine: |
|
|
"""Main RAG engine combining retrieval and generation.""" |
|
|
|
|
|
def __init__(self, force_recreate: bool = False): |
|
|
"""Initialize RAG engine components.""" |
|
|
print("Initializing RAG Engine...") |
|
|
|
|
|
|
|
|
self.processor = HPMORProcessor() |
|
|
self.vector_store = VectorStoreManager() |
|
|
self.model_chain = ModelChain() |
|
|
|
|
|
|
|
|
self._initialize_index(force_recreate) |
|
|
|
|
|
|
|
|
self.response_cache = {} |
|
|
|
|
|
def _initialize_index(self, force_recreate: bool = False): |
|
|
"""Initialize or load the vector index.""" |
|
|
|
|
|
documents = self.processor.process(force_reprocess=force_recreate) |
|
|
|
|
|
|
|
|
self.index = self.vector_store.get_or_create_index( |
|
|
documents=documents, |
|
|
force_recreate=force_recreate |
|
|
) |
|
|
|
|
|
print(f"Index ready with {len(documents)} documents") |
|
|
|
|
|
def retrieve_context(self, query: str, top_k: Optional[int] = None) -> tuple[str, List[Dict]]: |
|
|
"""Retrieve relevant context for a query.""" |
|
|
if top_k is None: |
|
|
top_k = config.top_k_retrieval |
|
|
|
|
|
|
|
|
nodes = self.vector_store.query(query, top_k=top_k) |
|
|
|
|
|
|
|
|
context_parts = [] |
|
|
source_info = [] |
|
|
|
|
|
for i, node in enumerate(nodes, 1): |
|
|
|
|
|
context_parts.append(f"[Excerpt {i}]\n{node.text}") |
|
|
|
|
|
|
|
|
source_info.append({ |
|
|
"chunk_id": node.metadata.get("chunk_id", "unknown"), |
|
|
"chapter_number": node.metadata.get("chapter_number", 0), |
|
|
"chapter_title": node.metadata.get("chapter_title", "Unknown"), |
|
|
"score": float(node.score) if node.score else 0.0, |
|
|
"text_preview": node.text[:200] + "..." if len(node.text) > 200 else node.text |
|
|
}) |
|
|
|
|
|
context = "\n\n".join(context_parts) |
|
|
return context, source_info |
|
|
|
|
|
def query( |
|
|
self, |
|
|
question: str, |
|
|
top_k: Optional[int] = None, |
|
|
force_model: Optional[ModelType] = None, |
|
|
return_sources: bool = True, |
|
|
use_cache: bool = True, |
|
|
stream: bool = False |
|
|
) -> Dict[str, Any]: |
|
|
"""Execute RAG query with retrieval and generation.""" |
|
|
|
|
|
cache_key = f"{question}_{top_k}_{force_model}" |
|
|
if use_cache and cache_key in self.response_cache and not stream: |
|
|
print("Returning cached response") |
|
|
return self.response_cache[cache_key] |
|
|
|
|
|
|
|
|
print(f"Retrieving context for: {question[:100]}...") |
|
|
context, sources = self.retrieve_context(question, top_k) |
|
|
|
|
|
|
|
|
print("Generating response...") |
|
|
try: |
|
|
result = self.model_chain.generate_response( |
|
|
query=question, |
|
|
context=context, |
|
|
force_model=force_model, |
|
|
stream=stream |
|
|
) |
|
|
|
|
|
|
|
|
full_response = { |
|
|
"question": question, |
|
|
"answer": result.get("response"), |
|
|
"model_used": result.get("model_used"), |
|
|
"sources": sources if return_sources else None, |
|
|
"context_size": len(context), |
|
|
"streaming": stream, |
|
|
"fallback_used": result.get("fallback", False) |
|
|
} |
|
|
|
|
|
|
|
|
if use_cache and not stream: |
|
|
self.response_cache[cache_key] = full_response |
|
|
|
|
|
return full_response |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error generating response: {e}") |
|
|
return { |
|
|
"question": question, |
|
|
"answer": f"Error generating response: {str(e)}", |
|
|
"model_used": None, |
|
|
"sources": sources if return_sources else None, |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
def chat( |
|
|
self, |
|
|
messages: List[Dict[str, str]], |
|
|
stream: bool = False |
|
|
) -> Dict[str, Any]: |
|
|
"""Handle chat conversation with context.""" |
|
|
|
|
|
if not messages or messages[-1]["role"] != "user": |
|
|
return {"error": "No user message found"} |
|
|
|
|
|
current_question = messages[-1]["content"] |
|
|
|
|
|
|
|
|
conversation_context = "" |
|
|
if len(messages) > 1: |
|
|
prev_messages = messages[:-1][-4:] |
|
|
for msg in prev_messages: |
|
|
role = "Human" if msg["role"] == "user" else "Assistant" |
|
|
conversation_context += f"{role}: {msg['content']}\n\n" |
|
|
|
|
|
|
|
|
if conversation_context: |
|
|
full_query = f"""Previous conversation: |
|
|
{conversation_context} |
|
|
|
|
|
Current question: {current_question}""" |
|
|
else: |
|
|
full_query = current_question |
|
|
|
|
|
|
|
|
response = self.query( |
|
|
question=full_query, |
|
|
return_sources=True, |
|
|
stream=stream |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
"""Get statistics about the RAG engine.""" |
|
|
vector_stats = self.vector_store.get_stats() |
|
|
|
|
|
stats = { |
|
|
"vector_store": vector_stats, |
|
|
"cache_size": len(self.response_cache), |
|
|
"models_available": { |
|
|
"ollama": self.model_chain.check_ollama_available(), |
|
|
"groq": self.model_chain.groq_available |
|
|
} |
|
|
} |
|
|
|
|
|
return stats |
|
|
|
|
|
def clear_cache(self): |
|
|
"""Clear response cache.""" |
|
|
self.response_cache = {} |
|
|
print("Response cache cleared") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Test RAG engine.""" |
|
|
|
|
|
print("Initializing RAG engine...") |
|
|
engine = RAGEngine(force_recreate=False) |
|
|
|
|
|
|
|
|
test_questions = [ |
|
|
"What is Harry Potter's approach to understanding magic?", |
|
|
"How does Harry react when he first learns about magic?", |
|
|
"What are Harry's thoughts on the scientific method?", |
|
|
] |
|
|
|
|
|
for question in test_questions: |
|
|
print(f"\n{'='*80}") |
|
|
print(f"Question: {question}") |
|
|
print(f"{'='*80}") |
|
|
|
|
|
response = engine.query(question, top_k=3) |
|
|
|
|
|
print(f"\nModel used: {response['model_used']}") |
|
|
print(f"Context size: {response['context_size']} characters") |
|
|
|
|
|
if response.get("fallback_used"): |
|
|
print("(Fallback to Groq was used)") |
|
|
|
|
|
print(f"\nAnswer:\n{response['answer']}") |
|
|
|
|
|
if response.get("sources"): |
|
|
print(f"\nSources ({len(response['sources'])} chunks):") |
|
|
for i, source in enumerate(response['sources'], 1): |
|
|
print(f" {i}. Chapter {source['chapter_number']}: {source['chapter_title']}") |
|
|
print(f" Score: {source['score']:.4f}") |
|
|
|
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print("Engine Statistics:") |
|
|
stats = engine.get_stats() |
|
|
print(json.dumps(stats, indent=2)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |