seer / rag_components.py
SakibAhmed's picture
Update rag_components.py
65c1702 verified
Raw
History Blame Contribute Delete
18 kB
import os
import logging
import json
import time
import csv
from typing import List, Dict, Optional, Any
import torch
from sentence_transformers import CrossEncoder
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_text_splitters import RecursiveCharacterTextSplitter
from config import (
RAG_RERANKER_MODEL_NAME, RAG_DETAILED_LOGGING,
RAG_CHUNK_SIZE, RAG_CHUNK_OVERLAP, RAG_CHUNKED_SOURCES_FILENAME,
RAG_FAISS_INDEX_SUBDIR_NAME, RAG_INITIAL_FETCH_K, RAG_RERANKER_K,
RAG_MAX_FILES_FOR_INCREMENTAL, RAG_CSV_ROWS_PER_CHUNK
)
from utils import FAISS_RAG_SUPPORTED_EXTENSIONS
logger = logging.getLogger(__name__)
def _row_to_text(row: Dict) -> str:
return "\n".join([
f"{k}: {v}"
for k, v in row.items()
if k and v and str(v).strip()
])
def _csv_rows_to_documents(file_path: str, filename: str, rows_per_chunk: int) -> List[Document]:
rows_per_chunk = max(1, int(rows_per_chunk))
docs: List[Document] = []
row_buffer: List[str] = []
row_start = 1
chunk_index = 0
def flush_buffer() -> None:
nonlocal row_buffer, row_start, chunk_index
if not row_buffer:
return
row_end = row_start + len(row_buffer) - 1
if len(row_buffer) == 1:
full_location = f"{filename}, Row {row_start}"
else:
full_location = f"{filename}, Rows {row_start}-{row_end}"
metadata = {
"source_document_name": filename,
"chunk_index": chunk_index,
"source_type": "csv",
"full_location": full_location,
"row_start": row_start,
"row_end": row_end,
"rows_in_chunk": len(row_buffer),
}
docs.append(Document(page_content="\n\n".join(row_buffer), metadata=metadata))
row_buffer = []
chunk_index += 1
with open(file_path, mode='r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for csv_row_number, row in enumerate(reader, start=1):
row_text = _row_to_text(row)
if not row_text:
continue
if not row_buffer:
row_start = csv_row_number
row_buffer.append(row_text)
if len(row_buffer) >= rows_per_chunk:
flush_buffer()
flush_buffer()
return docs
class DocumentReranker:
def __init__(self, model_name: str = RAG_RERANKER_MODEL_NAME):
self.logger = logging.getLogger(__name__ + ".DocumentReranker")
self.model_name = model_name
self.model = None
try:
self.logger.info(f"[RERANKER_INIT] Loading reranker model: {self.model_name}")
start_time = time.time()
self.model = CrossEncoder(model_name, trust_remote_code=True)
load_time = time.time() - start_time
self.logger.info(f"[RERANKER_INIT] Reranker model '{self.model_name}' loaded successfully in {load_time:.2f}s")
except Exception as e:
self.logger.error(f"[RERANKER_INIT] Failed to load reranker model '{self.model_name}': {e}", exc_info=True)
raise RuntimeError(f"Could not initialize reranker model: {e}") from e
def rerank_documents(self, query: str, documents: List[Document], top_k: int) -> List[Document]:
if not documents or not self.model:
return documents[:top_k] if documents else []
try:
start_time = time.time()
doc_pairs = [[query, doc.page_content] for doc in documents]
scores = self.model.predict(doc_pairs)
rerank_time = time.time() - start_time
self.logger.info(f"[RERANKER] Computed relevance scores in {rerank_time:.3f}s")
doc_score_pairs = list(zip(documents, scores))
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
reranked_docs = []
for doc, score in doc_score_pairs[:top_k]:
doc.metadata["reranker_score"] = float(score)
reranked_docs.append(doc)
return reranked_docs
except Exception as e:
self.logger.error(f"[RERANKER] Error during reranking: {e}", exc_info=True)
return documents[:top_k] if documents else []
class FAISSRetrieverWithScore(BaseRetriever):
vectorstore: FAISS
reranker: Optional[DocumentReranker] = None
initial_fetch_k: int = RAG_INITIAL_FETCH_K
final_k: int = RAG_RERANKER_K
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
start_time = time.time()
num_to_fetch = self.initial_fetch_k if self.reranker else self.final_k
logger.info(f"[RETRIEVER] Fetching {num_to_fetch} docs (Rerank={self.reranker is not None})")
docs_and_scores = self.vectorstore.similarity_search_with_score(query, k=num_to_fetch)
relevant_docs = []
for doc, score in docs_and_scores:
doc.metadata["retrieval_score"] = float(score)
relevant_docs.append(doc)
if self.reranker and relevant_docs:
relevant_docs = self.reranker.rerank_documents(query, relevant_docs, top_k=self.final_k)
total_time = time.time() - start_time
logger.info(f"[RETRIEVER] Completed in {total_time:.3f}s. Returned {len(relevant_docs)} docs.")
return relevant_docs
class KnowledgeRAG:
def __init__(
self,
index_storage_dir: str,
embedding_model_name: str,
use_gpu_for_embeddings: bool,
chunk_size: int = RAG_CHUNK_SIZE,
chunk_overlap: int = RAG_CHUNK_OVERLAP,
reranker_model_name: Optional[str] = None,
enable_reranker: bool = True,
):
self.logger = logging.getLogger(__name__ + ".KnowledgeRAG")
self.logger.info(f"[RAG_INIT] Initializing KnowledgeRAG system")
self.index_storage_dir = index_storage_dir
os.makedirs(self.index_storage_dir, exist_ok=True)
self.embedding_model_name = embedding_model_name
self.use_gpu_for_embeddings = use_gpu_for_embeddings
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.csv_rows_per_chunk = RAG_CSV_ROWS_PER_CHUNK
self.reranker_model_name = reranker_model_name or RAG_RERANKER_MODEL_NAME
self.enable_reranker = enable_reranker
self.reranker = None
device = "cpu"
if self.use_gpu_for_embeddings:
if torch.cuda.is_available():
self.logger.info(f"[RAG_INIT] CUDA available. Requesting GPU.")
device = "cuda"
else:
self.logger.warning("[RAG_INIT] CUDA not available. Fallback to CPU.")
self.embeddings = HuggingFaceEmbeddings(
model_name=self.embedding_model_name,
model_kwargs={"device": device},
encode_kwargs={"normalize_embeddings": True}
)
if self.enable_reranker:
try:
self.reranker = DocumentReranker(self.reranker_model_name)
except Exception as e:
self.logger.warning(f"[RAG_INIT] Reranker Init Failed: {e}")
self.reranker = None
self.vector_store: Optional[FAISS] = None
self.retriever: Optional[FAISSRetrieverWithScore] = None
self.processed_source_files: List[str] = []
def _save_chunk_config(self):
faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
config_file = os.path.join(faiss_path, "chunk_config.json")
with open(config_file, 'w') as f:
json.dump({
"chunk_size": self.chunk_size,
"chunk_overlap": self.chunk_overlap,
"csv_rows_per_chunk": self.csv_rows_per_chunk
}, f)
def _load_chunk_config(self) -> Optional[dict]:
faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
config_file = os.path.join(faiss_path, "chunk_config.json")
if os.path.exists(config_file):
with open(config_file, 'r') as f:
return json.load(f)
return None
def chunk_config_has_changed(self) -> bool:
faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
if not os.path.exists(faiss_path):
return False
saved = self._load_chunk_config()
if saved is None:
self.logger.warning("[CONFIG_CHANGE] No chunk config found. Forcing rebuild to ensure consistency.")
return True
saved_csv_rows_per_chunk = saved.get("csv_rows_per_chunk", 1)
changed = (
saved.get("chunk_size") != self.chunk_size
or saved.get("chunk_overlap") != self.chunk_overlap
or saved_csv_rows_per_chunk != self.csv_rows_per_chunk
)
if changed:
self.logger.warning(
f"[CONFIG_CHANGE] Chunk config mismatch! "
f"Saved=(size={saved.get('chunk_size')}, overlap={saved.get('chunk_overlap')}, "
f"csv_rows_per_chunk={saved_csv_rows_per_chunk}) "
f"Current=(size={self.chunk_size}, overlap={self.chunk_overlap}, "
f"csv_rows_per_chunk={self.csv_rows_per_chunk}). "
f"Index will be rebuilt."
)
return changed
def build_index_from_source_files(self, source_folder_path: str):
self.logger.info(f"[INDEX_BUILD] Building from: {source_folder_path}")
if not os.path.isdir(source_folder_path):
raise FileNotFoundError(f"Source folder not found: '{source_folder_path}'.")
all_docs = []
processed_files = []
text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
pre_chunked_path = os.path.join(self.index_storage_dir, RAG_CHUNKED_SOURCES_FILENAME)
if os.path.exists(pre_chunked_path):
try:
self.logger.info(f"[INDEX_BUILD] Loading pre-chunked documents from {pre_chunked_path}")
with open(pre_chunked_path, 'r', encoding='utf-8') as f:
chunk_data_list = json.load(f)
for chunk in chunk_data_list:
doc = Document(page_content=chunk.get("page_content", ""), metadata=chunk.get("metadata", {}))
all_docs.append(doc)
if 'source_document_name' in doc.metadata:
processed_files.append(doc.metadata['source_document_name'])
processed_files = sorted(list(set(processed_files)))
except Exception as e:
self.logger.error(f"[INDEX_BUILD] JSON load failed: {e}")
processed_set = set(processed_files)
for filename in os.listdir(source_folder_path):
if filename in processed_set:
continue
file_path = os.path.join(source_folder_path, filename)
if not os.path.isfile(file_path): continue
file_ext = filename.split('.')[-1].lower()
if file_ext in FAISS_RAG_SUPPORTED_EXTENSIONS:
if file_ext == 'csv':
try:
docs = _csv_rows_to_documents(file_path, filename, self.csv_rows_per_chunk)
all_docs.extend(docs)
processed_files.append(filename)
self.logger.info(f"[INDEX_BUILD] Processed CSV {filename} into {len(docs)} chunks.")
except Exception as e:
self.logger.error(f"[INDEX_BUILD] Error processing CSV {filename}: {e}")
else:
text_content = FAISS_RAG_SUPPORTED_EXTENSIONS[file_ext](file_path)
if text_content and text_content != "CSV_HANDLED_NATIVELY":
chunks = text_splitter.split_text(text_content)
for i, chunk_text in enumerate(chunks):
meta = {"source_document_name": filename, "chunk_index": i}
all_docs.append(Document(page_content=chunk_text, metadata=meta))
processed_files.append(filename)
if not all_docs:
raise ValueError("No documents to index.")
self.processed_source_files = processed_files
self.logger.info(f"[INDEX_BUILD] Creating FAISS index with {len(all_docs)} chunks.")
self.vector_store = FAISS.from_documents(all_docs, self.embeddings)
faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
self.vector_store.save_local(faiss_path)
self._save_chunk_config()
self.retriever = FAISSRetrieverWithScore(
vectorstore=self.vector_store,
reranker=self.reranker,
initial_fetch_k=RAG_INITIAL_FETCH_K,
final_k=RAG_RERANKER_K
)
def load_index_from_disk(self):
faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
if not os.path.exists(faiss_path):
raise FileNotFoundError("Index not found.")
self.vector_store = FAISS.load_local(
folder_path=faiss_path,
embeddings=self.embeddings,
allow_dangerous_deserialization=True
)
self.retriever = FAISSRetrieverWithScore(
vectorstore=self.vector_store,
reranker=self.reranker,
initial_fetch_k=RAG_INITIAL_FETCH_K,
final_k=RAG_RERANKER_K
)
meta_file = os.path.join(faiss_path, "processed_files.json")
if os.path.exists(meta_file):
with open(meta_file, 'r') as f:
self.processed_source_files = json.load(f)
else:
self.processed_source_files = ["Loaded from disk (unknown sources)"]
self.logger.info("[INDEX_LOAD] Success.")
def update_index_with_new_files(self, source_folder_path: str, max_files_to_process: Optional[int] = None) -> Dict[str, Any]:
self.logger.info(f"[INDEX_UPDATE] Checking for new files in: {source_folder_path}")
if not self.vector_store:
raise RuntimeError("Cannot update: no index loaded.")
processed_set = set(self.processed_source_files)
all_new_files = []
for filename in sorted(os.listdir(source_folder_path)):
if filename not in processed_set:
file_ext = filename.split('.')[-1].lower()
if file_ext in FAISS_RAG_SUPPORTED_EXTENSIONS:
all_new_files.append(filename)
if not all_new_files:
return {"status": "success", "message": "No new files found.", "files_added": []}
limit = max_files_to_process if max_files_to_process is not None else RAG_MAX_FILES_FOR_INCREMENTAL
files_to_process = all_new_files[:limit]
new_docs = []
text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
for filename in files_to_process:
file_path = os.path.join(source_folder_path, filename)
file_ext = filename.split('.')[-1].lower()
if file_ext == 'csv':
try:
new_docs.extend(_csv_rows_to_documents(file_path, filename, self.csv_rows_per_chunk))
except Exception as e:
self.logger.error(f"[INDEX_UPDATE] Error processing CSV {filename}: {e}")
else:
text_content = FAISS_RAG_SUPPORTED_EXTENSIONS[file_ext](file_path)
if text_content and text_content != "CSV_HANDLED_NATIVELY":
chunks = text_splitter.split_text(text_content)
for i, chunk_text in enumerate(chunks):
meta = {"source_document_name": filename, "chunk_index": i}
new_docs.append(Document(page_content=chunk_text, metadata=meta))
if not new_docs:
return {"status": "warning", "message": "New files found but no text extracted.", "files_added": []}
self.vector_store.add_documents(new_docs)
faiss_path = os.path.join(self.index_storage_dir, RAG_FAISS_INDEX_SUBDIR_NAME)
self.vector_store.save_local(faiss_path)
self.processed_source_files.extend(files_to_process)
with open(os.path.join(faiss_path, "processed_files.json"), 'w') as f:
json.dump(sorted(self.processed_source_files), f)
return {
"status": "success",
"message": f"Added {len(files_to_process)} files.",
"files_added": files_to_process,
"remaining": len(all_new_files) - len(files_to_process)
}
def search_knowledge_base(self, query: str, top_k: Optional[int] = None) -> List[Dict[str, Any]]:
if not self.retriever:
raise RuntimeError("Retriever not initialized.")
original_k = self.retriever.final_k
if top_k:
self.retriever.final_k = top_k
try:
docs = self.retriever.invoke(query)
results = []
for doc in docs:
results.append({
"content": doc.page_content,
"metadata": doc.metadata,
"score": doc.metadata.get("reranker_score") or doc.metadata.get("retrieval_score")
})
return results
finally:
self.retriever.final_k = original_k