|
|
import os |
|
|
import json |
|
|
import pickle |
|
|
import logging |
|
|
import heapq |
|
|
import datetime |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Any, Tuple |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
|
from rank_bm25 import BM25Okapi |
|
|
from underthesea import word_tokenize |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import chromadb |
|
|
from chromadb.config import Settings |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=os.getenv("LOG_LEVEL", "INFO"), |
|
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s" |
|
|
) |
|
|
logger = logging.getLogger("hybrid_retriever") |
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent.parent |
|
|
BM25_INDEX_PATH = BASE_DIR / "bm25_index.pkl" |
|
|
SESSION_DIR = BASE_DIR / "sessions" |
|
|
SESSION_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tokenize_vi(text: str) -> List[str]: |
|
|
return word_tokenize(text, format="text").lower().split() |
|
|
|
|
|
def rff_fusion(bm25_results: List[Dict[str, Any]], dense_results: List[Dict[str, Any]], |
|
|
k: int = 60, top_n: int = 10) -> List[Dict[str, Any]]: |
|
|
|
|
|
fused_scores = {} |
|
|
provenance = {} |
|
|
|
|
|
chunk_lookup = {} |
|
|
|
|
|
def update_scores(results, source): |
|
|
for rank, result in enumerate(results): |
|
|
chunk_id = result["chunk_id"] |
|
|
orig_score = result["score"] |
|
|
contrib = 1.0 / (k + rank + 1) |
|
|
fused_scores[chunk_id] = fused_scores.get(chunk_id, 0) + contrib |
|
|
provenance.setdefault(chunk_id, {})[source] = { |
|
|
"rank": rank + 1, |
|
|
"orig_score": orig_score, |
|
|
"rrf_contrib": contrib, |
|
|
} |
|
|
|
|
|
if chunk_id not in chunk_lookup: |
|
|
chunk_lookup[chunk_id] = result |
|
|
|
|
|
update_scores(bm25_results, "bm25") |
|
|
update_scores(dense_results, "dense") |
|
|
|
|
|
|
|
|
top_chunks = heapq.nlargest(top_n, fused_scores.items(), key=lambda x: x[1]) |
|
|
|
|
|
|
|
|
final_results = [] |
|
|
for chunk_id, rrf_score in top_chunks: |
|
|
chunk_result_info = chunk_lookup[chunk_id] |
|
|
is_bm25, is_dense = False, False |
|
|
|
|
|
|
|
|
if "bm25" in provenance[chunk_id]: |
|
|
bm25_rank = provenance[chunk_id]["bm25"]["rank"] |
|
|
is_bm25 = bool(bm25_rank <= top_n) |
|
|
|
|
|
if "dense" in provenance[chunk_id]: |
|
|
dense_rank = provenance[chunk_id]["dense"]["rank"] |
|
|
is_dense = bool(dense_rank <= top_n) |
|
|
|
|
|
result_doc = { |
|
|
"chunk_id": chunk_id, |
|
|
"doc_id": chunk_result_info["doc_id"], |
|
|
"doc_path": chunk_result_info["doc_path"], |
|
|
"path": chunk_result_info["path"], |
|
|
"token_count": chunk_result_info["token_count"], |
|
|
"rff_score": float(rrf_score), |
|
|
"is_bm25": is_bm25, |
|
|
"is_dense": is_dense, |
|
|
"text": chunk_result_info["text"], |
|
|
"chunk_for_embedding": chunk_result_info["chunk_for_embedding"] |
|
|
} |
|
|
final_results.append(result_doc) |
|
|
|
|
|
output_path = Path("output.json") |
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
|
json.dump(final_results, f, ensure_ascii=False, indent=2, sort_keys=True) |
|
|
|
|
|
return final_results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BM25Retriever: |
|
|
def __init__(self, index_path: str = str(BM25_INDEX_PATH)): |
|
|
self.index_path = index_path |
|
|
self.index = self._load_index(index_path) |
|
|
self.bm25: BM25Okapi = self.index["bm25"] |
|
|
self.chunks: List[Dict[str, Any]] = self.index["chunks"] |
|
|
self.tokenized_corpus: List[List[str]] = self.index["tokenized_corpus"] |
|
|
logger.info("BM25Search loaded %d chunks from %s", len(self.chunks), index_path) |
|
|
|
|
|
def _load_index(self, path: str) -> Dict[str, Any]: |
|
|
if not os.path.exists(path): |
|
|
raise FileNotFoundError(f"BM25 index file not found: {path}") |
|
|
with open(path, "rb") as f: |
|
|
return pickle.load(f) |
|
|
|
|
|
def search(self, query: str, top_k: int = 20) -> List[Dict[str, Any]]: |
|
|
tokens = tokenize_vi(query) |
|
|
scores = self.bm25.get_scores(tokens) |
|
|
|
|
|
|
|
|
ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)[:top_k] |
|
|
|
|
|
results = [] |
|
|
for idx, score in ranked: |
|
|
chunk = self.chunks[idx] |
|
|
results.append({ |
|
|
"chunk_id": chunk["id"], |
|
|
"doc_id": chunk["doc_id"], |
|
|
"doc_path": str(BASE_DIR / "raw_docs" / (chunk["doc_id"].split("_")[0] + ".docx")), |
|
|
"path": chunk["path"], |
|
|
"text": chunk["text"], |
|
|
"chunk_for_embedding": chunk["chunk_for_embedding"], |
|
|
"token_count": chunk["token_count"], |
|
|
"score": float(score) |
|
|
}) |
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
parent_dir = os.path.dirname(current_dir) |
|
|
|
|
|
class DenseRetriever: |
|
|
def __init__(self, persist_dir: str = os.path.join(parent_dir, "chroma_db"), collection: str = "snote", embedding_model_name: str = "AITeamVN/Vietnamese_Embedding_v2", device: str = "cpu"): |
|
|
settings = Settings(chroma_db_impl="duckdb+parquet", persist_directory=persist_dir) |
|
|
self.client = chromadb.Client(settings) |
|
|
self.collection = self.client.get_collection(collection) |
|
|
|
|
|
|
|
|
self.model = SentenceTransformer(embedding_model_name, device=device) |
|
|
logger.info("DenseRetriever ready with model=%s, persist_dir=%s", embedding_model_name, persist_dir) |
|
|
|
|
|
def embed_query(self, query: str) -> List[float]: |
|
|
vec = self.model.encode([query], convert_to_numpy=True)[0] |
|
|
return vec.astype(float).tolist() |
|
|
|
|
|
def search(self, query: str, top_k: int = 20) -> List[Dict[str, Any]]: |
|
|
query_vec = self.embed_query(query) |
|
|
results = self.collection.query( |
|
|
query_embeddings=[query_vec], |
|
|
n_results=top_k |
|
|
) |
|
|
|
|
|
|
|
|
formatted_results = [] |
|
|
|
|
|
ids = results["ids"][0] |
|
|
distances = results["distances"][0] |
|
|
documents = results["documents"][0] if results["documents"] else [None] * len(ids) |
|
|
metadatas = results["metadatas"][0] if results["metadatas"] else [{}] * len(ids) |
|
|
|
|
|
for i, (doc_id, distance, document, metadata) in enumerate(zip(ids, distances, documents, metadatas)): |
|
|
|
|
|
similarity_score = 1.0 - distance |
|
|
|
|
|
|
|
|
doc_base_id = metadata.get("doc_id", doc_id.split("::")[0] if "::" in doc_id else doc_id) |
|
|
path_info = metadata.get("path", "").split(" | ") if metadata.get("path") else ["Dense Retrieval Result"] |
|
|
chunk_for_embedding = metadata.get("chunk_for_embedding", "") |
|
|
formatted_results.append({ |
|
|
"chunk_id": doc_id, |
|
|
"doc_id": doc_base_id, |
|
|
"doc_path": str(BASE_DIR / "raw_docs" / (doc_base_id.split("_")[0] + ".docx")), |
|
|
"path": path_info, |
|
|
"text": document if document else f"Document ID: {doc_id}", |
|
|
"token_count": metadata.get("token_count", 0), |
|
|
"score": float(similarity_score), |
|
|
"chunk_for_embedding": chunk_for_embedding |
|
|
}) |
|
|
|
|
|
return formatted_results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HybridRAG: |
|
|
def __init__(self, bm25_retriever: BM25Retriever = BM25Retriever(), dense_retriever: DenseRetriever = DenseRetriever()): |
|
|
self.bm25_retriever = bm25_retriever |
|
|
self.dense_retriever = dense_retriever |
|
|
|
|
|
def get_results(self, query: str, top_k: int = 20, top_n: int = 10, session_id: str = None) -> List[Dict[str, Any]]: |
|
|
query = query.strip() |
|
|
bm25_results = self.bm25_retriever.search(query, top_k=top_k) |
|
|
dense_results = self.dense_retriever.search(query, top_k=top_k) |
|
|
results = rff_fusion(bm25_results, dense_results, k=60, top_n=top_n) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
bm25_retriever = BM25Retriever() |
|
|
dense_retriever = DenseRetriever() |
|
|
import json |
|
|
import os |
|
|
from pathlib import Path |
|
|
output_path = Path("output.json") |
|
|
if os.path.exists(output_path): |
|
|
os.remove(output_path) |
|
|
|
|
|
import time |
|
|
start_time = time.time() |
|
|
query = "Sinh viên không đóng học phí có được bảo vệ Khóa luận không?" |
|
|
hybrid_rag = HybridRAG(bm25_retriever, dense_retriever) |
|
|
final_results = hybrid_rag.get_results(query, top_k=20, top_n=10) |
|
|
|
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
|
json.dump(final_results, f, ensure_ascii=False, indent=2, sort_keys=True) |
|
|
|
|
|
|
|
|
|