File size: 16,479 Bytes
ca6e669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
import os
import logging
import json
import time
import csv
from typing import List, Dict, Optional, Any

import torch
from sentence_transformers import CrossEncoder

from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS

from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_text_splitters import RecursiveCharacterTextSplitter

from config import (
    RAG_RERANKER_MODEL_NAME, RAG_DETAILED_LOGGING,
    RAG_CHUNK_SIZE, RAG_CHUNK_OVERLAP, RAG_CHUNKED_SOURCES_FILENAME,
    RAG_FAISS_INDEX_SUBDIR_NAME, RAG_INITIAL_FETCH_K, RAG_RERANKER_K,
    RAG_MAX_FILES_FOR_INCREMENTAL
)
from utils import FAISS_RAG_SUPPORTED_EXTENSIONS

logger = logging.getLogger(__name__)


class DocumentReranker:
    def __init__(self, model_name: str = RAG_RERANKER_MODEL_NAME):
        self.logger = logging.getLogger(__name__ + ".DocumentReranker")
        self.model_name = model_name
        self.model = None
        
        try:
            self.logger.info(f"[RERANKER_INIT] Loading reranker model: {self.model_name}")
            start_time = time.time()
            self.model = CrossEncoder(model_name, trust_remote_code=True)
            load_time = time.time() - start_time
            self.logger.info(f"[RERANKER_INIT] Reranker model '{self.model_name}' loaded successfully in {load_time:.2f}s")
        except Exception as e:
            self.logger.error(f"[RERANKER_INIT] Failed to load reranker model '{self.model_name}': {e}", exc_info=True)
            raise RuntimeError(f"Could not initialize reranker model: {e}") from e

    def rerank_documents(self, query: str, documents: List[Document], top_k: int) -> List[Document]:
        if not documents or not self.model:
            return documents[:top_k] if documents else []

        try:
            start_time = time.time()
            doc_pairs = [[query, doc.page_content] for doc in documents]
            scores = self.model.predict(doc_pairs)
            rerank_time = time.time() - start_time
            self.logger.info(f"[RERANKER] Computed relevance scores in {rerank_time:.3f}s")
            
            doc_score_pairs = list(zip(documents, scores))
            doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
            
            reranked_docs = []
            for doc, score in doc_score_pairs[:top_k]:
                doc.metadata["reranker_score"] = float(score)
                reranked_docs.append(doc)
            
            return reranked_docs
        except Exception as e:
            self.logger.error(f"[RERANKER] Error during reranking: {e}", exc_info=True)
            return documents[:top_k] if documents else []


class FAISSRetrieverWithScore(BaseRetriever):
    vectorstore: FAISS
    reranker: Optional[DocumentReranker] = None
    initial_fetch_k: int = RAG_INITIAL_FETCH_K
    final_k: int = RAG_RERANKER_K

    def _get_relevant_documents(

        self, query: str, *, run_manager: CallbackManagerForRetrieverRun

    ) -> List[Document]:
        
        start_time = time.time()
        num_to_fetch = self.initial_fetch_k if self.reranker else self.final_k
        
        logger.info(f"[RETRIEVER] Fetching {num_to_fetch} docs (Rerank={self.reranker is not None})")

        docs_and_scores = self.vectorstore.similarity_search_with_score(query, k=num_to_fetch)
        
        relevant_docs = []
        for doc, score in docs_and_scores:
            doc.metadata["retrieval_score"] = float(score)
            relevant_docs.append(doc)
        
        if self.reranker and relevant_docs:
            relevant_docs = self.reranker.rerank_documents(query, relevant_docs, top_k=self.final_k)
        
        total_time = time.time() - start_time
        logger.info(f"[RETRIEVER] Completed in {total_time:.3f}s. Returned {len(relevant_docs)} docs.")
        return relevant_docs


