File size: 9,403 Bytes
44c5827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import os
import json
import pickle
import logging
import heapq
import datetime
from pathlib import Path
from typing import List, Dict, Any, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed

from rank_bm25 import BM25Okapi
from underthesea import word_tokenize
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings

# ---------------------------
# Config & Logging
# ---------------------------
logging.basicConfig(
    level=os.getenv("LOG_LEVEL", "INFO"),
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s"
)
logger = logging.getLogger("hybrid_retriever")

BASE_DIR = Path(__file__).resolve().parent.parent
BM25_INDEX_PATH = BASE_DIR / "bm25_index.pkl"
SESSION_DIR = BASE_DIR / "sessions"
SESSION_DIR.mkdir(parents=True, exist_ok=True)

# ---------------------------
# Helper functions
# ---------------------------
def tokenize_vi(text: str) -> List[str]:
    return word_tokenize(text, format="text").lower().split()

def rff_fusion(bm25_results: List[Dict[str, Any]], dense_results: List[Dict[str, Any]],
               k: int = 60, top_n: int = 10) -> List[Dict[str, Any]]:
    
    fused_scores = {}
    provenance = {}
    # Create document lookup for faster access
    chunk_lookup = {}

    def update_scores(results, source):
        for rank, result in enumerate(results):
            chunk_id = result["chunk_id"]
            orig_score = result["score"]
            contrib = 1.0 / (k + rank + 1)
            fused_scores[chunk_id] = fused_scores.get(chunk_id, 0) + contrib
            provenance.setdefault(chunk_id, {})[source] = {
                "rank": rank + 1,
                "orig_score": orig_score,
                "rrf_contrib": contrib,
            }
            # Store document info for later use
            if chunk_id not in chunk_lookup:
                chunk_lookup[chunk_id] = result
    
    update_scores(bm25_results, "bm25")
    update_scores(dense_results, "dense")

    # Get top documents by fused score
    top_chunks = heapq.nlargest(top_n, fused_scores.items(), key=lambda x: x[1])
    
    # Build final result with full document information
    final_results = []
    for chunk_id, rrf_score in top_chunks:
        chunk_result_info = chunk_lookup[chunk_id]
        is_bm25, is_dense = False, False
        # Determine which sources contributed to this document
        
        if "bm25" in provenance[chunk_id]:
            bm25_rank = provenance[chunk_id]["bm25"]["rank"]
            is_bm25 = bool(bm25_rank <= top_n)

        if "dense" in provenance[chunk_id]:
            dense_rank = provenance[chunk_id]["dense"]["rank"]
            is_dense = bool(dense_rank <= top_n)
        
        result_doc = {
            "chunk_id": chunk_id,
            "doc_id": chunk_result_info["doc_id"],
            "doc_path": chunk_result_info["doc_path"],
            "path": chunk_result_info["path"],
            "token_count": chunk_result_info["token_count"],
            "rff_score": float(rrf_score),
            "is_bm25": is_bm25,
            "is_dense": is_dense,
            "text": chunk_result_info["text"],
            "chunk_for_embedding": chunk_result_info["chunk_for_embedding"]
        }
        final_results.append(result_doc)
    
    output_path = Path("output.json")
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(final_results, f, ensure_ascii=False, indent=2, sort_keys=True)
    
    return final_results

    
# ---------------------------
# BM25 Search
# ---------------------------
class BM25Retriever:
    def __init__(self, index_path: str = str(BM25_INDEX_PATH)):
        self.index_path = index_path
        self.index = self._load_index(index_path)
        self.bm25: BM25Okapi = self.index["bm25"]
        self.chunks: List[Dict[str, Any]] = self.index["chunks"]
        self.tokenized_corpus: List[List[str]] = self.index["tokenized_corpus"]
        logger.info("BM25Search loaded %d chunks from %s", len(self.chunks), index_path)
    
    def _load_index(self, path: str) -> Dict[str, Any]:
        if not os.path.exists(path):
            raise FileNotFoundError(f"BM25 index file not found: {path}")
        with open(path, "rb") as f:
            return pickle.load(f)
    
    def search(self, query: str, top_k: int = 20) -> List[Dict[str, Any]]:
        tokens = tokenize_vi(query)
        scores = self.bm25.get_scores(tokens)
        
        # sort & pick top_k
        ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)[:top_k]

        results = []
        for idx, score in ranked:
            chunk = self.chunks[idx]
            results.append({
                "chunk_id": chunk["id"],
                "doc_id": chunk["doc_id"],
                "doc_path": str(BASE_DIR / "raw_docs" / (chunk["doc_id"].split("_")[0] + ".docx")),
                "path": chunk["path"],
                "text": chunk["text"],
                "chunk_for_embedding": chunk["chunk_for_embedding"],
                "token_count": chunk["token_count"],
                "score": float(score)
            })
        return results

