VietCat commited on
Commit
44013a5
·
1 Parent(s): 2dfd2c2

refactor request limiter

Browse files
app/embedding.py CHANGED
@@ -19,6 +19,8 @@ class EmbeddingClient:
19
  self.provider = getattr(settings, 'embedding_provider', 'default')
20
  self.model = getattr(settings, 'embedding_model', 'models/embedding-001')
21
  self.gemini_client: Optional[GeminiClient] = GeminiClient() if self.provider == 'gemini' else None
 
 
22
 
23
  @timing_decorator_async
24
  async def create_embedding(self, text: str) -> List[float]:
@@ -35,31 +37,37 @@ class EmbeddingClient:
35
  import asyncio
36
  loop = asyncio.get_event_loop()
37
  gemini_client = self.gemini_client # type: ignore
 
 
 
38
  embedding = await loop.run_in_executor(None, lambda: gemini_client.create_embedding(text, model=self.model))
 
39
  # Kiểm tra kiểu dữ liệu trả về
40
  if isinstance(embedding, list):
41
  preview = f"{embedding[:10]}...{embedding[-10:]}" if len(embedding) > 20 else str(embedding)
42
- logger.info(f"[DEBUG] Embedding API response: {preview}")
43
  return embedding
44
  else:
45
- logger.error(f"[DEBUG] Unknown embedding type: {type(embedding)} - value: {embedding}")
46
  raise RuntimeError(f"Embedding returned unexpected type: {type(embedding)}")
47
  except Exception as e:
48
- logger.error(f"Error creating embedding with Gemini: {e}")
49
  raise
 
 
50
  url = "https://vietcat-vietnameseembeddingv2.hf.space/embed"
51
  payload = {"text": text}
52
  try:
53
  response = await call_endpoint_with_retry(self._client, url, payload)
54
  if response is not None:
55
  data = response.json()
56
- logger.info(f"[DEBUG] Embedding API response: {data['embedding'][:10]}...{data['embedding'][-10:]}")
57
  return data["embedding"]
58
  else:
59
- logger.error("Embedding API response is None")
60
- raise RuntimeError("Embedding API response is None")
61
  except Exception as e:
62
- logger.error(f"Error creating embedding: {e}")
63
  raise
64
 
65
  def cosine_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
@@ -73,5 +81,12 @@ class EmbeddingClient:
73
  b = np.array(embedding2)
74
  return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
75
  except Exception as e:
76
- logger.error(f"Error calculating similarity: {e}")
77
- return 0.0
 
 
 
 
 
 
 
 
19
  self.provider = getattr(settings, 'embedding_provider', 'default')
20
  self.model = getattr(settings, 'embedding_model', 'models/embedding-001')
21
  self.gemini_client: Optional[GeminiClient] = GeminiClient() if self.provider == 'gemini' else None
22
+
23
+ logger.info(f"[EMBEDDING] Initialized with provider={self.provider}, model={self.model}")
24
 
25
  @timing_decorator_async
26
  async def create_embedding(self, text: str) -> List[float]:
 
37
  import asyncio
38
  loop = asyncio.get_event_loop()
39
  gemini_client = self.gemini_client # type: ignore
40
+
41
+ # Luôn sử dụng model từ config, không phụ thuộc vào key/model từ RequestLimitManager
42
+ logger.info(f"[EMBEDDING] Creating embedding with model={self.model}")
43
  embedding = await loop.run_in_executor(None, lambda: gemini_client.create_embedding(text, model=self.model))
44
+
45
  # Kiểm tra kiểu dữ liệu trả về
46
  if isinstance(embedding, list):
47
  preview = f"{embedding[:10]}...{embedding[-10:]}" if len(embedding) > 20 else str(embedding)
48
+ logger.info(f"[EMBEDDING] API response: {preview}")
49
  return embedding
50
  else:
51
+ logger.error(f"[EMBEDDING] Unknown embedding type: {type(embedding)} - value: {embedding}")
52
  raise RuntimeError(f"Embedding returned unexpected type: {type(embedding)}")
53
  except Exception as e:
54
+ logger.error(f"[EMBEDDING] Error creating embedding with Gemini: {e}")
55
  raise
56
+
57
+ # Fallback to HuggingFace embedding
58
  url = "https://vietcat-vietnameseembeddingv2.hf.space/embed"
59
  payload = {"text": text}
60
  try:
