VietCat commited on
Commit
6c134c0
·
1 Parent(s): a9dc0f3

update query and rerank in parallel

Browse files
Files changed (2) hide show
  1. app/gemini_client.py +151 -80
  2. app/reranker.py +152 -90
app/gemini_client.py CHANGED
@@ -6,29 +6,40 @@ from typing import Dict, List, Optional
6
  from google.generativeai.types import GenerationConfig, HarmCategory, HarmBlockThreshold
7
 
8
  from .request_limit_manager import RequestLimitManager
9
- from .utils import (
10
- _safe_truncate
11
- )
12
  from .config import get_settings
13
 
 
14
  class GeminiResponseError(Exception):
15
  """Custom exception for non-retriable Gemini response issues like safety or token limits."""
 
16
  def __init__(self, message, finish_reason=None, usage_metadata=None):
17
  super().__init__(message)
18
  self.finish_reason = finish_reason
19
  self.usage_metadata = usage_metadata
20
 
21
  def __str__(self):
22
- usage_str = f"Prompt: {self.usage_metadata.prompt_token_count}, Candidates: {self.usage_metadata.candidates_token_count}, Total: {self.usage_metadata.total_token_count}" if self.usage_metadata else "N/A"
 
 
 
 
23
  return f"{super().__str__()} (Finish Reason: {self.finish_reason}, Usage: {usage_str})"
24
 
 
25
  class GeminiClient:
26
  def __init__(self):
27
  self.limit_manager = RequestLimitManager("gemini")
28
  settings = get_settings()
29
- num_keys = len(settings.gemini_api_keys.split(',')) if settings.gemini_api_keys else 0
30
- num_models = len(settings.gemini_models.split(',')) if settings.gemini_models else 0
31
- logger.info(f"[GEMINI_INIT] Limiter is considering {num_keys} API keys and {num_models} models.")
 
 
 
 
 
 
32
  self._cached_model = None
33
  self._cached_key = None
34
  self._cached_model_instance = None
@@ -44,26 +55,32 @@ class GeminiClient:
44
  """
45
  Cache model instance để tránh recreate mỗi lần.
46
  """
47
- if (self._cached_key == key and
48
- self._cached_model == model and
49
- self._cached_model_instance is not None):
 
 
50
  return self._cached_model_instance
51
-
52
  # Configure và tạo model instance mới
53
  configure(api_key=key)
54
  self._cached_model_instance = GenerativeModel(model)
55
  self._cached_key = key
56
  self._cached_model = model
57
-
58
- logger.info(f"[GEMINI] Created new model instance for key={key[:5]}...{key[-5:]} model={model}")
 
 
59
  return self._cached_model_instance
60
 
61
  def _clear_cache_if_needed(self, new_key: str, new_model: str):
62
  """
63
  Chỉ clear cache khi key/model thực sự thay đổi.
64
  """
65
- if (self._cached_key != new_key or self._cached_model != new_model):
66
- logger.info(f"[GEMINI] Clearing cache due to key/model change: {self._cached_key}->{new_key}, {self._cached_model}->{new_model}")
 
 
67
  self._cached_model_instance = None
68
  self._cached_key = None
69
  self._cached_model = None
@@ -71,21 +88,19 @@ class GeminiClient:
71
  def generate_text(self, prompt: str, **kwargs) -> str:
72
  last_error = None
73
  max_retries = 3
74
-
75
  for attempt in range(max_retries):
76
  try:
77
  # Lấy current key/model từ manager
78
  key, model = self.limit_manager.get_current_key_model()
79
-
80
  # Sử dụng cached model instance
81
  _model = self._get_model_instance(key, model)
82
-
83
  response = _model.generate_content(
84
- prompt,
85
- safety_settings=self.safety_settings,
86
- **kwargs
87
  )
88
-
89
  # Log toàn bộ nội dung response ở mức INFO để tiện gỡ lỗi
90
  logger.info(f"[GEMINI][RAW_RESPONSE] {response}")
91
 
@@ -93,47 +108,81 @@ class GeminiClient:
93
  # 1. Kiểm tra response có hợp lệ không
94
  if not response.candidates:
95
  # Lỗi này nên được coi là lỗi tạm thời, thử lại với key/model khác
96
- raise ValueError("Gemini response is missing 'candidates' field. Retrying...")
 
 
97
 
98
- candidate = response.candidates[0]
99
- finish_reason_name = getattr(getattr(candidate, 'finish_reason', None), 'name', 'UNKNOWN')
 
 
100
  # Kiểm tra xem có nội dung thực sự không
101
  # Sửa: Dùng getattr để tránh AttributeError nếu 'parts' không tồn tại
102
- has_content = bool(candidate.content and getattr(candidate.content, 'parts', None))
 
 
103
 
104
  # 2. Phân loại lỗi và xử lý
105
  # Case 1: Lỗi nội dung không thể thử lại (SAFETY, MAX_TOKENS, etc.)
106
  if finish_reason_name != "STOP":
107
- usage_metadata = response.usage_metadata if hasattr(response, 'usage_metadata') else None
 
 
 
 
108
  error_message = f"Gemini response finished with non-OK reason: {finish_reason_name}."
109
  raise GeminiResponseError(
110
- error_message, finish_reason=finish_reason_name, usage_metadata=usage_metadata
 
 
111
  )
112
 
113
  # Case 2: Lỗi có thể thử lại (STOP nhưng không có nội dung)
114
- if not has_content: # Tại đây, ta biết chắc chắn finish_reason_name là "STOP"
115
- usage_metadata = response.usage_metadata if hasattr(response, 'usage_metadata') else None
116
- last_error = GeminiResponseError("Gemini response finished with STOP but has no content parts.", finish_reason='STOP_NO_CONTENT', usage_metadata=usage_metadata)
117
- logger.warning(f"[GEMINI] Model returned STOP with no content. Retrying with another key/model... (Attempt {attempt + 1}/{max_retries})")
118
- self.limit_manager.log_request(key, model, success=False, retry_delay=5)
119
- continue # Thử lại vòng lặp với key/model mới
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  # Case 3: Thành công (STOP và có nội dung)
122
  self.limit_manager.log_request(key, model, success=True)
