File size: 10,961 Bytes
62884e7
 
 
 
 
 
 
 
 
 
 
 
 
ef3e36a
62884e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef3e36a
62884e7
ef3e36a
62884e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
"""
RAG (Retrieval-Augmented Generation) Utilities

Provides document loading, chunking, embedding, and retrieval for the AI chatbot.
"""

import os
import json
import hashlib
from pathlib import Path
from typing import List, Dict, Optional, Tuple
import numpy as np

from utils.openrouter_client import get_embedding, get_query_embedding, get_openrouter_client


# Configuration
CHUNK_SIZE = 500  # Target tokens per chunk (approximate)
CHUNK_OVERLAP = 50  # Overlap between chunks
SUPPORTED_EXTENSIONS = {'.txt', '.md'}
CACHE_FILE = "embeddings_cache.json"


class DocumentChunk:
    """Represents a chunk of a document with its embedding."""
    
    def __init__(
        self,
        content: str,
        source_file: str,
        chunk_index: int,
        embedding: Optional[List[float]] = None
    ):
        self.content = content
        self.source_file = source_file
        self.chunk_index = chunk_index
        self.embedding = embedding
        self.content_hash = hashlib.md5(content.encode()).hexdigest()
    
    def to_dict(self) -> Dict:
        """Convert to dictionary for JSON serialization."""
        return {
            "content": self.content,
            "source_file": self.source_file,
            "chunk_index": self.chunk_index,
            "embedding": self.embedding,
            "content_hash": self.content_hash
        }
    
    @classmethod
    def from_dict(cls, data: Dict) -> 'DocumentChunk':
        """Create from dictionary."""
        chunk = cls(
            content=data["content"],
            source_file=data["source_file"],
            chunk_index=data["chunk_index"],
            embedding=data.get("embedding")
        )
        chunk.content_hash = data.get("content_hash", chunk.content_hash)
        return chunk