61
  response = await call_endpoint_with_retry(self._client, url, payload)
62
  if response is not None:
63
  data = response.json()
64
+ logger.info(f"[EMBEDDING] HuggingFace API response: {data['embedding'][:10]}...{data['embedding'][-10:]}")
65
  return data["embedding"]
66
  else:
67
+ logger.error("[EMBEDDING] HuggingFace API response is None")
68
+ raise RuntimeError("HuggingFace API response is None")
69
  except Exception as e:
70
+ logger.error(f"[EMBEDDING] Error creating embedding with HuggingFace: {e}")
71
  raise
72
 
73
  def cosine_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
 
81
  b = np.array(embedding2)
82
  return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
83
  except Exception as e:
84
+ logger.error(f"[EMBEDDING] Error calculating similarity: {e}")
85
+ return 0.0
86
+
87
+ def get_embedding_model(self) -> str:
88
+ """
89
+ Trả về model được config cho embedding.
90
+ Dùng để verify rằng model đúng được sử dụng.
91
+ """
92
+ return self.model
app/gemini_client.py CHANGED
@@ -8,25 +8,56 @@ from typing import List, Optional
8
  class GeminiClient:
9
  def __init__(self):
10
  self.limit_manager = RequestLimitManager("gemini")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def generate_text(self, prompt: str, **kwargs) -> str:
13
  last_error = None
14
- for key, model in self.limit_manager.iterate_key_model():
 
 
15
  try:
16
- configure(api_key=key)
17
- _model = GenerativeModel(model)
 
 
 
 
18
  response = _model.generate_content(prompt, **kwargs)
19
  self.limit_manager.log_request(key, model, success=True)
 
20
  if hasattr(response, 'usage_metadata'):
21
  logger.info(f"[GEMINI][USAGE] Prompt Token Count: {response.usage_metadata.prompt_token_count} - Candidate Token Count: {response.usage_metadata.candidates_token_count} - Total Token Count: {response.usage_metadata.total_token_count}")
 
22
  if hasattr(response, 'text'):
23
  logger.info(f"[GEMINI][TEXT_RESPONSE] {response.text}")
24
  return response.text
25
  elif hasattr(response, 'candidates') and response.candidates:
26
  logger.info(f"[GEMINI][CANDIDATES_RESPONSE] {response.candidates[0].content.parts[0].text}")
27
  return response.candidates[0].content.parts[0].text
 
28
  logger.info(f"[GEMINI][RAW_RESPONSE] {response}")
29
  return str(response)
 
30
  except Exception as e:
31
  import re
32
  msg = str(e)
@@ -35,39 +66,62 @@ class GeminiClient:
35
  m = re.search(r'retry_delay.*?seconds: (\d+)', msg)
36
  if m:
37
  retry_delay = int(m.group(1))
 
 
38
  self.limit_manager.log_request(key, model, success=False, retry_delay=retry_delay)
39
- last_error = e
40
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  raise last_error or RuntimeError("No available Gemini API key/model")
42
 
43
  def count_tokens(self, prompt: str) -> int:
44
- for key, model in self.limit_manager.iterate_key_model():
45
- try:
46
- configure(api_key=key)
47
- _model = GenerativeModel(model)
48
- return _model.count_tokens(prompt).total_tokens
49
- except Exception:
50
- continue
51
- return 0
52
 
53
  def create_embedding(self, text: str, model: Optional[str] = None) -> list:
54
  last_error = None
55
- for key, m in self.limit_manager.iterate_key_model():
56
- m = m or ""
57
- use_model = model if model not in (None, "") else m
58
- if not use_model:
59
- continue
60
- use_model = str(use_model)
61
  try:
 
 
 
 
 
 
 
 
 
 
62
  configure(api_key=key)
63
  response = embed_content(
64
  model=use_model,
65
  content=text,
66
  task_type="retrieval_query"
67
  )
 
68
  self.limit_manager.log_request(key, use_model, success=True)
69
  logger.info(f"[GEMINI][EMBEDDING][RAW_RESPONSE] {response['embedding'][:10]} ..... {response['embedding'][-10:]}")
70
  return response['embedding']
 
71
  except Exception as e:
72
  import re
73
  msg = str(e)
@@ -76,7 +130,16 @@ class GeminiClient:
76
  m_retry = re.search(r'retry_delay.*?seconds: (\d+)', msg)
77
  if m_retry:
78
  retry_delay = int(m_retry.group(1))
 
 
