VietCat commited on
Commit
28f4bd1
·
1 Parent(s): 34991da

optimize reranker

Browse files
Files changed (1) hide show
  1. app/reranker.py +138 -54
app/reranker.py CHANGED
@@ -5,6 +5,7 @@ from loguru import logger
5
  import asyncio
6
  import random
7
  import hashlib
 
8
  from .constants import BATCH_STATUS_MESSAGES
9
 
10
  class Reranker:
@@ -21,86 +22,136 @@ class Reranker:
21
  else:
22
  raise NotImplementedError(f"Rerank provider {self.provider} not supported yet.")
23
  self.facebook_client = facebook_client
24
- # Cache cho kết quả reranking
 
25
  self._rerank_cache = {}
 
 
 
26
 
27
  def _get_cache_key(self, query: str, docs: List[Dict]) -> str:
28
  """Tạo cache key từ query và docs."""
29
- # Tạo hash từ query doc IDs
 
30
  doc_ids = [str(doc.get('id', '')) for doc in docs[:15]] # Chỉ cache top 15 docs
31
- cache_content = query + "|".join(doc_ids)
32
  return hashlib.md5(cache_content.encode()).hexdigest()
33
 
34
- async def _score_doc(self, query: str, doc: Dict) -> Dict:
35
- """
36
- Score một document với query.
37
- """
38
- content = (doc.get('tieude', '') or '') + ' ' + (doc.get('noidung', '') or '')
39
- # Tối ưu prompt ngắn gọn hơn
40
- prompt = (
41
- f"Luật: {content[:500]}\n" # Giới hạn content length
42
- f"Hỏi: {query}\n"
43
- "Đánh giá mức độ liên quan (0-10). Chỉ trả về số."
44
- )
45
 
46
- try:
47
- if self.provider == 'gemini':
48
- loop = asyncio.get_event_loop()
49
- logger.info(f"[RERANK] Sending prompt to Gemini: {prompt}")
50
- score = await loop.run_in_executor(None, self.client.generate_text, prompt)
51
- logger.info(f"[RERANK] Got score from Gemini: {score}")
52
- else:
53
- raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in rerank method.")
 
 
 
 
 
 
 
 
54
 
55
- score = float(str(score).strip().split()[0])
56
- doc['rerank_score'] = score
57
- return doc
 
 
58
 
59
- except Exception as e:
60
- logger.error(f"[RERANK] Lỗi khi tính score: {e} | doc: {doc}")
61
- doc['rerank_score'] = 0
62
- return doc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  async def _batch_score_docs(self, query: str, docs: List[Dict]) -> List[Dict]:
65
  """
66
  Score nhiều documents cùng lúc bằng một prompt duy nhất.
 
67
  """
68
  if not docs:
69
  return []
70
 
71
- # Tạo prompt batch cho tất cả documents
72
  docs_content = []
73
  for i, doc in enumerate(docs):
74
- content = (doc.get('tieude', '') or '') + ' ' + (doc.get('noidung', '') or '')
75
- docs_content.append(f"{i+1}. {content[:300]}") # Giới hạn length
 
 
76
 
77
  batch_prompt = (
 
78
  f"Câu hỏi: {query}\n\n"
79
  f"Các đoạn luật:\n" + "\n".join(docs_content) + "\n\n"
80
- f"Đánh giá mức độ liên quan của từng đoạn (0-10). Trả về dạng: 1.8,2.5,3.0,..."
 
81
  )
82
 
83
  try:
84
  if self.provider == 'gemini':
85
  loop = asyncio.get_event_loop()
86
- logger.info(f"[RERANK] Sending batch prompt to Gemini")
87
  response = await loop.run_in_executor(None, self.client.generate_text, batch_prompt)
88
  logger.info(f"[RERANK] Got batch scores from Gemini: {response}")
89
 
90
- # Parse scores từ response
91
  scores_text = str(response).strip()
92
  scores = []
93
- for score_str in scores_text.split(','):
 
 
 
 
 
 
 
 
 
94
  try:
