import os import json import pickle import logging import heapq import datetime from pathlib import Path from typing import List, Dict, Any, Tuple from concurrent.futures import ThreadPoolExecutor, as_completed from rank_bm25 import BM25Okapi from underthesea import word_tokenize from sentence_transformers import SentenceTransformer import chromadb from chromadb.config import Settings # --------------------------- # Config & Logging # --------------------------- logging.basicConfig( level=os.getenv("LOG_LEVEL", "INFO"), format="%(asctime)s | %(levelname)s | %(name)s | %(message)s" ) logger = logging.getLogger("hybrid_retriever") BASE_DIR = Path(__file__).resolve().parent.parent BM25_INDEX_PATH = BASE_DIR / "bm25_index.pkl" SESSION_DIR = BASE_DIR / "sessions" SESSION_DIR.mkdir(parents=True, exist_ok=True) # --------------------------- # Helper functions # --------------------------- def tokenize_vi(text: str) -> List[str]: return word_tokenize(text, format="text").lower().split() def rff_fusion(bm25_results: List[Dict[str, Any]], dense_results: List[Dict[str, Any]], k: int = 60, top_n: int = 10) -> List[Dict[str, Any]]: fused_scores = {} provenance = {} # Create document lookup for faster access chunk_lookup = {} def update_scores(results, source): for rank, result in enumerate(results): chunk_id = result["chunk_id"] orig_score = result["score"] contrib = 1.0 / (k + rank + 1) fused_scores[chunk_id] = fused_scores.get(chunk_id, 0) + contrib provenance.setdefault(chunk_id, {})[source] = { "rank": rank + 1, "orig_score": orig_score, "rrf_contrib": contrib, } # Store document info for later use if chunk_id not in chunk_lookup: chunk_lookup[chunk_id] = result update_scores(bm25_results, "bm25") update_scores(dense_results, "dense") # Get top documents by fused score top_chunks = heapq.nlargest(top_n, fused_scores.items(), key=lambda x: x[1]) # Build final result with full document information final_results = [] for chunk_id, rrf_score in top_chunks: chunk_result_info = chunk_lookup[chunk_id] is_bm25, is_dense = False, False # Determine which sources contributed to this document if "bm25" in provenance[chunk_id]: bm25_rank = provenance[chunk_id]["bm25"]["rank"] is_bm25 = bool(bm25_rank <= top_n) if "dense" in provenance[chunk_id]: dense_rank = provenance[chunk_id]["dense"]["rank"] is_dense = bool(dense_rank <= top_n) result_doc = { "chunk_id": chunk_id, "doc_id": chunk_result_info["doc_id"], "doc_path": chunk_result_info["doc_path"], "path": chunk_result_info["path"], "token_count": chunk_result_info["token_count"], "rff_score": float(rrf_score), "is_bm25": is_bm25, "is_dense": is_dense, "text": chunk_result_info["text"], "chunk_for_embedding": chunk_result_info["chunk_for_embedding"] } final_results.append(result_doc) output_path = Path("output.json") with open(output_path, "w", encoding="utf-8") as f: json.dump(final_results, f, ensure_ascii=False, indent=2, sort_keys=True) return final_results # --------------------------- # BM25 Search # --------------------------- class BM25Retriever: def __init__(self, index_path: str = str(BM25_INDEX_PATH)): self.index_path = index_path self.index = self._load_index(index_path) self.bm25: BM25Okapi = self.index["bm25"] self.chunks: List[Dict[str, Any]] = self.index["chunks"] self.tokenized_corpus: List[List[str]] = self.index["tokenized_corpus"] logger.info("BM25Search loaded %d chunks from %s", len(self.chunks), index_path) def _load_index(self, path: str) -> Dict[str, Any]: if not os.path.exists(path): raise FileNotFoundError(f"BM25 index file not found: {path}") with open(path, "rb") as f: return pickle.load(f) def search(self, query: str, top_k: int = 20) -> List[Dict[str, Any]]: tokens = tokenize_vi(query) scores = self.bm25.get_scores(tokens) # sort & pick top_k ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)[:top_k] results = [] for idx, score in ranked: chunk = self.chunks[idx] results.append({ "chunk_id": chunk["id"], "doc_id": chunk["doc_id"], "doc_path": str(BASE_DIR / "raw_docs" / (chunk["doc_id"].split("_")[0] + ".docx")), "path": chunk["path"], "text": chunk["text"], "chunk_for_embedding": chunk["chunk_for_embedding"], "token_count": chunk["token_count"], "score": float(score) }) return results # --------------------------- # Dense Retrieval # --------------------------- import os current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(current_dir) class DenseRetriever: def __init__(self, persist_dir: str = os.path.join(parent_dir, "chroma_db"), collection: str = "snote", embedding_model_name: str = "AITeamVN/Vietnamese_Embedding_v2", device: str = "cpu"): settings = Settings(chroma_db_impl="duckdb+parquet", persist_directory=persist_dir) self.client = chromadb.Client(settings) self.collection = self.client.get_collection(collection) # load model self.model = SentenceTransformer(embedding_model_name, device=device) logger.info("DenseRetriever ready with model=%s, persist_dir=%s", embedding_model_name, persist_dir) def embed_query(self, query: str) -> List[float]: vec = self.model.encode([query], convert_to_numpy=True)[0] return vec.astype(float).tolist() def search(self, query: str, top_k: int = 20) -> List[Dict[str, Any]]: query_vec = self.embed_query(query) results = self.collection.query( query_embeddings=[query_vec], n_results=top_k ) # Convert ChromaDB results to BM25-compatible format formatted_results = [] ids = results["ids"][0] distances = results["distances"][0] # cosine distance (lower is better) documents = results["documents"][0] if results["documents"] else [None] * len(ids) metadatas = results["metadatas"][0] if results["metadatas"] else [{}] * len(ids) for i, (doc_id, distance, document, metadata) in enumerate(zip(ids, distances, documents, metadatas)): # Convert distance to similarity score (higher is better, like BM25) similarity_score = 1.0 - distance # Extract metadata fields doc_base_id = metadata.get("doc_id", doc_id.split("::")[0] if "::" in doc_id else doc_id) path_info = metadata.get("path", "").split(" | ") if metadata.get("path") else ["Dense Retrieval Result"] chunk_for_embedding = metadata.get("chunk_for_embedding", "") formatted_results.append({ "chunk_id": doc_id, "doc_id": doc_base_id, "doc_path": str(BASE_DIR / "raw_docs" / (doc_base_id.split("_")[0] + ".docx")), "path": path_info, "text": document if document else f"Document ID: {doc_id}", "token_count": metadata.get("token_count", 0), "score": float(similarity_score), "chunk_for_embedding": chunk_for_embedding }) return formatted_results # --------------------------- # Hybrid RAG # --------------------------- class HybridRAG: def __init__(self, bm25_retriever: BM25Retriever = BM25Retriever(), dense_retriever: DenseRetriever = DenseRetriever()): self.bm25_retriever = bm25_retriever self.dense_retriever = dense_retriever def get_results(self, query: str, top_k: int = 20, top_n: int = 10, session_id: str = None) -> List[Dict[str, Any]]: query = query.strip() bm25_results = self.bm25_retriever.search(query, top_k=top_k) dense_results = self.dense_retriever.search(query, top_k=top_k) results = rff_fusion(bm25_results, dense_results, k=60, top_n=top_n) return results if __name__ == "__main__": bm25_retriever = BM25Retriever() dense_retriever = DenseRetriever() import json import os from pathlib import Path output_path = Path("output.json") if os.path.exists(output_path): os.remove(output_path) import time start_time = time.time() query = "Sinh viên không đóng học phí có được bảo vệ Khóa luận không?" hybrid_rag = HybridRAG(bm25_retriever, dense_retriever) final_results = hybrid_rag.get_results(query, top_k=20, top_n=10) # Pretty print JSON với indent và ensure_ascii=False để hiển thị tiếng Việt đúng with open(output_path, "w", encoding="utf-8") as f: json.dump(final_results, f, ensure_ascii=False, indent=2, sort_keys=True)