# ---------------------------
# Dense Retrieval
# ---------------------------
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)

class DenseRetriever:
    def __init__(self, persist_dir: str = os.path.join(parent_dir, "chroma_db"), collection: str = "snote", embedding_model_name: str = "AITeamVN/Vietnamese_Embedding_v2", device: str = "cpu"):
        settings = Settings(chroma_db_impl="duckdb+parquet", persist_directory=persist_dir)
        self.client = chromadb.Client(settings)
        self.collection = self.client.get_collection(collection)

        # load model
        self.model = SentenceTransformer(embedding_model_name, device=device)
        logger.info("DenseRetriever ready with model=%s, persist_dir=%s", embedding_model_name, persist_dir)

    def embed_query(self, query: str) -> List[float]:
        vec = self.model.encode([query], convert_to_numpy=True)[0]
        return vec.astype(float).tolist()
    
    def search(self, query: str, top_k: int = 20) -> List[Dict[str, Any]]:
        query_vec = self.embed_query(query)
        results = self.collection.query(
            query_embeddings=[query_vec],
            n_results=top_k
        )
        
        # Convert ChromaDB results to BM25-compatible format
        formatted_results = []
        
        ids = results["ids"][0]
        distances = results["distances"][0]  # cosine distance (lower is better)
        documents = results["documents"][0] if results["documents"] else [None] * len(ids)
        metadatas = results["metadatas"][0] if results["metadatas"] else [{}] * len(ids)
        
        for i, (doc_id, distance, document, metadata) in enumerate(zip(ids, distances, documents, metadatas)):
            # Convert distance to similarity score (higher is better, like BM25)
            similarity_score = 1.0 - distance
            
            # Extract metadata fields
            doc_base_id = metadata.get("doc_id", doc_id.split("::")[0] if "::" in doc_id else doc_id)
            path_info = metadata.get("path", "").split(" | ") if metadata.get("path") else ["Dense Retrieval Result"]
            chunk_for_embedding = metadata.get("chunk_for_embedding", "")
            formatted_results.append({
                "chunk_id": doc_id,
                "doc_id": doc_base_id,
                "doc_path": str(BASE_DIR / "raw_docs" / (doc_base_id.split("_")[0] + ".docx")),
                "path": path_info,
                "text": document if document else f"Document ID: {doc_id}",
                "token_count": metadata.get("token_count", 0),
                "score": float(similarity_score),
                "chunk_for_embedding": chunk_for_embedding
            })
        
        return formatted_results

# ---------------------------
# Hybrid RAG
# ---------------------------
class HybridRAG:
    def __init__(self, bm25_retriever: BM25Retriever = BM25Retriever(), dense_retriever: DenseRetriever = DenseRetriever()):
        self.bm25_retriever = bm25_retriever
        self.dense_retriever = dense_retriever

    def get_results(self, query: str, top_k: int = 20, top_n: int = 10, session_id: str = None) -> List[Dict[str, Any]]:
        query = query.strip()
        bm25_results = self.bm25_retriever.search(query, top_k=top_k)
        dense_results = self.dense_retriever.search(query, top_k=top_k)
        results = rff_fusion(bm25_results, dense_results, k=60, top_n=top_n)
        
        return results


if __name__ == "__main__":
    bm25_retriever = BM25Retriever()
    dense_retriever = DenseRetriever()
    import json
    import os
    from pathlib import Path
    output_path = Path("output.json")
    if os.path.exists(output_path):
        os.remove(output_path)

    import time
    start_time = time.time()
    query = "Sinh viên không đóng học phí có được bảo vệ Khóa luận không?"
    hybrid_rag = HybridRAG(bm25_retriever, dense_retriever)
    final_results = hybrid_rag.get_results(query, top_k=20, top_n=10)
    
    # Pretty print JSON với indent và ensure_ascii=False để hiển thị tiếng Việt đúng
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(final_results, f, ensure_ascii=False, indent=2, sort_keys=True)