class KnowledgeRAG:
    def __init__(

        self,

        index_storage_dir: str,

        embedding_model_name: str,

        use_gpu_for_embeddings: bool,

        chunk_size: int = RAG_CHUNK_SIZE,

        chunk_overlap: int = RAG_CHUNK_OVERLAP,

        reranker_model_name: Optional[str] = None,

        enable_reranker: bool = True,

    ):
        self.logger = logging.getLogger(__name__ + ".KnowledgeRAG")
        self.logger.info(f"[RAG_INIT] Initializing KnowledgeRAG system")
        
        self.index_storage_dir = index_storage_dir
        os.makedirs(self.index_storage_dir, exist_ok=True)

        self.embedding_model_name = embedding_model_name
        self.use_gpu_for_embeddings = use_gpu_for_embeddings
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.reranker_model_name = reranker_model_name or RAG_RERANKER_MODEL_NAME
        self.enable_reranker = enable_reranker
        self.reranker = None

        device = "cpu"
        if self.use_gpu_for_embeddings:
            if torch.cuda.is_available():
                self.logger.info(f"[RAG_INIT] CUDA available. Requesting GPU.")
                device = "cuda"
            else:
                self.logger.warning("[RAG_INIT] CUDA not available. Fallback to CPU.")
        
        self.embeddings = HuggingFaceEmbeddings(
            model_name=self.embedding_model_name,
            model_kwargs={"device": device},
            encode_kwargs={"normalize_embeddings": True}
        )

        if self.enable_reranker:
            try:
                self.reranker = DocumentReranker(self.reranker_model_name)
            except Exception as e:
                self.logger.warning(f"[RAG_INIT] Reranker Init Failed: {e}")
                self.reranker = None

        self.vector_store: Optional[FAISS] = None
        self.retriever: Optional[FAISSRetrieverWithScore] = None
        self.processed_source_files: List[str] = []

    def _save_chunk_config(self):
        faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
        config_file = os.path.join(faiss_path, "chunk_config.json")
        with open(config_file, 'w') as f:
            json.dump({"chunk_size": self.chunk_size, "chunk_overlap": self.chunk_overlap}, f)

    def _load_chunk_config(self) -> Optional[dict]:
        faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
        config_file = os.path.join(faiss_path, "chunk_config.json")
        if os.path.exists(config_file):
            with open(config_file, 'r') as f:
                return json.load(f)
        return None

    def chunk_config_has_changed(self) -> bool:
        saved = self._load_chunk_config()
        if saved is None:
            return False  
        changed = saved.get("chunk_size") != self.chunk_size or saved.get("chunk_overlap") != self.chunk_overlap
        if changed:
            self.logger.warning(
                f"[CONFIG_CHANGE] Chunk config mismatch! "
                f"Saved=(size={saved.get('chunk_size')}, overlap={saved.get('chunk_overlap')}) "
                f"Current=(size={self.chunk_size}, overlap={self.chunk_overlap}). "
                f"Index will be rebuilt."
            )
        return changed

    def build_index_from_source_files(self, source_folder_path: str):
        self.logger.info(f"[INDEX_BUILD] Building from: {source_folder_path}")
        if not os.path.isdir(source_folder_path):
            raise FileNotFoundError(f"Source folder not found: '{source_folder_path}'.")

        all_docs = []
        processed_files = []
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)

        pre_chunked_path = os.path.join(self.index_storage_dir, RAG_CHUNKED_SOURCES_FILENAME)
        if os.path.exists(pre_chunked_path):
            try:
                with open(pre_chunked_path, 'r', encoding='utf-8') as f:
                    chunk_data_list = json.load(f)
                for chunk in chunk_data_list:
                    doc = Document(page_content=chunk.get("page_content", ""), metadata=chunk.get("metadata", {}))
                    all_docs.append(doc)
                    if 'source_document_name' in doc.metadata:
                        processed_files.append(doc.metadata['source_document_name'])
                processed_files = sorted(list(set(processed_files)))
            except Exception as e:
                self.logger.error(f"[INDEX_BUILD] JSON load failed: {e}")

        if not all_docs:
            for filename in os.listdir(source_folder_path):
                file_path = os.path.join(source_folder_path, filename)
                if not os.path.isfile(file_path): continue
                file_ext = filename.split('.')[-1].lower()
                
                if file_ext in FAISS_RAG_SUPPORTED_EXTENSIONS:
                    # Specific handler for CSV formatting
                    if file_ext == 'csv':
                        try:
                            with open(file_path, mode='r', encoding='utf-8-sig') as f:
                                reader = csv.DictReader(f)
                                for i, row in enumerate(reader):
                                    row_text = "\n".join([f"{k}: {v}" for k, v in row.items() if k and v and str(v).strip()])
                                    meta = {"source_document_name": filename, "chunk_index": i, "source_type": "csv"}
                                    all_docs.append(Document(page_content=row_text, metadata=meta))
                            processed_files.append(filename)
                        except Exception as e:
                            self.logger.error(f"[INDEX_BUILD] Error processing CSV {filename}: {e}")
                    else:
                        text_content = FAISS_RAG_SUPPORTED_EXTENSIONS[file_ext](file_path)
                        if text_content and text_content != "CSV_HANDLED_NATIVELY":
                            chunks = text_splitter.split_text(text_content)
                            for i, chunk_text in enumerate(chunks):
                                meta = {"source_document_name": filename, "chunk_index": i}
                                all_docs.append(Document(page_content=chunk_text, metadata=meta))
                            processed_files.append(filename)

        if not all_docs:
            raise ValueError("No documents to index.")

        self.processed_source_files = processed_files
        self.logger.info(f"[INDEX_BUILD] Creating FAISS index with {len(all_docs)} chunks.")
        
        self.vector_store = FAISS.from_documents(all_docs, self.embeddings)
        faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
        self.vector_store.save_local(faiss_path)
        self._save_chunk_config()
        
        self.retriever = FAISSRetrieverWithScore(
            vectorstore=self.vector_store,
            reranker=self.reranker,
            initial_fetch_k=RAG_INITIAL_FETCH_K,
            final_k=RAG_RERANKER_K
        )

    def load_index_from_disk(self):
        faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
        if not os.path.exists(faiss_path):
            raise FileNotFoundError("Index not found.")

        self.vector_store = FAISS.load_local(
            folder_path=faiss_path,
            embeddings=self.embeddings,
            allow_dangerous_deserialization=True
        )
        self.retriever = FAISSRetrieverWithScore(
            vectorstore=self.vector_store,
            reranker=self.reranker,
            initial_fetch_k=RAG_INITIAL_FETCH_K,
            final_k=RAG_RERANKER_K
        )
        
        meta_file = os.path.join(faiss_path, "processed_files.json")
        if os.path.exists(meta_file):
            with open(meta_file, 'r') as f:
                self.processed_source_files = json.load(f)
        else:
            self.processed_source_files = ["Loaded from disk (unknown sources)"]
            
        self.logger.info("[INDEX_LOAD] Success.")

    def update_index_with_new_files(self, source_folder_path: str, max_files_to_process: Optional[int] = None) -> Dict[str, Any]:
        self.logger.info(f"[INDEX_UPDATE] Checking for new files in: {source_folder_path}")
        
        if not self.vector_store:
            raise RuntimeError("Cannot update: no index loaded.")
        
        processed_set = set(self.processed_source_files)
        all_new_files = []
        for filename in sorted(os.listdir(source_folder_path)):
            if filename not in processed_set:
                file_ext = filename.split('.')[-1].lower()
                if file_ext in FAISS_RAG_SUPPORTED_EXTENSIONS:
                    all_new_files.append(filename)

        if not all_new_files:
            return {"status": "success", "message": "No new files found.", "files_added": []}
            
        limit = max_files_to_process if max_files_to_process is not None else RAG_MAX_FILES_FOR_INCREMENTAL
        files_to_process = all_new_files[:limit]
        
        new_docs = []
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
        
        for filename in files_to_process:
            file_path = os.path.join(source_folder_path, filename)
            file_ext = filename.split('.')[-1].lower()

            if file_ext == 'csv':
                try:
                    with open(file_path, mode='r', encoding='utf-8-sig') as f:
                        reader = csv.DictReader(f)
                        for i, row in enumerate(reader):
                            row_text = "\n".join([f"{k}: {v}" for k, v in row.items() if k and v and str(v).strip()])
                            meta = {"source_document_name": filename, "chunk_index": i, "source_type": "csv"}
                            new_docs.append(Document(page_content=row_text, metadata=meta))
                except Exception as e:
                    self.logger.error(f"[INDEX_UPDATE] Error processing CSV {filename}: {e}")
            else:
                text_content = FAISS_RAG_SUPPORTED_EXTENSIONS[file_ext](file_path)
                if text_content and text_content != "CSV_HANDLED_NATIVELY":
                    chunks = text_splitter.split_text(text_content)
                    for i, chunk_text in enumerate(chunks):
                        meta = {"source_document_name": filename, "chunk_index": i}
                        new_docs.append(Document(page_content=chunk_text, metadata=meta))
        
        if not new_docs:
            return {"status": "warning", "message": "New files found but no text extracted.", "files_added": []}
            
        self.vector_store.add_documents(new_docs)
        
        faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
        self.vector_store.save_local(faiss_path)
        
        self.processed_source_files.extend(files_to_process)
        with open(os.path.join(faiss_path, "processed_files.json"), 'w') as f:
            json.dump(sorted(self.processed_source_files), f)

        return {
            "status": "success",
            "message": f"Added {len(files_to_process)} files.",
            "files_added": files_to_process,
            "remaining": len(all_new_files) - len(files_to_process)
        }

    def search_knowledge_base(self, query: str, top_k: Optional[int] = None) -> List[Dict[str, Any]]:
        if not self.retriever:
            raise RuntimeError("Retriever not initialized.")
        
        original_k = self.retriever.final_k
        if top_k:
            self.retriever.final_k = top_k

        try:
            docs = self.retriever.invoke(query)
            results = []
            for doc in docs:
                results.append({
                    "content": doc.page_content,
                    "metadata": doc.metadata,
                    "score": doc.metadata.get("reranker_score") or doc.metadata.get("retrieval_score")
                })
            return results
        finally:
            self.retriever.final_k = original_k