123
- if hasattr(response, 'usage_metadata'):
124
- logger.info(f"[GEMINI][USAGE] Prompt Token Count: {response.usage_metadata.prompt_token_count} - Candidate Token Count: {response.usage_metadata.candidates_token_count} - Total Token Count: {response.usage_metadata.total_token_count}")
125
-
 
 
126
  try:
127
- logger.info(f"[GEMINI][TEXT_RESPONSE] {_safe_truncate(response.text)}")
 
 
128
  return response.text
129
  except ValueError as ve:
130
  # Safety net: Nếu truy cập .text thất bại dù các kiểm tra trước đó đã qua,
131
  # coi như đây là lỗi STOP_NO_CONTENT và ném ra để tầng trên xử lý.
132
- usage_metadata = response.usage_metadata if hasattr(response, 'usage_metadata') else None
 
 
 
 
133
  raise GeminiResponseError(
134
  f"Gemini response has no valid content part. Original error: {ve}",
135
- finish_reason='STOP_NO_CONTENT',
136
- usage_metadata=usage_metadata
137
  ) from ve
138
  # --- END: Cải tiến logic xử lý response ---
139
  except GeminiResponseError as e:
@@ -142,28 +191,41 @@ class GeminiClient:
142
  raise e
143
  except Exception as e:
144
  import re
 
145
  msg = str(e)
146
- if "429" in msg or "rate limit" in msg.lower():
147
- retry_delay = 60
148
- m = re.search(r'retry_delay.*?seconds: (\d+)', msg)
149
- if m:
150
- retry_delay = int(m.group(1))
151
-
152
- # Log failure với key/model thực tế đang được sử dụng
153
- self.limit_manager.log_request(key, model, success=False, retry_delay=retry_delay)
154
-
155
- # Chỉ clear cache nếu key/model thay đổi
156
- # Không clear cache ngay lập tức để tránh recreate không cần thiết
157
-
158
- logger.warning(f"[GEMINI] Rate limit hit, will retry with new key/model (attempt {attempt + 1}/{max_retries})")
 
 
 
 
 
 
 
 
 
 
159
  last_error = e
160
- continue
161
  else:
162
- # Lỗi khác không phải rate limit (vd: timeout, server error)
163
- # sẽ được propagate lên để lớp llm.py xử lý retry với backoff.
164
- logger.error(f"[GEMINI] Error generating text: {e}")
 
 
165
  raise e
166
-
167
  raise last_error or RuntimeError("No available Gemini API key/model")
168
 
169
  def count_tokens(self, prompt: str) -> int:
@@ -175,51 +237,60 @@ class GeminiClient:
175
  logger.error(f"[GEMINI] Error counting tokens: {e}")
176
  return 0
177
 
178
- def create_embedding(self, text: str, model: Optional[str] = None, task_type: str = "retrieval_query") -> list:
 
 
179
  last_error = None
180
  max_retries = 3
181
-
182
  for attempt in range(max_retries):
183
  try:
184
  key, default_model = self.limit_manager.get_current_key_model()
185
-
186
  # Ưu tiên model được truyền vào parameter, chỉ fallback về default_model nếu không có
187
  use_model = model if model and model.strip() else default_model
188
-
189
  if not use_model:
190
  raise ValueError("No model specified for embedding")
191
-
192
- logger.info(f"[GEMINI][EMBEDDING] Using model={use_model} (requested={model}, default={default_model}), task_type={task_type}")
193
-
 
 
194
  configure(api_key=key)
195
  response = embed_content(
196
- model=use_model,
197
- content=text,
198
- task_type=task_type
199
  )
200
-
201
  self.limit_manager.log_request(key, use_model, success=True)
202
- logger.info(f"[GEMINI][EMBEDDING][RAW_RESPONSE] {response['embedding'][:10]} ..... {response['embedding'][-10:]}")
203
- return response['embedding']
204
-
 
 
205
  except Exception as e:
206
  import re
 
207
  msg = str(e)
208
  if "429" in msg or "rate limit" in msg.lower():
209
  retry_delay = 60
210
- m_retry = re.search(r'retry_delay.*?seconds: (\d+)', msg)
211
  if m_retry:
212
  retry_delay = int(m_retry.group(1))
213
-
214
  # Log failure và trigger scan cho key/model mới
215
- self.limit_manager.log_request(key, use_model, success=False, retry_delay=retry_delay)
216
-
217
- logger.warning(f"[GEMINI] Rate limit hit in embedding, will retry with new key/model (attempt {attempt + 1}/{max_retries})")
 
 
 
 
218
  last_error = e
219
  continue
220
  else:
221
  logger.error(f"[GEMINI] Error creating embedding: {e}")
222
  last_error = e
223
  break
224
-
225
- raise last_error or RuntimeError("No available Gemini API key/model")
 
6
  from google.generativeai.types import GenerationConfig, HarmCategory, HarmBlockThreshold
7
 
8
  from .request_limit_manager import RequestLimitManager
9
+ from .utils import _safe_truncate
 
 
10
  from .config import get_settings
11
 
12
+
13
  class GeminiResponseError(Exception):
14
  """Custom exception for non-retriable Gemini response issues like safety or token limits."""
15
+
16
  def __init__(self, message, finish_reason=None, usage_metadata=None):
17
  super().__init__(message)
18
  self.finish_reason = finish_reason
19
  self.usage_metadata = usage_metadata
20
 
21
  def __str__(self):
22
+ usage_str = (
23
+ f"Prompt: {self.usage_metadata.prompt_token_count}, Candidates: {self.usage_metadata.candidates_token_count}, Total: {self.usage_metadata.total_token_count}"
24
+ if self.usage_metadata
25
+ else "N/A"
26
+ )
27
  return f"{super().__str__()} (Finish Reason: {self.finish_reason}, Usage: {usage_str})"