class RAGService:
    """Service for managing RAG document retrieval."""
    
    def __init__(self, docs_path: str = "rag_docs"):
        """
        Initialize the RAG service.
        
        Args:
            docs_path: Path to the documents folder
        """
        self.docs_path = Path(docs_path)
        self.cache_path = self.docs_path / CACHE_FILE
        self.chunks: List[DocumentChunk] = []
        self._loaded = False
    
    def _estimate_tokens(self, text: str) -> int:
        """Estimate token count (rough approximation: ~4 chars per token)."""
        return len(text) // 4
    
    def _chunk_text(self, text: str, source_file: str) -> List[DocumentChunk]:
        """
        Split text into chunks with overlap.
        
        Args:
            text: Text content to chunk
            source_file: Name of the source file
            
        Returns:
            List of DocumentChunk objects
        """
        chunks = []
        
        # Split into paragraphs first
        paragraphs = text.split('\n\n')
        
        current_chunk = ""
        chunk_index = 0
        
        for para in paragraphs:
            para = para.strip()
            if not para:
                continue
            
            # If adding this paragraph exceeds chunk size, save current and start new
            if self._estimate_tokens(current_chunk + para) > CHUNK_SIZE and current_chunk:
                chunks.append(DocumentChunk(
                    content=current_chunk.strip(),
                    source_file=source_file,
                    chunk_index=chunk_index
                ))
                chunk_index += 1
                
                # Keep overlap from the end of current chunk
                words = current_chunk.split()
                overlap_words = words[-CHUNK_OVERLAP:] if len(words) > CHUNK_OVERLAP else words
                current_chunk = " ".join(overlap_words) + "\n\n"
            
            current_chunk += para + "\n\n"
        
        # Don't forget the last chunk
        if current_chunk.strip():
            chunks.append(DocumentChunk(
                content=current_chunk.strip(),
                source_file=source_file,
                chunk_index=chunk_index
            ))
        
        return chunks
    
    def load_documents(self) -> int:
        """
        Load and chunk all documents from the docs folder.
        
        Returns:
            Number of chunks loaded
        """
        if not self.docs_path.exists():
            print(f"RAG docs folder not found: {self.docs_path}")
            return 0
        
        # Try to load from cache first
        cached_chunks = self._load_cache()
        cached_hashes = {c.content_hash for c in cached_chunks}
        
        new_chunks = []
        
        # Load all document files
        for file_path in self.docs_path.iterdir():
            if file_path.suffix.lower() not in SUPPORTED_EXTENSIONS:
                continue
            if file_path.name == CACHE_FILE or file_path.name.startswith('.'):
                continue
            
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    content = f.read()
                
                file_chunks = self._chunk_text(content, file_path.name)
                
                for chunk in file_chunks:
                    if chunk.content_hash in cached_hashes:
                        # Use cached version with embedding
                        cached_chunk = next(
                            (c for c in cached_chunks if c.content_hash == chunk.content_hash),
                            None
                        )
                        if cached_chunk:
                            new_chunks.append(cached_chunk)
                    else:
                        new_chunks.append(chunk)
                        
            except Exception as e:
                print(f"Error loading {file_path}: {e}")
        
        self.chunks = new_chunks
        self._loaded = True
        
        return len(self.chunks)
    
    def embed_documents(self) -> int:
        """
        Generate embeddings for all chunks that don't have them.
        
        Returns:
            Number of new embeddings generated
        """
        if not self._loaded:
            self.load_documents()
        
        client = get_openrouter_client()
        if not client.is_available:
            print("OpenRouter client not available, skipping embedding generation")
            return 0
        
        embedded_count = 0
        
        for chunk in self.chunks:
            if chunk.embedding is None:
                embedding = get_embedding(chunk.content)
                if embedding:
                    chunk.embedding = embedding
                    embedded_count += 1
        
        # Save to cache after embedding
        if embedded_count > 0:
            self._save_cache()
        
        return embedded_count
    
    def _load_cache(self) -> List[DocumentChunk]:
        """Load cached embeddings from file."""
        if not self.cache_path.exists():
            return []
        
        try:
            with open(self.cache_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            return [DocumentChunk.from_dict(d) for d in data]
        except Exception as e:
            print(f"Error loading cache: {e}")
            return []
    
    def _save_cache(self):
        """Save embeddings to cache file."""
        try:
            data = [c.to_dict() for c in self.chunks if c.embedding is not None]
            with open(self.cache_path, 'w', encoding='utf-8') as f:
                json.dump(data, f)
        except Exception as e:
            print(f"Error saving cache: {e}")
    
    def retrieve(self, query: str, top_k: int = 3) -> List[Tuple[DocumentChunk, float]]:
        """
        Retrieve the most relevant chunks for a query.
        
        Args:
            query: User's query
            top_k: Number of chunks to retrieve
            
        Returns:
            List of (chunk, similarity_score) tuples
        """
        if not self._loaded:
            self.load_documents()
            self.embed_documents()
        
        # Get query embedding
        query_embedding = get_query_embedding(query)
        if query_embedding is None:
            return []
        
        query_vec = np.array(query_embedding)
        
        # Calculate similarities
        results = []
        for chunk in self.chunks:
            if chunk.embedding is None:
                continue
            
            chunk_vec = np.array(chunk.embedding)
            
            # Cosine similarity
            similarity = np.dot(query_vec, chunk_vec) / (
                np.linalg.norm(query_vec) * np.linalg.norm(chunk_vec) + 1e-8
            )
            results.append((chunk, float(similarity)))
        
        # Sort by similarity and return top_k
        results.sort(key=lambda x: x[1], reverse=True)
        return results[:top_k]
    
    def build_context(self, query: str, top_k: int = 3) -> str:
        """
        Build context string from retrieved chunks.
        
        Args:
            query: User's query
            top_k: Number of chunks to include
            
        Returns:
            Formatted context string for the prompt
        """
        results = self.retrieve(query, top_k)
        
        if not results:
            return ""
        
        context_parts = []
        for chunk, score in results:
            source = chunk.source_file
            context_parts.append(f"[From {source}]:\n{chunk.content}")
        
        return "\n\n---\n\n".join(context_parts)


# Singleton instance
_rag_instance: Optional[RAGService] = None


def get_rag_service(docs_path: str = "rag_docs") -> RAGService:
    """Get or create the singleton RAG service instance."""
    global _rag_instance
    if _rag_instance is None:
        _rag_instance = RAGService(docs_path)
    return _rag_instance


def retrieve_relevant_chunks(query: str, top_k: int = 3) -> List[Tuple[DocumentChunk, float]]:
    """
    Convenience function to retrieve relevant chunks.
    
    Args:
        query: User's query
        top_k: Number of chunks to retrieve
        
    Returns:
        List of (chunk, score) tuples
    """
    service = get_rag_service()
    return service.retrieve(query, top_k)


def build_rag_context(query: str, top_k: int = 3) -> str:
    """
    Convenience function to build RAG context.
    
    Args:
        query: User's query
        top_k: Number of chunks to include
        
    Returns:
        Formatted context string
    """
    service = get_rag_service()
    return service.build_context(query, top_k)


def initialize_rag(docs_path: str = "rag_docs") -> int:
    """
    Initialize the RAG service by loading and embedding documents.
    
    Args:
        docs_path: Path to documents folder
        
    Returns:
        Number of chunks loaded
    """
    service = get_rag_service(docs_path)
    num_chunks = service.load_documents()
    service.embed_documents()
    return num_chunks