LiamKhoaLe commited on
Commit
d3ae861
·
1 Parent(s): cd574dc

Upd resource saver LB

Browse files
Files changed (1) hide show
  1. utils/cloud_llm.py +255 -33
utils/cloud_llm.py CHANGED
@@ -2,6 +2,7 @@
2
  import os
3
  import logging
4
  import requests
 
5
  from typing import Optional
6
 
7
  # Dynamic import for Google GenAI (only when not in local mode)
@@ -47,23 +48,88 @@ class KeyRotator:
47
  if not keys:
48
  logger.warning(f"[LLM] No keys found for prefix {env_prefix}_*")
49
  self.keys = keys
50
- self.dead = set()
 
 
51
  self.idx = 0
52
-
 
 
53
  def next_key(self) -> Optional[str]:
54
  if not self.keys:
55
  return None
 
 
 
 
 
 
 
 
 
56
  for _ in range(len(self.keys)):
57
  k = self.keys[self.idx % len(self.keys)]
58
  self.idx += 1
59
- if k not in self.dead:
60
- return k
 
 
 
 
 
 
 
 
 
 
 
61
  return None
62
 
63
- def mark_bad(self, key: Optional[str]):
64
- if key:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  self.dead.add(key)
66
- logger.warning(f"[LLM] Quarantined key (prefix hidden): {key[:6]}***")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  class GeminiClient:
69
  def __init__(self, rotator: KeyRotator, default_model: str):
@@ -91,8 +157,9 @@ class GeminiClient:
91
  logger.info(f"[LLM][Gemini] out={snip(text)}")
92
  return text
93
  except Exception as e:
94
- logger.error(f"[LLM][Gemini] {e}")
95
- self.rotator.mark_bad(key)
 
96
  return None
97
 
98
  class NvidiaClient:
@@ -138,18 +205,108 @@ class NvidiaClient:
138
  logger.info(f"[LLM][NVIDIA] out={snip(clean)}")
139
  return clean
140
  except Exception as e:
141
- logger.error(f"[LLM][NVIDIA] {e}")
142
- self.rotator.mark_bad(key)
 
143
  return None
144
 
145
  class Paraphraser:
146
- """Prefers NVIDIA (cheap), falls back to Gemini EASY only. Also offers translate/backtranslate and a tiny consistency judge."""
147
  def __init__(self, nvidia_model: str, gemini_model_easy: str, gemini_model_hard: str):
148
  self.nv = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
149
  self.gm_easy = GeminiClient(KeyRotator("GEMINI_API"), gemini_model_easy)
150
  # Only use GEMINI_MODEL_EASY, ignore hard model completely
151
  self.gm_hard = None # Disabled - only use easy model
152
- logger.info("Paraphraser initialized: NVIDIA -> GEMINI_EASY (GEMINI_HARD disabled)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  # Enhanced cleaning to remove conversational elements and comments
155
  def _clean_resp(self, resp: str) -> str:
@@ -253,16 +410,24 @@ class Paraphraser:
253
  temperature = 0.1 if difficulty == "easy" else 0.3
254
  max_tokens = min(600, max(128, len(text)//2))
255
 
256
- # Always try NVIDIA first (optimized for medical tasks)
257
- out = self.nv.generate(prompt, temperature=temperature, max_tokens=max_tokens)
258
- if out:
259
- return self._clean_resp(out)
 
 
 
 
 
260
 
261
- # Fallback to GEMINI with optimized parameters
262
- out = self.gm_easy.generate(prompt, max_output_tokens=max_tokens)
263
- if out:
264
- logger.info(f"[LLM][GEMINI] out={snip(self._clean_resp(out))}")
265
- return self._clean_resp(out)
 
 
 
266
  return text
267
 
268
  # ————— Translate & Backtranslate —————
@@ -281,9 +446,22 @@ class Paraphraser:
281
  f"{text}"
282
  )
283
 
284
- out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(800, len(text)+100))
285
- if out: return out.strip()
286
- return self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100))
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
  def backtranslate(self, text: str, via_lang: str = "vi") -> Optional[str]:
289
  if not text: return text
@@ -302,10 +480,22 @@ class Paraphraser:
302
  f"{mid}"
303
  )
304
 
305
- out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(900, len(text)+150))
306
- if out: return out.strip()
307
- res = self.gm_easy.generate(prompt, max_output_tokens=min(900, len(text)+150))
308
- return res.strip() if res else None
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
  # ————— Consistency Judge (cheap, ratio-based) —————
311
  def consistency_check(self, user: str, output: str) -> bool:
@@ -315,10 +505,42 @@ class Paraphraser:
315
  f"Question/Context: {user}\n\n"
316
  f"Medical Answer: {output}"
317
  )
318
- out = self.nv.generate(prompt, temperature=0.0, max_tokens=5)
319
- if not out:
320
- out = self.gm_easy.generate(prompt, max_output_tokens=5)
321
- return isinstance(out, str) and "PASS" in out.upper()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
  def medical_accuracy_check(self, question: str, answer: str) -> bool:
324
  """Check medical accuracy of Q&A pairs using cloud APIs"""
 
2
  import os
3
  import logging
4
  import requests
5
+ import time
6
  from typing import Optional
7
 
8
  # Dynamic import for Google GenAI (only when not in local mode)
 
48
  if not keys:
49
  logger.warning(f"[LLM] No keys found for prefix {env_prefix}_*")
50
  self.keys = keys
51
+ self.dead = set() # Permanently dead keys
52
+ self.temp_dead = {} # Temporarily dead keys with retry time
53
+ self.retry_counts = {} # Track retry attempts per key
54
  self.idx = 0
55
+ self.max_retries = 3 # Max retries before permanent death
56
+ self.retry_delay = 60 # Seconds to wait before retry
57
+
58
  def next_key(self) -> Optional[str]:
59
  if not self.keys:
60
  return None
61
+
62
+ # Clean up expired temporary dead keys
63
+ current_time = time.time()
64
+ expired_keys = [k for k, retry_time in self.temp_dead.items() if current_time > retry_time]
65
+ for k in expired_keys:
66
+ del self.temp_dead[k]
67
+ logger.info(f"[LLM] Key {k[:6]}*** is back in rotation after cooldown")
68
+
69
+ # Try to find an available key
70
  for _ in range(len(self.keys)):
71
  k = self.keys[self.idx % len(self.keys)]
72
  self.idx += 1
73
+
74
+ # Skip permanently dead keys
75
+ if k in self.dead:
76
+ continue
77
+
78
+ # Skip temporarily dead keys
79
+ if k in self.temp_dead and current_time < self.temp_dead[k]:
80
+ continue
81
+
82
+ return k
83
+
84
+ # All keys are dead or temporarily unavailable
85
+ logger.warning(f"[LLM] All keys for {env_prefix} are unavailable")
86
  return None
87
 
88
+ def mark_bad(self, key: Optional[str], error_type: str = "unknown"):
89
+ if not key:
90
+ return
91
+
92
+ current_time = time.time()
93
+ retry_count = self.retry_counts.get(key, 0)
94
+
95
+ # Determine if this is a temporary or permanent failure
96
+ is_temporary = self._is_temporary_error(error_type)
97
+
98
+ if is_temporary and retry_count < self.max_retries:
99
+ # Temporary failure - add to temp_dead with retry time
100
+ retry_delay = self.retry_delay * (2 ** retry_count) # Exponential backoff
101
+ self.temp_dead[key] = current_time + retry_delay
102
+ self.retry_counts[key] = retry_count + 1
103
+ logger.warning(f"[LLM] Key {key[:6]}*** temporarily quarantined for {retry_delay}s (attempt {retry_count + 1}/{self.max_retries})")
104
+ else:
105
+ # Permanent failure or max retries reached
106
  self.dead.add(key)
107
+ if key in self.temp_dead:
108
+ del self.temp_dead[key]
109
+ if key in self.retry_counts:
110
+ del self.retry_counts[key]
111
+ logger.error(f"[LLM] Key {key[:6]}*** permanently quarantined after {retry_count} retries")
112
+
113
+ def _is_temporary_error(self, error_type: str) -> bool:
114
+ """Determine if an error is temporary and worth retrying"""
115
+ temporary_errors = [
116
+ "rate_limit", "quota_exceeded", "too_many_requests", "429",
117
+ "service_unavailable", "503", "bad_gateway", "502",
118
+ "timeout", "connection_error", "network_error"
119
+ ]
120
+
121
+ error_lower = error_type.lower()
122
+ return any(temp_err in error_lower for temp_err in temporary_errors)
123
+
124
+ def get_stats(self) -> dict:
125
+ """Get rotator statistics"""
126
+ return {
127
+ "total_keys": len(self.keys),
128
+ "dead_keys": len(self.dead),
129
+ "temp_dead_keys": len(self.temp_dead),
130
+ "available_keys": len(self.keys) - len(self.dead) - len(self.temp_dead),
131
+ "retry_counts": self.retry_counts.copy()
132
+ }
133
 
134
  class GeminiClient:
135
  def __init__(self, rotator: KeyRotator, default_model: str):
 
157
  logger.info(f"[LLM][Gemini] out={snip(text)}")
158
  return text
159
  except Exception as e:
160
+ error_msg = str(e)
161
+ logger.error(f"[LLM][Gemini] {error_msg}")
162
+ self.rotator.mark_bad(key, error_msg)
163
  return None
164
 
165
  class NvidiaClient:
 
205
  logger.info(f"[LLM][NVIDIA] out={snip(clean)}")
206
  return clean
207
  except Exception as e:
208
+ error_msg = str(e)
209
+ logger.error(f"[LLM][NVIDIA] {error_msg}")
210
+ self.rotator.mark_bad(key, error_msg)
211
  return None
212
 
213
  class Paraphraser:
214
+ """Intelligent API load balancer with rate limiting and cost optimization."""
215
  def __init__(self, nvidia_model: str, gemini_model_easy: str, gemini_model_hard: str):
216
  self.nv = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
217
  self.gm_easy = GeminiClient(KeyRotator("GEMINI_API"), gemini_model_easy)
218
  # Only use GEMINI_MODEL_EASY, ignore hard model completely
219
  self.gm_hard = None # Disabled - only use easy model
220
+
221
+ # Rate limiting and load balancing
222
+ self.last_nvidia_call = 0
223
+ self.last_gemini_call = 0
224
+ self.min_call_interval = 0.1 # Minimum 100ms between calls
225
+ self.nvidia_success_rate = 1.0 # Track success rates for load balancing
226
+ self.gemini_success_rate = 1.0
227
+ self.call_counts = {"nvidia": 0, "gemini": 0, "failures": 0}
228
+
229
+ logger.info("Paraphraser initialized with intelligent load balancing: NVIDIA -> GEMINI_EASY")
230
+
231
+ def _rate_limit(self, api_type: str):
232
+ """Apply rate limiting to prevent API exhaustion"""
233
+ current_time = time.time()
234
+ if api_type == "nvidia":
235
+ time_since_last = current_time - self.last_nvidia_call
236
+ if time_since_last < self.min_call_interval:
237
+ sleep_time = self.min_call_interval - time_since_last
238
+ time.sleep(sleep_time)
239
+ self.last_nvidia_call = time.time()
240
+ elif api_type == "gemini":
241
+ time_since_last = current_time - self.last_gemini_call
242
+ if time_since_last < self.min_call_interval:
243
+ sleep_time = self.min_call_interval - time_since_last
244
+ time.sleep(sleep_time)
245
+ self.last_gemini_call = time.time()
246
+
247
+ def _select_api(self, prefer_cheap: bool = True) -> str:
248
+ """Intelligently select API based on success rates and availability"""
249
+ nvidia_stats = self.nv.rotator.get_stats()
250
+ gemini_stats = self.gm_easy.rotator.get_stats()
251
+
252
+ nvidia_available = nvidia_stats["available_keys"] > 0
253
+ gemini_available = gemini_stats["available_keys"] > 0
254
+
255
+ if not nvidia_available and not gemini_available:
256
+ return "none"
257
+ elif not nvidia_available:
258
+ return "gemini"
259
+ elif not gemini_available:
260
+ return "nvidia"
261
+
262
+ # Both available - use intelligent selection
263
+ if prefer_cheap:
264
+ # Prefer NVIDIA (cheaper) but consider success rates
265
+ if self.nvidia_success_rate > 0.8 or self.gemini_success_rate < 0.5:
266
+ return "nvidia"
267
+ else:
268
+ return "gemini"
269
+ else:
270
+ # Prefer quality (Gemini) but consider success rates
271
+ if self.gemini_success_rate > 0.8 or self.nvidia_success_rate < 0.5:
272
+ return "gemini"
273
+ else:
274
+ return "nvidia"
275
+
276
+ def _update_success_rate(self, api_type: str, success: bool):
277
+ """Update success rate tracking for load balancing"""
278
+ if api_type == "nvidia":
279
+ # Exponential moving average
280
+ alpha = 0.1
281
+ self.nvidia_success_rate = alpha * (1.0 if success else 0.0) + (1 - alpha) * self.nvidia_success_rate
282
+ elif api_type == "gemini":
283
+ alpha = 0.1
284
+ self.gemini_success_rate = alpha * (1.0 if success else 0.0) + (1 - alpha) * self.gemini_success_rate
285
+
286
+ def _call_api(self, prompt: str, api_type: str, **kwargs) -> Optional[str]:
287
+ """Make API call with rate limiting and error tracking"""
288
+ self._rate_limit(api_type)
289
+
290
+ try:
291
+ if api_type == "nvidia":
292
+ result = self.nv.generate(prompt, **kwargs)
293
+ self.call_counts["nvidia"] += 1
294
+ success = result is not None
295
+ self._update_success_rate("nvidia", success)
296
+ return result
297
+ elif api_type == "gemini":
298
+ result = self.gm_easy.generate(prompt, **kwargs)
299
+ self.call_counts["gemini"] += 1
300
+ success = result is not None
301
+ self._update_success_rate("gemini", success)
302
+ return result
303
+ except Exception as e:
304
+ self.call_counts["failures"] += 1
305
+ self._update_success_rate(api_type, False)
306
+ logger.error(f"[LLM] API call failed for {api_type}: {e}")
307
+ return None
308
+
309
+ return None
310
 
311
  # Enhanced cleaning to remove conversational elements and comments
312
  def _clean_resp(self, resp: str) -> str:
 
410
  temperature = 0.1 if difficulty == "easy" else 0.3
411
  max_tokens = min(600, max(128, len(text)//2))
412
 
413
+ # Intelligent API selection with fallback
414
+ api_type = self._select_api(prefer_cheap=True)
415
+
416
+ if api_type == "nvidia":
417
+ out = self._call_api(prompt, "nvidia", temperature=temperature, max_tokens=max_tokens)
418
+ if out:
419
+ return self._clean_resp(out)
420
+ # Fallback to Gemini if NVIDIA fails
421
+ api_type = self._select_api(prefer_cheap=False)
422
 
423
+ if api_type == "gemini":
424
+ out = self._call_api(prompt, "gemini", max_output_tokens=max_tokens)
425
+ if out:
426
+ logger.info(f"[LLM][GEMINI] out={snip(self._clean_resp(out))}")
427
+ return self._clean_resp(out)
428
+
429
+ # Both APIs failed
430
+ logger.warning(f"[LLM] All APIs failed for paraphrase, returning original text")
431
  return text
432
 
433
  # ————— Translate & Backtranslate —————
 
446
  f"{text}"
447
  )
448
 
449
+ # Intelligent API selection for translation
450
+ api_type = self._select_api(prefer_cheap=True)
451
+
452
+ if api_type == "nvidia":
453
+ out = self._call_api(prompt, "nvidia", temperature=0.0, max_tokens=min(800, len(text)+100))
454
+ if out:
455
+ return out.strip()
456
+ # Fallback to Gemini if NVIDIA fails
457
+ api_type = self._select_api(prefer_cheap=False)
458
+
459
+ if api_type == "gemini":
460
+ out = self._call_api(prompt, "gemini", max_output_tokens=min(800, len(text)+100))
461
+ if out:
462
+ return out.strip()
463
+
464
+ return None
465
 
466
  def backtranslate(self, text: str, via_lang: str = "vi") -> Optional[str]:
467
  if not text: return text
 
480
  f"{mid}"
481
  )
482
 
483
+ # Intelligent API selection for backtranslation
484
+ api_type = self._select_api(prefer_cheap=True)
485
+
486
+ if api_type == "nvidia":
487
+ out = self._call_api(prompt, "nvidia", temperature=0.0, max_tokens=min(900, len(text)+150))
488
+ if out:
489
+ return out.strip()
490
+ # Fallback to Gemini if NVIDIA fails
491
+ api_type = self._select_api(prefer_cheap=False)
492
+
493
+ if api_type == "gemini":
494
+ out = self._call_api(prompt, "gemini", max_output_tokens=min(900, len(text)+150))
495
+ if out:
496
+ return out.strip()
497
+
498
+ return None
499
 
500
  # ————— Consistency Judge (cheap, ratio-based) —————
501
  def consistency_check(self, user: str, output: str) -> bool:
 
505
  f"Question/Context: {user}\n\n"
506
  f"Medical Answer: {output}"
507
  )
508
+
509
+ # Use intelligent API selection for consistency check
510
+ api_type = self._select_api(prefer_cheap=True)
511
+
512
+ if api_type == "nvidia":
513
+ out = self._call_api(prompt, "nvidia", temperature=0.0, max_tokens=5)
514
+ if out:
515
+ return isinstance(out, str) and "PASS" in out.upper()
516
+ # Fallback to Gemini if NVIDIA fails
517
+ api_type = self._select_api(prefer_cheap=False)
518
+
519
+ if api_type == "gemini":
520
+ out = self._call_api(prompt, "gemini", max_output_tokens=5)
521
+ if out:
522
+ return isinstance(out, str) and "PASS" in out.upper()
523
+
524
+ # If both APIs fail, assume consistency (conservative approach)
525
+ logger.warning("[LLM] Consistency check failed due to API unavailability, assuming consistent")
526
+ return True
527
+
528
+ def get_api_stats(self) -> dict:
529
+ """Get comprehensive API usage statistics"""
530
+ nvidia_stats = self.nv.rotator.get_stats()
531
+ gemini_stats = self.gm_easy.rotator.get_stats()
532
+
533
+ return {
534
+ "call_counts": self.call_counts.copy(),
535
+ "success_rates": {
536
+ "nvidia": self.nvidia_success_rate,
537
+ "gemini": self.gemini_success_rate
538
+ },
539
+ "nvidia_rotator": nvidia_stats,
540
+ "gemini_rotator": gemini_stats,
541
+ "total_calls": sum(self.call_counts.values()),
542
+ "failure_rate": self.call_counts["failures"] / max(1, sum(self.call_counts.values()))
543
+ }
544
 
545
  def medical_accuracy_check(self, question: str, answer: str) -> bool:
546
  """Check medical accuracy of Q&A pairs using cloud APIs"""