79
  self.limit_manager.log_request(key, use_model, success=False, retry_delay=retry_delay)
80
- last_error = e
81
- continue
 
 
 
 
 
 
 
82
  raise last_error or RuntimeError("No available Gemini API key/model")
 
8
  class GeminiClient:
9
  def __init__(self):
10
  self.limit_manager = RequestLimitManager("gemini")
11
+ self._cached_model = None
12
+ self._cached_key = None
13
+ self._cached_model_instance = None
14
+
15
+ def _get_model_instance(self, key: str, model: str):
16
+ """
17
+ Cache model instance để tránh recreate mỗi lần.
18
+ """
19
+ if (self._cached_key == key and
20
+ self._cached_model == model and
21
+ self._cached_model_instance is not None):
22
+ return self._cached_model_instance
23
+
24
+ # Configure và tạo model instance mới
25
+ configure(api_key=key)
26
+ self._cached_model_instance = GenerativeModel(model)
27
+ self._cached_key = key
28
+ self._cached_model = model
29
+
30
+ logger.info(f"[GEMINI] Created new model instance for key={key[:5]}...{key[-5:]} model={model}")
31
+ return self._cached_model_instance
32
 
33
  def generate_text(self, prompt: str, **kwargs) -> str:
34
  last_error = None
35
+ max_retries = 3
36
+
37
+ for attempt in range(max_retries):
38
  try:
39
+ # Lấy current key/model từ manager
40
+ key, model = self.limit_manager.get_current_key_model()
41
+
42
+ # Sử dụng cached model instance
43
+ _model = self._get_model_instance(key, model)
44
+
45
  response = _model.generate_content(prompt, **kwargs)
46
  self.limit_manager.log_request(key, model, success=True)
47
+
48
  if hasattr(response, 'usage_metadata'):
49
  logger.info(f"[GEMINI][USAGE] Prompt Token Count: {response.usage_metadata.prompt_token_count} - Candidate Token Count: {response.usage_metadata.candidates_token_count} - Total Token Count: {response.usage_metadata.total_token_count}")
50
+
51
  if hasattr(response, 'text'):
52
  logger.info(f"[GEMINI][TEXT_RESPONSE] {response.text}")
53
  return response.text
54
  elif hasattr(response, 'candidates') and response.candidates:
55
  logger.info(f"[GEMINI][CANDIDATES_RESPONSE] {response.candidates[0].content.parts[0].text}")
56
  return response.candidates[0].content.parts[0].text
57
+
58
  logger.info(f"[GEMINI][RAW_RESPONSE] {response}")
59
  return str(response)
60
+
61
  except Exception as e:
62
  import re
63
  msg = str(e)
 
66
  m = re.search(r'retry_delay.*?seconds: (\d+)', msg)
67
  if m:
68
  retry_delay = int(m.group(1))
69
+
70
+ # Log failure và trigger scan cho key/model mới
71
  self.limit_manager.log_request(key, model, success=False, retry_delay=retry_delay)
72
+
73
+ # Clear cache để force tạo model instance mới với key/model mới
74
+ self._cached_model_instance = None
75
+ self._cached_key = None
76
+ self._cached_model = None
77
+
78
+ logger.warning(f"[GEMINI] Rate limit hit, will retry with new key/model (attempt {attempt + 1}/{max_retries})")
79
+ last_error = e
80
+ continue
81
+ else:
82
+ # Lỗi khác không phải rate limit
83
+ logger.error(f"[GEMINI] Error generating text: {e}")
84
+ last_error = e
85
+ break
86
+
87
  raise last_error or RuntimeError("No available Gemini API key/model")
88
 
89
  def count_tokens(self, prompt: str) -> int:
90
+ try:
91
+ key, model = self.limit_manager.get_current_key_model()
92
+ _model = self._get_model_instance(key, model)
93
+ return _model.count_tokens(prompt).total_tokens
94
+ except Exception as e:
95
+ logger.error(f"[GEMINI] Error counting tokens: {e}")
96
+ return 0
 
97
 
98
  def create_embedding(self, text: str, model: Optional[str] = None) -> list:
99
  last_error = None
100
+ max_retries = 3
101
+
102
+ for attempt in range(max_retries):
 
 
 
103
  try:
104
+ key, default_model = self.limit_manager.get_current_key_model()
105
+
106
+ # Ưu tiên model được truyền vào parameter, chỉ fallback về default_model nếu không có
107
+ use_model = model if model and model.strip() else default_model
108
+
109
+ if not use_model:
110
+ raise ValueError("No model specified for embedding")
111
+
112
+ logger.info(f"[GEMINI][EMBEDDING] Using model={use_model} (requested={model}, default={default_model})")
113
+
114
  configure(api_key=key)
115
  response = embed_content(
116
  model=use_model,
117
  content=text,
118
  task_type="retrieval_query"
119
  )
120
+
121
  self.limit_manager.log_request(key, use_model, success=True)
122
  logger.info(f"[GEMINI][EMBEDDING][RAW_RESPONSE] {response['embedding'][:10]} ..... {response['embedding'][-10:]}")
123
  return response['embedding']
124
+
125
  except Exception as e:
126
  import re
127
  msg = str(e)
 
130
  m_retry = re.search(r'retry_delay.*?seconds: (\d+)', msg)
131
  if m_retry:
132
  retry_delay = int(m_retry.group(1))
133
+
134
+ # Log failure và trigger scan cho key/model mới
135
  self.limit_manager.log_request(key, use_model, success=False, retry_delay=retry_delay)
136
+
137
+ logger.warning(f"[GEMINI] Rate limit hit in embedding, will retry with new key/model (attempt {attempt + 1}/{max_retries})")
138
+ last_error = e
139
+ continue
140
+ else:
141
+ logger.error(f"[GEMINI] Error creating embedding: {e}")
142
+ last_error = e
143
+ break
144
+
145
  raise last_error or RuntimeError("No available Gemini API key/model")
app/main.py CHANGED
@@ -19,6 +19,7 @@ from .constants import VEHICLE_KEYWORDS, SHEET_RANGE, VEHICLE_KEYWORD_TO_COLUMN
19
  from .health import router as health_router
20
  from .llm import create_llm_client
21
  from .reranker import Reranker
 
22
 
23
  app = FastAPI(title="WeBot Facebook Messenger API")
24
 
