Commit ·
6ef4823
0
Parent(s):
Initial commit
Browse files- .env +30 -0
- .gitignore +15 -0
- .python-version +1 -0
- README.md +0 -0
- data/processed/chapters.json +0 -0
- data/processed/documents.json +0 -0
- data/raw/hpmor.html +0 -0
- main.py +171 -0
- pyproject.toml +20 -0
- src/__init__.py +1 -0
- src/chat_interface.py +244 -0
- src/config.py +104 -0
- src/document_processor.py +218 -0
- src/model_chain.py +322 -0
- src/rag_engine.py +231 -0
- src/vector_store.py +198 -0
- uv.lock +0 -0
.env
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Groq API Configuration
|
| 2 |
+
GROQ_API_KEY=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
|
| 3 |
+
|
| 4 |
+
# Ollama Configuration
|
| 5 |
+
OLLAMA_HOST=http://localhost:11434
|
| 6 |
+
|
| 7 |
+
# Model Configuration
|
| 8 |
+
LOCAL_MODEL_SMALL=llama3.2:3b
|
| 9 |
+
LOCAL_MODEL_LARGE=llama3.1:8b
|
| 10 |
+
GROQ_MODEL=llama-3.3-70b-versatile
|
| 11 |
+
|
| 12 |
+
# Embedding Model
|
| 13 |
+
EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
| 14 |
+
|
| 15 |
+
# Processing Parameters
|
| 16 |
+
CHUNK_SIZE=1000
|
| 17 |
+
CHUNK_OVERLAP=200
|
| 18 |
+
TOP_K_RETRIEVAL=5
|
| 19 |
+
|
| 20 |
+
# Model Selection Thresholds
|
| 21 |
+
COMPLEXITY_THRESHOLD=0.7
|
| 22 |
+
MAX_LOCAL_CONTEXT_SIZE=4000
|
| 23 |
+
|
| 24 |
+
# ChromaDB Settings
|
| 25 |
+
CHROMA_PERSIST_DIR=./chroma_db
|
| 26 |
+
COLLECTION_NAME=hpmor_collection
|
| 27 |
+
|
| 28 |
+
# Gradio Settings
|
| 29 |
+
GRADIO_SERVER_PORT=7860
|
| 30 |
+
GRADIO_SHARE=False
|
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
|
| 12 |
+
# Database
|
| 13 |
+
chroma_db/
|
| 14 |
+
blobs/
|
| 15 |
+
models/
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11
|
README.md
ADDED
|
Binary file (4.05 kB). View file
|
|
|
data/processed/chapters.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/processed/documents.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/hpmor.html
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
main.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Main entry point for HPMOR Q&A System."""
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
import argparse
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
# Add src to path
|
| 9 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 10 |
+
|
| 11 |
+
from src.config import config
|
| 12 |
+
from src.document_processor import HPMORProcessor
|
| 13 |
+
from src.vector_store import VectorStoreManager
|
| 14 |
+
from src.rag_engine import RAGEngine
|
| 15 |
+
from src.chat_interface import ChatInterface
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def setup_system(force_recreate: bool = False):
|
| 19 |
+
"""Set up the HPMOR Q&A system."""
|
| 20 |
+
print("="*80)
|
| 21 |
+
print("HPMOR Q&A System Setup")
|
| 22 |
+
print("="*80)
|
| 23 |
+
|
| 24 |
+
# Process documents
|
| 25 |
+
print("\n1. Processing HPMOR document...")
|
| 26 |
+
processor = HPMORProcessor()
|
| 27 |
+
documents = processor.process(force_reprocess=force_recreate)
|
| 28 |
+
print(f" ✓ Processed {len(documents)} chunks")
|
| 29 |
+
|
| 30 |
+
# Create vector index
|
| 31 |
+
print("\n2. Creating vector index...")
|
| 32 |
+
vector_store = VectorStoreManager()
|
| 33 |
+
index = vector_store.get_or_create_index(documents, force_recreate=force_recreate)
|
| 34 |
+
stats = vector_store.get_stats()
|
| 35 |
+
print(f" ✓ Index created with {stats['num_vectors']} vectors")
|
| 36 |
+
|
| 37 |
+
print("\n✅ Setup complete! The system is ready to use.")
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_system():
|
| 42 |
+
"""Test the system with sample queries."""
|
| 43 |
+
print("="*80)
|
| 44 |
+
print("HPMOR Q&A System Test")
|
| 45 |
+
print("="*80)
|
| 46 |
+
|
| 47 |
+
engine = RAGEngine(force_recreate=False)
|
| 48 |
+
|
| 49 |
+
test_questions = [
|
| 50 |
+
"What is Harry Potter's full name in HPMOR?",
|
| 51 |
+
"How does Harry first react to learning about magic?",
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
for question in test_questions:
|
| 55 |
+
print(f"\nQ: {question}")
|
| 56 |
+
response = engine.query(question, top_k=3)
|
| 57 |
+
|
| 58 |
+
if isinstance(response["answer"], str):
|
| 59 |
+
answer = response["answer"]
|
| 60 |
+
else:
|
| 61 |
+
answer = str(response["answer"])
|
| 62 |
+
|
| 63 |
+
print(f"A: {answer[:500]}...")
|
| 64 |
+
print(f" (Model: {response['model_used']})")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def check_ollama():
|
| 68 |
+
"""Check if Ollama is installed and running."""
|
| 69 |
+
import httpx
|
| 70 |
+
|
| 71 |
+
print("\nChecking Ollama status...")
|
| 72 |
+
try:
|
| 73 |
+
response = httpx.get(f"{config.ollama_host}/api/tags", timeout=2.0)
|
| 74 |
+
if response.status_code == 200:
|
| 75 |
+
print("✓ Ollama is running")
|
| 76 |
+
data = response.json()
|
| 77 |
+
if data.get("models"):
|
| 78 |
+
print(f" Available models: {', '.join([m['name'] for m in data['models']])}")
|
| 79 |
+
else:
|
| 80 |
+
print(" ⚠ No models installed. Run: ollama pull llama3.2:7b")
|
| 81 |
+
return True
|
| 82 |
+
else:
|
| 83 |
+
print("✗ Ollama is not responding correctly")
|
| 84 |
+
return False
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"✗ Ollama is not running. Please start it with: ollama serve")
|
| 87 |
+
print(f" Error: {e}")
|
| 88 |
+
return False
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def main():
|
| 92 |
+
"""Main entry point."""
|
| 93 |
+
parser = argparse.ArgumentParser(description="HPMOR Q&A System")
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"command",
|
| 96 |
+
choices=["setup", "chat", "test", "check"],
|
| 97 |
+
help="Command to run"
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--force",
|
| 101 |
+
action="store_true",
|
| 102 |
+
help="Force recreate index and reprocess documents"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
args = parser.parse_args()
|
| 106 |
+
|
| 107 |
+
if args.command == "setup":
|
| 108 |
+
setup_system(force_recreate=args.force)
|
| 109 |
+
|
| 110 |
+
elif args.command == "check":
|
| 111 |
+
print("System Check")
|
| 112 |
+
print("-" * 40)
|
| 113 |
+
|
| 114 |
+
# Check Ollama
|
| 115 |
+
ollama_ok = check_ollama()
|
| 116 |
+
|
| 117 |
+
# Check Groq
|
| 118 |
+
print("\nChecking Groq API...")
|
| 119 |
+
if config.has_groq_api():
|
| 120 |
+
print("✓ Groq API key configured")
|
| 121 |
+
else:
|
| 122 |
+
print("✗ Groq API key not configured")
|
| 123 |
+
print(" Add your key to .env file")
|
| 124 |
+
|
| 125 |
+
# Check data
|
| 126 |
+
print("\nChecking data files...")
|
| 127 |
+
if config.hpmor_file.exists():
|
| 128 |
+
print(f"✓ HPMOR file found: {config.hpmor_file}")
|
| 129 |
+
else:
|
| 130 |
+
print(f"✗ HPMOR file not found: {config.hpmor_file}")
|
| 131 |
+
|
| 132 |
+
# Check processed data
|
| 133 |
+
processed_docs = config.processed_data_dir / "documents.json"
|
| 134 |
+
if processed_docs.exists():
|
| 135 |
+
print(f"✓ Processed documents found")
|
| 136 |
+
else:
|
| 137 |
+
print("✗ No processed documents. Run: python main.py setup")
|
| 138 |
+
|
| 139 |
+
elif args.command == "test":
|
| 140 |
+
test_system()
|
| 141 |
+
|
| 142 |
+
elif args.command == "chat":
|
| 143 |
+
print("="*80)
|
| 144 |
+
print("HPMOR Q&A Chat Interface")
|
| 145 |
+
print("="*80)
|
| 146 |
+
|
| 147 |
+
# Check system
|
| 148 |
+
check_ollama()
|
| 149 |
+
|
| 150 |
+
if not config.has_groq_api():
|
| 151 |
+
print("\n⚠ Warning: Groq API key not configured.")
|
| 152 |
+
print(" The system will only use local models (if Ollama is running).")
|
| 153 |
+
print(" For best performance, add your Groq API key to the .env file.")
|
| 154 |
+
|
| 155 |
+
# Check if setup is needed
|
| 156 |
+
processed_docs = config.processed_data_dir / "documents.json"
|
| 157 |
+
if not processed_docs.exists():
|
| 158 |
+
print("\n⚠ No processed documents found. Running setup...")
|
| 159 |
+
setup_system()
|
| 160 |
+
|
| 161 |
+
# Launch chat interface
|
| 162 |
+
print("\nStarting chat interface...")
|
| 163 |
+
chat = ChatInterface()
|
| 164 |
+
chat.launch()
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
if len(sys.argv) == 1:
|
| 169 |
+
# No arguments provided, default to chat
|
| 170 |
+
sys.argv.append("chat")
|
| 171 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "hpmor-qa"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"chromadb>=1.1.1",
|
| 9 |
+
"gradio>=5.49.1",
|
| 10 |
+
"httpx>=0.28.1",
|
| 11 |
+
"langchain>=0.3.27",
|
| 12 |
+
"langchain-groq>=0.3.8",
|
| 13 |
+
"litellm>=1.78.0",
|
| 14 |
+
"llama-index>=0.14.4",
|
| 15 |
+
"llama-index-embeddings-huggingface>=0.6.1",
|
| 16 |
+
"llama-index-llms-groq>=0.4.1",
|
| 17 |
+
"llama-index-llms-ollama>=0.8.0",
|
| 18 |
+
"llama-index-vector-stores-chroma>=0.5.3",
|
| 19 |
+
"lxml>=6.0.2",
|
| 20 |
+
]
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# HPMOR Q&A System
|
src/chat_interface.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio chat interface for HPMOR Q&A system."""
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import json
|
| 5 |
+
from typing import List, Tuple, Optional
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
from src.rag_engine import RAGEngine
|
| 9 |
+
from src.model_chain import ModelType
|
| 10 |
+
from src.config import config
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ChatInterface:
|
| 14 |
+
"""Gradio-based chat interface for HPMOR Q&A."""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
"""Initialize the chat interface."""
|
| 18 |
+
print("Initializing HPMOR Q&A Chat Interface...")
|
| 19 |
+
self.engine = RAGEngine(force_recreate=False)
|
| 20 |
+
self.conversation_history = []
|
| 21 |
+
|
| 22 |
+
def format_sources(self, sources: List[dict]) -> str:
|
| 23 |
+
"""Format sources for display."""
|
| 24 |
+
if not sources:
|
| 25 |
+
return "No sources found"
|
| 26 |
+
|
| 27 |
+
formatted = []
|
| 28 |
+
for i, source in enumerate(sources, 1):
|
| 29 |
+
formatted.append(
|
| 30 |
+
f"**Source {i}** - Chapter {source['chapter_number']}: {source['chapter_title']}\n"
|
| 31 |
+
f"Relevance Score: {source['score']:.2f}\n"
|
| 32 |
+
f"Preview: *{source['text_preview'][:150]}...*"
|
| 33 |
+
)
|
| 34 |
+
return "\n\n".join(formatted)
|
| 35 |
+
|
| 36 |
+
def process_message(
|
| 37 |
+
self,
|
| 38 |
+
message: str,
|
| 39 |
+
history: List[List[str]],
|
| 40 |
+
model_choice: str,
|
| 41 |
+
top_k: int,
|
| 42 |
+
show_sources: bool
|
| 43 |
+
) -> Tuple[str, str, str]:
|
| 44 |
+
"""Process a chat message and return response."""
|
| 45 |
+
if not message:
|
| 46 |
+
return "", "", "Please enter a question."
|
| 47 |
+
|
| 48 |
+
# Convert model choice to enum
|
| 49 |
+
model_map = {
|
| 50 |
+
"Auto (Intelligent Routing)": None,
|
| 51 |
+
"Local Small (Fast)": ModelType.LOCAL_SMALL,
|
| 52 |
+
"Local Large (Better)": ModelType.LOCAL_LARGE,
|
| 53 |
+
"Groq API (Best)": ModelType.GROQ_API
|
| 54 |
+
}
|
| 55 |
+
force_model = model_map.get(model_choice)
|
| 56 |
+
|
| 57 |
+
# Convert history to messages format
|
| 58 |
+
messages = []
|
| 59 |
+
for user_msg, assistant_msg in history:
|
| 60 |
+
if user_msg:
|
| 61 |
+
messages.append({"role": "user", "content": user_msg})
|
| 62 |
+
if assistant_msg:
|
| 63 |
+
messages.append({"role": "assistant", "content": assistant_msg})
|
| 64 |
+
messages.append({"role": "user", "content": message})
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
# Get response from engine
|
| 68 |
+
response = self.engine.chat(messages, stream=False)
|
| 69 |
+
|
| 70 |
+
# Extract answer
|
| 71 |
+
if isinstance(response.get("answer"), str):
|
| 72 |
+
answer = response["answer"]
|
| 73 |
+
else:
|
| 74 |
+
# Handle LlamaIndex response object
|
| 75 |
+
answer = str(response.get("answer", "No response generated"))
|
| 76 |
+
|
| 77 |
+
# Format model info
|
| 78 |
+
model_info = f"**Model Used:** {response.get('model_used', 'Unknown')}"
|
| 79 |
+
if response.get("fallback_used"):
|
| 80 |
+
model_info += " (via fallback)"
|
| 81 |
+
model_info += f"\n**Context Size:** {response.get('context_size', 0)} characters"
|
| 82 |
+
|
| 83 |
+
# Format sources if requested
|
| 84 |
+
sources_text = ""
|
| 85 |
+
if show_sources and response.get("sources"):
|
| 86 |
+
sources_text = self.format_sources(response["sources"])
|
| 87 |
+
|
| 88 |
+
return answer, sources_text, model_info
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
error_msg = f"Error: {str(e)}"
|
| 92 |
+
return error_msg, "", "Error occurred"
|
| 93 |
+
|
| 94 |
+
def clear_conversation(self):
|
| 95 |
+
"""Clear conversation history and cache."""
|
| 96 |
+
self.conversation_history = []
|
| 97 |
+
self.engine.clear_cache()
|
| 98 |
+
return None, "", "", "Conversation cleared"
|
| 99 |
+
|
| 100 |
+
def create_interface(self) -> gr.Blocks:
|
| 101 |
+
"""Create the Gradio interface."""
|
| 102 |
+
with gr.Blocks(title="HPMOR Q&A System", theme=gr.themes.Soft()) as interface:
|
| 103 |
+
gr.Markdown(
|
| 104 |
+
"""
|
| 105 |
+
# 📚 Harry Potter and the Methods of Rationality - Q&A System
|
| 106 |
+
|
| 107 |
+
Ask questions about HPMOR and get intelligent answers powered by RAG (Retrieval-Augmented Generation).
|
| 108 |
+
The system uses local models when possible and falls back to Groq API for complex queries.
|
| 109 |
+
"""
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
with gr.Row():
|
| 113 |
+
with gr.Column(scale=2):
|
| 114 |
+
chatbot = gr.Chatbot(
|
| 115 |
+
label="Chat",
|
| 116 |
+
height=500,
|
| 117 |
+
show_copy_button=True
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
with gr.Row():
|
| 121 |
+
msg_input = gr.Textbox(
|
| 122 |
+
label="Your Question",
|
| 123 |
+
placeholder="Ask anything about HPMOR...",
|
| 124 |
+
lines=2,
|
| 125 |
+
scale=4
|
| 126 |
+
)
|
| 127 |
+
submit_btn = gr.Button("Send", variant="primary", scale=1)
|
| 128 |
+
|
| 129 |
+
with gr.Column(scale=1):
|
| 130 |
+
gr.Markdown("### Settings")
|
| 131 |
+
|
| 132 |
+
model_choice = gr.Radio(
|
| 133 |
+
choices=[
|
| 134 |
+
"Auto (Intelligent Routing)",
|
| 135 |
+
"Local Small (Fast)",
|
| 136 |
+
"Local Large (Better)",
|
| 137 |
+
"Groq API (Best)"
|
| 138 |
+
],
|
| 139 |
+
value="Auto (Intelligent Routing)",
|
| 140 |
+
label="Model Selection"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
top_k = gr.Slider(
|
| 144 |
+
minimum=1,
|
| 145 |
+
maximum=10,
|
| 146 |
+
value=5,
|
| 147 |
+
step=1,
|
| 148 |
+
label="Number of Context Chunks"
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
show_sources = gr.Checkbox(
|
| 152 |
+
value=True,
|
| 153 |
+
label="Show Sources"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
clear_btn = gr.Button("Clear Conversation", variant="secondary")
|
| 157 |
+
|
| 158 |
+
gr.Markdown("### Model Info")
|
| 159 |
+
model_info = gr.Markdown(
|
| 160 |
+
value="Ready to answer questions",
|
| 161 |
+
elem_classes=["model-info"]
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
with gr.Row():
|
| 165 |
+
sources_display = gr.Markdown(
|
| 166 |
+
label="Retrieved Sources",
|
| 167 |
+
value="",
|
| 168 |
+
visible=True
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# Example questions
|
| 172 |
+
gr.Examples(
|
| 173 |
+
examples=[
|
| 174 |
+
"What is Harry's initial reaction to learning about magic?",
|
| 175 |
+
"How does Harry apply the scientific method to understand magic?",
|
| 176 |
+
"What are the key differences between Harry and Hermione's approaches to learning?",
|
| 177 |
+
"Explain the concept of 'rationality' as presented in the story",
|
| 178 |
+
"What magical experiments does Harry conduct?",
|
| 179 |
+
],
|
| 180 |
+
inputs=msg_input,
|
| 181 |
+
label="Example Questions"
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Event handlers
|
| 185 |
+
def respond(message, history, model, topk, sources):
|
| 186 |
+
"""Handle message submission."""
|
| 187 |
+
answer, sources_text, info = self.process_message(
|
| 188 |
+
message, history, model, topk, sources
|
| 189 |
+
)
|
| 190 |
+
history.append([message, answer])
|
| 191 |
+
return "", history, sources_text, info
|
| 192 |
+
|
| 193 |
+
msg_input.submit(
|
| 194 |
+
respond,
|
| 195 |
+
inputs=[msg_input, chatbot, model_choice, top_k, show_sources],
|
| 196 |
+
outputs=[msg_input, chatbot, sources_display, model_info]
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
submit_btn.click(
|
| 200 |
+
respond,
|
| 201 |
+
inputs=[msg_input, chatbot, model_choice, top_k, show_sources],
|
| 202 |
+
outputs=[msg_input, chatbot, sources_display, model_info]
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
clear_btn.click(
|
| 206 |
+
lambda: self.clear_conversation(),
|
| 207 |
+
outputs=[chatbot, sources_display, msg_input, model_info]
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Add custom CSS
|
| 211 |
+
interface.css = """
|
| 212 |
+
.model-info {
|
| 213 |
+
background-color: #f0f0f0;
|
| 214 |
+
padding: 10px;
|
| 215 |
+
border-radius: 5px;
|
| 216 |
+
font-size: 0.9em;
|
| 217 |
+
}
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
return interface
|
| 221 |
+
|
| 222 |
+
def launch(self):
|
| 223 |
+
"""Launch the Gradio interface."""
|
| 224 |
+
interface = self.create_interface()
|
| 225 |
+
|
| 226 |
+
print(f"\nLaunching HPMOR Q&A Chat Interface...")
|
| 227 |
+
print(f"Server will be available at: http://localhost:{config.gradio_server_port}")
|
| 228 |
+
|
| 229 |
+
interface.launch(
|
| 230 |
+
server_name="0.0.0.0",
|
| 231 |
+
server_port=config.gradio_server_port,
|
| 232 |
+
share=config.gradio_share,
|
| 233 |
+
favicon_path=None
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def main():
|
| 238 |
+
"""Launch the chat interface."""
|
| 239 |
+
chat = ChatInterface()
|
| 240 |
+
chat.launch()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
if __name__ == "__main__":
|
| 244 |
+
main()
|
src/config.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration management for HPMOR Q&A System."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
# Load environment variables
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
class Config(BaseModel):
|
| 13 |
+
"""Application configuration."""
|
| 14 |
+
|
| 15 |
+
# API Keys
|
| 16 |
+
groq_api_key: Optional[str] = Field(default=os.getenv("GROQ_API_KEY"))
|
| 17 |
+
|
| 18 |
+
# Ollama Settings
|
| 19 |
+
ollama_host: str = Field(default=os.getenv("OLLAMA_HOST", "http://localhost:11434"))
|
| 20 |
+
|
| 21 |
+
# Model Names
|
| 22 |
+
local_model_small: str = Field(default=os.getenv("LOCAL_MODEL_SMALL", "llama3.2:7b"))
|
| 23 |
+
local_model_large: str = Field(default=os.getenv("LOCAL_MODEL_LARGE", "llama3.2:13b"))
|
| 24 |
+
groq_model: str = Field(default=os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile"))
|
| 25 |
+
|
| 26 |
+
# Embedding Model
|
| 27 |
+
embedding_model: str = Field(
|
| 28 |
+
default=os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Processing Parameters
|
| 32 |
+
chunk_size: int = Field(default=int(os.getenv("CHUNK_SIZE", "1000")))
|
| 33 |
+
chunk_overlap: int = Field(default=int(os.getenv("CHUNK_OVERLAP", "200")))
|
| 34 |
+
top_k_retrieval: int = Field(default=int(os.getenv("TOP_K_RETRIEVAL", "5")))
|
| 35 |
+
|
| 36 |
+
# Model Selection Thresholds
|
| 37 |
+
complexity_threshold: float = Field(
|
| 38 |
+
default=float(os.getenv("COMPLEXITY_THRESHOLD", "0.7"))
|
| 39 |
+
)
|
| 40 |
+
max_local_context_size: int = Field(
|
| 41 |
+
default=int(os.getenv("MAX_LOCAL_CONTEXT_SIZE", "4000"))
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# ChromaDB Settings
|
| 45 |
+
chroma_persist_dir: Path = Field(
|
| 46 |
+
default=Path(os.getenv("CHROMA_PERSIST_DIR", "./chroma_db"))
|
| 47 |
+
)
|
| 48 |
+
collection_name: str = Field(
|
| 49 |
+
default=os.getenv("COLLECTION_NAME", "hpmor_collection")
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Gradio Settings
|
| 53 |
+
gradio_server_port: int = Field(
|
| 54 |
+
default=int(os.getenv("GRADIO_SERVER_PORT", "7860"))
|
| 55 |
+
)
|
| 56 |
+
gradio_share: bool = Field(
|
| 57 |
+
default=os.getenv("GRADIO_SHARE", "False").lower() == "true"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# File Paths
|
| 61 |
+
data_dir: Path = Field(default=Path("data"))
|
| 62 |
+
raw_data_dir: Path = Field(default=Path("data/raw"))
|
| 63 |
+
processed_data_dir: Path = Field(default=Path("data/processed"))
|
| 64 |
+
hpmor_file: Path = Field(default=Path("data/raw/hpmor.html"))
|
| 65 |
+
|
| 66 |
+
def validate_paths(self) -> None:
|
| 67 |
+
"""Create necessary directories if they don't exist."""
|
| 68 |
+
for dir_path in [self.data_dir, self.raw_data_dir, self.processed_data_dir]:
|
| 69 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
self.chroma_persist_dir.mkdir(parents=True, exist_ok=True)
|
| 72 |
+
|
| 73 |
+
def has_groq_api(self) -> bool:
|
| 74 |
+
"""Check if Groq API key is configured."""
|
| 75 |
+
return self.groq_api_key and self.groq_api_key != "your_groq_api_key_here"
|
| 76 |
+
|
| 77 |
+
def get_model_config(self, model_type: str) -> dict:
|
| 78 |
+
"""Get configuration for a specific model type."""
|
| 79 |
+
configs = {
|
| 80 |
+
"local_small": {
|
| 81 |
+
"model": self.local_model_small,
|
| 82 |
+
"type": "ollama",
|
| 83 |
+
"max_tokens": 2048,
|
| 84 |
+
"temperature": 0.7,
|
| 85 |
+
},
|
| 86 |
+
"local_large": {
|
| 87 |
+
"model": self.local_model_large,
|
| 88 |
+
"type": "ollama",
|
| 89 |
+
"max_tokens": 4096,
|
| 90 |
+
"temperature": 0.7,
|
| 91 |
+
},
|
| 92 |
+
"groq": {
|
| 93 |
+
"model": self.groq_model,
|
| 94 |
+
"type": "groq",
|
| 95 |
+
"api_key": self.groq_api_key,
|
| 96 |
+
"max_tokens": 8192,
|
| 97 |
+
"temperature": 0.7,
|
| 98 |
+
},
|
| 99 |
+
}
|
| 100 |
+
return configs.get(model_type, configs["local_small"])
|
| 101 |
+
|
| 102 |
+
# Create global config instance
|
| 103 |
+
config = Config()
|
| 104 |
+
config.validate_paths()
|
src/document_processor.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Document processor for parsing and chunking HPMOR HTML."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List, Dict, Optional
|
| 7 |
+
from bs4 import BeautifulSoup
|
| 8 |
+
from llama_index.core import Document
|
| 9 |
+
from llama_index.core.node_parser import SentenceSplitter
|
| 10 |
+
from src.config import config
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class HPMORProcessor:
|
| 14 |
+
"""Process HPMOR HTML document into chunks for RAG."""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.chunk_size = config.chunk_size
|
| 18 |
+
self.chunk_overlap = config.chunk_overlap
|
| 19 |
+
self.processed_dir = config.processed_data_dir
|
| 20 |
+
|
| 21 |
+
def parse_html(self, file_path: Path) -> List[Dict]:
|
| 22 |
+
"""Parse HTML file and extract chapters with metadata."""
|
| 23 |
+
print(f"Parsing HTML file: {file_path}")
|
| 24 |
+
|
| 25 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 26 |
+
html_content = f.read()
|
| 27 |
+
|
| 28 |
+
soup = BeautifulSoup(html_content, 'lxml')
|
| 29 |
+
|
| 30 |
+
# Remove style and script tags
|
| 31 |
+
for tag in soup(['style', 'script']):
|
| 32 |
+
tag.decompose()
|
| 33 |
+
|
| 34 |
+
# Try to identify chapters by common patterns
|
| 35 |
+
chapters = []
|
| 36 |
+
chapter_pattern = re.compile(r'Chapter\s+(\d+)', re.IGNORECASE)
|
| 37 |
+
|
| 38 |
+
# Find all h1, h2, h3 tags that might be chapter headers
|
| 39 |
+
headers = soup.find_all(['h1', 'h2', 'h3'])
|
| 40 |
+
|
| 41 |
+
current_chapter = None
|
| 42 |
+
current_content = []
|
| 43 |
+
chapter_num = 0
|
| 44 |
+
|
| 45 |
+
for header in headers:
|
| 46 |
+
header_text = header.get_text(strip=True)
|
| 47 |
+
match = chapter_pattern.search(header_text)
|
| 48 |
+
|
| 49 |
+
if match:
|
| 50 |
+
# Save previous chapter if exists
|
| 51 |
+
if current_chapter and current_content:
|
| 52 |
+
chapters.append({
|
| 53 |
+
'chapter_number': current_chapter['number'],
|
| 54 |
+
'chapter_title': current_chapter['title'],
|
| 55 |
+
'content': '\n'.join(current_content)
|
| 56 |
+
})
|
| 57 |
+
|
| 58 |
+
# Start new chapter
|
| 59 |
+
chapter_num = int(match.group(1))
|
| 60 |
+
current_chapter = {
|
| 61 |
+
'number': chapter_num,
|
| 62 |
+
'title': header_text
|
| 63 |
+
}
|
| 64 |
+
current_content = []
|
| 65 |
+
|
| 66 |
+
# Get content after this header until next chapter
|
| 67 |
+
for sibling in header.find_next_siblings():
|
| 68 |
+
if sibling.name in ['h1', 'h2', 'h3']:
|
| 69 |
+
if chapter_pattern.search(sibling.get_text()):
|
| 70 |
+
break
|
| 71 |
+
text = sibling.get_text(strip=True)
|
| 72 |
+
if text:
|
| 73 |
+
current_content.append(text)
|
| 74 |
+
|
| 75 |
+
# Add the last chapter
|
| 76 |
+
if current_chapter and current_content:
|
| 77 |
+
chapters.append({
|
| 78 |
+
'chapter_number': current_chapter['number'],
|
| 79 |
+
'chapter_title': current_chapter['title'],
|
| 80 |
+
'content': '\n'.join(current_content)
|
| 81 |
+
})
|
| 82 |
+
|
| 83 |
+
# If no chapters found, treat entire content as one document
|
| 84 |
+
if not chapters:
|
| 85 |
+
print("No chapter structure found, processing as single document")
|
| 86 |
+
text_content = soup.get_text(separator='\n', strip=True)
|
| 87 |
+
chapters = [{
|
| 88 |
+
'chapter_number': 0,
|
| 89 |
+
'chapter_title': 'Harry Potter and the Methods of Rationality',
|
| 90 |
+
'content': text_content
|
| 91 |
+
}]
|
| 92 |
+
|
| 93 |
+
print(f"Extracted {len(chapters)} chapters")
|
| 94 |
+
return chapters
|
| 95 |
+
|
| 96 |
+
def create_chunks(self, chapters: List[Dict]) -> List[Document]:
|
| 97 |
+
"""Create overlapping chunks from chapters."""
|
| 98 |
+
print(f"Creating chunks with size={self.chunk_size}, overlap={self.chunk_overlap}")
|
| 99 |
+
|
| 100 |
+
documents = []
|
| 101 |
+
splitter = SentenceSplitter(
|
| 102 |
+
chunk_size=self.chunk_size,
|
| 103 |
+
chunk_overlap=self.chunk_overlap,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
for chapter in chapters:
|
| 107 |
+
# Create a document for the chapter
|
| 108 |
+
chapter_doc = Document(
|
| 109 |
+
text=chapter['content'],
|
| 110 |
+
metadata={
|
| 111 |
+
'chapter_number': chapter['chapter_number'],
|
| 112 |
+
'chapter_title': chapter['chapter_title'],
|
| 113 |
+
'source': 'hpmor.html'
|
| 114 |
+
}
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Split into chunks
|
| 118 |
+
nodes = splitter.get_nodes_from_documents([chapter_doc])
|
| 119 |
+
|
| 120 |
+
# Convert nodes back to documents with enhanced metadata
|
| 121 |
+
for i, node in enumerate(nodes):
|
| 122 |
+
doc = Document(
|
| 123 |
+
text=node.text,
|
| 124 |
+
metadata={
|
| 125 |
+
**chapter_doc.metadata,
|
| 126 |
+
'chunk_id': f"ch{chapter['chapter_number']}_chunk{i}",
|
| 127 |
+
'chunk_index': i,
|
| 128 |
+
'total_chunks_in_chapter': len(nodes)
|
| 129 |
+
}
|
| 130 |
+
)
|
| 131 |
+
documents.append(doc)
|
| 132 |
+
|
| 133 |
+
print(f"Created {len(documents)} chunks total")
|
| 134 |
+
return documents
|
| 135 |
+
|
| 136 |
+
def save_processed_data(self, documents: List[Document], chapters: List[Dict]) -> None:
|
| 137 |
+
"""Save processed documents and metadata to disk."""
|
| 138 |
+
# Save documents as JSON for easy loading
|
| 139 |
+
docs_data = []
|
| 140 |
+
for doc in documents:
|
| 141 |
+
docs_data.append({
|
| 142 |
+
'text': doc.text,
|
| 143 |
+
'metadata': doc.metadata
|
| 144 |
+
})
|
| 145 |
+
|
| 146 |
+
docs_file = self.processed_dir / 'documents.json'
|
| 147 |
+
with open(docs_file, 'w', encoding='utf-8') as f:
|
| 148 |
+
json.dump(docs_data, f, indent=2, ensure_ascii=False)
|
| 149 |
+
print(f"Saved {len(docs_data)} documents to {docs_file}")
|
| 150 |
+
|
| 151 |
+
# Save chapter metadata
|
| 152 |
+
chapters_file = self.processed_dir / 'chapters.json'
|
| 153 |
+
with open(chapters_file, 'w', encoding='utf-8') as f:
|
| 154 |
+
json.dump(chapters, f, indent=2, ensure_ascii=False)
|
| 155 |
+
print(f"Saved chapter metadata to {chapters_file}")
|
| 156 |
+
|
| 157 |
+
def load_processed_data(self) -> Optional[List[Document]]:
|
| 158 |
+
"""Load previously processed documents."""
|
| 159 |
+
docs_file = self.processed_dir / 'documents.json'
|
| 160 |
+
|
| 161 |
+
if not docs_file.exists():
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
with open(docs_file, 'r', encoding='utf-8') as f:
|
| 165 |
+
docs_data = json.load(f)
|
| 166 |
+
|
| 167 |
+
documents = []
|
| 168 |
+
for doc_data in docs_data:
|
| 169 |
+
doc = Document(
|
| 170 |
+
text=doc_data['text'],
|
| 171 |
+
metadata=doc_data['metadata']
|
| 172 |
+
)
|
| 173 |
+
documents.append(doc)
|
| 174 |
+
|
| 175 |
+
print(f"Loaded {len(documents)} documents from cache")
|
| 176 |
+
return documents
|
| 177 |
+
|
| 178 |
+
def process(self, force_reprocess: bool = False) -> List[Document]:
|
| 179 |
+
"""Main processing pipeline."""
|
| 180 |
+
# Check if already processed
|
| 181 |
+
if not force_reprocess:
|
| 182 |
+
documents = self.load_processed_data()
|
| 183 |
+
if documents:
|
| 184 |
+
return documents
|
| 185 |
+
|
| 186 |
+
# Process from scratch
|
| 187 |
+
print("Processing HPMOR document from scratch...")
|
| 188 |
+
|
| 189 |
+
if not config.hpmor_file.exists():
|
| 190 |
+
raise FileNotFoundError(f"HPMOR file not found: {config.hpmor_file}")
|
| 191 |
+
|
| 192 |
+
# Parse HTML
|
| 193 |
+
chapters = self.parse_html(config.hpmor_file)
|
| 194 |
+
|
| 195 |
+
# Create chunks
|
| 196 |
+
documents = self.create_chunks(chapters)
|
| 197 |
+
|
| 198 |
+
# Save processed data
|
| 199 |
+
self.save_processed_data(documents, chapters)
|
| 200 |
+
|
| 201 |
+
return documents
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def main():
|
| 205 |
+
"""Process HPMOR document."""
|
| 206 |
+
processor = HPMORProcessor()
|
| 207 |
+
documents = processor.process(force_reprocess=True)
|
| 208 |
+
print(f"\nProcessing complete! Created {len(documents)} document chunks.")
|
| 209 |
+
|
| 210 |
+
# Show sample
|
| 211 |
+
if documents:
|
| 212 |
+
print("\nSample chunk:")
|
| 213 |
+
print(f"Text: {documents[0].text[:200]}...")
|
| 214 |
+
print(f"Metadata: {documents[0].metadata}")
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
main()
|
src/model_chain.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model chaining logic with Groq fallback."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import Optional, Dict, Any, List
|
| 5 |
+
from enum import Enum
|
| 6 |
+
|
| 7 |
+
from llama_index.llms.ollama import Ollama
|
| 8 |
+
from llama_index.llms.groq import Groq
|
| 9 |
+
from llama_index.core.llms import LLM
|
| 10 |
+
from litellm import completion
|
| 11 |
+
import httpx
|
| 12 |
+
|
| 13 |
+
from src.config import config
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ModelType(Enum):
|
| 17 |
+
"""Model types for routing."""
|
| 18 |
+
LOCAL_SMALL = "local_small"
|
| 19 |
+
LOCAL_LARGE = "local_large"
|
| 20 |
+
GROQ_API = "groq"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class QueryComplexity(Enum):
|
| 24 |
+
"""Query complexity levels."""
|
| 25 |
+
SIMPLE = "simple" # Factual questions, definitions
|
| 26 |
+
MODERATE = "moderate" # Analysis, reasoning
|
| 27 |
+
COMPLEX = "complex" # Creative, multi-step reasoning
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ModelChain:
|
| 31 |
+
"""Intelligent model routing with fallback to Groq."""
|
| 32 |
+
|
| 33 |
+
def __init__(self):
|
| 34 |
+
self.models = {}
|
| 35 |
+
self.groq_available = config.has_groq_api()
|
| 36 |
+
|
| 37 |
+
# Initialize models lazily
|
| 38 |
+
self._ollama_available = None
|
| 39 |
+
|
| 40 |
+
def check_ollama_available(self) -> bool:
|
| 41 |
+
"""Check if Ollama is running and available."""
|
| 42 |
+
if self._ollama_available is not None:
|
| 43 |
+
return self._ollama_available
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
# Try to connect to Ollama
|
| 47 |
+
response = httpx.get(f"{config.ollama_host}/api/tags", timeout=2.0)
|
| 48 |
+
self._ollama_available = response.status_code == 200
|
| 49 |
+
if self._ollama_available:
|
| 50 |
+
print("Ollama is available")
|
| 51 |
+
else:
|
| 52 |
+
print("Ollama is not responding correctly")
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"Ollama not available: {e}")
|
| 55 |
+
self._ollama_available = False
|
| 56 |
+
|
| 57 |
+
return self._ollama_available
|
| 58 |
+
|
| 59 |
+
def get_model(self, model_type: ModelType) -> Optional[LLM]:
|
| 60 |
+
"""Get or initialize a model."""
|
| 61 |
+
if model_type in self.models:
|
| 62 |
+
return self.models[model_type]
|
| 63 |
+
|
| 64 |
+
if model_type == ModelType.GROQ_API:
|
| 65 |
+
if not self.groq_available:
|
| 66 |
+
print("Groq API key not configured")
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
# For groq/compound model, we'll use litellm
|
| 71 |
+
# Return a wrapper that uses litellm
|
| 72 |
+
return "groq" # Special marker for litellm usage
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"Failed to initialize Groq: {e}")
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
elif model_type in [ModelType.LOCAL_SMALL, ModelType.LOCAL_LARGE]:
|
| 78 |
+
if not self.check_ollama_available():
|
| 79 |
+
print("Ollama not available, falling back to Groq")
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
model_config = config.get_model_config(model_type.value)
|
| 83 |
+
try:
|
| 84 |
+
model = Ollama(
|
| 85 |
+
model=model_config["model"],
|
| 86 |
+
base_url=config.ollama_host,
|
| 87 |
+
temperature=model_config["temperature"],
|
| 88 |
+
request_timeout=120.0,
|
| 89 |
+
)
|
| 90 |
+
self.models[model_type] = model
|
| 91 |
+
print(f"Initialized {model_type.value} model: {model_config['model']}")
|
| 92 |
+
return model
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(f"Failed to initialize Ollama model: {e}")
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
def analyze_query_complexity(self, query: str, context_size: int = 0) -> QueryComplexity:
|
| 100 |
+
"""Analyze query complexity to determine which model to use."""
|
| 101 |
+
query_lower = query.lower()
|
| 102 |
+
|
| 103 |
+
# Simple queries - factual questions
|
| 104 |
+
simple_patterns = [
|
| 105 |
+
r"what is",
|
| 106 |
+
r"who is",
|
| 107 |
+
r"when did",
|
| 108 |
+
r"where is",
|
| 109 |
+
r"define",
|
| 110 |
+
r"list",
|
| 111 |
+
r"name",
|
| 112 |
+
r"how many",
|
| 113 |
+
r"yes or no",
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
# Complex queries - requiring reasoning or creativity
|
| 117 |
+
complex_patterns = [
|
| 118 |
+
r"explain why",
|
| 119 |
+
r"analyze",
|
| 120 |
+
r"compare and contrast",
|
| 121 |
+
r"what would happen if",
|
| 122 |
+
r"imagine",
|
| 123 |
+
r"create",
|
| 124 |
+
r"write a",
|
| 125 |
+
r"develop",
|
| 126 |
+
r"design",
|
| 127 |
+
r"evaluate",
|
| 128 |
+
r"critique",
|
| 129 |
+
r"synthesize",
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
# Check for simple patterns
|
| 133 |
+
for pattern in simple_patterns:
|
| 134 |
+
if re.search(pattern, query_lower):
|
| 135 |
+
return QueryComplexity.SIMPLE
|
| 136 |
+
|
| 137 |
+
# Check for complex patterns
|
| 138 |
+
for pattern in complex_patterns:
|
| 139 |
+
if re.search(pattern, query_lower):
|
| 140 |
+
return QueryComplexity.COMPLEX
|
| 141 |
+
|
| 142 |
+
# Check query length and context size
|
| 143 |
+
if len(query.split()) > 50 or context_size > config.max_local_context_size:
|
| 144 |
+
return QueryComplexity.COMPLEX
|
| 145 |
+
|
| 146 |
+
# Default to moderate
|
| 147 |
+
return QueryComplexity.MODERATE
|
| 148 |
+
|
| 149 |
+
def route_query(
|
| 150 |
+
self,
|
| 151 |
+
query: str,
|
| 152 |
+
context: Optional[str] = None,
|
| 153 |
+
force_model: Optional[ModelType] = None
|
| 154 |
+
) -> ModelType:
|
| 155 |
+
"""Determine which model to use for the query."""
|
| 156 |
+
if force_model:
|
| 157 |
+
return force_model
|
| 158 |
+
|
| 159 |
+
context_size = len(context) if context else 0
|
| 160 |
+
complexity = self.analyze_query_complexity(query, context_size)
|
| 161 |
+
|
| 162 |
+
# Check Ollama availability
|
| 163 |
+
ollama_available = self.check_ollama_available()
|
| 164 |
+
|
| 165 |
+
# Routing logic
|
| 166 |
+
if complexity == QueryComplexity.SIMPLE:
|
| 167 |
+
if ollama_available:
|
| 168 |
+
return ModelType.LOCAL_SMALL
|
| 169 |
+
elif self.groq_available:
|
| 170 |
+
return ModelType.GROQ_API
|
| 171 |
+
elif complexity == QueryComplexity.MODERATE:
|
| 172 |
+
if ollama_available:
|
| 173 |
+
return ModelType.LOCAL_LARGE
|
| 174 |
+
elif self.groq_available:
|
| 175 |
+
return ModelType.GROQ_API
|
| 176 |
+
else: # COMPLEX
|
| 177 |
+
if self.groq_available:
|
| 178 |
+
return ModelType.GROQ_API
|
| 179 |
+
elif ollama_available:
|
| 180 |
+
return ModelType.LOCAL_LARGE
|
| 181 |
+
|
| 182 |
+
# Final fallback
|
| 183 |
+
if self.groq_available:
|
| 184 |
+
return ModelType.GROQ_API
|
| 185 |
+
elif ollama_available:
|
| 186 |
+
return ModelType.LOCAL_SMALL
|
| 187 |
+
else:
|
| 188 |
+
raise RuntimeError("No models available! Please check Ollama or configure Groq API key.")
|
| 189 |
+
|
| 190 |
+
def generate_response(
|
| 191 |
+
self,
|
| 192 |
+
query: str,
|
| 193 |
+
context: Optional[str] = None,
|
| 194 |
+
force_model: Optional[ModelType] = None,
|
| 195 |
+
stream: bool = False
|
| 196 |
+
) -> Dict[str, Any]:
|
| 197 |
+
"""Generate response using appropriate model."""
|
| 198 |
+
# Determine which model to use
|
| 199 |
+
model_type = self.route_query(query, context, force_model)
|
| 200 |
+
print(f"Using model: {model_type.value}")
|
| 201 |
+
|
| 202 |
+
# Prepare prompt
|
| 203 |
+
if context:
|
| 204 |
+
prompt = f"""Context from Harry Potter and the Methods of Rationality:
|
| 205 |
+
{context}
|
| 206 |
+
|
| 207 |
+
Question: {query}
|
| 208 |
+
|
| 209 |
+
Please provide a detailed answer based on the context provided above."""
|
| 210 |
+
else:
|
| 211 |
+
prompt = query
|
| 212 |
+
|
| 213 |
+
# Try primary model
|
| 214 |
+
try:
|
| 215 |
+
model = self.get_model(model_type)
|
| 216 |
+
|
| 217 |
+
if model == "groq": # Special handling for Groq via litellm
|
| 218 |
+
# Use litellm for Groq
|
| 219 |
+
response = completion(
|
| 220 |
+
model=f"groq/{config.groq_model}",
|
| 221 |
+
messages=[{"role": "user", "content": prompt}],
|
| 222 |
+
api_key=config.groq_api_key,
|
| 223 |
+
temperature=0.7,
|
| 224 |
+
max_tokens=2048,
|
| 225 |
+
stream=stream
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
if stream:
|
| 229 |
+
return {
|
| 230 |
+
"response": response,
|
| 231 |
+
"model_used": model_type.value,
|
| 232 |
+
"streaming": True
|
| 233 |
+
}
|
| 234 |
+
else:
|
| 235 |
+
return {
|
| 236 |
+
"response": response.choices[0].message.content,
|
| 237 |
+
"model_used": model_type.value,
|
| 238 |
+
"tokens_used": response.usage.total_tokens if hasattr(response, 'usage') else None
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
elif model:
|
| 242 |
+
# Use LlamaIndex model
|
| 243 |
+
if stream:
|
| 244 |
+
response = model.stream_complete(prompt)
|
| 245 |
+
else:
|
| 246 |
+
response = model.complete(prompt)
|
| 247 |
+
|
| 248 |
+
return {
|
| 249 |
+
"response": response,
|
| 250 |
+
"model_used": model_type.value,
|
| 251 |
+
"streaming": stream
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
except Exception as e:
|
| 255 |
+
print(f"Error with {model_type.value}: {e}")
|
| 256 |
+
|
| 257 |
+
# Try fallback
|
| 258 |
+
if model_type != ModelType.GROQ_API and self.groq_available:
|
| 259 |
+
print("Falling back to Groq API...")
|
| 260 |
+
model_type = ModelType.GROQ_API
|
| 261 |
+
try:
|
| 262 |
+
response = completion(
|
| 263 |
+
model=f"groq/{config.groq_model}",
|
| 264 |
+
messages=[{"role": "user", "content": prompt}],
|
| 265 |
+
api_key=config.groq_api_key,
|
| 266 |
+
temperature=0.7,
|
| 267 |
+
max_tokens=2048,
|
| 268 |
+
stream=stream
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
if stream:
|
| 272 |
+
return {
|
| 273 |
+
"response": response,
|
| 274 |
+
"model_used": model_type.value,
|
| 275 |
+
"streaming": True,
|
| 276 |
+
"fallback": True
|
| 277 |
+
}
|
| 278 |
+
else:
|
| 279 |
+
return {
|
| 280 |
+
"response": response.choices[0].message.content,
|
| 281 |
+
"model_used": model_type.value,
|
| 282 |
+
"tokens_used": response.usage.total_tokens if hasattr(response, 'usage') else None,
|
| 283 |
+
"fallback": True
|
| 284 |
+
}
|
| 285 |
+
except Exception as e2:
|
| 286 |
+
print(f"Fallback to Groq also failed: {e2}")
|
| 287 |
+
raise RuntimeError(f"All models failed. Last error: {e2}")
|
| 288 |
+
|
| 289 |
+
raise RuntimeError("No models available for response generation")
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def main():
|
| 293 |
+
"""Test model chaining."""
|
| 294 |
+
chain = ModelChain()
|
| 295 |
+
|
| 296 |
+
# Test queries of different complexities
|
| 297 |
+
test_queries = [
|
| 298 |
+
("What is Harry's full name?", QueryComplexity.SIMPLE),
|
| 299 |
+
("Explain Harry's reasoning about magic", QueryComplexity.MODERATE),
|
| 300 |
+
("Analyze the philosophical implications of Harry's scientific approach to magic", QueryComplexity.COMPLEX),
|
| 301 |
+
]
|
| 302 |
+
|
| 303 |
+
for query, expected_complexity in test_queries:
|
| 304 |
+
print(f"\nQuery: {query}")
|
| 305 |
+
complexity = chain.analyze_query_complexity(query)
|
| 306 |
+
print(f"Detected complexity: {complexity}")
|
| 307 |
+
print(f"Expected complexity: {expected_complexity}")
|
| 308 |
+
|
| 309 |
+
try:
|
| 310 |
+
model_type = chain.route_query(query)
|
| 311 |
+
print(f"Selected model: {model_type.value}")
|
| 312 |
+
|
| 313 |
+
# Generate response
|
| 314 |
+
result = chain.generate_response(query)
|
| 315 |
+
print(f"Model used: {result['model_used']}")
|
| 316 |
+
print(f"Response preview: {str(result['response'])[:200]}...")
|
| 317 |
+
except Exception as e:
|
| 318 |
+
print(f"Error: {e}")
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
if __name__ == "__main__":
|
| 322 |
+
main()
|
src/rag_engine.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RAG query engine for HPMOR Q&A system."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional, List, Dict, Any
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from llama_index.core import Document
|
| 8 |
+
from src.document_processor import HPMORProcessor
|
| 9 |
+
from src.vector_store import VectorStoreManager
|
| 10 |
+
from src.model_chain import ModelChain, ModelType
|
| 11 |
+
from src.config import config
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RAGEngine:
|
| 15 |
+
"""Main RAG engine combining retrieval and generation."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, force_recreate: bool = False):
|
| 18 |
+
"""Initialize RAG engine components."""
|
| 19 |
+
print("Initializing RAG Engine...")
|
| 20 |
+
|
| 21 |
+
# Initialize components
|
| 22 |
+
self.processor = HPMORProcessor()
|
| 23 |
+
self.vector_store = VectorStoreManager()
|
| 24 |
+
self.model_chain = ModelChain()
|
| 25 |
+
|
| 26 |
+
# Process and index documents
|
| 27 |
+
self._initialize_index(force_recreate)
|
| 28 |
+
|
| 29 |
+
# Cache for responses
|
| 30 |
+
self.response_cache = {}
|
| 31 |
+
|
| 32 |
+
def _initialize_index(self, force_recreate: bool = False):
|
| 33 |
+
"""Initialize or load the vector index."""
|
| 34 |
+
# Process documents
|
| 35 |
+
documents = self.processor.process(force_reprocess=force_recreate)
|
| 36 |
+
|
| 37 |
+
# Create or load index
|
| 38 |
+
self.index = self.vector_store.get_or_create_index(
|
| 39 |
+
documents=documents,
|
| 40 |
+
force_recreate=force_recreate
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
print(f"Index ready with {len(documents)} documents")
|
| 44 |
+
|
| 45 |
+
def retrieve_context(self, query: str, top_k: Optional[int] = None) -> tuple[str, List[Dict]]:
|
| 46 |
+
"""Retrieve relevant context for a query."""
|
| 47 |
+
if top_k is None:
|
| 48 |
+
top_k = config.top_k_retrieval
|
| 49 |
+
|
| 50 |
+
# Query vector store
|
| 51 |
+
nodes = self.vector_store.query(query, top_k=top_k)
|
| 52 |
+
|
| 53 |
+
# Format context
|
| 54 |
+
context_parts = []
|
| 55 |
+
source_info = []
|
| 56 |
+
|
| 57 |
+
for i, node in enumerate(nodes, 1):
|
| 58 |
+
# Add to context
|
| 59 |
+
context_parts.append(f"[Excerpt {i}]\n{node.text}")
|
| 60 |
+
|
| 61 |
+
# Collect source info
|
| 62 |
+
source_info.append({
|
| 63 |
+
"chunk_id": node.metadata.get("chunk_id", "unknown"),
|
| 64 |
+
"chapter_number": node.metadata.get("chapter_number", 0),
|
| 65 |
+
"chapter_title": node.metadata.get("chapter_title", "Unknown"),
|
| 66 |
+
"score": float(node.score) if node.score else 0.0,
|
| 67 |
+
"text_preview": node.text[:200] + "..." if len(node.text) > 200 else node.text
|
| 68 |
+
})
|
| 69 |
+
|
| 70 |
+
context = "\n\n".join(context_parts)
|
| 71 |
+
return context, source_info
|
| 72 |
+
|
| 73 |
+
def query(
|
| 74 |
+
self,
|
| 75 |
+
question: str,
|
| 76 |
+
top_k: Optional[int] = None,
|
| 77 |
+
force_model: Optional[ModelType] = None,
|
| 78 |
+
return_sources: bool = True,
|
| 79 |
+
use_cache: bool = True,
|
| 80 |
+
stream: bool = False
|
| 81 |
+
) -> Dict[str, Any]:
|
| 82 |
+
"""Execute RAG query with retrieval and generation."""
|
| 83 |
+
# Check cache
|
| 84 |
+
cache_key = f"{question}_{top_k}_{force_model}"
|
| 85 |
+
if use_cache and cache_key in self.response_cache and not stream:
|
| 86 |
+
print("Returning cached response")
|
| 87 |
+
return self.response_cache[cache_key]
|
| 88 |
+
|
| 89 |
+
# Retrieve context
|
| 90 |
+
print(f"Retrieving context for: {question[:100]}...")
|
| 91 |
+
context, sources = self.retrieve_context(question, top_k)
|
| 92 |
+
|
| 93 |
+
# Generate response
|
| 94 |
+
print("Generating response...")
|
| 95 |
+
try:
|
| 96 |
+
result = self.model_chain.generate_response(
|
| 97 |
+
query=question,
|
| 98 |
+
context=context,
|
| 99 |
+
force_model=force_model,
|
| 100 |
+
stream=stream
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Prepare full response
|
| 104 |
+
full_response = {
|
| 105 |
+
"question": question,
|
| 106 |
+
"answer": result.get("response"),
|
| 107 |
+
"model_used": result.get("model_used"),
|
| 108 |
+
"sources": sources if return_sources else None,
|
| 109 |
+
"context_size": len(context),
|
| 110 |
+
"streaming": stream,
|
| 111 |
+
"fallback_used": result.get("fallback", False)
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
# Cache if not streaming
|
| 115 |
+
if use_cache and not stream:
|
| 116 |
+
self.response_cache[cache_key] = full_response
|
| 117 |
+
|
| 118 |
+
return full_response
|
| 119 |
+
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f"Error generating response: {e}")
|
| 122 |
+
return {
|
| 123 |
+
"question": question,
|
| 124 |
+
"answer": f"Error generating response: {str(e)}",
|
| 125 |
+
"model_used": None,
|
| 126 |
+
"sources": sources if return_sources else None,
|
| 127 |
+
"error": str(e)
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
def chat(
|
| 131 |
+
self,
|
| 132 |
+
messages: List[Dict[str, str]],
|
| 133 |
+
stream: bool = False
|
| 134 |
+
) -> Dict[str, Any]:
|
| 135 |
+
"""Handle chat conversation with context."""
|
| 136 |
+
# Get the latest user message
|
| 137 |
+
if not messages or messages[-1]["role"] != "user":
|
| 138 |
+
return {"error": "No user message found"}
|
| 139 |
+
|
| 140 |
+
current_question = messages[-1]["content"]
|
| 141 |
+
|
| 142 |
+
# Build conversation context if multiple messages
|
| 143 |
+
conversation_context = ""
|
| 144 |
+
if len(messages) > 1:
|
| 145 |
+
prev_messages = messages[:-1][-4:] # Keep last 4 messages for context
|
| 146 |
+
for msg in prev_messages:
|
| 147 |
+
role = "Human" if msg["role"] == "user" else "Assistant"
|
| 148 |
+
conversation_context += f"{role}: {msg['content']}\n\n"
|
| 149 |
+
|
| 150 |
+
# Modify question to include conversation context
|
| 151 |
+
if conversation_context:
|
| 152 |
+
full_query = f"""Previous conversation:
|
| 153 |
+
{conversation_context}
|
| 154 |
+
|
| 155 |
+
Current question: {current_question}"""
|
| 156 |
+
else:
|
| 157 |
+
full_query = current_question
|
| 158 |
+
|
| 159 |
+
# Execute RAG query
|
| 160 |
+
response = self.query(
|
| 161 |
+
question=full_query,
|
| 162 |
+
return_sources=True,
|
| 163 |
+
stream=stream
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
return response
|
| 167 |
+
|
| 168 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 169 |
+
"""Get statistics about the RAG engine."""
|
| 170 |
+
vector_stats = self.vector_store.get_stats()
|
| 171 |
+
|
| 172 |
+
stats = {
|
| 173 |
+
"vector_store": vector_stats,
|
| 174 |
+
"cache_size": len(self.response_cache),
|
| 175 |
+
"models_available": {
|
| 176 |
+
"ollama": self.model_chain.check_ollama_available(),
|
| 177 |
+
"groq": self.model_chain.groq_available
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
return stats
|
| 182 |
+
|
| 183 |
+
def clear_cache(self):
|
| 184 |
+
"""Clear response cache."""
|
| 185 |
+
self.response_cache = {}
|
| 186 |
+
print("Response cache cleared")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def main():
|
| 190 |
+
"""Test RAG engine."""
|
| 191 |
+
# Initialize engine
|
| 192 |
+
print("Initializing RAG engine...")
|
| 193 |
+
engine = RAGEngine(force_recreate=False)
|
| 194 |
+
|
| 195 |
+
# Test queries
|
| 196 |
+
test_questions = [
|
| 197 |
+
"What is Harry Potter's approach to understanding magic?",
|
| 198 |
+
"How does Harry react when he first learns about magic?",
|
| 199 |
+
"What are Harry's thoughts on the scientific method?",
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
for question in test_questions:
|
| 203 |
+
print(f"\n{'='*80}")
|
| 204 |
+
print(f"Question: {question}")
|
| 205 |
+
print(f"{'='*80}")
|
| 206 |
+
|
| 207 |
+
response = engine.query(question, top_k=3)
|
| 208 |
+
|
| 209 |
+
print(f"\nModel used: {response['model_used']}")
|
| 210 |
+
print(f"Context size: {response['context_size']} characters")
|
| 211 |
+
|
| 212 |
+
if response.get("fallback_used"):
|
| 213 |
+
print("(Fallback to Groq was used)")
|
| 214 |
+
|
| 215 |
+
print(f"\nAnswer:\n{response['answer']}")
|
| 216 |
+
|
| 217 |
+
if response.get("sources"):
|
| 218 |
+
print(f"\nSources ({len(response['sources'])} chunks):")
|
| 219 |
+
for i, source in enumerate(response['sources'], 1):
|
| 220 |
+
print(f" {i}. Chapter {source['chapter_number']}: {source['chapter_title']}")
|
| 221 |
+
print(f" Score: {source['score']:.4f}")
|
| 222 |
+
|
| 223 |
+
# Show stats
|
| 224 |
+
print(f"\n{'='*80}")
|
| 225 |
+
print("Engine Statistics:")
|
| 226 |
+
stats = engine.get_stats()
|
| 227 |
+
print(json.dumps(stats, indent=2))
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
if __name__ == "__main__":
|
| 231 |
+
main()
|
src/vector_store.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Vector store management for document embeddings."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import chromadb
|
| 8 |
+
from chromadb.config import Settings
|
| 9 |
+
from llama_index.core import Document, VectorStoreIndex, StorageContext
|
| 10 |
+
from llama_index.vector_stores.chroma import ChromaVectorStore
|
| 11 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 12 |
+
from llama_index.core.node_parser import SentenceSplitter
|
| 13 |
+
|
| 14 |
+
from src.config import config
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class VectorStoreManager:
|
| 18 |
+
"""Manage ChromaDB vector store for document embeddings."""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.collection_name = config.collection_name
|
| 22 |
+
self.persist_dir = str(config.chroma_persist_dir)
|
| 23 |
+
self.embedding_model = config.embedding_model
|
| 24 |
+
|
| 25 |
+
# Initialize embedding model
|
| 26 |
+
print(f"Loading embedding model: {self.embedding_model}")
|
| 27 |
+
self.embed_model = HuggingFaceEmbedding(
|
| 28 |
+
model_name=self.embedding_model,
|
| 29 |
+
cache_folder="./models"
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# Initialize ChromaDB client
|
| 33 |
+
self.chroma_client = chromadb.PersistentClient(
|
| 34 |
+
path=self.persist_dir,
|
| 35 |
+
settings=Settings(anonymized_telemetry=False)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Get or create collection
|
| 39 |
+
self.collection = None
|
| 40 |
+
self.vector_store = None
|
| 41 |
+
self.index = None
|
| 42 |
+
|
| 43 |
+
def initialize_collection(self, reset: bool = False) -> None:
|
| 44 |
+
"""Initialize ChromaDB collection."""
|
| 45 |
+
if reset:
|
| 46 |
+
# Delete existing collection if it exists
|
| 47 |
+
try:
|
| 48 |
+
self.chroma_client.delete_collection(name=self.collection_name)
|
| 49 |
+
print(f"Deleted existing collection: {self.collection_name}")
|
| 50 |
+
except Exception:
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
# Create or get collection
|
| 54 |
+
self.collection = self.chroma_client.get_or_create_collection(
|
| 55 |
+
name=self.collection_name,
|
| 56 |
+
metadata={"hnsw:space": "cosine"}
|
| 57 |
+
)
|
| 58 |
+
print(f"Using collection: {self.collection_name}")
|
| 59 |
+
|
| 60 |
+
# Initialize vector store
|
| 61 |
+
self.vector_store = ChromaVectorStore(
|
| 62 |
+
chroma_collection=self.collection,
|
| 63 |
+
embedding_function=self.embed_model
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def create_index(self, documents: List[Document], show_progress: bool = True) -> VectorStoreIndex:
|
| 67 |
+
"""Create vector index from documents."""
|
| 68 |
+
if not self.vector_store:
|
| 69 |
+
self.initialize_collection()
|
| 70 |
+
|
| 71 |
+
print(f"Creating index from {len(documents)} documents...")
|
| 72 |
+
|
| 73 |
+
# Create storage context
|
| 74 |
+
storage_context = StorageContext.from_defaults(
|
| 75 |
+
vector_store=self.vector_store
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Create index with documents
|
| 79 |
+
self.index = VectorStoreIndex.from_documents(
|
| 80 |
+
documents,
|
| 81 |
+
storage_context=storage_context,
|
| 82 |
+
embed_model=self.embed_model,
|
| 83 |
+
show_progress=show_progress
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
print("Index created successfully!")
|
| 87 |
+
return self.index
|
| 88 |
+
|
| 89 |
+
def load_index(self) -> Optional[VectorStoreIndex]:
|
| 90 |
+
"""Load existing index from storage."""
|
| 91 |
+
if not self.vector_store:
|
| 92 |
+
self.initialize_collection()
|
| 93 |
+
|
| 94 |
+
# Check if collection has data
|
| 95 |
+
if self.collection.count() == 0:
|
| 96 |
+
print("No existing index found in ChromaDB")
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
print(f"Loading index with {self.collection.count()} vectors")
|
| 100 |
+
|
| 101 |
+
# Create storage context
|
| 102 |
+
storage_context = StorageContext.from_defaults(
|
| 103 |
+
vector_store=self.vector_store
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Load index
|
| 107 |
+
self.index = VectorStoreIndex.from_vector_store(
|
| 108 |
+
self.vector_store,
|
| 109 |
+
storage_context=storage_context,
|
| 110 |
+
embed_model=self.embed_model
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
return self.index
|
| 114 |
+
|
| 115 |
+
def get_or_create_index(
|
| 116 |
+
self,
|
| 117 |
+
documents: Optional[List[Document]] = None,
|
| 118 |
+
force_recreate: bool = False
|
| 119 |
+
) -> VectorStoreIndex:
|
| 120 |
+
"""Get existing index or create new one."""
|
| 121 |
+
if not force_recreate:
|
| 122 |
+
# Try to load existing index
|
| 123 |
+
index = self.load_index()
|
| 124 |
+
if index:
|
| 125 |
+
return index
|
| 126 |
+
|
| 127 |
+
# Create new index
|
| 128 |
+
if not documents:
|
| 129 |
+
raise ValueError("No documents provided for creating index")
|
| 130 |
+
|
| 131 |
+
self.initialize_collection(reset=True)
|
| 132 |
+
return self.create_index(documents)
|
| 133 |
+
|
| 134 |
+
def query(self, query_text: str, top_k: int = None) -> List:
|
| 135 |
+
"""Query the vector store."""
|
| 136 |
+
if not self.index:
|
| 137 |
+
raise ValueError("Index not initialized. Call get_or_create_index first.")
|
| 138 |
+
|
| 139 |
+
if top_k is None:
|
| 140 |
+
top_k = config.top_k_retrieval
|
| 141 |
+
|
| 142 |
+
# Use retriever directly instead of query engine to avoid LLM requirement
|
| 143 |
+
retriever = self.index.as_retriever(
|
| 144 |
+
similarity_top_k=top_k
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Retrieve nodes
|
| 148 |
+
nodes = retriever.retrieve(query_text)
|
| 149 |
+
return nodes
|
| 150 |
+
|
| 151 |
+
def get_stats(self) -> dict:
|
| 152 |
+
"""Get statistics about the vector store."""
|
| 153 |
+
if not self.collection:
|
| 154 |
+
self.initialize_collection()
|
| 155 |
+
|
| 156 |
+
stats = {
|
| 157 |
+
"collection_name": self.collection_name,
|
| 158 |
+
"persist_dir": self.persist_dir,
|
| 159 |
+
"embedding_model": self.embedding_model,
|
| 160 |
+
"num_vectors": self.collection.count(),
|
| 161 |
+
"metadata": self.collection.metadata
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
return stats
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def main():
|
| 168 |
+
"""Test vector store functionality."""
|
| 169 |
+
from src.document_processor import HPMORProcessor
|
| 170 |
+
|
| 171 |
+
# Process documents
|
| 172 |
+
processor = HPMORProcessor()
|
| 173 |
+
documents = processor.process()
|
| 174 |
+
|
| 175 |
+
# Create vector store
|
| 176 |
+
vector_store = VectorStoreManager()
|
| 177 |
+
index = vector_store.get_or_create_index(documents, force_recreate=True)
|
| 178 |
+
|
| 179 |
+
# Get stats
|
| 180 |
+
stats = vector_store.get_stats()
|
| 181 |
+
print("\nVector Store Statistics:")
|
| 182 |
+
for key, value in stats.items():
|
| 183 |
+
print(f" {key}: {value}")
|
| 184 |
+
|
| 185 |
+
# Test query
|
| 186 |
+
test_query = "What is Harry's opinion on magic?"
|
| 187 |
+
print(f"\nTest query: '{test_query}'")
|
| 188 |
+
results = vector_store.query(test_query, top_k=3)
|
| 189 |
+
|
| 190 |
+
print(f"\nFound {len(results)} relevant chunks:")
|
| 191 |
+
for i, node in enumerate(results, 1):
|
| 192 |
+
print(f"\n{i}. Score: {node.score:.4f}")
|
| 193 |
+
print(f" Chapter: {node.metadata.get('chapter_title', 'Unknown')}")
|
| 194 |
+
print(f" Text preview: {node.text[:200]}...")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
if __name__ == "__main__":
|
| 198 |
+
main()
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|