File size: 7,591 Bytes
6ef4823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""RAG query engine for HPMOR Q&A system."""

from typing import Optional, List, Dict, Any
import json
from pathlib import Path

from llama_index.core import Document
from src.document_processor import HPMORProcessor
from src.vector_store import VectorStoreManager
from src.model_chain import ModelChain, ModelType
from src.config import config


class RAGEngine:
    """Main RAG engine combining retrieval and generation."""

    def __init__(self, force_recreate: bool = False):
        """Initialize RAG engine components."""
        print("Initializing RAG Engine...")

        # Initialize components
        self.processor = HPMORProcessor()
        self.vector_store = VectorStoreManager()
        self.model_chain = ModelChain()

        # Process and index documents
        self._initialize_index(force_recreate)

        # Cache for responses
        self.response_cache = {}

    def _initialize_index(self, force_recreate: bool = False):
        """Initialize or load the vector index."""
        # Process documents
        documents = self.processor.process(force_reprocess=force_recreate)

        # Create or load index
        self.index = self.vector_store.get_or_create_index(
            documents=documents,
            force_recreate=force_recreate
        )

        print(f"Index ready with {len(documents)} documents")

    def retrieve_context(self, query: str, top_k: Optional[int] = None) -> tuple[str, List[Dict]]:
        """Retrieve relevant context for a query."""
        if top_k is None:
            top_k = config.top_k_retrieval

        # Query vector store
        nodes = self.vector_store.query(query, top_k=top_k)

        # Format context
        context_parts = []
        source_info = []

        for i, node in enumerate(nodes, 1):
            # Add to context
            context_parts.append(f"[Excerpt {i}]\n{node.text}")

            # Collect source info
            source_info.append({
                "chunk_id": node.metadata.get("chunk_id", "unknown"),
                "chapter_number": node.metadata.get("chapter_number", 0),
                "chapter_title": node.metadata.get("chapter_title", "Unknown"),
                "score": float(node.score) if node.score else 0.0,
                "text_preview": node.text[:200] + "..." if len(node.text) > 200 else node.text
            })

        context = "\n\n".join(context_parts)
        return context, source_info

    def query(
        self,
        question: str,
        top_k: Optional[int] = None,
        force_model: Optional[ModelType] = None,
        return_sources: bool = True,
        use_cache: bool = True,
        stream: bool = False
    ) -> Dict[str, Any]:
        """Execute RAG query with retrieval and generation."""
        # Check cache
        cache_key = f"{question}_{top_k}_{force_model}"
        if use_cache and cache_key in self.response_cache and not stream:
            print("Returning cached response")
            return self.response_cache[cache_key]

        # Retrieve context
        print(f"Retrieving context for: {question[:100]}...")
        context, sources = self.retrieve_context(question, top_k)

        # Generate response
        print("Generating response...")
        try:
            result = self.model_chain.generate_response(
                query=question,
                context=context,
                force_model=force_model,
                stream=stream
            )

            # Prepare full response
            full_response = {
                "question": question,
                "answer": result.get("response"),
                "model_used": result.get("model_used"),
                "sources": sources if return_sources else None,
                "context_size": len(context),
                "streaming": stream,
                "fallback_used": result.get("fallback", False)
            }

            # Cache if not streaming
            if use_cache and not stream:
                self.response_cache[cache_key] = full_response

            return full_response

        except Exception as e:
            print(f"Error generating response: {e}")
            return {
                "question": question,
                "answer": f"Error generating response: {str(e)}",
                "model_used": None,
                "sources": sources if return_sources else None,
                "error": str(e)
            }

    def chat(
        self,
        messages: List[Dict[str, str]],
        stream: bool = False
    ) -> Dict[str, Any]:
        """Handle chat conversation with context."""
        # Get the latest user message
        if not messages or messages[-1]["role"] != "user":
            return {"error": "No user message found"}

        current_question = messages[-1]["content"]

        # Build conversation context if multiple messages
        conversation_context = ""
        if len(messages) > 1:
            prev_messages = messages[:-1][-4:]  # Keep last 4 messages for context
            for msg in prev_messages:
                role = "Human" if msg["role"] == "user" else "Assistant"
                conversation_context += f"{role}: {msg['content']}\n\n"

        # Modify question to include conversation context
        if conversation_context:
            full_query = f"""Previous conversation:
{conversation_context}

Current question: {current_question}"""
        else:
            full_query = current_question

        # Execute RAG query
        response = self.query(
            question=full_query,
            return_sources=True,
            stream=stream
        )

        return response

    def get_stats(self) -> Dict[str, Any]:
        """Get statistics about the RAG engine."""
        vector_stats = self.vector_store.get_stats()

        stats = {
            "vector_store": vector_stats,
            "cache_size": len(self.response_cache),
            "models_available": {
                "ollama": self.model_chain.check_ollama_available(),
                "groq": self.model_chain.groq_available
            }
        }

        return stats

    def clear_cache(self):
        """Clear response cache."""
        self.response_cache = {}
        print("Response cache cleared")


def main():
    """Test RAG engine."""
    # Initialize engine
    print("Initializing RAG engine...")
    engine = RAGEngine(force_recreate=False)

    # Test queries
    test_questions = [
        "What is Harry Potter's approach to understanding magic?",
        "How does Harry react when he first learns about magic?",
        "What are Harry's thoughts on the scientific method?",
    ]

    for question in test_questions:
        print(f"\n{'='*80}")
        print(f"Question: {question}")
        print(f"{'='*80}")

        response = engine.query(question, top_k=3)

        print(f"\nModel used: {response['model_used']}")
        print(f"Context size: {response['context_size']} characters")

        if response.get("fallback_used"):
            print("(Fallback to Groq was used)")

        print(f"\nAnswer:\n{response['answer']}")

        if response.get("sources"):
            print(f"\nSources ({len(response['sources'])} chunks):")
            for i, source in enumerate(response['sources'], 1):
                print(f"  {i}. Chapter {source['chapter_number']}: {source['chapter_title']}")
                print(f"     Score: {source['score']:.4f}")

    # Show stats
    print(f"\n{'='*80}")
    print("Engine Statistics:")
    stats = engine.get_stats()
    print(json.dumps(stats, indent=2))


if __name__ == "__main__":
    main()