28
 
29
+
30
  class GeminiClient:
31
  def __init__(self):
32
  self.limit_manager = RequestLimitManager("gemini")
33
  settings = get_settings()
34
+ num_keys = (
35
+ len(settings.gemini_api_keys.split(",")) if settings.gemini_api_keys else 0
36
+ )
37
+ num_models = (
38
+ len(settings.gemini_models.split(",")) if settings.gemini_models else 0
39
+ )
40
+ logger.info(
41
+ f"[GEMINI_INIT] Limiter is considering {num_keys} API keys and {num_models} models."
42
+ )
43
  self._cached_model = None
44
  self._cached_key = None
45
  self._cached_model_instance = None
 
55
  """
56
  Cache model instance để tránh recreate mỗi lần.
57
  """
58
+ if (
59
+ self._cached_key == key
60
+ and self._cached_model == model
61
+ and self._cached_model_instance is not None
62
+ ):
63
  return self._cached_model_instance
64
+
65
  # Configure và tạo model instance mới
66
  configure(api_key=key)
67
  self._cached_model_instance = GenerativeModel(model)
68
  self._cached_key = key
69
  self._cached_model = model
70
+
71
+ logger.info(
72
+ f"[GEMINI] Created new model instance for key={key[:5]}...{key[-5:]} model={model}"
73
+ )
74
  return self._cached_model_instance
75
 
76
  def _clear_cache_if_needed(self, new_key: str, new_model: str):
77
  """
78
  Chỉ clear cache khi key/model thực sự thay đổi.
79
  """
80
+ if self._cached_key != new_key or self._cached_model != new_model:
81
+ logger.info(
82
+ f"[GEMINI] Clearing cache due to key/model change: {self._cached_key}->{new_key}, {self._cached_model}->{new_model}"
83
+ )
84
  self._cached_model_instance = None
85
  self._cached_key = None
86
  self._cached_model = None
 
88
  def generate_text(self, prompt: str, **kwargs) -> str:
89
  last_error = None
90
  max_retries = 3
91
+
92
  for attempt in range(max_retries):
93
  try:
94
  # Lấy current key/model từ manager
95
  key, model = self.limit_manager.get_current_key_model()
96
+
97
  # Sử dụng cached model instance
98
  _model = self._get_model_instance(key, model)
99
+
100
  response = _model.generate_content(
101
+ prompt, safety_settings=self.safety_settings, **kwargs
 
 
102
  )
103
+
104
  # Log toàn bộ nội dung response ở mức INFO để tiện gỡ lỗi
105
  logger.info(f"[GEMINI][RAW_RESPONSE] {response}")
106
 
 
108
  # 1. Kiểm tra response có hợp lệ không
109
  if not response.candidates:
110
  # Lỗi này nên được coi là lỗi tạm thời, thử lại với key/model khác
111
+ raise ValueError(
112
+ "Gemini response is missing 'candidates' field. Retrying..."
113
+ )
114
 
115
+ candidate = response.candidates[0]
116
+ finish_reason_name = getattr(
117
+ getattr(candidate, "finish_reason", None), "name", "UNKNOWN"
118
+ )
119
  # Kiểm tra xem có nội dung thực sự không
120
  # Sửa: Dùng getattr để tránh AttributeError nếu 'parts' không tồn tại
121
+ has_content = bool(
122
+ candidate.content and getattr(candidate.content, "parts", None)
123
+ )
124
 
125
  # 2. Phân loại lỗi và xử lý
126
  # Case 1: Lỗi nội dung không thể thử lại (SAFETY, MAX_TOKENS, etc.)
127
  if finish_reason_name != "STOP":
128
+ usage_metadata = (
129
+ response.usage_metadata
130
+ if hasattr(response, "usage_metadata")
131
+ else None
132
+ )
133
  error_message = f"Gemini response finished with non-OK reason: {finish_reason_name}."
134
  raise GeminiResponseError(
135
+ error_message,
136
+ finish_reason=finish_reason_name,
137
+ usage_metadata=usage_metadata,
138
  )
139
 
140
  # Case 2: Lỗi có thể thử lại (STOP nhưng không có nội dung)
141
+ if (
142
+ not has_content
143
+ ): # Tại đây, ta biết chắc chắn finish_reason_name "STOP"
144
+ usage_metadata = (
145
+ response.usage_metadata
146
+ if hasattr(response, "usage_metadata")
147
+ else None
148
+ )
149
+ last_error = GeminiResponseError(
150
+ "Gemini response finished with STOP but has no content parts.",
151
+ finish_reason="STOP_NO_CONTENT",
152
+ usage_metadata=usage_metadata,
153
+ )
154
+ logger.warning(
155
+ f"[GEMINI] Model returned STOP with no content. Retrying with another key/model... (Attempt {attempt + 1}/{max_retries})"
156
+ )
157
+ self.limit_manager.log_request(
158
+ key, model, success=False, retry_delay=5
159
+ )
160
+ continue # Thử lại vòng lặp với key/model mới
161
 
162
  # Case 3: Thành công (STOP và có nội dung)
163
  self.limit_manager.log_request(key, model, success=True)
164
+ if hasattr(response, "usage_metadata"):
165
+ logger.info(
166
+ f"[GEMINI][USAGE] Prompt Token Count: {response.usage_metadata.prompt_token_count} - Candidate Token Count: {response.usage_metadata.candidates_token_count} - Total Token Count: {response.usage_metadata.total_token_count}"
167
+ )
168
+
169
  try:
170
+ logger.info(
171
+ f"[GEMINI][TEXT_RESPONSE] {_safe_truncate(response.text)}"
172
+ )
173
  return response.text
174
  except ValueError as ve:
175
  # Safety net: Nếu truy cập .text thất bại dù các kiểm tra trước đó đã qua,
