File size: 7,591 Bytes
6ef4823 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
"""RAG query engine for HPMOR Q&A system."""
from typing import Optional, List, Dict, Any
import json
from pathlib import Path
from llama_index.core import Document
from src.document_processor import HPMORProcessor
from src.vector_store import VectorStoreManager
from src.model_chain import ModelChain, ModelType
from src.config import config
class RAGEngine:
"""Main RAG engine combining retrieval and generation."""
def __init__(self, force_recreate: bool = False):
"""Initialize RAG engine components."""
print("Initializing RAG Engine...")
# Initialize components
self.processor = HPMORProcessor()
self.vector_store = VectorStoreManager()
self.model_chain = ModelChain()
# Process and index documents
self._initialize_index(force_recreate)
# Cache for responses
self.response_cache = {}
def _initialize_index(self, force_recreate: bool = False):
"""Initialize or load the vector index."""
# Process documents
documents = self.processor.process(force_reprocess=force_recreate)
# Create or load index
self.index = self.vector_store.get_or_create_index(
documents=documents,
force_recreate=force_recreate
)
print(f"Index ready with {len(documents)} documents")
def retrieve_context(self, query: str, top_k: Optional[int] = None) -> tuple[str, List[Dict]]:
"""Retrieve relevant context for a query."""
if top_k is None:
top_k = config.top_k_retrieval
# Query vector store
nodes = self.vector_store.query(query, top_k=top_k)
# Format context
context_parts = []
source_info = []
for i, node in enumerate(nodes, 1):
# Add to context
context_parts.append(f"[Excerpt {i}]\n{node.text}")
# Collect source info
source_info.append({
"chunk_id": node.metadata.get("chunk_id", "unknown"),
"chapter_number": node.metadata.get("chapter_number", 0),
"chapter_title": node.metadata.get("chapter_title", "Unknown"),
"score": float(node.score) if node.score else 0.0,
"text_preview": node.text[:200] + "..." if len(node.text) > 200 else node.text
})
context = "\n\n".join(context_parts)
return context, source_info
def query(
self,
question: str,
top_k: Optional[int] = None,
force_model: Optional[ModelType] = None,
return_sources: bool = True,
use_cache: bool = True,
stream: bool = False
) -> Dict[str, Any]:
"""Execute RAG query with retrieval and generation."""
# Check cache
cache_key = f"{question}_{top_k}_{force_model}"
if use_cache and cache_key in self.response_cache and not stream:
print("Returning cached response")
return self.response_cache[cache_key]
# Retrieve context
print(f"Retrieving context for: {question[:100]}...")
context, sources = self.retrieve_context(question, top_k)
# Generate response
print("Generating response...")
try:
result = self.model_chain.generate_response(
query=question,
context=context,
force_model=force_model,
stream=stream
)
# Prepare full response
full_response = {
"question": question,
"answer": result.get("response"),
"model_used": result.get("model_used"),
"sources": sources if return_sources else None,
"context_size": len(context),
"streaming": stream,
"fallback_used": result.get("fallback", False)
}
# Cache if not streaming
if use_cache and not stream:
self.response_cache[cache_key] = full_response
return full_response
except Exception as e:
print(f"Error generating response: {e}")
return {
"question": question,
"answer": f"Error generating response: {str(e)}",
"model_used": None,
"sources": sources if return_sources else None,
"error": str(e)
}
def chat(
self,
messages: List[Dict[str, str]],
stream: bool = False
) -> Dict[str, Any]:
"""Handle chat conversation with context."""
# Get the latest user message
if not messages or messages[-1]["role"] != "user":
return {"error": "No user message found"}
current_question = messages[-1]["content"]
# Build conversation context if multiple messages
conversation_context = ""
if len(messages) > 1:
prev_messages = messages[:-1][-4:] # Keep last 4 messages for context
for msg in prev_messages:
role = "Human" if msg["role"] == "user" else "Assistant"
conversation_context += f"{role}: {msg['content']}\n\n"
# Modify question to include conversation context
if conversation_context:
full_query = f"""Previous conversation:
{conversation_context}
Current question: {current_question}"""
else:
full_query = current_question
# Execute RAG query
response = self.query(
question=full_query,
return_sources=True,
stream=stream
)
return response
def get_stats(self) -> Dict[str, Any]:
"""Get statistics about the RAG engine."""
vector_stats = self.vector_store.get_stats()
stats = {
"vector_store": vector_stats,
"cache_size": len(self.response_cache),
"models_available": {
"ollama": self.model_chain.check_ollama_available(),
"groq": self.model_chain.groq_available
}
}
return stats
def clear_cache(self):
"""Clear response cache."""
self.response_cache = {}
print("Response cache cleared")
def main():
"""Test RAG engine."""
# Initialize engine
print("Initializing RAG engine...")
engine = RAGEngine(force_recreate=False)
# Test queries
test_questions = [
"What is Harry Potter's approach to understanding magic?",
"How does Harry react when he first learns about magic?",
"What are Harry's thoughts on the scientific method?",
]
for question in test_questions:
print(f"\n{'='*80}")
print(f"Question: {question}")
print(f"{'='*80}")
response = engine.query(question, top_k=3)
print(f"\nModel used: {response['model_used']}")
print(f"Context size: {response['context_size']} characters")
if response.get("fallback_used"):
print("(Fallback to Groq was used)")
print(f"\nAnswer:\n{response['answer']}")
if response.get("sources"):
print(f"\nSources ({len(response['sources'])} chunks):")
for i, source in enumerate(response['sources'], 1):
print(f" {i}. Chapter {source['chapter_number']}: {source['chapter_title']}")
print(f" Score: {source['score']:.4f}")
# Show stats
print(f"\n{'='*80}")
print("Engine Statistics:")
stats = engine.get_stats()
print(json.dumps(stats, indent=2))
if __name__ == "__main__":
main() |