95
- score = float(score_str.strip().split('.')[0])
96
- scores.append(score)
97
- except:
 
 
 
 
 
98
  scores.append(0)
99
 
100
- # Gán scores cho documents
 
 
101
  for i, doc in enumerate(docs):
102
- doc['rerank_score'] = scores[i] if i < len(scores) else 0
103
 
 
104
  return docs
105
 
106
  else:
@@ -108,15 +159,53 @@ class Reranker:
108
 
109
  except Exception as e:
110
  logger.error(f"[RERANK] Lỗi khi batch score: {e}")
111
- # Fallback về individual scoring
112
  for doc in docs:
113
  doc['rerank_score'] = 0
114
  return docs
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  async def rerank(self, query: str, docs: List[Dict], top_k: int = 5) -> List[Dict]:
117
  """
118
  Rerank docs theo độ liên quan với query, trả về top_k docs.
119
- Sử dụng batch processing để tối ưu hiệu suất.
120
  """
121
  logger.info(f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | top_k: {top_k}")
122
 
@@ -125,9 +214,9 @@ class Reranker:
125
 
126
  # Kiểm tra cache trước
127
  cache_key = self._get_cache_key(query, docs)
128
- if cache_key in self._rerank_cache:
129
- logger.info(f"[RERANK] Cache hit for query, returning cached result")
130
- cached_result = self._rerank_cache[cache_key][:top_k]
131
  return cached_result
132
 
133
  # Giới hạn số lượng docs để rerank - chỉ rerank top 15 docs có similarity cao nhất
@@ -152,7 +241,7 @@ class Reranker:
152
  doc['rerank_score'] = 0
153
  scored.append(doc)
154
 
155
- # Gửi Facebook message chỉ một lần sau khi hoàn thành
156
  if self.facebook_client:
157
  try:
158
  message = random.choice(BATCH_STATUS_MESSAGES)
@@ -164,13 +253,8 @@ class Reranker:
164
  scored = sorted(scored, key=lambda x: x['rerank_score'], reverse=True)
165
  result = scored[:top_k]
166
 
167
- # Cache kết quả
168
- self._rerank_cache[cache_key] = scored
169
- # Giới hạn cache size để tránh memory leak
170
- if len(self._rerank_cache) > 100:
171
- # Xóa cache cũ nhất
172
- oldest_key = next(iter(self._rerank_cache))
173
- del self._rerank_cache[oldest_key]
174
 
175
  logger.info(f"[RERANK] Top reranked docs: {result}")
176
  return result
 
5
  import asyncio
6
  import random
7
  import hashlib
8
+ import time
9
  from .constants import BATCH_STATUS_MESSAGES
10
 
11
  class Reranker:
 
22
  else:
23
  raise NotImplementedError(f"Rerank provider {self.provider} not supported yet.")
24
  self.facebook_client = facebook_client
25
+
26
+ # Cải thiện cache với TTL và quản lý memory
27
  self._rerank_cache = {}
28
+ self._cache_ttl = 3600 # 1 giờ
29
+ self._max_cache_size = 200 # Tăng cache size
30
+ self._cache_timestamps = {}
31
 
32
  def _get_cache_key(self, query: str, docs: List[Dict]) -> str:
33
  """Tạo cache key từ query và docs."""
34
+ # Tối ưu hóa cache key generation
35
+ query_normalized = query.lower().strip()
36
  doc_ids = [str(doc.get('id', '')) for doc in docs[:15]] # Chỉ cache top 15 docs
37
+ cache_content = query_normalized + "|".join(sorted(doc_ids))
38
  return hashlib.md5(cache_content.encode()).hexdigest()
39
 
40
+ def _clean_cache(self):
41
+ """Dọn dẹp cache cũ và quản lý memory."""
42
+ current_time = time.time()
 
 
 
 
 
 
 
 
43
 
44
+ # Xóa cache entries đã hết hạn
45
+ expired_keys = [
46
+ key for key, timestamp in self._cache_timestamps.items()
47
+ if current_time - timestamp > self._cache_ttl
48
+ ]
49
+
50
+ for key in expired_keys:
51
+ del self._rerank_cache[key]
52
+ del self._cache_timestamps[key]
53
+
54
+ # Nếu cache vẫn quá lớn, xóa entries cũ nhất
55
+ if len(self._rerank_cache) > self._max_cache_size:
56
+ sorted_keys = sorted(
57
+ self._cache_timestamps.keys(),
58
+ key=lambda k: self._cache_timestamps[k]
59
+ )
60
 
61
+ # Xóa 20% cache entries cũ nhất
62
+ keys_to_remove = sorted_keys[:len(sorted_keys) // 5]
63
+ for key in keys_to_remove:
64
+ del self._rerank_cache[key]
65
+ del self._cache_timestamps[key]
66
 
67
+ logger.info(f"[RERANK] Cleaned cache: removed {len(keys_to_remove)} old entries")
68
+
69
+ def _get_cached_result(self, cache_key: str, top_k: int) -> List[Dict]:
70
+ """Lấy kết quả từ cache nếu có và còn hợp lệ."""
71
+ if cache_key in self._rerank_cache:
72
+ current_time = time.time()
73
+ if current_time - self._cache_timestamps.get(cache_key, 0) <= self._cache_ttl:
74
+ cached_result = self._rerank_cache[cache_key][:top_k]
75
+ logger.info(f"[RERANK] Cache hit for query, returning {len(cached_result)} cached results")
76
+ return cached_result
77
+ else:
78
+ # Cache đã hết hạn, xóa
79
+ del self._rerank_cache[cache_key]
80
+ del self._cache_timestamps[cache_key]
81
+
82
+ return []
83
+
84
+ def _set_cached_result(self, cache_key: str, scored_docs: List[Dict]):
85
+ """Lưu kết quả vào cache."""
86
+ self._rerank_cache[cache_key] = scored_docs
87
+ self._cache_timestamps[cache_key] = time.time()
88
+
89
+ # Dọn dẹp cache nếu cần
90
+ if len(self._rerank_cache) > self._max_cache_size:
91
+ self._clean_cache()
92
 
93
  async def _batch_score_docs(self, query: str, docs: List[Dict]) -> List[Dict]:
94
  """
95
  Score nhiều documents cùng lúc bằng một prompt duy nhất.
96
+ Không cắt bớt nội dung luật.
97
  """
98
  if not docs:
99
  return []
100
 
101
+ # Không giới hạn content length, giữ nguyên nội dung luật
102
  docs_content = []
103
  for i, doc in enumerate(docs):
104
+ tieude = doc.get('tieude', '').strip()
105
+ noidung = doc.get('noidung', '').strip()
106
+ content = f"{tieude} {noidung}".strip()
107
+ docs_content.append(f"{i+1}. {content}")
108
 
109
  batch_prompt = (
110
+ f"Đánh giá mức độ liên quan giữa câu hỏi và các đoạn luật sau:\n\n"
111
  f"Câu hỏi: {query}\n\n"
112
  f"Các đoạn luật:\n" + "\n".join(docs_content) + "\n\n"
113
+ f"Trả về điểm số từ 0-10 cho từng đoạn, phân cách bằng dấu phẩy.\n"
114
+ f"Ví dụ: 8,5,7,3,9"
115
  )
116
 
117
  try:
118
  if self.provider == 'gemini':
119
  loop = asyncio.get_event_loop()
120
+ logger.info(f"[RERANK] Sending batch prompt to Gemini for {len(docs)} docs")
121
  response = await loop.run_in_executor(None, self.client.generate_text, batch_prompt)
122
  logger.info(f"[RERANK] Got batch scores from Gemini: {response}")
123
 
124
+ # Cải thiện parsing scores
125
  scores_text = str(response).strip()
126
  scores = []
127
+
128
+ # Xử lý nhiều format response có thể có
129
+ if ',' in scores_text:
130
+ score_parts = scores_text.split(',')
131
+ elif ' ' in scores_text:
132
+ score_parts = scores_text.split()
133
+ else:
134
+ score_parts = scores_text.replace('.', ',').split(',')
135
+
136
+ for score_str in score_parts:
137
  try:
138
+ clean_score = ''.join(c for c in score_str.strip() if c.isdigit() or c == '.')
139
+ if clean_score:
140
+ score = float(clean_score)
141
+ score = max(0, min(10, score))
142
+ scores.append(score)
143
+ else:
144
+ scores.append(0)
145
+ except (ValueError, TypeError):
146
  scores.append(0)
147
 
148
+ while len(scores) < len(docs):
149
+ scores.append(0)
150
+
151
  for i, doc in enumerate(docs):
152
+ doc['rerank_score'] = scores[i]
153
 
154
+ logger.info(f"[RERANK] Successfully scored {len(docs)} docs with scores: {scores}")
155
  return docs
156
 
157
  else:
 
159
 
160
  except Exception as e:
161
  logger.error(f"[RERANK] Lỗi khi batch score: {e}")
 
162
  for doc in docs:
163
  doc['rerank_score'] = 0
164
  return docs
165
 
166
+ async def _score_doc(self, query: str, doc: Dict) -> Dict:
167
+ """
168
+ Score một document với query.
169
+ Không cắt bớt nội dung luật.
170
+ """
171
+ tieude = doc.get('tieude', '').strip()
172
+ noidung = doc.get('noidung', '').strip()
173
+ content = f"{tieude} {noidung}".strip()
174
+ prompt = (
175
+ f"Đánh giá mức độ liên quan:\n"
176
+ f"Luật: {content}\n"
177
+ f"Hỏi: {query}\n"
178
+ f"Điểm (0-10):"
179
+ )
180
+ try:
181
+ if self.provider == 'gemini':
182
+ loop = asyncio.get_event_loop()
183
+ logger.info(f"[RERANK] Sending individual prompt to Gemini")
184
+ score_response = await loop.run_in_executor(None, self.client.generate_text, prompt)
185
+ logger.info(f"[RERANK] Got individual score from Gemini: {score_response}")
186
+ score_text = str(score_response).strip()
187
+ try:
188
+ clean_score = ''.join(c for c in score_text if c.isdigit() or c == '.')
189
+ if clean_score:
190
+ score = float(clean_score)
191
+ score = max(0, min(10, score))
192
+ else:
193
+ score = 0
194
+ except (ValueError, TypeError):
195
+ score = 0
196
+ doc['rerank_score'] = score
197
+ return doc
198
+ else:
199
+ raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in rerank method.")
200
+ except Exception as e:
201
+ logger.error(f"[RERANK] Lỗi khi tính score: {e} | doc: {doc}")
202
+ doc['rerank_score'] = 0
203
+ return doc
204
+
205
  async def rerank(self, query: str, docs: List[Dict], top_k: int = 5) -> List[Dict]:
206
  """
207
  Rerank docs theo độ liên quan với query, trả về top_k docs.
208
+ Sử dụng batch processing và caching để tối ưu hiệu suất.
209
  """
210
  logger.info(f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | top_k: {top_k}")
211
 
 
214
 
215
  # Kiểm tra cache trước
216
  cache_key = self._get_cache_key(query, docs)
217
+ cached_result = self._get_cached_result(cache_key, top_k)
218
+
219
+ if cached_result:
220
  return cached_result
221
 
222
  # Giới hạn số lượng docs để rerank - chỉ rerank top 15 docs có similarity cao nhất
 
241
  doc['rerank_score'] = 0
242
  scored.append(doc)
243
 
244
+ # Gửi Facebook message sau khi hoàn thành
245
  if self.facebook_client:
246
  try:
247
  message = random.choice(BATCH_STATUS_MESSAGES)
 
253
  scored = sorted(scored, key=lambda x: x['rerank_score'], reverse=True)
254
  result = scored[:top_k]
255
 
256
+ # Cache kết quả với system mới
257
+ self._set_cached_result(cache_key, scored)
 
 
 
 
 
258
 
259
  logger.info(f"[RERANK] Top reranked docs: {result}")
260
  return result