deenaik commited on
Commit
6ef4823
·
0 Parent(s):

Initial commit

Browse files
.env ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Groq API Configuration
2
+ GROQ_API_KEY=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
3
+
4
+ # Ollama Configuration
5
+ OLLAMA_HOST=http://localhost:11434
6
+
7
+ # Model Configuration
8
+ LOCAL_MODEL_SMALL=llama3.2:3b
9
+ LOCAL_MODEL_LARGE=llama3.1:8b
10
+ GROQ_MODEL=llama-3.3-70b-versatile
11
+
12
+ # Embedding Model
13
+ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
14
+
15
+ # Processing Parameters
16
+ CHUNK_SIZE=1000
17
+ CHUNK_OVERLAP=200
18
+ TOP_K_RETRIEVAL=5
19
+
20
+ # Model Selection Thresholds
21
+ COMPLEXITY_THRESHOLD=0.7
22
+ MAX_LOCAL_CONTEXT_SIZE=4000
23
+
24
+ # ChromaDB Settings
25
+ CHROMA_PERSIST_DIR=./chroma_db
26
+ COLLECTION_NAME=hpmor_collection
27
+
28
+ # Gradio Settings
29
+ GRADIO_SERVER_PORT=7860
30
+ GRADIO_SHARE=False
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ # Database
13
+ chroma_db/
14
+ blobs/
15
+ models/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
README.md ADDED
Binary file (4.05 kB). View file
 
data/processed/chapters.json ADDED
The diff for this file is too large to render. See raw diff
 
data/processed/documents.json ADDED
The diff for this file is too large to render. See raw diff
 
data/raw/hpmor.html ADDED
The diff for this file is too large to render. See raw diff
 
