File size: 15,996 Bytes
f6c9376
daffc5f
 
f6c9376
6c134c0
f6c9376
6c134c0
f6c9376
34991da
28f4bd1
6c134c0
 
4032184
 
f6c9376
6c134c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6c9376
4032184
f6c9376
6c134c0
 
 
eab288c
f6c9376
 
 
 
 
6c134c0
 
 
28f4bd1
34991da
28f4bd1
 
 
c025e27
 
34991da
 
 
28f4bd1
 
6c134c0
28f4bd1
34991da
f6c9376
28f4bd1
 
 
6c134c0
28f4bd1
 
6c134c0
 
28f4bd1
 
6c134c0
28f4bd1
 
 
6c134c0
28f4bd1
 
 
6c134c0
28f4bd1
6c134c0
28f4bd1
6c134c0
28f4bd1
 
 
6c134c0
 
 
 
28f4bd1
a9dc0f3
28f4bd1
 
 
6c134c0
 
 
 
a9dc0f3
 
6c134c0
 
 
 
 
 
 
 
28f4bd1
 
 
 
 
6c134c0
28f4bd1
 
 
 
 
 
6c134c0
28f4bd1
 
 
44013a5
6c134c0
34991da
 
 
28f4bd1
34991da
 
 
6c134c0
28f4bd1
34991da
 
9dcf8cb
 
 
6c134c0
28f4bd1
6c134c0
d86fb66
34991da
d86fb66
 
 
 
 
 
34991da
d86fb66
 
34991da
6c134c0
34991da
6c134c0
34991da
6c134c0
 
 
 
 
 
34991da
6c134c0
d86fb66
34991da
fd5dbb4
 
 
d86fb66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd5dbb4
 
 
d86fb66
 
 
fd5dbb4
d86fb66
fd5dbb4
 
 
 
 
 
 
d86fb66
fd5dbb4
 
 
 
d86fb66
6c134c0
34991da
6c134c0
34991da
6c134c0
 
 
28f4bd1
 
6c134c0
 
 
 
 
 
 
 
 
 
34991da
6c134c0
 
 
 
 
34991da
6c134c0
34991da
6c134c0
 
 
 
 
 
 
 
34991da
6c134c0
34991da
 
6c134c0
28f4bd1
 
 
 
 
6c134c0
 
28f4bd1
 
 
 
 
 
 
 
6c134c0
28f4bd1
 
6c134c0
 
 
 
 
 
28f4bd1
 
6c134c0
 
 
28f4bd1
 
 
 
 
 
 
6c134c0
28f4bd1
 
6c134c0
 
 
 
 
 
 
 
 
28f4bd1
 
daffc5f
6c134c0
 
 
f6c9376
a9dc0f3
28f4bd1
f6c9376
6c134c0
 
 
 
44013a5
 
6c134c0
34991da
 
a9dc0f3
6c134c0
28f4bd1
34991da
6c134c0
a9dc0f3
c025e27
34991da
6c134c0
 
 
 
34991da
 
 
6c134c0
 
 
34991da
6c134c0
 
 
34991da
 
 
96f79c9
34991da
 
96f79c9
34991da
6c134c0
34991da
6c134c0
a9dc0f3
6c134c0
 
a9dc0f3
6c134c0
 
a9dc0f3
28f4bd1
6c134c0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
from typing import List, Dict

from app.utils import timing_decorator_async
from .config import get_settings
from .gemini_client import GeminiClient, GeminiResponseError
from loguru import logger
import re
import asyncio
import hashlib
import time
from tenacity import retry, stop_after_attempt, wait_exponential

# from .constants import BATCH_STATUS_MESSAGES
# from .utils import get_random_message

# --- Retry decorator cho các lỗi tạm thời của Reranker (network, server-side) ---
retry_on_rerank_transient_error = retry(
    stop=stop_after_attempt(4),  # 1 lần gọi gốc + 3 lần thử lại
    wait=wait_exponential(multiplier=5, min=10, max=60),  # Chờ 10s, 20s, 40s
    retry=lambda retry_state: (
        retry_state.outcome.failed
        and not isinstance(retry_state.outcome.exception(), GeminiResponseError)
    ),
    before_sleep=lambda retry_state: logger.warning(
        f"[RERANK][RETRY] Rerank call failed with transient error, retrying... "
        f"Attempt: {retry_state.attempt_number}, Error: {retry_state.outcome.exception()}"
    ),
)


class Reranker:
    def __init__(self):
        settings = get_settings()
        self.provider = getattr(settings, "rerank_provider", settings.llm_provider)
        self.model = getattr(settings, "rerank_model", settings.llm_model)
        if self.provider == "gemini":
            self.client = GeminiClient()
        # elif self.provider == 'openai':
        #     self.client = OpenAIClient(settings.openai_api_key, model=self.model)
        # elif self.provider == 'cohere':
        #     self.client = CohereClient(settings.cohere_api_key, model=self.model)
        else:
            raise NotImplementedError(
                f"Rerank provider {self.provider} not supported yet."
            )
        # Cải thiện cache với TTL và quản lý memory
        self._rerank_cache = {}
        self._cache_ttl = 3600  # 1 giờ
        self._max_cache_size = 200  # Tăng cache size
        self._cache_timestamps = {}
        # Sử dụng max_docs_to_rerank từ config
        self.max_docs_to_rerank = settings.max_docs_to_rerank

    def _get_cache_key(self, query: str, docs: List[Dict]) -> str:
        """Tạo cache key từ query và docs."""
        # Tối ưu hóa cache key generation
        query_normalized = query.lower().strip()
        doc_ids = [str(doc.get("id", "")) for doc in docs[:15]]  # Chỉ cache top 15 docs
        cache_content = query_normalized + "|".join(sorted(doc_ids))
        return hashlib.md5(cache_content.encode()).hexdigest()

    def _clean_cache(self):
        """Dọn dẹp cache cũ và quản lý memory."""
        current_time = time.time()

        # Xóa cache entries đã hết hạn
        expired_keys = [
            key
            for key, timestamp in self._cache_timestamps.items()
            if current_time - timestamp > self._cache_ttl
        ]

        for key in expired_keys:
            del self._rerank_cache[key]
            del self._cache_timestamps[key]

        # Nếu cache vẫn quá lớn, xóa entries cũ nhất
        if len(self._rerank_cache) > self._max_cache_size:
            sorted_keys = sorted(
                self._cache_timestamps.keys(), key=lambda k: self._cache_timestamps[k]
            )

            # Xóa 20% cache entries cũ nhất
            keys_to_remove = sorted_keys[: len(sorted_keys) // 5]
            for key in keys_to_remove:
                del self._rerank_cache[key]
                del self._cache_timestamps[key]

            logger.info(
                f"[RERANK] Cleaned cache: removed {len(keys_to_remove)} old entries"
            )

    def _get_cached_result(self, cache_key: str, min_score: float) -> List[Dict]:
        """Lấy kết quả từ cache nếu có và còn hợp lệ."""
        if cache_key in self._rerank_cache:
            current_time = time.time()
            if (
                current_time - self._cache_timestamps.get(cache_key, 0)
                <= self._cache_ttl
            ):
                # Lọc theo điểm thay vì lấy top_k
                cached_docs = self._rerank_cache[cache_key]
                cached_result = [
                    doc
                    for doc in cached_docs
                    if doc.get("rerank_score", 0) >= min_score
                ]
                logger.info(
                    f"[RERANK] Cache hit for query, returning {len(cached_result)} cached results with score >= {min_score}"
                )
                return cached_result
            else:
                # Cache đã hết hạn, xóa
                del self._rerank_cache[cache_key]
                del self._cache_timestamps[cache_key]

        return []

    def _set_cached_result(self, cache_key: str, scored_docs: List[Dict]):
        """Lưu kết quả vào cache."""
        self._rerank_cache[cache_key] = scored_docs
        self._cache_timestamps[cache_key] = time.time()

        # Dọn dẹp cache nếu cần
        if len(self._rerank_cache) > self._max_cache_size:
            self._clean_cache()

    @retry_on_rerank_transient_error
    async def _batch_score_docs(self, query: str, docs: List[Dict]) -> List[Dict]:
        """
        Score nhiều documents cùng lúc bằng một prompt duy nhất.
        Không cắt bớt nội dung luật.
        """
        if not docs:
            return []

        # Không giới hạn content length, giữ nguyên nội dung luật
        docs_content = []
        for i, doc in enumerate(docs):
            # tieude = (doc.get('tieude') or '').strip()
            # noidung = (doc.get('noidung') or '').strip()
            # content = f"{tieude} {noidung}".strip()
            content = (doc.get("fullcontent") or "").strip()
            docs_content.append(f"{i+1}. {content}")

        # Sửa: Prompt được làm chặt chẽ hơn để yêu cầu LLM chỉ trả về điểm số.
        batch_prompt = (
            "Bạn là một hệ thống đánh giá. Đánh giá mức độ liên quan giữa câu hỏi và các đoạn luật được đánh số sau đây.\n"
            "QUY TẮC:\n"
            "1. Chỉ trả về MỘT DÒNG DUY NHẤT.\n"
            "2. Dòng đó CHỈ chứa danh sách các điểm số (từ 0 đến 10), mỗi điểm tương ứng với một đoạn luật.\n"
            "3. Các điểm số phải được phân cách bởi dấu phẩy.\n"
            "4. KHÔNG giải thích, KHÔNG định dạng markdown, KHÔNG thêm bất kỳ văn bản nào khác.\n\n"
            f"Câu hỏi: {query}\n\n"
            "Các đoạn luật:\n" + "\n".join(docs_content) + "\n\n"
            "ĐIỂM SỐ:"
        )

        try:
            if self.provider == "gemini":
                loop = asyncio.get_event_loop()
                logger.info(
                    f"[RERANK] Sending batch prompt to Gemini for {len(docs)} docs"
                )
                response = await loop.run_in_executor(
                    None, self.client.generate_text, batch_prompt
                )
                logger.info(f"[RERANK] Got batch scores from Gemini: {response}")

                # --- START: Cải thiện logic trích xuất điểm (Sửa lỗi) ---
                scores_text = str(response).strip()
                scores_line = ""
                score_strings = []

                # Tách response thành các dòng
                lines = scores_text.split("\n")

                # Ưu tiên 1: Tìm dòng cuối cùng chỉ chứa số, dấu phẩy, khoảng trắng.
                # Đây là trường hợp lý tưởng khi LLM tuân thủ prompt nghiêm ngặt.
                for line in reversed(lines):
                    line = line.strip()
                    if line and re.fullmatch(r"[0-9.,\s]+", line):
                        scores_line = line
                        logger.debug(
                            f"[RERANK] Found pure score line (best case): '{scores_line}'"
                        )
                        break

                # Ưu tiên 2: Nếu không tìm thấy, tìm dòng có chứa keyword và điểm số.
                # Regex này linh hoạt hơn để xử lý markdown và các biến thể keyword.
                if not scores_line:
                    keyword_regex = (
                        r"(?i)(?:Kết quả|Scores|Trả về|Điểm số)[\s\*:]*([0-9.,\s]+)$"
                    )
                    for line in reversed(lines):
                        line = line.strip()
                        match = re.search(keyword_regex, line)
                        if match:
                            scores_line = match.group(1).strip()
                            logger.debug(
                                f"[RERANK] Found scores line using keyword regex: '{scores_line}'"
                            )
                            break

                if scores_line:
                    # Trích xuất tất cả các số từ dòng đã tìm thấy
                    score_strings = re.findall(r"\b\d+(?:\.\d+)?\b", scores_line)
                else:
                    # Fallback cuối cùng: tìm số trong toàn bộ response nếu các phương pháp trên thất bại.
                    logger.warning(
                        "[RERANK] Could not find a dedicated score line. Falling back to parsing all numbers from response."
                    )
                    score_strings = re.findall(r"\b\d+(?:\.\d+)?\b", scores_text)
                # --- END: Cải thiện logic trích xuất điểm (Sửa lỗi) ---

                scores = []
                for s in score_strings:
                    try:
                        score = float(s)
                        # Chỉ chấp nhận các điểm số trong khoảng 0-10 để tăng độ chính xác
                        if 0 <= score <= 10:
                            scores.append(score)
                    except (ValueError, TypeError):
                        # Bỏ qua các giá trị không phải là số hợp lệ
                        continue

                # Đảm bảo số lượng điểm khớp với số lượng văn bản
                # Nếu thiếu, thêm điểm 0. Nếu thừa, cắt bớt.
                if len(scores) < len(docs):
                    scores.extend([0.0] * (len(docs) - len(scores)))
                else:
                    scores = scores[: len(docs)]

                for i, doc in enumerate(docs):
                    doc["rerank_score"] = scores[i]

                logger.info(
                    f"[RERANK] Successfully scored {len(docs)} docs with scores: {scores}"
                )
                return docs

            else:
                raise NotImplementedError(
                    f"Rerank provider {self.provider} not supported yet in batch method."
                )

        except GeminiResponseError as e:
            # Lỗi nội dung không thể retry (safety, max_tokens), gán điểm 0 và trả về.
            # Các lỗi khác (network, 500) sẽ được decorator retry.
            logger.error(f"[RERANK] Lỗi nội dung không thể retry khi batch score: {e}")
            for doc in docs:
                doc["rerank_score"] = 0
            return docs

    @retry_on_rerank_transient_error
    async def _score_doc(self, query: str, doc: Dict) -> Dict:
        """
        Score một document với query.
        Không cắt bớt nội dung luật.
        """
        tieude = (doc.get("tieude") or "").strip()
        noidung = (doc.get("noidung") or "").strip()
        content = f"{tieude} {noidung}".strip()
        prompt = (
            f"Đánh giá mức độ liên quan:\n"
            f"Luật: {content}\n"
            f"Hỏi: {query}\n"
            f"Điểm (0-10):"
        )
        try:
            if self.provider == "gemini":
                loop = asyncio.get_event_loop()
                logger.info(f"[RERANK] Sending individual prompt to Gemini")
                score_response = await loop.run_in_executor(
                    None, self.client.generate_text, prompt
                )
                logger.info(
                    f"[RERANK] Got individual score from Gemini: {score_response}"
                )
                score_text = str(score_response).strip()
                try:
                    clean_score = "".join(
                        c for c in score_text if c.isdigit() or c == "."
                    )
                    if clean_score:
                        score = float(clean_score)
                        score = max(0, min(10, score))
                    else:
                        score = 0
                except (ValueError, TypeError):
                    score = 0
                doc["rerank_score"] = score
                return doc
            else:
                raise NotImplementedError(
                    f"Rerank provider {self.provider} not supported yet in rerank method."
                )
        except GeminiResponseError as e:
            # Lỗi nội dung không thể retry (safety, max_tokens), gán điểm 0 và trả về.
            logger.error(
                f"[RERANK] Lỗi nội dung không thể retry khi tính score: {e} | doc: {doc}"
            )
            doc["rerank_score"] = 0
            return doc

    @timing_decorator_async
    async def rerank(
        self, query: str, docs: List[Dict], min_score: float = 7.0
    ) -> List[Dict]:
        """
        Rerank docs theo độ liên quan với query, trả về các docs có điểm >= min_score.
        Sử dụng batch processing và caching để tối ưu hiệu suất.
        """
        logger.info(
            f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | min_score: {min_score}"
        )

        if not docs:
            return []

        # Kiểm tra cache trước
        cache_key = self._get_cache_key(query, docs)
        cached_result = self._get_cached_result(cache_key, min_score)

        if cached_result:
            return cached_result

        # Giới hạn số lượng docs để rerank - chỉ rerank top N docs có similarity cao nhất
        max_docs_to_rerank = self.max_docs_to_rerank
        docs_to_rerank = docs[:max_docs_to_rerank]
        logger.info(
            f"[RERANK] Will rerank {len(docs_to_rerank)} docs (limited to top {max_docs_to_rerank})"
        )

        # Sử dụng batch processing thay vì individual scoring
        try:
            scored = await self._batch_score_docs(query, docs_to_rerank)
            logger.info(
                f"[RERANK] Batch processing completed, scored {len(scored)} docs"
            )
        except Exception as e:
            logger.error(
                f"[RERANK] Batch processing failed, falling back to individual scoring: {e}"
            )
            # Fallback về individual scoring nếu batch processing thất bại
            scored = []
            for doc in docs_to_rerank:
                try:
                    scored_doc = await self._score_doc(query, doc)
                    scored.append(scored_doc)
                except Exception as e:
                    logger.error(f"[RERANK] Error scoring individual doc: {e}")
                    doc["rerank_score"] = 0
                    scored.append(doc)

        # Sort theo score
        scored = sorted(scored, key=lambda x: x.get("rerank_score", 0), reverse=True)

        # Lọc theo min_score
        result = [doc for doc in scored if doc.get("rerank_score", 0) >= min_score]

        # Cache kết quả đã được chấm điểm (toàn bộ, trước khi lọc)
        self._set_cached_result(cache_key, scored)

        logger.info(
            f"[RERANK] Found {len(result)} docs with score >= {min_score}. Top results: {result[:2]}...{result[-2:] if len(result) > 2 else ''}"
        )
        return result