AI-Agent-RAG-Bot / rag_components.py
SakibAhmed's picture
Upload 14 files
ca6e669 verified
import os
import logging
import json
import time
import csv
from typing import List, Dict, Optional, Any
import torch
from sentence_transformers import CrossEncoder
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_text_splitters import RecursiveCharacterTextSplitter
from config import (
RAG_RERANKER_MODEL_NAME, RAG_DETAILED_LOGGING,
RAG_CHUNK_SIZE, RAG_CHUNK_OVERLAP, RAG_CHUNKED_SOURCES_FILENAME,
RAG_FAISS_INDEX_SUBDIR_NAME, RAG_INITIAL_FETCH_K, RAG_RERANKER_K,
RAG_MAX_FILES_FOR_INCREMENTAL
)
from utils import FAISS_RAG_SUPPORTED_EXTENSIONS
logger = logging.getLogger(__name__)
class DocumentReranker:
def __init__(self, model_name: str = RAG_RERANKER_MODEL_NAME):
self.logger = logging.getLogger(__name__ + ".DocumentReranker")
self.model_name = model_name
self.model = None
try:
self.logger.info(f"[RERANKER_INIT] Loading reranker model: {self.model_name}")
start_time = time.time()
self.model = CrossEncoder(model_name, trust_remote_code=True)
load_time = time.time() - start_time
self.logger.info(f"[RERANKER_INIT] Reranker model '{self.model_name}' loaded successfully in {load_time:.2f}s")
except Exception as e:
self.logger.error(f"[RERANKER_INIT] Failed to load reranker model '{self.model_name}': {e}", exc_info=True)
raise RuntimeError(f"Could not initialize reranker model: {e}") from e
def rerank_documents(self, query: str, documents: List[Document], top_k: int) -> List[Document]:
if not documents or not self.model:
return documents[:top_k] if documents else []
try:
start_time = time.time()
doc_pairs = [[query, doc.page_content] for doc in documents]
scores = self.model.predict(doc_pairs)
rerank_time = time.time() - start_time
self.logger.info(f"[RERANKER] Computed relevance scores in {rerank_time:.3f}s")
doc_score_pairs = list(zip(documents, scores))
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
reranked_docs = []
for doc, score in doc_score_pairs[:top_k]:
doc.metadata["reranker_score"] = float(score)
reranked_docs.append(doc)
return reranked_docs
except Exception as e:
self.logger.error(f"[RERANKER] Error during reranking: {e}", exc_info=True)
return documents[:top_k] if documents else []
class FAISSRetrieverWithScore(BaseRetriever):
vectorstore: FAISS
reranker: Optional[DocumentReranker] = None
initial_fetch_k: int = RAG_INITIAL_FETCH_K
final_k: int = RAG_RERANKER_K
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
start_time = time.time()
num_to_fetch = self.initial_fetch_k if self.reranker else self.final_k
logger.info(f"[RETRIEVER] Fetching {num_to_fetch} docs (Rerank={self.reranker is not None})")
docs_and_scores = self.vectorstore.similarity_search_with_score(query, k=num_to_fetch)
relevant_docs = []
for doc, score in docs_and_scores:
doc.metadata["retrieval_score"] = float(score)
relevant_docs.append(doc)
if self.reranker and relevant_docs:
relevant_docs = self.reranker.rerank_documents(query, relevant_docs, top_k=self.final_k)
total_time = time.time() - start_time
logger.info(f"[RETRIEVER] Completed in {total_time:.3f}s. Returned {len(relevant_docs)} docs.")
return relevant_docs
class KnowledgeRAG:
def __init__(
self,
index_storage_dir: str,
embedding_model_name: str,
use_gpu_for_embeddings: bool,
chunk_size: int = RAG_CHUNK_SIZE,
chunk_overlap: int = RAG_CHUNK_OVERLAP,
reranker_model_name: Optional[str] = None,
enable_reranker: bool = True,
):
self.logger = logging.getLogger(__name__ + ".KnowledgeRAG")
self.logger.info(f"[RAG_INIT] Initializing KnowledgeRAG system")
self.index_storage_dir = index_storage_dir
os.makedirs(self.index_storage_dir, exist_ok=True)
self.embedding_model_name = embedding_model_name
self.use_gpu_for_embeddings = use_gpu_for_embeddings
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.reranker_model_name = reranker_model_name or RAG_RERANKER_MODEL_NAME
self.enable_reranker = enable_reranker
self.reranker = None
device = "cpu"
if self.use_gpu_for_embeddings:
if torch.cuda.is_available():
self.logger.info(f"[RAG_INIT] CUDA available. Requesting GPU.")
device = "cuda"
else:
self.logger.warning("[RAG_INIT] CUDA not available. Fallback to CPU.")
self.embeddings = HuggingFaceEmbeddings(
model_name=self.embedding_model_name,
model_kwargs={"device": device},
encode_kwargs={"normalize_embeddings": True}
)
if self.enable_reranker:
try:
self.reranker = DocumentReranker(self.reranker_model_name)
except Exception as e:
self.logger.warning(f"[RAG_INIT] Reranker Init Failed: {e}")
self.reranker = None
self.vector_store: Optional[FAISS] = None
self.retriever: Optional[FAISSRetrieverWithScore] = None
self.processed_source_files: List[str] = []
def _save_chunk_config(self):
faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
config_file = os.path.join(faiss_path, "chunk_config.json")
with open(config_file, 'w') as f:
json.dump({"chunk_size": self.chunk_size, "chunk_overlap": self.chunk_overlap}, f)
def _load_chunk_config(self) -> Optional[dict]:
faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
config_file = os.path.join(faiss_path, "chunk_config.json")
if os.path.exists(config_file):
with open(config_file, 'r') as f:
return json.load(f)
return None
def chunk_config_has_changed(self) -> bool:
saved = self._load_chunk_config()
if saved is None:
return False
changed = saved.get("chunk_size") != self.chunk_size or saved.get("chunk_overlap") != self.chunk_overlap
if changed:
self.logger.warning(
f"[CONFIG_CHANGE] Chunk config mismatch! "
f"Saved=(size={saved.get('chunk_size')}, overlap={saved.get('chunk_overlap')}) "
f"Current=(size={self.chunk_size}, overlap={self.chunk_overlap}). "
f"Index will be rebuilt."
)
return changed
def build_index_from_source_files(self, source_folder_path: str):
self.logger.info(f"[INDEX_BUILD] Building from: {source_folder_path}")
if not os.path.isdir(source_folder_path):
raise FileNotFoundError(f"Source folder not found: '{source_folder_path}'.")
all_docs = []
processed_files = []
text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
pre_chunked_path = os.path.join(self.index_storage_dir, RAG_CHUNKED_SOURCES_FILENAME)
if os.path.exists(pre_chunked_path):
try:
with open(pre_chunked_path, 'r', encoding='utf-8') as f:
chunk_data_list = json.load(f)
for chunk in chunk_data_list:
doc = Document(page_content=chunk.get("page_content", ""), metadata=chunk.get("metadata", {}))
all_docs.append(doc)
if 'source_document_name' in doc.metadata:
processed_files.append(doc.metadata['source_document_name'])
processed_files = sorted(list(set(processed_files)))
except Exception as e:
self.logger.error(f"[INDEX_BUILD] JSON load failed: {e}")
if not all_docs:
for filename in os.listdir(source_folder_path):
file_path = os.path.join(source_folder_path, filename)
if not os.path.isfile(file_path): continue
file_ext = filename.split('.')[-1].lower()
if file_ext in FAISS_RAG_SUPPORTED_EXTENSIONS:
# Specific handler for CSV formatting
if file_ext == 'csv':
try:
with open(file_path, mode='r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader):
row_text = "\n".join([f"{k}: {v}" for k, v in row.items() if k and v and str(v).strip()])
meta = {"source_document_name": filename, "chunk_index": i, "source_type": "csv"}
all_docs.append(Document(page_content=row_text, metadata=meta))
processed_files.append(filename)
except Exception as e:
self.logger.error(f"[INDEX_BUILD] Error processing CSV {filename}: {e}")
else:
text_content = FAISS_RAG_SUPPORTED_EXTENSIONS[file_ext](file_path)
if text_content and text_content != "CSV_HANDLED_NATIVELY":
chunks = text_splitter.split_text(text_content)
for i, chunk_text in enumerate(chunks):
meta = {"source_document_name": filename, "chunk_index": i}
all_docs.append(Document(page_content=chunk_text, metadata=meta))
processed_files.append(filename)
if not all_docs:
raise ValueError("No documents to index.")
self.processed_source_files = processed_files
self.logger.info(f"[INDEX_BUILD] Creating FAISS index with {len(all_docs)} chunks.")
self.vector_store = FAISS.from_documents(all_docs, self.embeddings)
faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
self.vector_store.save_local(faiss_path)
self._save_chunk_config()
self.retriever = FAISSRetrieverWithScore(
vectorstore=self.vector_store,
reranker=self.reranker,
initial_fetch_k=RAG_INITIAL_FETCH_K,
final_k=RAG_RERANKER_K
)
def load_index_from_disk(self):
faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
if not os.path.exists(faiss_path):
raise FileNotFoundError("Index not found.")
self.vector_store = FAISS.load_local(
folder_path=faiss_path,
embeddings=self.embeddings,
allow_dangerous_deserialization=True
)
self.retriever = FAISSRetrieverWithScore(
vectorstore=self.vector_store,
reranker=self.reranker,
initial_fetch_k=RAG_INITIAL_FETCH_K,
final_k=RAG_RERANKER_K
)
meta_file = os.path.join(faiss_path, "processed_files.json")
if os.path.exists(meta_file):
with open(meta_file, 'r') as f:
self.processed_source_files = json.load(f)
else:
self.processed_source_files = ["Loaded from disk (unknown sources)"]
self.logger.info("[INDEX_LOAD] Success.")
def update_index_with_new_files(self, source_folder_path: str, max_files_to_process: Optional[int] = None) -> Dict[str, Any]:
self.logger.info(f"[INDEX_UPDATE] Checking for new files in: {source_folder_path}")
if not self.vector_store:
raise RuntimeError("Cannot update: no index loaded.")
processed_set = set(self.processed_source_files)
all_new_files = []
for filename in sorted(os.listdir(source_folder_path)):
if filename not in processed_set:
file_ext = filename.split('.')[-1].lower()
if file_ext in FAISS_RAG_SUPPORTED_EXTENSIONS:
all_new_files.append(filename)
if not all_new_files:
return {"status": "success", "message": "No new files found.", "files_added": []}
limit = max_files_to_process if max_files_to_process is not None else RAG_MAX_FILES_FOR_INCREMENTAL
files_to_process = all_new_files[:limit]
new_docs = []
text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
for filename in files_to_process:
file_path = os.path.join(source_folder_path, filename)
file_ext = filename.split('.')[-1].lower()
if file_ext == 'csv':
try:
with open(file_path, mode='r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader):
row_text = "\n".join([f"{k}: {v}" for k, v in row.items() if k and v and str(v).strip()])
meta = {"source_document_name": filename, "chunk_index": i, "source_type": "csv"}
new_docs.append(Document(page_content=row_text, metadata=meta))
except Exception as e:
self.logger.error(f"[INDEX_UPDATE] Error processing CSV {filename}: {e}")
else:
text_content = FAISS_RAG_SUPPORTED_EXTENSIONS[file_ext](file_path)
if text_content and text_content != "CSV_HANDLED_NATIVELY":
chunks = text_splitter.split_text(text_content)
for i, chunk_text in enumerate(chunks):
meta = {"source_document_name": filename, "chunk_index": i}
new_docs.append(Document(page_content=chunk_text, metadata=meta))
if not new_docs:
return {"status": "warning", "message": "New files found but no text extracted.", "files_added": []}
self.vector_store.add_documents(new_docs)
faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
self.vector_store.save_local(faiss_path)
self.processed_source_files.extend(files_to_process)
with open(os.path.join(faiss_path, "processed_files.json"), 'w') as f:
json.dump(sorted(self.processed_source_files), f)
return {
"status": "success",
"message": f"Added {len(files_to_process)} files.",
"files_added": files_to_process,
"remaining": len(all_new_files) - len(files_to_process)
}
def search_knowledge_base(self, query: str, top_k: Optional[int] = None) -> List[Dict[str, Any]]:
if not self.retriever:
raise RuntimeError("Retriever not initialized.")
original_k = self.retriever.final_k
if top_k:
self.retriever.final_k = top_k
try:
docs = self.retriever.invoke(query)
results = []
for doc in docs:
results.append({
"content": doc.page_content,
"metadata": doc.metadata,
"score": doc.metadata.get("reranker_score") or doc.metadata.get("retrieval_score")
})
return results
finally:
self.retriever.final_k = original_k