@@ -39,6 +40,10 @@ logger.info("[STARTUP] Đang lấy PORT từ biến môi trường hoặc config
39
  port = int(os.environ.get("PORT", settings.port if hasattr(settings, 'port') else 7860))
40
  logger.info(f"[STARTUP] PORT sử dụng: {port}")
41
 
 
 
 
 
42
  logger.info("[STARTUP] Khởi tạo FacebookClient...")
43
  facebook_client = FacebookClient(settings.facebook_app_secret)
44
  logger.info("[STARTUP] Khởi tạo SheetsClient...")
 
19
  from .health import router as health_router
20
  from .llm import create_llm_client
21
  from .reranker import Reranker
22
+ from .request_limit_manager import RequestLimitManager
23
 
24
  app = FastAPI(title="WeBot Facebook Messenger API")
25
 
 
40
  port = int(os.environ.get("PORT", settings.port if hasattr(settings, 'port') else 7860))
41
  logger.info(f"[STARTUP] PORT sử dụng: {port}")
42
 
43
+ logger.info("[STARTUP] Khởi tạo global RequestLimitManager...")
44
+ # Global RequestLimitManager instance - singleton
45
+ request_limit_manager = RequestLimitManager("gemini")
46
+
47
  logger.info("[STARTUP] Khởi tạo FacebookClient...")
48
  facebook_client = FacebookClient(settings.facebook_app_secret)
49
  logger.info("[STARTUP] Khởi tạo SheetsClient...")
app/request_limit_manager.py CHANGED
@@ -5,10 +5,23 @@ from app.config import get_settings
5
  from loguru import logger
6
 
7
  class RequestLimitManager:
 
 
 
 
 
 
 
 
 
 
8
  def __init__(self, provider: str):
 
 
9
  self.provider = provider
10
  self.lock = threading.Lock()
11
  self._init_keys_models()
 
12
 
13
  def _init_keys_models(self):
14
  settings = get_settings()
@@ -22,16 +35,82 @@ class RequestLimitManager:
22
  self.status[key] = {}
23
  for model in self.models:
24
  self.status[key][model] = {"status": "active", "timestamp": now}
25
- self.default_key: Optional[str] = self.api_keys[0] if self.api_keys else None
26
- self.default_model: Optional[str] = self.models[0] if self.models else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def log_request(self, key: str, model: str, success: bool, retry_delay: Optional[int] = None):
 
 
 
 
29
  with self.lock:
30
  now = time.time()
31
  if key not in self.status:
32
  self.status[key] = {}
33
  if model not in self.status[key]:
34
  self.status[key][model] = {"status": "active", "timestamp": now}
 
35
  if success:
36
  logger.info(f"[LIMIT] Mark key={key[:5]}...{key[-5:]} - model={model} as active at {now}")
37
  self.status[key][model]["status"] = "active"
@@ -40,34 +119,17 @@ class RequestLimitManager:
40
  logger.warning(f"[LIMIT] Mark key={key[:5]}...{key[-5:]} - model={model} as blocked until {now + (retry_delay or 60)} (retry_delay={retry_delay})")
41
  self.status[key][model]["status"] = "blocked"
42
  self.status[key][model]["timestamp"] = now + (retry_delay or 60)
 
 
 
 
 
 
43
 
44
  def iterate_key_model(self) -> Iterator[Tuple[str, str]]:
45
- now = time.time()
46
- keys = self.api_keys[:]
47
- models = self.models[:]
48
- # Ưu tiên default key/model nếu có
49
- if self.default_key and self.default_key in keys:
50
- keys.remove(self.default_key)
51
- keys = [self.default_key] + keys
52
- if self.default_model and self.default_model in models:
53
- models.remove(self.default_model)
54
- models = [self.default_model] + models
55
- logger.info(f"[LIMIT] Trying key/model candidates: {[(k[:6]+'...', m) for k in keys for m in models]}")
56
- found = False
57
- for key in keys:
58
- for model in models:
59
- info = self.status.get(key, {}).get(model, {"status": "active", "timestamp": 0.0})
60
- status = info.get("status", "active")
61
- ts = float(info.get("timestamp", 0.0))
62
- if status == "active":
63
- logger.info(f"[LIMIT] Use key={key[:5]}...{key[-5:]} - model={model} (active)")
64
- found = True
65
- yield key, model
66
- elif status == "blocked" and now > ts:
67
- logger.info(f"[LIMIT] Use key={key[:5]}...{key[-5:]} - model={model} (was blocked, now retry)")
68
- found = True
69
- yield key, model
70
- if not found:
71
- logger.warning(f"[LIMIT] No available key/model for provider {self.provider}")
72
- pass
73
- # Nếu không có key/model nào hợp lệ, không yield gì
 
5
  from loguru import logger
6
 
7
  class RequestLimitManager:
8
+ _instance = None
9
+ _lock = threading.Lock()
10
+
11
+ def __new__(cls, provider: str):
12
+ if cls._instance is None:
13
+ with cls._lock:
14
+ if cls._instance is None:
15
+ cls._instance = super().__new__(cls)
16
+ return cls._instance
17
+
18
  def __init__(self, provider: str):
19
+ if hasattr(self, 'initialized'):
20
+ return
21
  self.provider = provider
22
  self.lock = threading.Lock()
23
  self._init_keys_models()
24
+ self.initialized = True
25
 
26
  def _init_keys_models(self):
27
  settings = get_settings()
 
35
  self.status[key] = {}
36
  for model in self.models:
37
  self.status[key][model] = {"status": "active", "timestamp": now}
38
+ self.current_key: Optional[str] = self.api_keys[0] if self.api_keys else None
39
+ self.current_model: Optional[str] = self.models[0] if self.models else None
40
+ key_display = f"{self.current_key[:5]}...{self.current_key[-5:]}" if self.current_key else "None"
41
+ logger.info(f"[LIMIT] Initialized with current key={key_display} model={self.current_model}")
42
+
43
+ def get_current_key_model(self) -> Tuple[str, str]:
44
+ """
45
+ Trả về cặp key/model hiện tại đang active.
46
+ Chỉ scan tìm key/model mới khi current pair bị blocked.
47
+ """
48
+ with self.lock:
49
+ now = time.time()
50
+
51
+ # Check if current pair is still available
52
+ if self.current_key and self.current_model:
53
+ info = self.status.get(self.current_key, {}).get(self.current_model, {})
54
+ status = info.get("status", "active")
55
+ ts = float(info.get("timestamp", 0.0))
56
+
57
+ if status == "active" or (status == "blocked" and now > ts):
58
+ logger.info(f"[LIMIT] Using current key={self.current_key[:5]}...{self.current_key[-5:]} model={self.current_model}")
59
+ return self.current_key, self.current_model
60
+
61
+ # Current pair not available, scan for new one
62
+ logger.warning(f"[LIMIT] Current pair not available, scanning for new key/model...")
63
+ new_key, new_model = self._find_available_key_model()
64
+
65
+ if new_key and new_model:
66
+ self.current_key = new_key
67
+ self.current_model = new_model
68
+ logger.info(f"[LIMIT] Switched to new key={self.current_key[:5]}...{self.current_key[-5:]} model={self.current_model}")
69
+ return self.current_key, self.current_model
70
+ else:
71
+ logger.error(f"[LIMIT] No available key/model found for provider {self.provider}")
72
+ raise RuntimeError(f"No available key/model for provider {self.provider}")
73
+
74
+ def _find_available_key_model(self) -> Tuple[Optional[str], Optional[str]]:
75
+ """
76
+ Tìm cặp key/model khả dụng gần nhất.
77
+ """
78
+ now = time.time()
79
+ keys = self.api_keys[:]
80
+ models = self.models[:]
81
+
82
+ # Ưu tiên default key/model nếu có
83
+ if self.current_key and self.current_key in keys:
84
+ keys.remove(self.current_key)
85
+ keys = [self.current_key] + keys
86
+ if self.current_model and self.current_model in models:
87
+ models.remove(self.current_model)
88
+ models = [self.current_model] + models
89
+
90
+ for key in keys:
91
+ for model in models:
92
+ info = self.status.get(key, {}).get(model, {"status": "active", "timestamp": 0.0})
93
+ status = info.get("status", "active")
94
+ ts = float(info.get("timestamp", 0.0))
95
+
96
+ if status == "active" or (status == "blocked" and now > ts):
97
+ logger.info(f"[LIMIT] Found available key={key[:5]}...{key[-5:]} model={model}")
98
+ return key, model
99
+
100
+ return None, None
101
 
102
  def log_request(self, key: str, model: str, success: bool, retry_delay: Optional[int] = None):
103
+ """
104
+ Log kết quả request và cập nhật status.
105
+ Nếu request fail với 429, trigger scan cho key/model mới.
106
+ """
107
  with self.lock:
108
  now = time.time()
109
  if key not in self.status:
110
  self.status[key] = {}
111
  if model not in self.status[key]:
112
  self.status[key][model] = {"status": "active", "timestamp": now}
113
+
114
  if success:
115
  logger.info(f"[LIMIT] Mark key={key[:5]}...{key[-5:]} - model={model} as active at {now}")
116
  self.status[key][model]["status"] = "active"
 
119
  logger.warning(f"[LIMIT] Mark key={key[:5]}...{key[-5:]} - model={model} as blocked until {now + (retry_delay or 60)} (retry_delay={retry_delay})")
120
  self.status[key][model]["status"] = "blocked"
121
  self.status[key][model]["timestamp"] = now + (retry_delay or 60)
122
+
123
+ # Nếu current pair bị blocked, trigger scan cho pair mới
124
+ if key == self.current_key and model == self.current_model:
125
+ logger.warning(f"[LIMIT] Current pair blocked, will scan for new pair on next request")
126
+ self.current_key = None
127
+ self.current_model = None
128
 
129
  def iterate_key_model(self) -> Iterator[Tuple[str, str]]:
130
+ """
131
+ Legacy method - chỉ trả về current pair.
132
+ Để tương thích với code cũ.
133
+ """
134
+ key, model = self.get_current_key_model()
135
+ yield key, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/reranker.py CHANGED
@@ -18,34 +18,76 @@ class Reranker:
18
  else:
19
  raise NotImplementedError(f"Rerank provider {self.provider} not supported yet.")
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  async def rerank(self, query: str, docs: List[Dict], top_k: int = 5) -> List[Dict]:
22
  """
23
  Rerank docs theo độ liên quan với query, trả về top_k docs.
 
24
  """
25
  logger.info(f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | top_k: {top_k}")
 
 
 
 
 
 
 
 
 
 
26
  scored = []
27
- for doc in docs:
28
- content = (doc.get('tieude', '') or '') + ' ' + (doc.get('noidung', '') or '')
29
- prompt = (
30
- f"Đoạn luật: {content}\n"
31
- f"Câu hỏi: {query}\n"
32
- "Hãy đánh giá mức độ liên quan giữa đoạn luật và câu hỏi trên thang điểm 0-10. "
33
- "Chỉ trả về một số duy nhất."
34
- )
35
- try:
36
- if self.provider == 'gemini':
37
- loop = asyncio.get_event_loop()
38
- logger.info(f"[RERANK] Sending prompt to Gemini: {prompt}")
39
- score = await loop.run_in_executor(None, self.client.generate_text, prompt)
40
- logger.info(f"[RERANK] Got score from Gemini: {score}")
41
- else:
42
- raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in rerank method.")
43
- score = float(str(score).strip().split()[0])
44
- except Exception as e:
45
- logger.error(f"[RERANK] Lỗi khi tính score: {e} | doc: {doc}")
46
- score = 0
47
- doc['rerank_score'] = score
48
- scored.append(doc)
49
  scored = sorted(scored, key=lambda x: x['rerank_score'], reverse=True)
50
- logger.info(f"[RERANK] Top reranked docs: {scored[:top_k]}")
51
- return scored[:top_k]
 
 
 
18
  else:
19
  raise NotImplementedError(f"Rerank provider {self.provider} not supported yet.")
20
 
21
+ async def _score_doc(self, query: str, doc: Dict) -> Dict:
22
+ """
23
+ Score một document với query.
24
+ """
25
+ content = (doc.get('tieude', '') or '') + ' ' + (doc.get('noidung', '') or '')
26
+ prompt = (
27
+ f"Đoạn luật: {content}\n"
28
+ f"Câu hỏi: {query}\n"
29
+ "Hãy đánh giá mức độ liên quan giữa đoạn luật và câu hỏi trên thang điểm 0-10. "
30
+ "Chỉ trả về một số duy nhất."
31
+ )
32
+
33
+ try:
34
+ if self.provider == 'gemini':
35
+ loop = asyncio.get_event_loop()
36
+ logger.info(f"[RERANK] Sending prompt to Gemini: {prompt}")
37
+ score = await loop.run_in_executor(None, self.client.generate_text, prompt)
38
+ logger.info(f"[RERANK] Got score from Gemini: {score}")
39
+ else:
40
+ raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in rerank method.")
41
+
42
+ score = float(str(score).strip().split()[0])
43
+ doc['rerank_score'] = score
44
+ return doc
45
+
46
+ except Exception as e:
47
+ logger.error(f"[RERANK] Lỗi khi tính score: {e} | doc: {doc}")
48
+ doc['rerank_score'] = 0
49
+ return doc
50
+
51
  async def rerank(self, query: str, docs: List[Dict], top_k: int = 5) -> List[Dict]:
52
  """
53
  Rerank docs theo độ liên quan với query, trả về top_k docs.
54
+ Sử dụng concurrency để process nhiều docs cùng lúc.
55
  """
56
  logger.info(f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | top_k: {top_k}")
57
+
58
+ if not docs:
59
+ return []
60
+
61
+ # Giới hạn số docs để rerank (tối đa 10 docs)
62
+ docs_to_rerank = docs[:10] if len(docs) > 10 else docs
63
+ logger.info(f"[RERANK] Will rerank {len(docs_to_rerank)} docs (limited from {len(docs)})")
64
+
65
+ # Process docs với concurrency
66
+ batch_size = 5 # Process 5 docs cùng lúc
67
  scored = []
68
+
69
+ for i in range(0, len(docs_to_rerank), batch_size):
70
+ batch = docs_to_rerank[i:i + batch_size]
71
+ logger.info(f"[RERANK] Processing batch {i//batch_size + 1}: {len(batch)} docs")
72
+
73
+ # Tạo tasks cho batch hiện tại
74
+ tasks = [self._score_doc(query, doc) for doc in batch]
75
+
76
+ # Chạy batch concurrently
77
+ batch_results = await asyncio.gather(*tasks, return_exceptions=True)
78
+
79
+ # Xử kết quả
80
+ for result in batch_results:
81
+ if isinstance(result, Exception):
82
+ logger.error(f"[RERANK] Batch processing error: {result}")
83
+ continue
84
+ scored.append(result)
85
+
86
+ logger.info(f"[RERANK] Completed batch {i//batch_size + 1}, processed {len(scored)} docs so far")
87
+
88
+ # Sort theo score và trả về top_k
 
89
  scored = sorted(scored, key=lambda x: x['rerank_score'], reverse=True)
90
+ result = scored[:top_k]
91
+
92
+ logger.info(f"[RERANK] Top reranked docs: {result}")
93
+ return result