176
  # coi như đây là lỗi STOP_NO_CONTENT và ném ra để tầng trên xử lý.
177
+ usage_metadata = (
178
+ response.usage_metadata
179
+ if hasattr(response, "usage_metadata")
180
+ else None
181
+ )
182
  raise GeminiResponseError(
183
  f"Gemini response has no valid content part. Original error: {ve}",
184
+ finish_reason="STOP_NO_CONTENT",
185
+ usage_metadata=usage_metadata,
186
  ) from ve
187
  # --- END: Cải tiến logic xử lý response ---
188
  except GeminiResponseError as e:
 
191
  raise e
192
  except Exception as e:
193
  import re
194
+
195
  msg = str(e)
196
+ # Kiểm tra lỗi rate limit hoặc lỗi server (5xx)
197
+ is_rate_limit = "429" in msg or "rate limit" in msg.lower()
198
+ is_server_error = any(
199
+ code in msg for code in ["500", "502", "503", "504"]
200
+ )
201
+
202
+ if is_rate_limit or is_server_error:
203
+ retry_delay = 60 # Mặc định cho lỗi server
204
+ if is_rate_limit:
205
+ m = re.search(r"retry_delay.*?seconds: (\d+)", msg)
206
+ if m:
207
+ retry_delay = int(m.group(1))
208
+
209
+ # Log lỗi và chặn cặp key/model hiện tại trong một khoảng thời gian
210
+ self.limit_manager.log_request(
211
+ key, model, success=False, retry_delay=retry_delay
212
+ )
213
+
214
+ error_type = "Rate limit" if is_rate_limit else "Server"
215
+ logger.warning(
216
+ f"[GEMINI] {error_type} error hit, will retry with new key/model "
217
+ f"(attempt {attempt + 1}/{max_retries}). Error: {e}"
218
+ )
219
  last_error = e
220
+ continue # Tiếp tục vòng lặp để thử key/model mới
221
  else:
222
+ # Các lỗi khác không phải rate limit hoặc server error (vd: network timeout, invalid argument)
223
+ # sẽ được propagate lên để lớp llm.py/reranker.py xử lý retry với backoff.
224
+ logger.error(
225
+ f"[GEMINI] Unhandled error generating text, propagating up: {e}"
226
+ )
227
  raise e
228
+
229
  raise last_error or RuntimeError("No available Gemini API key/model")
230
 
231
  def count_tokens(self, prompt: str) -> int:
 
237
  logger.error(f"[GEMINI] Error counting tokens: {e}")
238
  return 0
239
 
240
+ def create_embedding(
241
+ self, text: str, model: Optional[str] = None, task_type: str = "retrieval_query"
242
+ ) -> list:
243
  last_error = None
244
  max_retries = 3
245
+
246
  for attempt in range(max_retries):
247
  try:
248
  key, default_model = self.limit_manager.get_current_key_model()
249
+
250
  # Ưu tiên model được truyền vào parameter, chỉ fallback về default_model nếu không có
251
  use_model = model if model and model.strip() else default_model
252
+
253
  if not use_model:
254
  raise ValueError("No model specified for embedding")
255
+
256
+ logger.info(
257
+ f"[GEMINI][EMBEDDING] Using model={use_model} (requested={model}, default={default_model}), task_type={task_type}"
258
+ )
259
+
260
  configure(api_key=key)
261
  response = embed_content(
262
+ model=use_model, content=text, task_type=task_type
 
 
263
  )
264
+
265
  self.limit_manager.log_request(key, use_model, success=True)
266
+ logger.info(
267
+ f"[GEMINI][EMBEDDING][RAW_RESPONSE] {response['embedding'][:10]} ..... {response['embedding'][-10:]}"
268
+ )
269
+ return response["embedding"]
270
+
271
  except Exception as e:
272
  import re
273
+
274
  msg = str(e)
275
  if "429" in msg or "rate limit" in msg.lower():
276
  retry_delay = 60
277
+ m_retry = re.search(r"retry_delay.*?seconds: (\d+)", msg)
278
  if m_retry:
279
  retry_delay = int(m_retry.group(1))
280
+
281
  # Log failure và trigger scan cho key/model mới
282
+ self.limit_manager.log_request(
283
+ key, use_model, success=False, retry_delay=retry_delay
284
+ )
285
+
286
+ logger.warning(
287
+ f"[GEMINI] Rate limit hit in embedding, will retry with new key/model (attempt {attempt + 1}/{max_retries})"
288
+ )
289
  last_error = e
290
  continue
291
  else:
292
  logger.error(f"[GEMINI] Error creating embedding: {e}")
293
  last_error = e
294
  break
295
+
296
+ raise last_error or RuntimeError("No available Gemini API key/model")
app/reranker.py CHANGED
@@ -2,27 +2,47 @@ from typing import List, Dict
2
 
3
  from app.utils import timing_decorator_async
4
  from .config import get_settings
5
- from .gemini_client import GeminiClient
6
  from loguru import logger
 
7
  import asyncio
8
  import hashlib
9
  import time
 
 
10
  # from .constants import BATCH_STATUS_MESSAGES
11
  # from .utils import get_random_message
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  class Reranker:
14
  def __init__(self):
15
  settings = get_settings()
16
- self.provider = getattr(settings, 'rerank_provider', settings.llm_provider)
17
- self.model = getattr(settings, 'rerank_model', settings.llm_model)
18
- if self.provider == 'gemini':
19
  self.client = GeminiClient()
20
  # elif self.provider == 'openai':
21
  # self.client = OpenAIClient(settings.openai_api_key, model=self.model)
22
  # elif self.provider == 'cohere':
23
  # self.client = CohereClient(settings.cohere_api_key, model=self.model)
24
  else:
25
- raise NotImplementedError(f"Rerank provider {self.provider} not supported yet.")
 
 
26
  # Cải thiện cache với TTL và quản lý memory
27
  self._rerank_cache = {}
28
  self._cache_ttl = 3600 # 1 giờ
