Spaces:
Running
Running
| import os | |
| import logging | |
| import json | |
| import time | |
| import csv | |
| from typing import List, Dict, Optional, Any | |
| import torch | |
| from sentence_transformers import CrossEncoder | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_core.documents import Document | |
| from langchain_core.retrievers import BaseRetriever | |
| from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from config import ( | |
| RAG_RERANKER_MODEL_NAME, RAG_DETAILED_LOGGING, | |
| RAG_CHUNK_SIZE, RAG_CHUNK_OVERLAP, RAG_CHUNKED_SOURCES_FILENAME, | |
| RAG_FAISS_INDEX_SUBDIR_NAME, RAG_INITIAL_FETCH_K, RAG_RERANKER_K, | |
| RAG_MAX_FILES_FOR_INCREMENTAL | |
| ) | |
| from utils import FAISS_RAG_SUPPORTED_EXTENSIONS | |
| logger = logging.getLogger(__name__) | |
| class DocumentReranker: | |
| def __init__(self, model_name: str = RAG_RERANKER_MODEL_NAME): | |
| self.logger = logging.getLogger(__name__ + ".DocumentReranker") | |
| self.model_name = model_name | |
| self.model = None | |
| try: | |
| self.logger.info(f"[RERANKER_INIT] Loading reranker model: {self.model_name}") | |
| start_time = time.time() | |
| self.model = CrossEncoder(model_name, trust_remote_code=True) | |
| load_time = time.time() - start_time | |
| self.logger.info(f"[RERANKER_INIT] Reranker model '{self.model_name}' loaded successfully in {load_time:.2f}s") | |
| except Exception as e: | |
| self.logger.error(f"[RERANKER_INIT] Failed to load reranker model '{self.model_name}': {e}", exc_info=True) | |
| raise RuntimeError(f"Could not initialize reranker model: {e}") from e | |
| def rerank_documents(self, query: str, documents: List[Document], top_k: int) -> List[Document]: | |
| if not documents or not self.model: | |
| return documents[:top_k] if documents else [] | |
| try: | |
| start_time = time.time() | |
| doc_pairs = [[query, doc.page_content] for doc in documents] | |
| scores = self.model.predict(doc_pairs) | |
| rerank_time = time.time() - start_time | |
| self.logger.info(f"[RERANKER] Computed relevance scores in {rerank_time:.3f}s") | |
| doc_score_pairs = list(zip(documents, scores)) | |
| doc_score_pairs.sort(key=lambda x: x[1], reverse=True) | |
| reranked_docs = [] | |
| for doc, score in doc_score_pairs[:top_k]: | |
| doc.metadata["reranker_score"] = float(score) | |
| reranked_docs.append(doc) | |
| return reranked_docs | |
| except Exception as e: | |
| self.logger.error(f"[RERANKER] Error during reranking: {e}", exc_info=True) | |
| return documents[:top_k] if documents else [] | |
| class FAISSRetrieverWithScore(BaseRetriever): | |
| vectorstore: FAISS | |
| reranker: Optional[DocumentReranker] = None | |
| initial_fetch_k: int = RAG_INITIAL_FETCH_K | |
| final_k: int = RAG_RERANKER_K | |
| def _get_relevant_documents( | |
| self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
| ) -> List[Document]: | |
| start_time = time.time() | |
| num_to_fetch = self.initial_fetch_k if self.reranker else self.final_k | |
| logger.info(f"[RETRIEVER] Fetching {num_to_fetch} docs (Rerank={self.reranker is not None})") | |
| docs_and_scores = self.vectorstore.similarity_search_with_score(query, k=num_to_fetch) | |
| relevant_docs = [] | |
| for doc, score in docs_and_scores: | |
| doc.metadata["retrieval_score"] = float(score) | |
| relevant_docs.append(doc) | |
| if self.reranker and relevant_docs: | |
| relevant_docs = self.reranker.rerank_documents(query, relevant_docs, top_k=self.final_k) | |
| total_time = time.time() - start_time | |
| logger.info(f"[RETRIEVER] Completed in {total_time:.3f}s. Returned {len(relevant_docs)} docs.") | |
| return relevant_docs | |
| class KnowledgeRAG: | |
| def __init__( | |
| self, | |
| index_storage_dir: str, | |
| embedding_model_name: str, | |
| use_gpu_for_embeddings: bool, | |
| chunk_size: int = RAG_CHUNK_SIZE, | |
| chunk_overlap: int = RAG_CHUNK_OVERLAP, | |
| reranker_model_name: Optional[str] = None, | |
| enable_reranker: bool = True, | |
| ): | |
| self.logger = logging.getLogger(__name__ + ".KnowledgeRAG") | |
| self.logger.info(f"[RAG_INIT] Initializing KnowledgeRAG system") | |
| self.index_storage_dir = index_storage_dir | |
| os.makedirs(self.index_storage_dir, exist_ok=True) | |
| self.embedding_model_name = embedding_model_name | |
| self.use_gpu_for_embeddings = use_gpu_for_embeddings | |
| self.chunk_size = chunk_size | |
| self.chunk_overlap = chunk_overlap | |
| self.reranker_model_name = reranker_model_name or RAG_RERANKER_MODEL_NAME | |
| self.enable_reranker = enable_reranker | |
| self.reranker = None | |
| device = "cpu" | |
| if self.use_gpu_for_embeddings: | |
| if torch.cuda.is_available(): | |
| self.logger.info(f"[RAG_INIT] CUDA available. Requesting GPU.") | |
| device = "cuda" | |
| else: | |
| self.logger.warning("[RAG_INIT] CUDA not available. Fallback to CPU.") | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name=self.embedding_model_name, | |
| model_kwargs={"device": device}, | |
| encode_kwargs={"normalize_embeddings": True} | |
| ) | |
| if self.enable_reranker: | |
| try: | |
| self.reranker = DocumentReranker(self.reranker_model_name) | |
| except Exception as e: | |
| self.logger.warning(f"[RAG_INIT] Reranker Init Failed: {e}") | |
| self.reranker = None | |
| self.vector_store: Optional[FAISS] = None | |
| self.retriever: Optional[FAISSRetrieverWithScore] = None | |
| self.processed_source_files: List[str] = [] | |
| def _save_chunk_config(self): | |
| faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME) | |
| config_file = os.path.join(faiss_path, "chunk_config.json") | |
| with open(config_file, 'w') as f: | |
| json.dump({"chunk_size": self.chunk_size, "chunk_overlap": self.chunk_overlap}, f) | |
| def _load_chunk_config(self) -> Optional[dict]: | |
| faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME) | |
| config_file = os.path.join(faiss_path, "chunk_config.json") | |
| if os.path.exists(config_file): | |
| with open(config_file, 'r') as f: | |
| return json.load(f) | |
| return None | |
| def chunk_config_has_changed(self) -> bool: | |
| saved = self._load_chunk_config() | |
| if saved is None: | |
| return False | |
| changed = saved.get("chunk_size") != self.chunk_size or saved.get("chunk_overlap") != self.chunk_overlap | |
| if changed: | |
| self.logger.warning( | |
| f"[CONFIG_CHANGE] Chunk config mismatch! " | |
| f"Saved=(size={saved.get('chunk_size')}, overlap={saved.get('chunk_overlap')}) " | |
| f"Current=(size={self.chunk_size}, overlap={self.chunk_overlap}). " | |
| f"Index will be rebuilt." | |
| ) | |
| return changed | |
| def build_index_from_source_files(self, source_folder_path: str): | |
| self.logger.info(f"[INDEX_BUILD] Building from: {source_folder_path}") | |
| if not os.path.isdir(source_folder_path): | |
| raise FileNotFoundError(f"Source folder not found: '{source_folder_path}'.") | |
| all_docs = [] | |
| processed_files = [] | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) | |
| pre_chunked_path = os.path.join(self.index_storage_dir, RAG_CHUNKED_SOURCES_FILENAME) | |
| if os.path.exists(pre_chunked_path): | |
| try: | |
| with open(pre_chunked_path, 'r', encoding='utf-8') as f: | |
| chunk_data_list = json.load(f) | |
| for chunk in chunk_data_list: | |
| doc = Document(page_content=chunk.get("page_content", ""), metadata=chunk.get("metadata", {})) | |
| all_docs.append(doc) | |
| if 'source_document_name' in doc.metadata: | |
| processed_files.append(doc.metadata['source_document_name']) | |
| processed_files = sorted(list(set(processed_files))) | |
| except Exception as e: | |
| self.logger.error(f"[INDEX_BUILD] JSON load failed: {e}") | |
| if not all_docs: | |
| for filename in os.listdir(source_folder_path): | |
| file_path = os.path.join(source_folder_path, filename) | |
| if not os.path.isfile(file_path): continue | |
| file_ext = filename.split('.')[-1].lower() | |
| if file_ext in FAISS_RAG_SUPPORTED_EXTENSIONS: | |
| # Specific handler for CSV formatting | |
| if file_ext == 'csv': | |
| try: | |
| with open(file_path, mode='r', encoding='utf-8-sig') as f: | |
| reader = csv.DictReader(f) | |
| for i, row in enumerate(reader): | |
| row_text = "\n".join([f"{k}: {v}" for k, v in row.items() if k and v and str(v).strip()]) | |
| meta = {"source_document_name": filename, "chunk_index": i, "source_type": "csv"} | |
| all_docs.append(Document(page_content=row_text, metadata=meta)) | |
| processed_files.append(filename) | |
| except Exception as e: | |
| self.logger.error(f"[INDEX_BUILD] Error processing CSV {filename}: {e}") | |
| else: | |
| text_content = FAISS_RAG_SUPPORTED_EXTENSIONS[file_ext](file_path) | |
| if text_content and text_content != "CSV_HANDLED_NATIVELY": | |
| chunks = text_splitter.split_text(text_content) | |
| for i, chunk_text in enumerate(chunks): | |
| meta = {"source_document_name": filename, "chunk_index": i} | |
| all_docs.append(Document(page_content=chunk_text, metadata=meta)) | |
| processed_files.append(filename) | |
| if not all_docs: | |
| raise ValueError("No documents to index.") | |
| self.processed_source_files = processed_files | |
| self.logger.info(f"[INDEX_BUILD] Creating FAISS index with {len(all_docs)} chunks.") | |
| self.vector_store = FAISS.from_documents(all_docs, self.embeddings) | |
| faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME) | |
| self.vector_store.save_local(faiss_path) | |
| self._save_chunk_config() | |
| self.retriever = FAISSRetrieverWithScore( | |
| vectorstore=self.vector_store, | |
| reranker=self.reranker, | |
| initial_fetch_k=RAG_INITIAL_FETCH_K, | |
| final_k=RAG_RERANKER_K | |
| ) | |
| def load_index_from_disk(self): | |
| faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME) | |
| if not os.path.exists(faiss_path): | |
| raise FileNotFoundError("Index not found.") | |
| self.vector_store = FAISS.load_local( | |
| folder_path=faiss_path, | |
| embeddings=self.embeddings, | |
| allow_dangerous_deserialization=True | |
| ) | |
| self.retriever = FAISSRetrieverWithScore( | |
| vectorstore=self.vector_store, | |
| reranker=self.reranker, | |
| initial_fetch_k=RAG_INITIAL_FETCH_K, | |
| final_k=RAG_RERANKER_K | |
| ) | |
| meta_file = os.path.join(faiss_path, "processed_files.json") | |
| if os.path.exists(meta_file): | |
| with open(meta_file, 'r') as f: | |
| self.processed_source_files = json.load(f) | |
| else: | |
| self.processed_source_files = ["Loaded from disk (unknown sources)"] | |
| self.logger.info("[INDEX_LOAD] Success.") | |
| def update_index_with_new_files(self, source_folder_path: str, max_files_to_process: Optional[int] = None) -> Dict[str, Any]: | |
| self.logger.info(f"[INDEX_UPDATE] Checking for new files in: {source_folder_path}") | |
| if not self.vector_store: | |
| raise RuntimeError("Cannot update: no index loaded.") | |
| processed_set = set(self.processed_source_files) | |
| all_new_files = [] | |
| for filename in sorted(os.listdir(source_folder_path)): | |
| if filename not in processed_set: | |
| file_ext = filename.split('.')[-1].lower() | |
| if file_ext in FAISS_RAG_SUPPORTED_EXTENSIONS: | |
| all_new_files.append(filename) | |
| if not all_new_files: | |
| return {"status": "success", "message": "No new files found.", "files_added": []} | |
| limit = max_files_to_process if max_files_to_process is not None else RAG_MAX_FILES_FOR_INCREMENTAL | |
| files_to_process = all_new_files[:limit] | |
| new_docs = [] | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) | |
| for filename in files_to_process: | |
| file_path = os.path.join(source_folder_path, filename) | |
| file_ext = filename.split('.')[-1].lower() | |
| if file_ext == 'csv': | |
| try: | |
| with open(file_path, mode='r', encoding='utf-8-sig') as f: | |
| reader = csv.DictReader(f) | |
| for i, row in enumerate(reader): | |
| row_text = "\n".join([f"{k}: {v}" for k, v in row.items() if k and v and str(v).strip()]) | |
| meta = {"source_document_name": filename, "chunk_index": i, "source_type": "csv"} | |
| new_docs.append(Document(page_content=row_text, metadata=meta)) | |
| except Exception as e: | |
| self.logger.error(f"[INDEX_UPDATE] Error processing CSV {filename}: {e}") | |
| else: | |
| text_content = FAISS_RAG_SUPPORTED_EXTENSIONS[file_ext](file_path) | |
| if text_content and text_content != "CSV_HANDLED_NATIVELY": | |
| chunks = text_splitter.split_text(text_content) | |
| for i, chunk_text in enumerate(chunks): | |
| meta = {"source_document_name": filename, "chunk_index": i} | |
| new_docs.append(Document(page_content=chunk_text, metadata=meta)) | |
| if not new_docs: | |
| return {"status": "warning", "message": "New files found but no text extracted.", "files_added": []} | |
| self.vector_store.add_documents(new_docs) | |
| faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME) | |
| self.vector_store.save_local(faiss_path) | |
| self.processed_source_files.extend(files_to_process) | |
| with open(os.path.join(faiss_path, "processed_files.json"), 'w') as f: | |
| json.dump(sorted(self.processed_source_files), f) | |
| return { | |
| "status": "success", | |
| "message": f"Added {len(files_to_process)} files.", | |
| "files_added": files_to_process, | |
| "remaining": len(all_new_files) - len(files_to_process) | |
| } | |
| def search_knowledge_base(self, query: str, top_k: Optional[int] = None) -> List[Dict[str, Any]]: | |
| if not self.retriever: | |
| raise RuntimeError("Retriever not initialized.") | |
| original_k = self.retriever.final_k | |
| if top_k: | |
| self.retriever.final_k = top_k | |
| try: | |
| docs = self.retriever.invoke(query) | |
| results = [] | |
| for doc in docs: | |
| results.append({ | |
| "content": doc.page_content, | |
| "metadata": doc.metadata, | |
| "score": doc.metadata.get("reranker_score") or doc.metadata.get("retrieval_score") | |
| }) | |
| return results | |
| finally: | |
| self.retriever.final_k = original_k |