| import os
|
| import logging
|
| import json
|
| import time
|
| from typing import List, Dict, Optional, Any
|
|
|
| import torch
|
| from sentence_transformers import CrossEncoder
|
|
|
| from langchain_huggingface import HuggingFaceEmbeddings
|
| from langchain_community.vectorstores import FAISS
|
|
|
| from langchain_core.documents import Document
|
| from langchain_core.retrievers import BaseRetriever
|
| from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
| from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
|
| from config import (
|
| RAG_RERANKER_MODEL_NAME, RAG_DETAILED_LOGGING,
|
| RAG_CHUNK_SIZE, RAG_CHUNK_OVERLAP, RAG_CHUNKED_SOURCES_FILENAME,
|
| RAG_FAISS_INDEX_SUBDIR_NAME, RAG_INITIAL_FETCH_K, RAG_RERANKER_K,
|
| RAG_MAX_FILES_FOR_INCREMENTAL
|
| )
|
| from utils import FAISS_RAG_SUPPORTED_EXTENSIONS
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| class DocumentReranker:
|
| def __init__(self, model_name: str = RAG_RERANKER_MODEL_NAME):
|
| self.logger = logging.getLogger(__name__ + ".DocumentReranker")
|
| self.model_name = model_name
|
| self.model = None
|
|
|
| try:
|
| self.logger.info(f"[RERANKER_INIT] Loading reranker model: {self.model_name}")
|
| start_time = time.time()
|
| self.model = CrossEncoder(model_name, trust_remote_code=True)
|
| load_time = time.time() - start_time
|
| self.logger.info(f"[RERANKER_INIT] Reranker model '{self.model_name}' loaded successfully in {load_time:.2f}s")
|
| except Exception as e:
|
| self.logger.error(f"[RERANKER_INIT] Failed to load reranker model '{self.model_name}': {e}", exc_info=True)
|
| raise RuntimeError(f"Could not initialize reranker model: {e}") from e
|
|
|
| def rerank_documents(self, query: str, documents: List[Document], top_k: int) -> List[Document]:
|
| if not documents or not self.model:
|
| return documents[:top_k] if documents else []
|
|
|
| try:
|
| start_time = time.time()
|
| doc_pairs = [[query, doc.page_content] for doc in documents]
|
| scores = self.model.predict(doc_pairs)
|
| rerank_time = time.time() - start_time
|
| self.logger.info(f"[RERANKER] Computed relevance scores in {rerank_time:.3f}s")
|
|
|
| doc_score_pairs = list(zip(documents, scores))
|
| doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
|
|
|
| reranked_docs = []
|
| for doc, score in doc_score_pairs[:top_k]:
|
| doc.metadata["reranker_score"] = float(score)
|
| reranked_docs.append(doc)
|
|
|
| return reranked_docs
|
| except Exception as e:
|
| self.logger.error(f"[RERANKER] Error during reranking: {e}", exc_info=True)
|
| return documents[:top_k] if documents else []
|
|
|
|
|
| class FAISSRetrieverWithScore(BaseRetriever):
|
| vectorstore: FAISS
|
| reranker: Optional[DocumentReranker] = None
|
| initial_fetch_k: int = RAG_INITIAL_FETCH_K
|
| final_k: int = RAG_RERANKER_K
|
|
|
| def _get_relevant_documents(
|
| self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
| ) -> List[Document]:
|
|
|
| start_time = time.time()
|
| num_to_fetch = self.initial_fetch_k if self.reranker else self.final_k
|
|
|
| logger.info(f"[RETRIEVER] Fetching {num_to_fetch} docs (Rerank={self.reranker is not None})")
|
|
|
| docs_and_scores = self.vectorstore.similarity_search_with_score(query, k=num_to_fetch)
|
|
|
| relevant_docs = []
|
| for doc, score in docs_and_scores:
|
| doc.metadata["retrieval_score"] = float(score)
|
| relevant_docs.append(doc)
|
|
|
| if self.reranker and relevant_docs:
|
| relevant_docs = self.reranker.rerank_documents(query, relevant_docs, top_k=self.final_k)
|
|
|
| total_time = time.time() - start_time
|
| logger.info(f"[RETRIEVER] Completed in {total_time:.3f}s. Returned {len(relevant_docs)} docs.")
|
| return relevant_docs
|
|
|
|
|
| class KnowledgeRAG:
|
| def __init__(
|
| self,
|
| index_storage_dir: str,
|
| embedding_model_name: str,
|
| use_gpu_for_embeddings: bool,
|
| chunk_size: int = RAG_CHUNK_SIZE,
|
| chunk_overlap: int = RAG_CHUNK_OVERLAP,
|
| reranker_model_name: Optional[str] = None,
|
| enable_reranker: bool = True,
|
| ):
|
| self.logger = logging.getLogger(__name__ + ".KnowledgeRAG")
|
| self.logger.info(f"[RAG_INIT] Initializing KnowledgeRAG system")
|
|
|
| self.index_storage_dir = index_storage_dir
|
| os.makedirs(self.index_storage_dir, exist_ok=True)
|
|
|
| self.embedding_model_name = embedding_model_name
|
| self.use_gpu_for_embeddings = use_gpu_for_embeddings
|
| self.chunk_size = chunk_size
|
| self.chunk_overlap = chunk_overlap
|
| self.reranker_model_name = reranker_model_name or RAG_RERANKER_MODEL_NAME
|
| self.enable_reranker = enable_reranker
|
| self.reranker = None
|
|
|
| device = "cpu"
|
| if self.use_gpu_for_embeddings:
|
| if torch.cuda.is_available():
|
| self.logger.info(f"[RAG_INIT] CUDA available. Requesting GPU.")
|
| device = "cuda"
|
| else:
|
| self.logger.warning("[RAG_INIT] CUDA not available. Fallback to CPU.")
|
|
|
| self.embeddings = HuggingFaceEmbeddings(
|
| model_name=self.embedding_model_name,
|
| model_kwargs={"device": device},
|
| encode_kwargs={"normalize_embeddings": True}
|
| )
|
|
|
| if self.enable_reranker:
|
| try:
|
| self.reranker = DocumentReranker(self.reranker_model_name)
|
| except Exception as e:
|
| self.logger.warning(f"[RAG_INIT] Reranker Init Failed: {e}")
|
| self.reranker = None
|
|
|
| self.vector_store: Optional[FAISS] = None
|
| self.retriever: Optional[FAISSRetrieverWithScore] = None
|
| self.processed_source_files: List[str] = []
|
|
|
| def _save_chunk_config(self):
|
| """Persist current chunk settings alongside the FAISS index for change detection."""
|
| faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
|
| config_file = os.path.join(faiss_path, "chunk_config.json")
|
| with open(config_file, 'w') as f:
|
| json.dump({"chunk_size": self.chunk_size, "chunk_overlap": self.chunk_overlap}, f)
|
|
|
| def _load_chunk_config(self) -> Optional[dict]:
|
| """Load previously saved chunk config. Returns None if not found."""
|
| faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
|
| config_file = os.path.join(faiss_path, "chunk_config.json")
|
| if os.path.exists(config_file):
|
| with open(config_file, 'r') as f:
|
| return json.load(f)
|
| return None
|
|
|
| def chunk_config_has_changed(self) -> bool:
|
| """Returns True if chunk_size or chunk_overlap differ from what the index was built with."""
|
| saved = self._load_chunk_config()
|
| if saved is None:
|
| return False
|
| changed = saved.get("chunk_size") != self.chunk_size or saved.get("chunk_overlap") != self.chunk_overlap
|
| if changed:
|
| self.logger.warning(
|
| f"[CONFIG_CHANGE] Chunk config mismatch! "
|
| f"Saved=(size={saved.get('chunk_size')}, overlap={saved.get('chunk_overlap')}) "
|
| f"Current=(size={self.chunk_size}, overlap={self.chunk_overlap}). "
|
| f"Index will be rebuilt."
|
| )
|
| return changed
|
|
|
| def build_index_from_source_files(self, source_folder_path: str):
|
| self.logger.info(f"[INDEX_BUILD] Building from: {source_folder_path}")
|
| if not os.path.isdir(source_folder_path):
|
| raise FileNotFoundError(f"Source folder not found: '{source_folder_path}'.")
|
|
|
| all_docs = []
|
| processed_files = []
|
|
|
|
|
| qa_separators = [
|
| r"\n\n(?=[Qq](?:uestion|UESTION)?\s*\d*\s*:)",
|
| r"\n(?=[Qq](?:uestion|UESTION)?\s*\d*\s*:)",
|
| r"\n\n",
|
| r"\n",
|
| r" ",
|
| r""
|
| ]
|
| text_splitter = RecursiveCharacterTextSplitter(
|
| chunk_size=self.chunk_size,
|
| chunk_overlap=self.chunk_overlap,
|
| separators=qa_separators,
|
| is_separator_regex=True
|
| )
|
|
|
|
|
| pre_chunked_path = os.path.join(self.index_storage_dir, RAG_CHUNKED_SOURCES_FILENAME)
|
| if os.path.exists(pre_chunked_path):
|
| try:
|
| with open(pre_chunked_path, 'r', encoding='utf-8') as f:
|
| chunk_data_list = json.load(f)
|
| for chunk in chunk_data_list:
|
| doc = Document(page_content=chunk.get("page_content", ""), metadata=chunk.get("metadata", {}))
|
| all_docs.append(doc)
|
| if 'source_document_name' in doc.metadata:
|
| processed_files.append(doc.metadata['source_document_name'])
|
| processed_files = sorted(list(set(processed_files)))
|
| except Exception as e:
|
| self.logger.error(f"[INDEX_BUILD] JSON load failed: {e}")
|
|
|
|
|
| if not all_docs:
|
| for filename in os.listdir(source_folder_path):
|
| file_path = os.path.join(source_folder_path, filename)
|
| if not os.path.isfile(file_path): continue
|
| file_ext = filename.split('.')[-1].lower()
|
| if file_ext in FAISS_RAG_SUPPORTED_EXTENSIONS:
|
| text_content = FAISS_RAG_SUPPORTED_EXTENSIONS[file_ext](file_path)
|
| if text_content:
|
| chunks = text_splitter.split_text(text_content)
|
| for i, chunk_text in enumerate(chunks):
|
| meta = {"source_document_name": filename, "chunk_index": i}
|
| all_docs.append(Document(page_content=chunk_text, metadata=meta))
|
| processed_files.append(filename)
|
|
|
| if not all_docs:
|
| raise ValueError("No documents to index.")
|
|
|
| self.processed_source_files = processed_files
|
| self.logger.info(f"[INDEX_BUILD] Creating FAISS index with {len(all_docs)} chunks.")
|
|
|
| self.vector_store = FAISS.from_documents(all_docs, self.embeddings)
|
| faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
|
| self.vector_store.save_local(faiss_path)
|
| self._save_chunk_config()
|
|
|
| self.retriever = FAISSRetrieverWithScore(
|
| vectorstore=self.vector_store,
|
| reranker=self.reranker,
|
| initial_fetch_k=RAG_INITIAL_FETCH_K,
|
| final_k=RAG_RERANKER_K
|
| )
|
|
|
| def load_index_from_disk(self):
|
| faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
|
| if not os.path.exists(faiss_path):
|
| raise FileNotFoundError("Index not found.")
|
|
|
| self.vector_store = FAISS.load_local(
|
| folder_path=faiss_path,
|
| embeddings=self.embeddings,
|
| allow_dangerous_deserialization=True
|
| )
|
| self.retriever = FAISSRetrieverWithScore(
|
| vectorstore=self.vector_store,
|
| reranker=self.reranker,
|
| initial_fetch_k=RAG_INITIAL_FETCH_K,
|
| final_k=RAG_RERANKER_K
|
| )
|
|
|
|
|
| meta_file = os.path.join(faiss_path, "processed_files.json")
|
| if os.path.exists(meta_file):
|
| with open(meta_file, 'r') as f:
|
| self.processed_source_files = json.load(f)
|
| else:
|
| self.processed_source_files = ["Loaded from disk (unknown sources)"]
|
|
|
| self.logger.info("[INDEX_LOAD] Success.")
|
|
|
| def update_index_with_new_files(self, source_folder_path: str, max_files_to_process: Optional[int] = None) -> Dict[str, Any]:
|
| self.logger.info(f"[INDEX_UPDATE] Checking for new files in: {source_folder_path}")
|
|
|
| if not self.vector_store:
|
| raise RuntimeError("Cannot update: no index loaded.")
|
|
|
| processed_set = set(self.processed_source_files)
|
| all_new_files = []
|
| for filename in sorted(os.listdir(source_folder_path)):
|
| if filename not in processed_set:
|
| file_ext = filename.split('.')[-1].lower()
|
| if file_ext in FAISS_RAG_SUPPORTED_EXTENSIONS:
|
| all_new_files.append(filename)
|
|
|
| if not all_new_files:
|
| return {"status": "success", "message": "No new files found.", "files_added": []}
|
|
|
| limit = max_files_to_process if max_files_to_process is not None else RAG_MAX_FILES_FOR_INCREMENTAL
|
| files_to_process = all_new_files[:limit]
|
|
|
| new_docs = []
|
|
|
|
|
| qa_separators = [
|
| r"\n\n(?=[Qq](?:uestion|UESTION)?\s*\d*\s*:)",
|
| r"\n(?=[Qq](?:uestion|UESTION)?\s*\d*\s*:)",
|
| r"\n\n",
|
| r"\n",
|
| r" ",
|
| r""
|
| ]
|
| text_splitter = RecursiveCharacterTextSplitter(
|
| chunk_size=self.chunk_size,
|
| chunk_overlap=self.chunk_overlap,
|
| separators=qa_separators,
|
| is_separator_regex=True
|
| )
|
|
|
| for filename in files_to_process:
|
| file_path = os.path.join(source_folder_path, filename)
|
| text_content = FAISS_RAG_SUPPORTED_EXTENSIONS[filename.split('.')[-1].lower()](file_path)
|
|
|
| if text_content:
|
| chunks = text_splitter.split_text(text_content)
|
| for i, chunk_text in enumerate(chunks):
|
| meta = {"source_document_name": filename, "chunk_index": i}
|
| new_docs.append(Document(page_content=chunk_text, metadata=meta))
|
|
|
| if not new_docs:
|
| return {"status": "warning", "message": "New files found but no text extracted.", "files_added": []}
|
|
|
| self.vector_store.add_documents(new_docs)
|
|
|
| faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
|
| self.vector_store.save_local(faiss_path)
|
|
|
| self.processed_source_files.extend(files_to_process)
|
| with open(os.path.join(faiss_path, "processed_files.json"), 'w') as f:
|
| json.dump(sorted(self.processed_source_files), f)
|
|
|
| return {
|
| "status": "success",
|
| "message": f"Added {len(files_to_process)} files.",
|
| "files_added": files_to_process,
|
| "remaining": len(all_new_files) - len(files_to_process)
|
| }
|
|
|
| def search_knowledge_base(self, query: str, top_k: Optional[int] = None) -> List[Dict[str, Any]]:
|
| if not self.retriever:
|
| raise RuntimeError("Retriever not initialized.")
|
|
|
| original_k = self.retriever.final_k
|
| if top_k:
|
| self.retriever.final_k = top_k
|
|
|
| try:
|
| docs = self.retriever.invoke(query)
|
| results = []
|
| for doc in docs:
|
| results.append({
|
| "content": doc.page_content,
|
| "metadata": doc.metadata,
|
| "score": doc.metadata.get("reranker_score") or doc.metadata.get("retrieval_score")
|
| })
|
| return results
|
| finally:
|
| self.retriever.final_k = original_k |