main.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Main entry point for HPMOR Q&A System."""
3
+
4
+ import sys
5
+ import argparse
6
+ from pathlib import Path
7
+
8
+ # Add src to path
9
+ sys.path.insert(0, str(Path(__file__).parent))
10
+
11
+ from src.config import config
12
+ from src.document_processor import HPMORProcessor
13
+ from src.vector_store import VectorStoreManager
14
+ from src.rag_engine import RAGEngine
15
+ from src.chat_interface import ChatInterface
16
+
17
+
18
+ def setup_system(force_recreate: bool = False):
19
+ """Set up the HPMOR Q&A system."""
20
+ print("="*80)
21
+ print("HPMOR Q&A System Setup")
22
+ print("="*80)
23
+
24
+ # Process documents
25
+ print("\n1. Processing HPMOR document...")
26
+ processor = HPMORProcessor()
27
+ documents = processor.process(force_reprocess=force_recreate)
28
+ print(f" ✓ Processed {len(documents)} chunks")
29
+
30
+ # Create vector index
31
+ print("\n2. Creating vector index...")
32
+ vector_store = VectorStoreManager()
33
+ index = vector_store.get_or_create_index(documents, force_recreate=force_recreate)
34
+ stats = vector_store.get_stats()
35
+ print(f" ✓ Index created with {stats['num_vectors']} vectors")
36
+
37
+ print("\n✅ Setup complete! The system is ready to use.")
38
+ return True
39
+
40
+
41
+ def test_system():
42
+ """Test the system with sample queries."""
43
+ print("="*80)
44
+ print("HPMOR Q&A System Test")
45
+ print("="*80)
46
+
47
+ engine = RAGEngine(force_recreate=False)
48
+
49
+ test_questions = [
50
+ "What is Harry Potter's full name in HPMOR?",
51
+ "How does Harry first react to learning about magic?",
52
+ ]
53
+
54
+ for question in test_questions:
55
+ print(f"\nQ: {question}")
56
+ response = engine.query(question, top_k=3)
57
+
58
+ if isinstance(response["answer"], str):
59
+ answer = response["answer"]
60
+ else:
61
+ answer = str(response["answer"])
62
+
63
+ print(f"A: {answer[:500]}...")
64
+ print(f" (Model: {response['model_used']})")
65
+
66
+
67
+ def check_ollama():
68
+ """Check if Ollama is installed and running."""
69
+ import httpx
70
+
71
+ print("\nChecking Ollama status...")
72
+ try:
73
+ response = httpx.get(f"{config.ollama_host}/api/tags", timeout=2.0)
74
+ if response.status_code == 200:
75
+ print("✓ Ollama is running")
76
+ data = response.json()
77
+ if data.get("models"):
78
+ print(f" Available models: {', '.join([m['name'] for m in data['models']])}")
79
+ else:
80
+ print(" ⚠ No models installed. Run: ollama pull llama3.2:7b")
81
+ return True
82
+ else:
83
+ print("✗ Ollama is not responding correctly")
84
+ return False
85
+ except Exception as e:
86
+ print(f"✗ Ollama is not running. Please start it with: ollama serve")
87
+ print(f" Error: {e}")
88
+ return False
89
+
90
+
91
+ def main():
92
+ """Main entry point."""
93
+ parser = argparse.ArgumentParser(description="HPMOR Q&A System")
94
+ parser.add_argument(
95
+ "command",
96
+ choices=["setup", "chat", "test", "check"],
97
+ help="Command to run"
98
+ )
99
+ parser.add_argument(
100
+ "--force",
101
+ action="store_true",
102
+ help="Force recreate index and reprocess documents"
103
+ )
104
+
105
+ args = parser.parse_args()
106
+
107
+ if args.command == "setup":
108
+ setup_system(force_recreate=args.force)
109
+
110
+ elif args.command == "check":
111
+ print("System Check")
112
+ print("-" * 40)
113
+
114
+ # Check Ollama
115
+ ollama_ok = check_ollama()
116
+
117
+ # Check Groq
118
+ print("\nChecking Groq API...")
119
+ if config.has_groq_api():
120
+ print("✓ Groq API key configured")
121
+ else:
122
+ print("✗ Groq API key not configured")
123
+ print(" Add your key to .env file")
124
+
125
+ # Check data
126
+ print("\nChecking data files...")
127
+ if config.hpmor_file.exists():
128
+ print(f"✓ HPMOR file found: {config.hpmor_file}")
129
+ else:
130
+ print(f"✗ HPMOR file not found: {config.hpmor_file}")
131
+
132
+ # Check processed data
133
+ processed_docs = config.processed_data_dir / "documents.json"
134
+ if processed_docs.exists():
135
+ print(f"✓ Processed documents found")
136
+ else:
137
+ print("✗ No processed documents. Run: python main.py setup")
138
+
139
+ elif args.command == "test":
140
+ test_system()
141
+
142
+ elif args.command == "chat":
143
+ print("="*80)
144
+ print("HPMOR Q&A Chat Interface")
145
+ print("="*80)
146
+
147
+ # Check system
148
+ check_ollama()
149
+
150
+ if not config.has_groq_api():
151
+ print("\n⚠ Warning: Groq API key not configured.")
152
+ print(" The system will only use local models (if Ollama is running).")
153
+ print(" For best performance, add your Groq API key to the .env file.")
154
+
155
+ # Check if setup is needed
156
+ processed_docs = config.processed_data_dir / "documents.json"
157
+ if not processed_docs.exists():
158
+ print("\n⚠ No processed documents found. Running setup...")
159
+ setup_system()
160
+
161
+ # Launch chat interface
162
+ print("\nStarting chat interface...")
163
+ chat = ChatInterface()
164
+ chat.launch()
165
+
166
+
167
+ if __name__ == "__main__":
168
+ if len(sys.argv) == 1:
169
+ # No arguments provided, default to chat
170
+ sys.argv.append("chat")
171
+ main()
pyproject.toml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "hpmor-qa"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "chromadb>=1.1.1",
9
+ "gradio>=5.49.1",
10
+ "httpx>=0.28.1",
11
+ "langchain>=0.3.27",
12
+ "langchain-groq>=0.3.8",
13
+ "litellm>=1.78.0",
14
+ "llama-index>=0.14.4",
15
+ "llama-index-embeddings-huggingface>=0.6.1",
16
+ "llama-index-llms-groq>=0.4.1",
17
+ "llama-index-llms-ollama>=0.8.0",
18
+ "llama-index-vector-stores-chroma>=0.5.3",
19
+ "lxml>=6.0.2",
20
+ ]
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # HPMOR Q&A System
src/chat_interface.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio chat interface for HPMOR Q&A system."""
2
+
3
+ import gradio as gr
4
+ import json
5
+ from typing import List, Tuple, Optional
6
+ from datetime import datetime
7
+
8
+ from src.rag_engine import RAGEngine
9
+ from src.model_chain import ModelType
10
+ from src.config import config
11
+
12
+
13
+ class ChatInterface:
14
+ """Gradio-based chat interface for HPMOR Q&A."""
15
+
16
+ def __init__(self):
17
+ """Initialize the chat interface."""
18
+ print("Initializing HPMOR Q&A Chat Interface...")
19
+ self.engine = RAGEngine(force_recreate=False)
20
+ self.conversation_history = []
21
+
22
+ def format_sources(self, sources: List[dict]) -> str:
23
+ """Format sources for display."""
24
+ if not sources:
25
+ return "No sources found"
26
+
27
+ formatted = []
28
+ for i, source in enumerate(sources, 1):
29
+ formatted.append(
30
+ f"**Source {i}** - Chapter {source['chapter_number']}: {source['chapter_title']}\n"
31
+ f"Relevance Score: {source['score']:.2f}\n"
32
+ f"Preview: *{source['text_preview'][:150]}...*"
33
+ )
34
+ return "\n\n".join(formatted)
35
+
36
+ def process_message(
37
+ self,
38
+ message: str,
39
+ history: List[List[str]],
40
+ model_choice: str,
41
+ top_k: int,
42
+ show_sources: bool
43
+ ) -> Tuple[str, str, str]:
44
+ """Process a chat message and return response."""
45
+ if not message:
46
+ return "", "", "Please enter a question."
47
+
48
+ # Convert model choice to enum
49
+ model_map = {
50
+ "Auto (Intelligent Routing)": None,
51
+ "Local Small (Fast)": ModelType.LOCAL_SMALL,
52
+ "Local Large (Better)": ModelType.LOCAL_LARGE,
53
+ "Groq API (Best)": ModelType.GROQ_API
54
+ }
55
+ force_model = model_map.get(model_choice)
56
+
57
+ # Convert history to messages format
58
+ messages = []
59
+ for user_msg, assistant_msg in history:
60
+ if user_msg:
61
+ messages.append({"role": "user", "content": user_msg})
62
+ if assistant_msg:
63
+ messages.append({"role": "assistant", "content": assistant_msg})
64
+ messages.append({"role": "user", "content": message})
65
+
66
+ try:
67
+ # Get response from engine
68
+ response = self.engine.chat(messages, stream=False)
69
+
70
+ # Extract answer
71
+ if isinstance(response.get("answer"), str):
72
+ answer = response["answer"]
73
+ else:
74
+ # Handle LlamaIndex response object
75
+ answer = str(response.get("answer", "No response generated"))
76
+
77
+ # Format model info
78
+ model_info = f"**Model Used:** {response.get('model_used', 'Unknown')}"
79
+ if response.get("fallback_used"):
80
+ model_info += " (via fallback)"
81
+ model_info += f"\n**Context Size:** {response.get('context_size', 0)} characters"
82
+
83
+ # Format sources if requested
84
+ sources_text = ""
85
+ if show_sources and response.get("sources"):
86
+ sources_text = self.format_sources(response["sources"])
87
+
88
+ return answer, sources_text, model_info
89
+
90
+ except Exception as e:
91
+ error_msg = f"Error: {str(e)}"
92
+ return error_msg, "", "Error occurred"
93
+
94
+ def clear_conversation(self):
95
+ """Clear conversation history and cache."""
96
+ self.conversation_history = []
97
+ self.engine.clear_cache()
98
+ return None, "", "", "Conversation cleared"
99
+
100
+ def create_interface(self) -> gr.Blocks:
101
+ """Create the Gradio interface."""
102
+ with gr.Blocks(title="HPMOR Q&A System", theme=gr.themes.Soft()) as interface:
103
+ gr.Markdown(
104
+ """
105
+ # 📚 Harry Potter and the Methods of Rationality - Q&A System
106
+
107
+ Ask questions about HPMOR and get intelligent answers powered by RAG (Retrieval-Augmented Generation).
108
+ The system uses local models when possible and falls back to Groq API for complex queries.
109
+ """
110
+ )
111
+
112
+ with gr.Row():
113
+ with gr.Column(scale=2):
114
+ chatbot = gr.Chatbot(
115
+ label="Chat",
116
+ height=500,
117
+ show_copy_button=True
118
+ )
119
+
120
+ with gr.Row():
121
+ msg_input = gr.Textbox(
122
+ label="Your Question",
123
+ placeholder="Ask anything about HPMOR...",
124
+ lines=2,
125
+ scale=4
126
+ )
127
+ submit_btn = gr.Button("Send", variant="primary", scale=1)
128
+
129
+ with gr.Column(scale=1):
130
+ gr.Markdown("### Settings")
131
+
132
+ model_choice = gr.Radio(
133
+ choices=[
134
+ "Auto (Intelligent Routing)",
135
+ "Local Small (Fast)",
136
+ "Local Large (Better)",
137
+ "Groq API (Best)"
138
+ ],
139
+ value="Auto (Intelligent Routing)",
140
+ label="Model Selection"
141
+ )
142
+
143
+ top_k = gr.Slider(
144
+ minimum=1,
145
+ maximum=10,
146
+ value=5,
147
+ step=1,
148
+ label="Number of Context Chunks"
149
+ )
150
+
151
+ show_sources = gr.Checkbox(
152
+ value=True,
153
+ label="Show Sources"
154
+ )
155
+
156
+ clear_btn = gr.Button("Clear Conversation", variant="secondary")
157
+
158
+ gr.Markdown("### Model Info")
159
+ model_info = gr.Markdown(
160
+ value="Ready to answer questions",
161
+ elem_classes=["model-info"]
162
+ )
163
+
164
+ with gr.Row():
165
+ sources_display = gr.Markdown(
166
+ label="Retrieved Sources",
167
+ value="",
168
+ visible=True
169
+ )
170
+
171
+ # Example questions
172
+ gr.Examples(
173
+ examples=[
174
+ "What is Harry's initial reaction to learning about magic?",
175
+ "How does Harry apply the scientific method to understand magic?",
176
+ "What are the key differences between Harry and Hermione's approaches to learning?",
177
+ "Explain the concept of 'rationality' as presented in the story",
178
+ "What magical experiments does Harry conduct?",
179
+ ],
180
+ inputs=msg_input,
181
+ label="Example Questions"
182
+ )
183
+
184
+ # Event handlers
185
+ def respond(message, history, model, topk, sources):
186
+ """Handle message submission."""
187
+ answer, sources_text, info = self.process_message(
188
+ message, history, model, topk, sources
189
+ )
190
+ history.append([message, answer])
191
+ return "", history, sources_text, info
192
+
193
+ msg_input.submit(
194
+ respond,
195
+ inputs=[msg_input, chatbot, model_choice, top_k, show_sources],
196
+ outputs=[msg_input, chatbot, sources_display, model_info]
197
+ )
198
+
199
+ submit_btn.click(
200
+ respond,
201
+ inputs=[msg_input, chatbot, model_choice, top_k, show_sources],
202
+ outputs=[msg_input, chatbot, sources_display, model_info]
203
+ )
204
+
205
+ clear_btn.click(
206
+ lambda: self.clear_conversation(),
207
+ outputs=[chatbot, sources_display, msg_input, model_info]
208
+ )
209
+
210
+ # Add custom CSS
211
+ interface.css = """
212
+ .model-info {
213
+ background-color: #f0f0f0;
214
+ padding: 10px;
215
+ border-radius: 5px;
216
+ font-size: 0.9em;
217
+ }
218
+ """
219
+
220
+ return interface
221
+
222
+ def launch(self):
223
+ """Launch the Gradio interface."""
224
+ interface = self.create_interface()
225
+
226
+ print(f"\nLaunching HPMOR Q&A Chat Interface...")
227
+ print(f"Server will be available at: http://localhost:{config.gradio_server_port}")
228
+
229
+ interface.launch(
230
+ server_name="0.0.0.0",
231
+ server_port=config.gradio_server_port,
232
+ share=config.gradio_share,
233
+ favicon_path=None
234
+ )
235
+
236
+
237
+ def main():
238
+ """Launch the chat interface."""
239
+ chat = ChatInterface()
240
+ chat.launch()
241
+
242
+
243
+ if __name__ == "__main__":
244
+ main()
src/config.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration management for HPMOR Q&A System."""
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Optional
6
+ from pydantic import BaseModel, Field
7
+ from dotenv import load_dotenv
8
+
9
+ # Load environment variables
10
+ load_dotenv()
11
+
12
+ class Config(BaseModel):
13
+ """Application configuration."""
14
+
15
+ # API Keys
16
+ groq_api_key: Optional[str] = Field(default=os.getenv("GROQ_API_KEY"))
17
+
18
+ # Ollama Settings
19
+ ollama_host: str = Field(default=os.getenv("OLLAMA_HOST", "http://localhost:11434"))
20
+
21
+ # Model Names
22
+ local_model_small: str = Field(default=os.getenv("LOCAL_MODEL_SMALL", "llama3.2:7b"))
23
+ local_model_large: str = Field(default=os.getenv("LOCAL_MODEL_LARGE", "llama3.2:13b"))
24
+ groq_model: str = Field(default=os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile"))
25
+
26
+ # Embedding Model
27
+ embedding_model: str = Field(
28
+ default=os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
29
+ )
30
+
31
+ # Processing Parameters
32
+ chunk_size: int = Field(default=int(os.getenv("CHUNK_SIZE", "1000")))
33
+ chunk_overlap: int = Field(default=int(os.getenv("CHUNK_OVERLAP", "200")))
34
+ top_k_retrieval: int = Field(default=int(os.getenv("TOP_K_RETRIEVAL", "5")))
35
+
36
+ # Model Selection Thresholds
37
+ complexity_threshold: float = Field(
38
+ default=float(os.getenv("COMPLEXITY_THRESHOLD", "0.7"))
39
+ )
40
+ max_local_context_size: int = Field(
41
+ default=int(os.getenv("MAX_LOCAL_CONTEXT_SIZE", "4000"))
42
+ )
43
+
44
+ # ChromaDB Settings
45
+ chroma_persist_dir: Path = Field(
46
+ default=Path(os.getenv("CHROMA_PERSIST_DIR", "./chroma_db"))
47
+ )
48
+ collection_name: str = Field(
49
+ default=os.getenv("COLLECTION_NAME", "hpmor_collection")
50
+ )
51
+
52
+ # Gradio Settings
53
+ gradio_server_port: int = Field(
54
+ default=int(os.getenv("GRADIO_SERVER_PORT", "7860"))
55
+ )
56
+ gradio_share: bool = Field(
57
+ default=os.getenv("GRADIO_SHARE", "False").lower() == "true"
58
+ )
59
+
60
+ # File Paths
61
+ data_dir: Path = Field(default=Path("data"))
62
+ raw_data_dir: Path = Field(default=Path("data/raw"))
63
+ processed_data_dir: Path = Field(default=Path("data/processed"))
64
+ hpmor_file: Path = Field(default=Path("data/raw/hpmor.html"))
65
+
66
+ def validate_paths(self) -> None:
67
+ """Create necessary directories if they don't exist."""
68
+ for dir_path in [self.data_dir, self.raw_data_dir, self.processed_data_dir]:
69
+ dir_path.mkdir(parents=True, exist_ok=True)
70
+
71
+ self.chroma_persist_dir.mkdir(parents=True, exist_ok=True)
72
+
73
+ def has_groq_api(self) -> bool:
74
+ """Check if Groq API key is configured."""
75
+ return self.groq_api_key and self.groq_api_key != "your_groq_api_key_here"
76
+
77
+ def get_model_config(self, model_type: str) -> dict:
78
+ """Get configuration for a specific model type."""
79
+ configs = {
80
+ "local_small": {
81
+ "model": self.local_model_small,
82
+ "type": "ollama",
83
+ "max_tokens": 2048,
84
+ "temperature": 0.7,
85
+ },
86
+ "local_large": {
87
+ "model": self.local_model_large,
88
+ "type": "ollama",
89
+ "max_tokens": 4096,
90
+ "temperature": 0.7,
91
+ },
92
+ "groq": {
93
+ "model": self.groq_model,
94
+ "type": "groq",
95
+ "api_key": self.groq_api_key,
96
+ "max_tokens": 8192,
97
+ "temperature": 0.7,
98
+ },
99
+ }
100
+ return configs.get(model_type, configs["local_small"])
101
+
102
+ # Create global config instance
103
+ config = Config()
104
+ config.validate_paths()
src/document_processor.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Document processor for parsing and chunking HPMOR HTML."""
2
+
3
+ import re
4
+ import json
5
+ from pathlib import Path
6
+ from typing import List, Dict, Optional
7
+ from bs4 import BeautifulSoup
8
+ from llama_index.core import Document
9
+ from llama_index.core.node_parser import SentenceSplitter
10
+ from src.config import config
11
+
12
+
13
+ class HPMORProcessor:
14
+ """Process HPMOR HTML document into chunks for RAG."""
15
+
16
+ def __init__(self):
17
+ self.chunk_size = config.chunk_size
18
+ self.chunk_overlap = config.chunk_overlap
19
+ self.processed_dir = config.processed_data_dir
20
+
21
+ def parse_html(self, file_path: Path) -> List[Dict]:
22
+ """Parse HTML file and extract chapters with metadata."""
23
+ print(f"Parsing HTML file: {file_path}")
24
+
25
+ with open(file_path, 'r', encoding='utf-8') as f:
26
+ html_content = f.read()
27
+
28
+ soup = BeautifulSoup(html_content, 'lxml')
29
+
30
+ # Remove style and script tags
31
+ for tag in soup(['style', 'script']):
32
+ tag.decompose()
33
+
34
+ # Try to identify chapters by common patterns
35
+ chapters = []
36
+ chapter_pattern = re.compile(r'Chapter\s+(\d+)', re.IGNORECASE)
37
+
38
+ # Find all h1, h2, h3 tags that might be chapter headers
39
+ headers = soup.find_all(['h1', 'h2', 'h3'])
40
+
41
+ current_chapter = None
42
+ current_content = []
43
+ chapter_num = 0
44
+
45
+ for header in headers:
46
+ header_text = header.get_text(strip=True)
47
+ match = chapter_pattern.search(header_text)
48
+
49
+ if match:
50
+ # Save previous chapter if exists
51
+ if current_chapter and current_content:
52
+ chapters.append({
53
+ 'chapter_number': current_chapter['number'],
54
+ 'chapter_title': current_chapter['title'],
55
+ 'content': '\n'.join(current_content)
56
+ })
57
+
58
+ # Start new chapter
59
+ chapter_num = int(match.group(1))
60
+ current_chapter = {
61
+ 'number': chapter_num,
62
+ 'title': header_text
63
+ }
64
+ current_content = []
65
+
66
+ # Get content after this header until next chapter
67
+ for sibling in header.find_next_siblings():
68
+ if sibling.name in ['h1', 'h2', 'h3']:
69
+ if chapter_pattern.search(sibling.get_text()):
70
+ break
71
+ text = sibling.get_text(strip=True)
72
+ if text:
73
+ current_content.append(text)
74
+
75
+ # Add the last chapter
76
+ if current_chapter and current_content:
77
+ chapters.append({
78
+ 'chapter_number': current_chapter['number'],
79
+ 'chapter_title': current_chapter['title'],
80
+ 'content': '\n'.join(current_content)
81
+ })
82
+
83
+ # If no chapters found, treat entire content as one document
84
+ if not chapters:
85
+ print("No chapter structure found, processing as single document")
86
+ text_content = soup.get_text(separator='\n', strip=True)
87
+ chapters = [{
88
+ 'chapter_number': 0,
89
+ 'chapter_title': 'Harry Potter and the Methods of Rationality',
90
+ 'content': text_content
91
+ }]
92
+
93
+ print(f"Extracted {len(chapters)} chapters")
94
+ return chapters
95
+
96
+ def create_chunks(self, chapters: List[Dict]) -> List[Document]:
97
+ """Create overlapping chunks from chapters."""
98
+ print(f"Creating chunks with size={self.chunk_size}, overlap={self.chunk_overlap}")
99
+
100
+ documents = []
101
+ splitter = SentenceSplitter(
102
+ chunk_size=self.chunk_size,
103
+ chunk_overlap=self.chunk_overlap,
104
+ )
105
+
106
+ for chapter in chapters:
107
+ # Create a document for the chapter
108
+ chapter_doc = Document(
109
+ text=chapter['content'],
110
+ metadata={
111
+ 'chapter_number': chapter['chapter_number'],
112
+ 'chapter_title': chapter['chapter_title'],
113
+ 'source': 'hpmor.html'
114
+ }
115
+ )
116
+
117
+ # Split into chunks
118
+ nodes = splitter.get_nodes_from_documents([chapter_doc])
119
+
120
+ # Convert nodes back to documents with enhanced metadata
121
+ for i, node in enumerate(nodes):
122
+ doc = Document(
123
+ text=node.text,
124
+ metadata={
125
+ **chapter_doc.metadata,
126
+ 'chunk_id': f"ch{chapter['chapter_number']}_chunk{i}",
127
+ 'chunk_index': i,
128
+ 'total_chunks_in_chapter': len(nodes)
129
+ }
130
+ )
131
+ documents.append(doc)
132
+
133
+ print(f"Created {len(documents)} chunks total")
134
+ return documents
135
+
136
+ def save_processed_data(self, documents: List[Document], chapters: List[Dict]) -> None:
137
+ """Save processed documents and metadata to disk."""
138
+ # Save documents as JSON for easy loading
139
+ docs_data = []
140
+ for doc in documents:
141
+ docs_data.append({
142
+ 'text': doc.text,
143
+ 'metadata': doc.metadata
144
+ })
145
+
146
+ docs_file = self.processed_dir / 'documents.json'
147
+ with open(docs_file, 'w', encoding='utf-8') as f:
148
+ json.dump(docs_data, f, indent=2, ensure_ascii=False)
149
+ print(f"Saved {len(docs_data)} documents to {docs_file}")
150
+
151
+ # Save chapter metadata
152
+ chapters_file = self.processed_dir / 'chapters.json'
153
+ with open(chapters_file, 'w', encoding='utf-8') as f:
154
+ json.dump(chapters, f, indent=2, ensure_ascii=False)
155
+ print(f"Saved chapter metadata to {chapters_file}")
156
+
157
+ def load_processed_data(self) -> Optional[List[Document]]:
158
+ """Load previously processed documents."""
159
+ docs_file = self.processed_dir / 'documents.json'
160
+
161
+ if not docs_file.exists():
162
+ return None
163
+
164
+ with open(docs_file, 'r', encoding='utf-8') as f:
165
+ docs_data = json.load(f)
166
+
167
+ documents = []
168
+ for doc_data in docs_data:
169
+ doc = Document(
170
+ text=doc_data['text'],
171
+ metadata=doc_data['metadata']
172
+ )
173
+ documents.append(doc)
174
+
175
+ print(f"Loaded {len(documents)} documents from cache")
176
+ return documents
177
+
178
+ def process(self, force_reprocess: bool = False) -> List[Document]:
179
+ """Main processing pipeline."""
180
+ # Check if already processed
181
+ if not force_reprocess:
182
+ documents = self.load_processed_data()
183
+ if documents:
184
+ return documents
185
+
186
+ # Process from scratch
187
+ print("Processing HPMOR document from scratch...")
188
+
189
+ if not config.hpmor_file.exists():
190
+ raise FileNotFoundError(f"HPMOR file not found: {config.hpmor_file}")
191
+
192
+ # Parse HTML
193
+ chapters = self.parse_html(config.hpmor_file)
194
+
195
+ # Create chunks
196
+ documents = self.create_chunks(chapters)
197
+
198
+ # Save processed data
199
+ self.save_processed_data(documents, chapters)
200
+
201
+ return documents
202
+
203
+
204
+ def main():
205
+ """Process HPMOR document."""
206
+ processor = HPMORProcessor()
207
+ documents = processor.process(force_reprocess=True)
208
+ print(f"\nProcessing complete! Created {len(documents)} document chunks.")
209
+
210
+ # Show sample
211
+ if documents:
212
+ print("\nSample chunk:")
213
+ print(f"Text: {documents[0].text[:200]}...")
214
+ print(f"Metadata: {documents[0].metadata}")
215
+
216
+
217
+ if __name__ == "__main__":
218
+ main()
src/model_chain.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model chaining logic with Groq fallback."""
2
+
3
+ import re
4
+ from typing import Optional, Dict, Any, List
5
+ from enum import Enum
6
+
7
+ from llama_index.llms.ollama import Ollama
8
+ from llama_index.llms.groq import Groq
9
+ from llama_index.core.llms import LLM
10
+ from litellm import completion
11
+ import httpx
12
+
13
+ from src.config import config
14
+
15
+
16
+ class ModelType(Enum):
17
+ """Model types for routing."""
18
+ LOCAL_SMALL = "local_small"
19
+ LOCAL_LARGE = "local_large"
20
+ GROQ_API = "groq"
21
+
22
+
23
+ class QueryComplexity(Enum):
24
+ """Query complexity levels."""
25
+ SIMPLE = "simple" # Factual questions, definitions
26
+ MODERATE = "moderate" # Analysis, reasoning
27
+ COMPLEX = "complex" # Creative, multi-step reasoning
28
+
29
+
30
+ class ModelChain:
31
+ """Intelligent model routing with fallback to Groq."""
32
+
33
+ def __init__(self):
34
+ self.models = {}
35
+ self.groq_available = config.has_groq_api()
36
+
37
+ # Initialize models lazily
38
+ self._ollama_available = None
39
+
40
+ def check_ollama_available(self) -> bool:
41
+ """Check if Ollama is running and available."""
42
+ if self._ollama_available is not None:
43
+ return self._ollama_available
44
+
45
+ try:
46
+ # Try to connect to Ollama
47
+ response = httpx.get(f"{config.ollama_host}/api/tags", timeout=2.0)
48
+ self._ollama_available = response.status_code == 200
49
+ if self._ollama_available:
50
+ print("Ollama is available")
51
+ else:
52
+ print("Ollama is not responding correctly")
53
+ except Exception as e:
54
+ print(f"Ollama not available: {e}")
55
+ self._ollama_available = False
56
+
57
+ return self._ollama_available
58
+
59
+ def get_model(self, model_type: ModelType) -> Optional[LLM]:
60
+ """Get or initialize a model."""
61
+ if model_type in self.models:
62
+ return self.models[model_type]
63
+
64
+ if model_type == ModelType.GROQ_API:
65
+ if not self.groq_available:
66
+ print("Groq API key not configured")
67
+ return None
68
+
69
+ try:
70
+ # For groq/compound model, we'll use litellm
71
+ # Return a wrapper that uses litellm
72
+ return "groq" # Special marker for litellm usage
73
+ except Exception as e:
74
+ print(f"Failed to initialize Groq: {e}")
75
+ return None
76
+
77
+ elif model_type in [ModelType.LOCAL_SMALL, ModelType.LOCAL_LARGE]:
78
+ if not self.check_ollama_available():
79
+ print("Ollama not available, falling back to Groq")
80
+ return None
81
+
82
+ model_config = config.get_model_config(model_type.value)
83
+ try:
84
+ model = Ollama(
85
+ model=model_config["model"],
86
+ base_url=config.ollama_host,
87
+ temperature=model_config["temperature"],
88
+ request_timeout=120.0,
89
+ )
90
+ self.models[model_type] = model
91
+ print(f"Initialized {model_type.value} model: {model_config['model']}")
92
+ return model
93
+ except Exception as e:
94
+ print(f"Failed to initialize Ollama model: {e}")
95
+ return None
96
+
97
+ return None
98
+
99
+ def analyze_query_complexity(self, query: str, context_size: int = 0) -> QueryComplexity:
100
+ """Analyze query complexity to determine which model to use."""
101
+ query_lower = query.lower()
102
+
103
+ # Simple queries - factual questions
104
+ simple_patterns = [
105
+ r"what is",
106
+ r"who is",
107
+ r"when did",
108
+ r"where is",
109
+ r"define",
110
+ r"list",
111
+ r"name",
112
+ r"how many",
113
+ r"yes or no",
114
+ ]
115
+
116
+ # Complex queries - requiring reasoning or creativity
117
+ complex_patterns = [
118
+ r"explain why",
119
+ r"analyze",
120
+ r"compare and contrast",
121
+ r"what would happen if",
122
+ r"imagine",
123
+ r"create",
124
+ r"write a",
125
+ r"develop",
126
+ r"design",
127
+ r"evaluate",
128
+ r"critique",
129
+ r"synthesize",
130
+ ]
131
+
132
+ # Check for simple patterns
133
+ for pattern in simple_patterns:
134
+ if re.search(pattern, query_lower):
135
+ return QueryComplexity.SIMPLE
136
+
137
+ # Check for complex patterns
138
+ for pattern in complex_patterns:
139
+ if re.search(pattern, query_lower):
140
+ return QueryComplexity.COMPLEX
141
+
142
+ # Check query length and context size
143
+ if len(query.split()) > 50 or context_size > config.max_local_context_size:
144
+ return QueryComplexity.COMPLEX
145
+
146
+ # Default to moderate
147
+ return QueryComplexity.MODERATE
148
+
149
+ def route_query(
150
+ self,
151
+ query: str,
152
+ context: Optional[str] = None,
153
+ force_model: Optional[ModelType] = None
154
+ ) -> ModelType:
155
+ """Determine which model to use for the query."""
156
+ if force_model:
157
+ return force_model
158
+
159
+ context_size = len(context) if context else 0
160
+ complexity = self.analyze_query_complexity(query, context_size)
161
+
162
+ # Check Ollama availability
163
+ ollama_available = self.check_ollama_available()
164
+
165
+ # Routing logic
166
+ if complexity == QueryComplexity.SIMPLE:
167
+ if ollama_available:
168
+ return ModelType.LOCAL_SMALL
169
+ elif self.groq_available:
170
+ return ModelType.GROQ_API
171
+ elif complexity == QueryComplexity.MODERATE:
172
+ if ollama_available:
173
+ return ModelType.LOCAL_LARGE
174
+ elif self.groq_available:
175
+ return ModelType.GROQ_API
176
+ else: # COMPLEX
177
+ if self.groq_available:
178
+ return ModelType.GROQ_API
179
+ elif ollama_available:
180
+ return ModelType.LOCAL_LARGE
181
+
182
+ # Final fallback
183
+ if self.groq_available:
184
+ return ModelType.GROQ_API
185
+ elif ollama_available:
186
+ return ModelType.LOCAL_SMALL
187
+ else:
188
+ raise RuntimeError("No models available! Please check Ollama or configure Groq API key.")
189
+
190
+ def generate_response(
191
+ self,
192
+ query: str,
193
+ context: Optional[str] = None,
194
+ force_model: Optional[ModelType] = None,
195
+ stream: bool = False
196
+ ) -> Dict[str, Any]:
197
+ """Generate response using appropriate model."""
198
+ # Determine which model to use
199
+ model_type = self.route_query(query, context, force_model)
200
+ print(f"Using model: {model_type.value}")
201
+
202
+ # Prepare prompt
203
+ if context:
204
+ prompt = f"""Context from Harry Potter and the Methods of Rationality:
205
+ {context}
206
+
207
+ Question: {query}
208
+
209
+ Please provide a detailed answer based on the context provided above."""
210
+ else:
211
+ prompt = query
212
+
213
+ # Try primary model
214
+ try:
215
+ model = self.get_model(model_type)
216
+
217
+ if model == "groq": # Special handling for Groq via litellm
218
+ # Use litellm for Groq
219
+ response = completion(
220
+ model=f"groq/{config.groq_model}",
221
+ messages=[{"role": "user", "content": prompt}],
222
+ api_key=config.groq_api_key,
223
+ temperature=0.7,
224
+ max_tokens=2048,
225
+ stream=stream
226
+ )
227
+
228
+ if stream:
229
+ return {
230
+ "response": response,
231
+ "model_used": model_type.value,
232
+ "streaming": True
233
+ }
234
+ else:
235
+ return {
236
+ "response": response.choices[0].message.content,
237
+ "model_used": model_type.value,
238
+ "tokens_used": response.usage.total_tokens if hasattr(response, 'usage') else None
239
+ }
240
+
241
+ elif model:
242
+ # Use LlamaIndex model
243
+ if stream:
244
+ response = model.stream_complete(prompt)
245
+ else:
246
+ response = model.complete(prompt)
247
+
248
+ return {
249
+ "response": response,
250
+ "model_used": model_type.value,
251
+ "streaming": stream
252
+ }
253
+
254
+ except Exception as e:
255
+ print(f"Error with {model_type.value}: {e}")
256
+
257
+ # Try fallback
258
+ if model_type != ModelType.GROQ_API and self.groq_available:
259
+ print("Falling back to Groq API...")
260
+ model_type = ModelType.GROQ_API
261
+ try:
262
+ response = completion(
263
+ model=f"groq/{config.groq_model}",
264
+ messages=[{"role": "user", "content": prompt}],
265
+ api_key=config.groq_api_key,
266
+ temperature=0.7,
267
+ max_tokens=2048,
268
+ stream=stream
269
+ )
270
+
271
+ if stream:
272
+ return {
273
+ "response": response,
274
+ "model_used": model_type.value,
275
+ "streaming": True,
276
+ "fallback": True
277
+ }
278
+ else:
279
+ return {
280
+ "response": response.choices[0].message.content,
281
+ "model_used": model_type.value,
282
+ "tokens_used": response.usage.total_tokens if hasattr(response, 'usage') else None,
283
+ "fallback": True
284
+ }
285
+ except Exception as e2:
286
+ print(f"Fallback to Groq also failed: {e2}")
287
+ raise RuntimeError(f"All models failed. Last error: {e2}")
288
+
289
+ raise RuntimeError("No models available for response generation")
290
+
291
+
292
+ def main():
293
+ """Test model chaining."""
294
+ chain = ModelChain()
295
+
296
+ # Test queries of different complexities
297
+ test_queries = [
298
+ ("What is Harry's full name?", QueryComplexity.SIMPLE),
299
+ ("Explain Harry's reasoning about magic", QueryComplexity.MODERATE),
300
+ ("Analyze the philosophical implications of Harry's scientific approach to magic", QueryComplexity.COMPLEX),
301
+ ]
302
+
303
+ for query, expected_complexity in test_queries:
304
+ print(f"\nQuery: {query}")
305
+ complexity = chain.analyze_query_complexity(query)
306
+ print(f"Detected complexity: {complexity}")
307
+ print(f"Expected complexity: {expected_complexity}")
308
+
309
+ try:
310
+ model_type = chain.route_query(query)
311
+ print(f"Selected model: {model_type.value}")
312
+
313
+ # Generate response
314
+ result = chain.generate_response(query)
315
+ print(f"Model used: {result['model_used']}")
316
+ print(f"Response preview: {str(result['response'])[:200]}...")
317
+ except Exception as e:
318
+ print(f"Error: {e}")
319
+
320
+
321
+ if __name__ == "__main__":
322
+ main()
src/rag_engine.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RAG query engine for HPMOR Q&A system."""
2
+
3
+ from typing import Optional, List, Dict, Any
4
+ import json
5
+ from pathlib import Path
6
+
7
+ from llama_index.core import Document
8
+ from src.document_processor import HPMORProcessor
9
+ from src.vector_store import VectorStoreManager
10
+ from src.model_chain import ModelChain, ModelType
11
+ from src.config import config
12
+
13
+
14
+ class RAGEngine:
15
+ """Main RAG engine combining retrieval and generation."""
16
+
17
+ def __init__(self, force_recreate: bool = False):
18
+ """Initialize RAG engine components."""
19
+ print("Initializing RAG Engine...")
20
+
21
+ # Initialize components
22
+ self.processor = HPMORProcessor()
23
+ self.vector_store = VectorStoreManager()
24
+ self.model_chain = ModelChain()
25
+
26
+ # Process and index documents
27
+ self._initialize_index(force_recreate)
28
+
29
+ # Cache for responses
30
+ self.response_cache = {}
31
+
32
+ def _initialize_index(self, force_recreate: bool = False):
33
+ """Initialize or load the vector index."""
34
+ # Process documents
35
+ documents = self.processor.process(force_reprocess=force_recreate)
36
+
37
+ # Create or load index
38
+ self.index = self.vector_store.get_or_create_index(
39
+ documents=documents,
40
+ force_recreate=force_recreate
41
+ )
42
+
43
+ print(f"Index ready with {len(documents)} documents")
44
+
45
+ def retrieve_context(self, query: str, top_k: Optional[int] = None) -> tuple[str, List[Dict]]:
46
+ """Retrieve relevant context for a query."""
47
+ if top_k is None:
48
+ top_k = config.top_k_retrieval
49
+
50
+ # Query vector store
51
+ nodes = self.vector_store.query(query, top_k=top_k)
52
+
53
+ # Format context
54
+ context_parts = []
55
+ source_info = []
56
+
57
+ for i, node in enumerate(nodes, 1):
58
+ # Add to context
59
+ context_parts.append(f"[Excerpt {i}]\n{node.text}")
60
+
61
+ # Collect source info
62
+ source_info.append({
63
+ "chunk_id": node.metadata.get("chunk_id", "unknown"),
64
+ "chapter_number": node.metadata.get("chapter_number", 0),
65
+ "chapter_title": node.metadata.get("chapter_title", "Unknown"),
66
+ "score": float(node.score) if node.score else 0.0,
67
+ "text_preview": node.text[:200] + "..." if len(node.text) > 200 else node.text
68
+ })
69
+
70
+ context = "\n\n".join(context_parts)
71
+ return context, source_info
72
+
73
+ def query(
74
+ self,
75
+ question: str,
76
+ top_k: Optional[int] = None,
77
+ force_model: Optional[ModelType] = None,
78
+ return_sources: bool = True,
79
+ use_cache: bool = True,
80
+ stream: bool = False
81
+ ) -> Dict[str, Any]:
82
+ """Execute RAG query with retrieval and generation."""
83
+ # Check cache
84
+ cache_key = f"{question}_{top_k}_{force_model}"
85
+ if use_cache and cache_key in self.response_cache and not stream:
86
+ print("Returning cached response")
87
+ return self.response_cache[cache_key]
88
+
89
+ # Retrieve context
90
+ print(f"Retrieving context for: {question[:100]}...")
91
+ context, sources = self.retrieve_context(question, top_k)
92
+
93
+ # Generate response
94
+ print("Generating response...")
95
+ try:
96
+ result = self.model_chain.generate_response(
97
+ query=question,
98
+ context=context,
99
+ force_model=force_model,
100
+ stream=stream
101
+ )
102
+
103
+ # Prepare full response
104
+ full_response = {
105
+ "question": question,
106
+ "answer": result.get("response"),
107
+ "model_used": result.get("model_used"),
108
+ "sources": sources if return_sources else None,
109
+ "context_size": len(context),
110
+ "streaming": stream,
111
+ "fallback_used": result.get("fallback", False)
112
+ }
113
+
114
+ # Cache if not streaming
115
+ if use_cache and not stream:
116
+ self.response_cache[cache_key] = full_response
117
+
118
+ return full_response
119
+
120
+ except Exception as e:
121
+ print(f"Error generating response: {e}")
122
+ return {
123
+ "question": question,
124
+ "answer": f"Error generating response: {str(e)}",
125
+ "model_used": None,
126
+ "sources": sources if return_sources else None,
127
+ "error": str(e)
128
+ }
129
+
130
+ def chat(
131
+ self,
132
+ messages: List[Dict[str, str]],
133
+ stream: bool = False
134
+ ) -> Dict[str, Any]:
135
+ """Handle chat conversation with context."""
136
+ # Get the latest user message
137
+ if not messages or messages[-1]["role"] != "user":
138
+ return {"error": "No user message found"}
139
+
140
+ current_question = messages[-1]["content"]
141
+
142
+ # Build conversation context if multiple messages
143
+ conversation_context = ""
144
+ if len(messages) > 1:
145
+ prev_messages = messages[:-1][-4:] # Keep last 4 messages for context
146
+ for msg in prev_messages:
147
+ role = "Human" if msg["role"] == "user" else "Assistant"
148
+ conversation_context += f"{role}: {msg['content']}\n\n"
149
+
150
+ # Modify question to include conversation context
151
+ if conversation_context:
152
+ full_query = f"""Previous conversation:
153
+ {conversation_context}
154
+
155
+ Current question: {current_question}"""
156
+ else:
157
+ full_query = current_question
158
+
159
+ # Execute RAG query
160
+ response = self.query(
161
+ question=full_query,
162
+ return_sources=True,
163
+ stream=stream
164
+ )
165
+
166
+ return response
167
+
168
+ def get_stats(self) -> Dict[str, Any]:
169
+ """Get statistics about the RAG engine."""
170
+ vector_stats = self.vector_store.get_stats()
171
+
172
+ stats = {
173
+ "vector_store": vector_stats,
174
+ "cache_size": len(self.response_cache),
175
+ "models_available": {
176
+ "ollama": self.model_chain.check_ollama_available(),
177
+ "groq": self.model_chain.groq_available
178
+ }
179
+ }
180
+
181
+ return stats
182
+
183
+ def clear_cache(self):
184
+ """Clear response cache."""
185
+ self.response_cache = {}
186
+ print("Response cache cleared")
187
+
188
+
189
+ def main():
190
+ """Test RAG engine."""
191
+ # Initialize engine
192
+ print("Initializing RAG engine...")
193
+ engine = RAGEngine(force_recreate=False)
194
+
195
+ # Test queries
196
+ test_questions = [
197
+ "What is Harry Potter's approach to understanding magic?",
198
+ "How does Harry react when he first learns about magic?",
199
+ "What are Harry's thoughts on the scientific method?",
200
+ ]
201
+
202
+ for question in test_questions:
203
+ print(f"\n{'='*80}")
204
+ print(f"Question: {question}")
205
+ print(f"{'='*80}")
206
+
207
+ response = engine.query(question, top_k=3)
208
+
209
+ print(f"\nModel used: {response['model_used']}")
210
+ print(f"Context size: {response['context_size']} characters")
211
+
212
+ if response.get("fallback_used"):
213
+ print("(Fallback to Groq was used)")
214
+
215
+ print(f"\nAnswer:\n{response['answer']}")
216
+
217
+ if response.get("sources"):
218
+ print(f"\nSources ({len(response['sources'])} chunks):")
219
+ for i, source in enumerate(response['sources'], 1):
220
+ print(f" {i}. Chapter {source['chapter_number']}: {source['chapter_title']}")
221
+ print(f" Score: {source['score']:.4f}")
222
+
223
+ # Show stats
224
+ print(f"\n{'='*80}")
225
+ print("Engine Statistics:")
226
+ stats = engine.get_stats()
227
+ print(json.dumps(stats, indent=2))
228
+
229
+
230
+ if __name__ == "__main__":
231
+ main()
src/vector_store.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Vector store management for document embeddings."""
2
+
3
+ import os
4
+ from typing import List, Optional
5
+ from pathlib import Path
6
+
7
+ import chromadb
8
+ from chromadb.config import Settings
9
+ from llama_index.core import Document, VectorStoreIndex, StorageContext
10
+ from llama_index.vector_stores.chroma import ChromaVectorStore
11
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
12
+ from llama_index.core.node_parser import SentenceSplitter
13
+
14
+ from src.config import config
15
+
16
+
17
+ class VectorStoreManager:
18
+ """Manage ChromaDB vector store for document embeddings."""
19
+
20
+ def __init__(self):
21
+ self.collection_name = config.collection_name
22
+ self.persist_dir = str(config.chroma_persist_dir)
23
+ self.embedding_model = config.embedding_model
24
+
25
+ # Initialize embedding model
26
+ print(f"Loading embedding model: {self.embedding_model}")
27
+ self.embed_model = HuggingFaceEmbedding(
28
+ model_name=self.embedding_model,
29
+ cache_folder="./models"
30
+ )
31
+
32
+ # Initialize ChromaDB client
33
+ self.chroma_client = chromadb.PersistentClient(
34
+ path=self.persist_dir,
35
+ settings=Settings(anonymized_telemetry=False)
36
+ )
37
+
38
+ # Get or create collection
39
+ self.collection = None
40
+ self.vector_store = None
41
+ self.index = None
42
+
43
+ def initialize_collection(self, reset: bool = False) -> None:
44
+ """Initialize ChromaDB collection."""
45
+ if reset:
46
+ # Delete existing collection if it exists
47
+ try:
48
+ self.chroma_client.delete_collection(name=self.collection_name)
49
+ print(f"Deleted existing collection: {self.collection_name}")
50
+ except Exception:
51
+ pass
52
+
53
+ # Create or get collection
54
+ self.collection = self.chroma_client.get_or_create_collection(
55
+ name=self.collection_name,
56
+ metadata={"hnsw:space": "cosine"}
57
+ )
58
+ print(f"Using collection: {self.collection_name}")
59
+
60
+ # Initialize vector store
61
+ self.vector_store = ChromaVectorStore(
62
+ chroma_collection=self.collection,
63
+ embedding_function=self.embed_model
64
+ )
65
+
66
+ def create_index(self, documents: List[Document], show_progress: bool = True) -> VectorStoreIndex:
67
+ """Create vector index from documents."""
68
+ if not self.vector_store:
69
+ self.initialize_collection()
70
+
71
+ print(f"Creating index from {len(documents)} documents...")
72
+
73
+ # Create storage context
74
+ storage_context = StorageContext.from_defaults(
75
+ vector_store=self.vector_store
76
+ )
77
+
78
+ # Create index with documents
79
+ self.index = VectorStoreIndex.from_documents(
80
+ documents,
81
+ storage_context=storage_context,
82
+ embed_model=self.embed_model,
83
+ show_progress=show_progress
84
+ )
85
+
86
+ print("Index created successfully!")
87
+ return self.index
88
+
89
+ def load_index(self) -> Optional[VectorStoreIndex]:
90
+ """Load existing index from storage."""
91
+ if not self.vector_store:
92
+ self.initialize_collection()
93
+
94
+ # Check if collection has data
95
+ if self.collection.count() == 0:
96
+ print("No existing index found in ChromaDB")
97
+ return None
98
+
99
+ print(f"Loading index with {self.collection.count()} vectors")
100
+
101
+ # Create storage context
102
+ storage_context = StorageContext.from_defaults(
103
+ vector_store=self.vector_store
104
+ )
105
+
106
+ # Load index
107
+ self.index = VectorStoreIndex.from_vector_store(
108
+ self.vector_store,
109
+ storage_context=storage_context,
110
+ embed_model=self.embed_model
111
+ )
112
+
113
+ return self.index
114
+
115
+ def get_or_create_index(
116
+ self,
117
+ documents: Optional[List[Document]] = None,
118
+ force_recreate: bool = False
119
+ ) -> VectorStoreIndex:
120
+ """Get existing index or create new one."""
121
+ if not force_recreate:
122
+ # Try to load existing index
123
+ index = self.load_index()
124
+ if index:
125
+ return index
126
+
127
+ # Create new index
128
+ if not documents:
129
+ raise ValueError("No documents provided for creating index")
130
+
131
+ self.initialize_collection(reset=True)
132
+ return self.create_index(documents)
133
+
134
+ def query(self, query_text: str, top_k: int = None) -> List:
135
+ """Query the vector store."""
136
+ if not self.index:
137
+ raise ValueError("Index not initialized. Call get_or_create_index first.")
138
+
139
+ if top_k is None:
140
+ top_k = config.top_k_retrieval
141
+
142
+ # Use retriever directly instead of query engine to avoid LLM requirement
143
+ retriever = self.index.as_retriever(
144
+ similarity_top_k=top_k
145
+ )
146
+
147
+ # Retrieve nodes
148
+ nodes = retriever.retrieve(query_text)
149
+ return nodes
150
+
151
+ def get_stats(self) -> dict:
152
+ """Get statistics about the vector store."""
153
+ if not self.collection:
154
+ self.initialize_collection()
155
+
156
+ stats = {
157
+ "collection_name": self.collection_name,
158
+ "persist_dir": self.persist_dir,
159
+ "embedding_model": self.embedding_model,
160
+ "num_vectors": self.collection.count(),
161
+ "metadata": self.collection.metadata
162
+ }
163
+
164
+ return stats
165
+
166
+
167
+ def main():
168
+ """Test vector store functionality."""
169
+ from src.document_processor import HPMORProcessor
170
+
171
+ # Process documents
172
+ processor = HPMORProcessor()
173
+ documents = processor.process()
174
+
175
+ # Create vector store
176
+ vector_store = VectorStoreManager()
177
+ index = vector_store.get_or_create_index(documents, force_recreate=True)
178
+
179
+ # Get stats
180
+ stats = vector_store.get_stats()
181
+ print("\nVector Store Statistics:")
182
+ for key, value in stats.items():
183
+ print(f" {key}: {value}")
184
+
185
+ # Test query
186
+ test_query = "What is Harry's opinion on magic?"
187
+ print(f"\nTest query: '{test_query}'")
188
+ results = vector_store.query(test_query, top_k=3)
189
+
190
+ print(f"\nFound {len(results)} relevant chunks:")
191
+ for i, node in enumerate(results, 1):
192
+ print(f"\n{i}. Score: {node.score:.4f}")
193
+ print(f" Chapter: {node.metadata.get('chapter_title', 'Unknown')}")
194
+ print(f" Text preview: {node.text[:200]}...")
195
+
196
+
197
+ if __name__ == "__main__":
198
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff