import os import logging import json import time import csv from typing import List, Dict, Optional, Any import torch from sentence_transformers import CrossEncoder from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_text_splitters import RecursiveCharacterTextSplitter from config import ( RAG_RERANKER_MODEL_NAME, RAG_DETAILED_LOGGING, RAG_CHUNK_SIZE, RAG_CHUNK_OVERLAP, RAG_CHUNKED_SOURCES_FILENAME, RAG_FAISS_INDEX_SUBDIR_NAME, RAG_INITIAL_FETCH_K, RAG_RERANKER_K, RAG_MAX_FILES_FOR_INCREMENTAL ) from utils import FAISS_RAG_SUPPORTED_EXTENSIONS logger = logging.getLogger(__name__) class DocumentReranker: def __init__(self, model_name: str = RAG_RERANKER_MODEL_NAME): self.logger = logging.getLogger(__name__ + ".DocumentReranker") self.model_name = model_name self.model = None try: self.logger.info(f"[RERANKER_INIT] Loading reranker model: {self.model_name}") start_time = time.time() self.model = CrossEncoder(model_name, trust_remote_code=True) load_time = time.time() - start_time self.logger.info(f"[RERANKER_INIT] Reranker model '{self.model_name}' loaded successfully in {load_time:.2f}s") except Exception as e: self.logger.error(f"[RERANKER_INIT] Failed to load reranker model '{self.model_name}': {e}", exc_info=True) raise RuntimeError(f"Could not initialize reranker model: {e}") from e def rerank_documents(self, query: str, documents: List[Document], top_k: int) -> List[Document]: if not documents or not self.model: return documents[:top_k] if documents else [] try: start_time = time.time() doc_pairs = [[query, doc.page_content] for doc in documents] scores = self.model.predict(doc_pairs) rerank_time = time.time() - start_time self.logger.info(f"[RERANKER] Computed relevance scores in {rerank_time:.3f}s") doc_score_pairs = list(zip(documents, scores)) doc_score_pairs.sort(key=lambda x: x[1], reverse=True) reranked_docs = [] for doc, score in doc_score_pairs[:top_k]: doc.metadata["reranker_score"] = float(score) reranked_docs.append(doc) return reranked_docs except Exception as e: self.logger.error(f"[RERANKER] Error during reranking: {e}", exc_info=True) return documents[:top_k] if documents else [] class FAISSRetrieverWithScore(BaseRetriever): vectorstore: FAISS reranker: Optional[DocumentReranker] = None initial_fetch_k: int = RAG_INITIAL_FETCH_K final_k: int = RAG_RERANKER_K def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: start_time = time.time() num_to_fetch = self.initial_fetch_k if self.reranker else self.final_k logger.info(f"[RETRIEVER] Fetching {num_to_fetch} docs (Rerank={self.reranker is not None})") docs_and_scores = self.vectorstore.similarity_search_with_score(query, k=num_to_fetch) relevant_docs = [] for doc, score in docs_and_scores: doc.metadata["retrieval_score"] = float(score) relevant_docs.append(doc) if self.reranker and relevant_docs: relevant_docs = self.reranker.rerank_documents(query, relevant_docs, top_k=self.final_k) total_time = time.time() - start_time logger.info(f"[RETRIEVER] Completed in {total_time:.3f}s. Returned {len(relevant_docs)} docs.") return relevant_docs class KnowledgeRAG: def __init__( self, index_storage_dir: str, embedding_model_name: str, use_gpu_for_embeddings: bool, chunk_size: int = RAG_CHUNK_SIZE, chunk_overlap: int = RAG_CHUNK_OVERLAP, reranker_model_name: Optional[str] = None, enable_reranker: bool = True, ): self.logger = logging.getLogger(__name__ + ".KnowledgeRAG") self.logger.info(f"[RAG_INIT] Initializing KnowledgeRAG system") self.index_storage_dir = index_storage_dir os.makedirs(self.index_storage_dir, exist_ok=True) self.embedding_model_name = embedding_model_name self.use_gpu_for_embeddings = use_gpu_for_embeddings self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap self.reranker_model_name = reranker_model_name or RAG_RERANKER_MODEL_NAME self.enable_reranker = enable_reranker self.reranker = None device = "cpu" if self.use_gpu_for_embeddings: if torch.cuda.is_available(): self.logger.info(f"[RAG_INIT] CUDA available. Requesting GPU.") device = "cuda" else: self.logger.warning("[RAG_INIT] CUDA not available. Fallback to CPU.") self.embeddings = HuggingFaceEmbeddings( model_name=self.embedding_model_name, model_kwargs={"device": device}, encode_kwargs={"normalize_embeddings": True} ) if self.enable_reranker: try: self.reranker = DocumentReranker(self.reranker_model_name) except Exception as e: self.logger.warning(f"[RAG_INIT] Reranker Init Failed: {e}") self.reranker = None self.vector_store: Optional[FAISS] = None self.retriever: Optional[FAISSRetrieverWithScore] = None self.processed_source_files: List[str] = [] def _save_chunk_config(self): faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME) config_file = os.path.join(faiss_path, "chunk_config.json") with open(config_file, 'w') as f: json.dump({"chunk_size": self.chunk_size, "chunk_overlap": self.chunk_overlap}, f) def _load_chunk_config(self) -> Optional[dict]: faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME) config_file = os.path.join(faiss_path, "chunk_config.json") if os.path.exists(config_file): with open(config_file, 'r') as f: return json.load(f) return None def chunk_config_has_changed(self) -> bool: saved = self._load_chunk_config() if saved is None: return False changed = saved.get("chunk_size") != self.chunk_size or saved.get("chunk_overlap") != self.chunk_overlap if changed: self.logger.warning( f"[CONFIG_CHANGE] Chunk config mismatch! " f"Saved=(size={saved.get('chunk_size')}, overlap={saved.get('chunk_overlap')}) " f"Current=(size={self.chunk_size}, overlap={self.chunk_overlap}). " f"Index will be rebuilt." ) return changed def build_index_from_source_files(self, source_folder_path: str): self.logger.info(f"[INDEX_BUILD] Building from: {source_folder_path}") if not os.path.isdir(source_folder_path): raise FileNotFoundError(f"Source folder not found: '{source_folder_path}'.") all_docs = [] processed_files = [] text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) pre_chunked_path = os.path.join(self.index_storage_dir, RAG_CHUNKED_SOURCES_FILENAME) if os.path.exists(pre_chunked_path): try: with open(pre_chunked_path, 'r', encoding='utf-8') as f: chunk_data_list = json.load(f) for chunk in chunk_data_list: doc = Document(page_content=chunk.get("page_content", ""), metadata=chunk.get("metadata", {})) all_docs.append(doc) if 'source_document_name' in doc.metadata: processed_files.append(doc.metadata['source_document_name']) processed_files = sorted(list(set(processed_files))) except Exception as e: self.logger.error(f"[INDEX_BUILD] JSON load failed: {e}") if not all_docs: for filename in os.listdir(source_folder_path): file_path = os.path.join(source_folder_path, filename) if not os.path.isfile(file_path): continue file_ext = filename.split('.')[-1].lower() if file_ext in FAISS_RAG_SUPPORTED_EXTENSIONS: # Specific handler for CSV formatting if file_ext == 'csv': try: with open(file_path, mode='r', encoding='utf-8-sig') as f: reader = csv.DictReader(f) for i, row in enumerate(reader): row_text = "\n".join([f"{k}: {v}" for k, v in row.items() if k and v and str(v).strip()]) meta = {"source_document_name": filename, "chunk_index": i, "source_type": "csv"} all_docs.append(Document(page_content=row_text, metadata=meta)) processed_files.append(filename) except Exception as e: self.logger.error(f"[INDEX_BUILD] Error processing CSV {filename}: {e}") else: text_content = FAISS_RAG_SUPPORTED_EXTENSIONS[file_ext](file_path) if text_content and text_content != "CSV_HANDLED_NATIVELY": chunks = text_splitter.split_text(text_content) for i, chunk_text in enumerate(chunks): meta = {"source_document_name": filename, "chunk_index": i} all_docs.append(Document(page_content=chunk_text, metadata=meta)) processed_files.append(filename) if not all_docs: raise ValueError("No documents to index.") self.processed_source_files = processed_files self.logger.info(f"[INDEX_BUILD] Creating FAISS index with {len(all_docs)} chunks.") self.vector_store = FAISS.from_documents(all_docs, self.embeddings) faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME) self.vector_store.save_local(faiss_path) self._save_chunk_config() self.retriever = FAISSRetrieverWithScore( vectorstore=self.vector_store, reranker=self.reranker, initial_fetch_k=RAG_INITIAL_FETCH_K, final_k=RAG_RERANKER_K ) def load_index_from_disk(self): faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME) if not os.path.exists(faiss_path): raise FileNotFoundError("Index not found.") self.vector_store = FAISS.load_local( folder_path=faiss_path, embeddings=self.embeddings, allow_dangerous_deserialization=True ) self.retriever = FAISSRetrieverWithScore( vectorstore=self.vector_store, reranker=self.reranker, initial_fetch_k=RAG_INITIAL_FETCH_K, final_k=RAG_RERANKER_K ) meta_file = os.path.join(faiss_path, "processed_files.json") if os.path.exists(meta_file): with open(meta_file, 'r') as f: self.processed_source_files = json.load(f) else: self.processed_source_files = ["Loaded from disk (unknown sources)"] self.logger.info("[INDEX_LOAD] Success.") def update_index_with_new_files(self, source_folder_path: str, max_files_to_process: Optional[int] = None) -> Dict[str, Any]: self.logger.info(f"[INDEX_UPDATE] Checking for new files in: {source_folder_path}") if not self.vector_store: raise RuntimeError("Cannot update: no index loaded.") processed_set = set(self.processed_source_files) all_new_files = [] for filename in sorted(os.listdir(source_folder_path)): if filename not in processed_set: file_ext = filename.split('.')[-1].lower() if file_ext in FAISS_RAG_SUPPORTED_EXTENSIONS: all_new_files.append(filename) if not all_new_files: return {"status": "success", "message": "No new files found.", "files_added": []} limit = max_files_to_process if max_files_to_process is not None else RAG_MAX_FILES_FOR_INCREMENTAL files_to_process = all_new_files[:limit] new_docs = [] text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) for filename in files_to_process: file_path = os.path.join(source_folder_path, filename) file_ext = filename.split('.')[-1].lower() if file_ext == 'csv': try: with open(file_path, mode='r', encoding='utf-8-sig') as f: reader = csv.DictReader(f) for i, row in enumerate(reader): row_text = "\n".join([f"{k}: {v}" for k, v in row.items() if k and v and str(v).strip()]) meta = {"source_document_name": filename, "chunk_index": i, "source_type": "csv"} new_docs.append(Document(page_content=row_text, metadata=meta)) except Exception as e: self.logger.error(f"[INDEX_UPDATE] Error processing CSV {filename}: {e}") else: text_content = FAISS_RAG_SUPPORTED_EXTENSIONS[file_ext](file_path) if text_content and text_content != "CSV_HANDLED_NATIVELY": chunks = text_splitter.split_text(text_content) for i, chunk_text in enumerate(chunks): meta = {"source_document_name": filename, "chunk_index": i} new_docs.append(Document(page_content=chunk_text, metadata=meta)) if not new_docs: return {"status": "warning", "message": "New files found but no text extracted.", "files_added": []} self.vector_store.add_documents(new_docs) faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME) self.vector_store.save_local(faiss_path) self.processed_source_files.extend(files_to_process) with open(os.path.join(faiss_path, "processed_files.json"), 'w') as f: json.dump(sorted(self.processed_source_files), f) return { "status": "success", "message": f"Added {len(files_to_process)} files.", "files_added": files_to_process, "remaining": len(all_new_files) - len(files_to_process) } def search_knowledge_base(self, query: str, top_k: Optional[int] = None) -> List[Dict[str, Any]]: if not self.retriever: raise RuntimeError("Retriever not initialized.") original_k = self.retriever.final_k if top_k: self.retriever.final_k = top_k try: docs = self.retriever.invoke(query) results = [] for doc in docs: results.append({ "content": doc.page_content, "metadata": doc.metadata, "score": doc.metadata.get("reranker_score") or doc.metadata.get("retrieval_score") }) return results finally: self.retriever.final_k = original_k