@@ -35,65 +55,77 @@ class Reranker:
35
  """Tạo cache key từ query và docs."""
36
  # Tối ưu hóa cache key generation
37
  query_normalized = query.lower().strip()
38
- doc_ids = [str(doc.get('id', '')) for doc in docs[:15]] # Chỉ cache top 15 docs
39
  cache_content = query_normalized + "|".join(sorted(doc_ids))
40
  return hashlib.md5(cache_content.encode()).hexdigest()
41
 
42
  def _clean_cache(self):
43
  """Dọn dẹp cache cũ và quản lý memory."""
44
  current_time = time.time()
45
-
46
  # Xóa cache entries đã hết hạn
47
  expired_keys = [
48
- key for key, timestamp in self._cache_timestamps.items()
 
49
  if current_time - timestamp > self._cache_ttl
50
  ]
51
-
52
  for key in expired_keys:
53
  del self._rerank_cache[key]
54
  del self._cache_timestamps[key]
55
-
56
  # Nếu cache vẫn quá lớn, xóa entries cũ nhất
57
  if len(self._rerank_cache) > self._max_cache_size:
58
  sorted_keys = sorted(
59
- self._cache_timestamps.keys(),
60
- key=lambda k: self._cache_timestamps[k]
61
  )
62
-
63
  # Xóa 20% cache entries cũ nhất
64
- keys_to_remove = sorted_keys[:len(sorted_keys) // 5]
65
  for key in keys_to_remove:
66
  del self._rerank_cache[key]
67
  del self._cache_timestamps[key]
68
-
69
- logger.info(f"[RERANK] Cleaned cache: removed {len(keys_to_remove)} old entries")
 
 
70
 
71
  def _get_cached_result(self, cache_key: str, min_score: float) -> List[Dict]:
72
  """Lấy kết quả từ cache nếu có và còn hợp lệ."""
73
  if cache_key in self._rerank_cache:
74
  current_time = time.time()
75
- if current_time - self._cache_timestamps.get(cache_key, 0) <= self._cache_ttl:
 
 
 
76
  # Lọc theo điểm thay vì lấy top_k
77
  cached_docs = self._rerank_cache[cache_key]
78
- cached_result = [doc for doc in cached_docs if doc.get('rerank_score', 0) >= min_score]
79
- logger.info(f"[RERANK] Cache hit for query, returning {len(cached_result)} cached results with score >= {min_score}")
 
 
 
 
 
 
80
  return cached_result
81
  else:
82
  # Cache đã hết hạn, xóa
83
  del self._rerank_cache[cache_key]
84
  del self._cache_timestamps[cache_key]
85
-
86
  return []
87
 
88
  def _set_cached_result(self, cache_key: str, scored_docs: List[Dict]):
89
  """Lưu kết quả vào cache."""
90
  self._rerank_cache[cache_key] = scored_docs
91
  self._cache_timestamps[cache_key] = time.time()
92
-
93
  # Dọn dẹp cache nếu cần
94
  if len(self._rerank_cache) > self._max_cache_size:
95
  self._clean_cache()
96
 
 
97
  async def _batch_score_docs(self, query: str, docs: List[Dict]) -> List[Dict]:
98
  """
99
  Score nhiều documents cùng lúc bằng một prompt duy nhất.
@@ -101,16 +133,16 @@ class Reranker:
101
  """
102
  if not docs:
103
  return []
104
-
105
  # Không giới hạn content length, giữ nguyên nội dung luật
106
  docs_content = []
107
  for i, doc in enumerate(docs):
108
  # tieude = (doc.get('tieude') or '').strip()
109
  # noidung = (doc.get('noidung') or '').strip()
110
  # content = f"{tieude} {noidung}".strip()
111
- content = (doc.get('fullcontent') or '').strip()
112
  docs_content.append(f"{i+1}. {content}")
113
-
114
  batch_prompt = (
115
  f"Đánh giá mức độ liên quan giữa câu hỏi và các đoạn luật sau:\n\n"
116
  f"Câu hỏi: {query}\n\n"
@@ -118,63 +150,70 @@ class Reranker:
118
  f"Trả về điểm số từ 0-10 cho từng đoạn, phân cách bằng dấu phẩy.\n"
119
  f"Ví dụ: 8,5,7,3,9"
120
  )
121
-
122
  try:
123
- if self.provider == 'gemini':
124
  loop = asyncio.get_event_loop()
125
- logger.info(f"[RERANK] Sending batch prompt to Gemini for {len(docs)} docs")
126
- response = await loop.run_in_executor(None, self.client.generate_text, batch_prompt)
 
 
 
 
127
  logger.info(f"[RERANK] Got batch scores from Gemini: {response}")
128
-
129
- # Cải thiện parsing scores
130
  scores_text = str(response).strip()
 
 
 
131
  scores = []
132
-
133
- # Xử lý nhiều format response có thể có
134
- if ',' in scores_text:
135
- score_parts = scores_text.split(',')
136
- elif ' ' in scores_text:
137
- score_parts = scores_text.split()
138
- else:
139
- score_parts = scores_text.replace('.', ',').split(',')
140
-
141
- for score_str in score_parts:
142
  try:
143
- clean_score = ''.join(c for c in score_str.strip() if c.isdigit() or c == '.')
144
- if clean_score:
145
- score = float(clean_score)
146
- score = max(0, min(10, score))
147
  scores.append(score)
148
- else:
149
- scores.append(0)
150
  except (ValueError, TypeError):
151
- scores.append(0)
152
-
153
- while len(scores) < len(docs):
154
- scores.append(0)
155
-
 
 
 
 
 
156
  for i, doc in enumerate(docs):
157
- doc['rerank_score'] = scores[i]
158
-
159
- logger.info(f"[RERANK] Successfully scored {len(docs)} docs with scores: {scores}")
 
 
160
  return docs
161
-
162
  else:
163
- raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in batch method.")
164
-
165
- except Exception as e:
166
- logger.error(f"[RERANK] Lỗi khi batch score: {e}")
 
 
 
 
167
  for doc in docs:
168
- doc['rerank_score'] = 0
169
  return docs
170
 
 
171
  async def _score_doc(self, query: str, doc: Dict) -> Dict:
172
  """
173
  Score một document với query.
174
  Không cắt bớt nội dung luật.
175
  """
176
- tieude = (doc.get('tieude') or '').strip()
177
- noidung = (doc.get('noidung') or '').strip()
178
  content = f"{tieude} {noidung}".strip()
179
  prompt = (
180
  f"Đánh giá mức độ liên quan:\n"
@@ -183,14 +222,20 @@ class Reranker:
183
  f"Điểm (0-10):"
184
  )
185
  try:
186
- if self.provider == 'gemini':
187
  loop = asyncio.get_event_loop()
188
  logger.info(f"[RERANK] Sending individual prompt to Gemini")
189
- score_response = await loop.run_in_executor(None, self.client.generate_text, prompt)
190
- logger.info(f"[RERANK] Got individual score from Gemini: {score_response}")
 
 
 
 
191
  score_text = str(score_response).strip()
192
  try:
193
- clean_score = ''.join(c for c in score_text if c.isdigit() or c == '.')
 
 
194
  if clean_score:
195
  score = float(clean_score)
196
  score = max(0, min(10, score))
@@ -198,44 +243,59 @@ class Reranker:
198
  score = 0
199
  except (ValueError, TypeError):
200
  score = 0
201
- doc['rerank_score'] = score
202
  return doc
203
  else:
204
- raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in rerank method.")
205
- except Exception as e:
206
- logger.error(f"[RERANK] Lỗi khi tính score: {e} | doc: {doc}")
207
- doc['rerank_score'] = 0
 
 
 
 
 
208
  return doc
209
 
210
  @timing_decorator_async
211
- async def rerank(self, query: str, docs: List[Dict], min_score: float = 7.0) -> List[Dict]:
 
 
212
  """
213
  Rerank docs theo độ liên quan với query, trả về các docs có điểm >= min_score.
214
  Sử dụng batch processing và caching để tối ưu hiệu suất.
215
  """
216
- logger.info(f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | min_score: {min_score}")
217
-
 
 
218
  if not docs:
219
  return []
220
-
221
  # Kiểm tra cache trước
222
  cache_key = self._get_cache_key(query, docs)
223
  cached_result = self._get_cached_result(cache_key, min_score)
224
-
225
  if cached_result:
226
  return cached_result
227
-
228
  # Giới hạn số lượng docs để rerank - chỉ rerank top N docs có similarity cao nhất
229
  max_docs_to_rerank = self.max_docs_to_rerank
230
  docs_to_rerank = docs[:max_docs_to_rerank]
231
- logger.info(f"[RERANK] Will rerank {len(docs_to_rerank)} docs (limited to top {max_docs_to_rerank})")
232
-
 
 
233
  # Sử dụng batch processing thay vì individual scoring
234
  try:
235
  scored = await self._batch_score_docs(query, docs_to_rerank)
236
- logger.info(f"[RERANK] Batch processing completed, scored {len(scored)} docs")
 
 
237
  except Exception as e:
238
- logger.error(f"[RERANK] Batch processing failed, falling back to individual scoring: {e}")
 
 
239
  # Fallback về individual scoring nếu batch processing thất bại
240
  scored = []
241
  for doc in docs_to_rerank:
@@ -244,17 +304,19 @@ class Reranker:
244
  scored.append(scored_doc)
245
  except Exception as e:
246
  logger.error(f"[RERANK] Error scoring individual doc: {e}")
247
- doc['rerank_score'] = 0
248
  scored.append(doc)
249
-
250
  # Sort theo score
251
- scored = sorted(scored, key=lambda x: x.get('rerank_score', 0), reverse=True)
252
-
253
  # Lọc theo min_score
254
- result = [doc for doc in scored if doc.get('rerank_score', 0) >= min_score]
255
-
256
  # Cache kết quả đã được chấm điểm (toàn bộ, trước khi lọc)
257
  self._set_cached_result(cache_key, scored)
258
-
259
- logger.info(f"[RERANK] Found {len(result)} docs with score >= {min_score}. Top results: {result[:2]}...{result[-2:] if len(result) > 2 else ''}")
260
- return result
 
 
 
2
 
3
  from app.utils import timing_decorator_async
4
  from .config import get_settings
5
+ from .gemini_client import GeminiClient, GeminiResponseError
6
  from loguru import logger
7
+ import re
8
  import asyncio
9
  import hashlib
10
  import time
11
+ from tenacity import retry, stop_after_attempt, wait_exponential
12
+
13
  # from .constants import BATCH_STATUS_MESSAGES
14
  # from .utils import get_random_message
15
 
16
+ # --- Retry decorator cho các lỗi tạm thời của Reranker (network, server-side) ---
17
+ retry_on_rerank_transient_error = retry(
18
+ stop=stop_after_attempt(4), # 1 lần gọi gốc + 3 lần thử lại
19
+ wait=wait_exponential(multiplier=5, min=10, max=60), # Chờ 10s, 20s, 40s
20
+ retry=lambda retry_state: (
21
+ retry_state.outcome.failed
22
+ and not isinstance(retry_state.outcome.exception(), GeminiResponseError)
23
+ ),
24
+ before_sleep=lambda retry_state: logger.warning(
25
+ f"[RERANK][RETRY] Rerank call failed with transient error, retrying... "
26
+ f"Attempt: {retry_state.attempt_number}, Error: {retry_state.outcome.exception()}"
27
+ ),
28
+ )
29
+
30
+
31
  class Reranker:
32
  def __init__(self):
33
  settings = get_settings()
34
+ self.provider = getattr(settings, "rerank_provider", settings.llm_provider)
35
+ self.model = getattr(settings, "rerank_model", settings.llm_model)
36
+ if self.provider == "gemini":
37
  self.client = GeminiClient()
38
  # elif self.provider == 'openai':
39
  # self.client = OpenAIClient(settings.openai_api_key, model=self.model)
40
  # elif self.provider == 'cohere':
41
  # self.client = CohereClient(settings.cohere_api_key, model=self.model)
42
  else:
43
+ raise NotImplementedError(
44
+ f"Rerank provider {self.provider} not supported yet."
45
+ )
46
  # Cải thiện cache với TTL và quản lý memory
47
  self._rerank_cache = {}
48
  self._cache_ttl = 3600 # 1 giờ
 
55
  """Tạo cache key từ query và docs."""
56
  # Tối ưu hóa cache key generation
57
  query_normalized = query.lower().strip()
58
+ doc_ids = [str(doc.get("id", "")) for doc in docs[:15]] # Chỉ cache top 15 docs
59
  cache_content = query_normalized + "|".join(sorted(doc_ids))
60
  return hashlib.md5(cache_content.encode()).hexdigest()
61
 
62
  def _clean_cache(self):
63
  """Dọn dẹp cache cũ và quản lý memory."""
64
  current_time = time.time()
65
+
66
  # Xóa cache entries đã hết hạn
67
  expired_keys = [
68
+ key
69
+ for key, timestamp in self._cache_timestamps.items()
70
  if current_time - timestamp > self._cache_ttl
71
  ]
72
+
73
  for key in expired_keys:
74
  del self._rerank_cache[key]
75
  del self._cache_timestamps[key]
76
+
77
  # Nếu cache vẫn quá lớn, xóa entries cũ nhất
78
  if len(self._rerank_cache) > self._max_cache_size:
79
  sorted_keys = sorted(
80
+ self._cache_timestamps.keys(), key=lambda k: self._cache_timestamps[k]
 
81
  )
82
+
83
  # Xóa 20% cache entries cũ nhất
84
+ keys_to_remove = sorted_keys[: len(sorted_keys) // 5]
85
  for key in keys_to_remove:
86
  del self._rerank_cache[key]
87
  del self._cache_timestamps[key]
88
+
89
+ logger.info(
90
+ f"[RERANK] Cleaned cache: removed {len(keys_to_remove)} old entries"
91
+ )
92
 
93
  def _get_cached_result(self, cache_key: str, min_score: float) -> List[Dict]:
94
  """Lấy kết quả từ cache nếu có và còn hợp lệ."""
95
  if cache_key in self._rerank_cache:
96
  current_time = time.time()
97
+ if (
98
+ current_time - self._cache_timestamps.get(cache_key, 0)
99
+ <= self._cache_ttl
100
+ ):
101
  # Lọc theo điểm thay vì lấy top_k
102
  cached_docs = self._rerank_cache[cache_key]
103
+ cached_result = [
104
+ doc
105
+ for doc in cached_docs
106
+ if doc.get("rerank_score", 0) >= min_score
107
+ ]
108
+ logger.info(
109
+ f"[RERANK] Cache hit for query, returning {len(cached_result)} cached results with score >= {min_score}"
110
+ )
111
  return cached_result
112
  else:
113
  # Cache đã hết hạn, xóa
114
  del self._rerank_cache[cache_key]
115
  del self._cache_timestamps[cache_key]
116
+
117
  return []
118
 
119
  def _set_cached_result(self, cache_key: str, scored_docs: List[Dict]):
120
  """Lưu kết quả vào cache."""
121
  self._rerank_cache[cache_key] = scored_docs
122
  self._cache_timestamps[cache_key] = time.time()
123
+
124
  # Dọn dẹp cache nếu cần
125
  if len(self._rerank_cache) > self._max_cache_size:
126
  self._clean_cache()
127
 
128
+ @retry_on_rerank_transient_error
129
  async def _batch_score_docs(self, query: str, docs: List[Dict]) -> List[Dict]:
130
  """
131
  Score nhiều documents cùng lúc bằng một prompt duy nhất.
 
133
  """
134
  if not docs:
135
  return []
136
+
137
  # Không giới hạn content length, giữ nguyên nội dung luật
138
  docs_content = []
139
  for i, doc in enumerate(docs):
140
  # tieude = (doc.get('tieude') or '').strip()
141
  # noidung = (doc.get('noidung') or '').strip()
142
  # content = f"{tieude} {noidung}".strip()
143
+ content = (doc.get("fullcontent") or "").strip()
144
  docs_content.append(f"{i+1}. {content}")
145
+
146
  batch_prompt = (
147
  f"Đánh giá mức độ liên quan giữa câu hỏi và các đoạn luật sau:\n\n"
148
  f"Câu hỏi: {query}\n\n"
 
150
  f"Trả về điểm số từ 0-10 cho từng đoạn, phân cách bằng dấu phẩy.\n"
151
  f"Ví dụ: 8,5,7,3,9"
152
  )
153
+
154
  try:
155
+ if self.provider == "gemini":
156
  loop = asyncio.get_event_loop()
157
+ logger.info(
158
+ f"[RERANK] Sending batch prompt to Gemini for {len(docs)} docs"
159
+ )
160
+ response = await loop.run_in_executor(
161
+ None, self.client.generate_text, batch_prompt
162
+ )
163
  logger.info(f"[RERANK] Got batch scores from Gemini: {response}")
164
+
165
+ # Cải thiện parsing scores bằng regex để chỉ lấy các số hợp lệ
166
  scores_text = str(response).strip()
167
+ # Tìm tất cả các chuỗi số (integer hoặc float) trong văn bản trả về
168
+ score_strings = re.findall(r"\b\d+(?:\.\d+)?\b", scores_text)
169
+
170
  scores = []
171
+ for s in score_strings:
 
 
 
 
 
 
 
 
 
172
  try:
173
+ score = float(s)
174
+ # Chỉ chấp nhận các điểm số trong khoảng 0-10 để tăng độ chính xác
175
+ if 0 <= score <= 10:
 
176
  scores.append(score)
 
 
177
  except (ValueError, TypeError):
178
+ # Bỏ qua các giá trị không phải là số hợp lệ
179
+ continue
180
+
181
+ # Đảm bảo số lượng điểm khớp với số lượng văn bản
182
+ # Nếu thiếu, thêm điểm 0. Nếu thừa, cắt bớt.
183
+ if len(scores) < len(docs):
184
+ scores.extend([0.0] * (len(docs) - len(scores)))
185
+ else:
186
+ scores = scores[: len(docs)]
187
+
188
  for i, doc in enumerate(docs):
189
+ doc["rerank_score"] = scores[i]
190
+
191
+ logger.info(
192
+ f"[RERANK] Successfully scored {len(docs)} docs with scores: {scores}"
193
+ )
194
  return docs
195
+
196
  else:
197
+ raise NotImplementedError(
198
+ f"Rerank provider {self.provider} not supported yet in batch method."
199
+ )
200
+
201
+ except GeminiResponseError as e:
202
+ # Lỗi nội dung không thể retry (safety, max_tokens), gán điểm 0 và trả về.
203
+ # Các lỗi khác (network, 500) sẽ được decorator retry.
204
+ logger.error(f"[RERANK] Lỗi nội dung không thể retry khi batch score: {e}")
205
  for doc in docs:
206
+ doc["rerank_score"] = 0
207
  return docs
208
 
209
+ @retry_on_rerank_transient_error
210
  async def _score_doc(self, query: str, doc: Dict) -> Dict:
211
  """
212
  Score một document với query.
213
  Không cắt bớt nội dung luật.
214
  """
215
+ tieude = (doc.get("tieude") or "").strip()
216
+ noidung = (doc.get("noidung") or "").strip()
217
  content = f"{tieude} {noidung}".strip()
218
  prompt = (
219
  f"Đánh giá mức độ liên quan:\n"
 
222
  f"Điểm (0-10):"
223
  )
224
  try:
225
+ if self.provider == "gemini":
226
  loop = asyncio.get_event_loop()
227
  logger.info(f"[RERANK] Sending individual prompt to Gemini")
228
+ score_response = await loop.run_in_executor(
229
+ None, self.client.generate_text, prompt
230
+ )
231
+ logger.info(
232
+ f"[RERANK] Got individual score from Gemini: {score_response}"
233
+ )
234
  score_text = str(score_response).strip()
235
  try:
236
+ clean_score = "".join(
237
+ c for c in score_text if c.isdigit() or c == "."
238
+ )
239
  if clean_score:
240
  score = float(clean_score)
241
  score = max(0, min(10, score))
 
243
  score = 0
244
  except (ValueError, TypeError):
245
  score = 0
246
+ doc["rerank_score"] = score
247
  return doc
248
  else:
249
+ raise NotImplementedError(
250
+ f"Rerank provider {self.provider} not supported yet in rerank method."
251
+ )
252
+ except GeminiResponseError as e:
253
+ # Lỗi nội dung không thể retry (safety, max_tokens), gán điểm 0 và trả về.
254
+ logger.error(
255
+ f"[RERANK] Lỗi nội dung không thể retry khi tính score: {e} | doc: {doc}"
256
+ )
257
+ doc["rerank_score"] = 0
258
  return doc
259
 
260
  @timing_decorator_async
261
+ async def rerank(
262
+ self, query: str, docs: List[Dict], min_score: float = 7.0
263
+ ) -> List[Dict]:
264
  """
265
  Rerank docs theo độ liên quan với query, trả về các docs có điểm >= min_score.
266
  Sử dụng batch processing và caching để tối ưu hiệu suất.
267
  """
268
+ logger.info(
269
+ f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | min_score: {min_score}"
270
+ )
271
+
272
  if not docs:
273
  return []
274
+
275
  # Kiểm tra cache trước
276
  cache_key = self._get_cache_key(query, docs)
277
  cached_result = self._get_cached_result(cache_key, min_score)
278
+
279
  if cached_result:
280
  return cached_result
281
+
282
  # Giới hạn số lượng docs để rerank - chỉ rerank top N docs có similarity cao nhất
283
  max_docs_to_rerank = self.max_docs_to_rerank
284
  docs_to_rerank = docs[:max_docs_to_rerank]
285
+ logger.info(
286
+ f"[RERANK] Will rerank {len(docs_to_rerank)} docs (limited to top {max_docs_to_rerank})"
287
+ )
288
+
289
  # Sử dụng batch processing thay vì individual scoring
290
  try:
291
  scored = await self._batch_score_docs(query, docs_to_rerank)
292
+ logger.info(
293
+ f"[RERANK] Batch processing completed, scored {len(scored)} docs"
294
+ )
295
  except Exception as e:
296
+ logger.error(
297
+ f"[RERANK] Batch processing failed, falling back to individual scoring: {e}"
298
+ )
299
  # Fallback về individual scoring nếu batch processing thất bại
300
  scored = []
301
  for doc in docs_to_rerank:
 
304
  scored.append(scored_doc)
305
  except Exception as e:
306
  logger.error(f"[RERANK] Error scoring individual doc: {e}")
307
+ doc["rerank_score"] = 0
308
  scored.append(doc)
309
+
310
  # Sort theo score
311
+ scored = sorted(scored, key=lambda x: x.get("rerank_score", 0), reverse=True)
312
+
313
  # Lọc theo min_score
314
+ result = [doc for doc in scored if doc.get("rerank_score", 0) >= min_score]
315
+
316
  # Cache kết quả đã được chấm điểm (toàn bộ, trước khi lọc)
317
  self._set_cached_result(cache_key, scored)
318
+
319
+ logger.info(
320
+ f"[RERANK] Found {len(result)} docs with score >= {min_score}. Top results: {result[:2]}...{result[-2:] if len(result) > 2 else ''}"
321
+ )
322
+ return result