VietCat commited on
Commit
8812f42
·
1 Parent(s): f0e68b1

add quota manager

Browse files
Files changed (2) hide show
  1. app/embedding.py +7 -7
  2. app/gemini_client.py +14 -8
app/embedding.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List
2
  import numpy as np
3
  from loguru import logger
4
  import httpx
@@ -18,10 +18,7 @@ class EmbeddingClient:
18
  settings = get_settings()
19
  self.provider = getattr(settings, 'embedding_provider', 'default')
20
  self.model = getattr(settings, 'embedding_model', 'models/embedding-001')
21
- if self.provider == 'gemini':
22
- self.gemini_client = GeminiClient(settings.gemini_api_key, model=self.model)
23
- else:
24
- self.gemini_client = None
25
 
26
  @timing_decorator_async
27
  async def create_embedding(self, text: str) -> List[float]:
@@ -30,12 +27,15 @@ class EmbeddingClient:
30
  Input: text (str)
31
  Output: list[float] embedding vector.
32
  """
33
- if self.provider == 'gemini' and self.gemini_client:
 
 
34
  try:
35
  # GeminiClient.create_embedding là hàm sync, chạy trong executor
36
  import asyncio
37
  loop = asyncio.get_event_loop()
38
- embedding = await loop.run_in_executor(None, self.gemini_client.create_embedding, text)
 
39
  # Kiểm tra kiểu dữ liệu trả về
40
  if isinstance(embedding, list):
41
  preview = f"{embedding[:10]}...{embedding[-10:]}" if len(embedding) > 20 else str(embedding)
 
1
+ from typing import List, Optional
2
  import numpy as np
3
  from loguru import logger
4
  import httpx
 
18
  settings = get_settings()
19
  self.provider = getattr(settings, 'embedding_provider', 'default')
20
  self.model = getattr(settings, 'embedding_model', 'models/embedding-001')
21
+ self.gemini_client: Optional[GeminiClient] = GeminiClient() if self.provider == 'gemini' else None
 
 
 
22
 
23
  @timing_decorator_async
24
  async def create_embedding(self, text: str) -> List[float]:
 
27
  Input: text (str)
28
  Output: list[float] embedding vector.
29
  """
30
+ if self.provider == 'gemini':
31
+ if not self.gemini_client:
32
+ raise RuntimeError("GeminiClient is not initialized")
33
  try:
34
  # GeminiClient.create_embedding là hàm sync, chạy trong executor
35
  import asyncio
36
  loop = asyncio.get_event_loop()
37
+ gemini_client = self.gemini_client # type: ignore
38
+ embedding = await loop.run_in_executor(None, lambda: gemini_client.create_embedding(text, model=self.model))
39
  # Kiểm tra kiểu dữ liệu trả về
40
  if isinstance(embedding, list):
41
  preview = f"{embedding[:10]}...{embedding[-10:]}" if len(embedding) > 20 else str(embedding)
app/gemini_client.py CHANGED
@@ -3,6 +3,7 @@ from google.generativeai.client import configure
3
  from google.generativeai.generative_models import GenerativeModel
4
  from loguru import logger
5
  from .request_limit_manager import RequestLimitManager
 
6
 
7
  class GeminiClient:
8
  def __init__(self):
@@ -49,17 +50,22 @@ class GeminiClient:
49
  continue
50
  return 0
51
 
52
- def create_embedding(self, text: str) -> list:
53
  last_error = None
54
- for key, model in self.limit_manager.iterate_key_model():
 
 
 
 
 
55
  try:
56
  configure(api_key=key)
57
  response = embed_content(
58
- model=model,
59
  content=text,
60
  task_type="retrieval_query"
61
  )
62
- self.limit_manager.log_request(key, model, success=True)
63
  logger.info(f"[GEMINI][EMBEDDING][RAW_RESPONSE] {response['embedding'][:10]} ..... {response['embedding'][-10:]}")
64
  return response['embedding']
65
  except Exception as e:
@@ -67,10 +73,10 @@ class GeminiClient:
67
  msg = str(e)
68
  if "429" in msg or "rate limit" in msg.lower():
69
  retry_delay = 60
70
- m = re.search(r'retry_delay.*?seconds: (\d+)', msg)
71
- if m:
72
- retry_delay = int(m.group(1))
73
- self.limit_manager.log_request(key, model, success=False, retry_delay=retry_delay)
74
  last_error = e
75
  continue
76
  raise last_error or RuntimeError("No available Gemini API key/model")
 
3
  from google.generativeai.generative_models import GenerativeModel
4
  from loguru import logger
5
  from .request_limit_manager import RequestLimitManager
6
+ from typing import List, Optional
7
 
8
  class GeminiClient:
9
  def __init__(self):
 
50
  continue
51
  return 0
52
 
53
+ def create_embedding(self, text: str, model: Optional[str] = None) -> list:
54
  last_error = None
55
+ for key, m in self.limit_manager.iterate_key_model():
56
+ m = m or ""
57
+ use_model = model if model not in (None, "") else m
58
+ if not use_model:
59
+ continue
60
+ use_model = str(use_model)
61
  try:
62
  configure(api_key=key)
63
  response = embed_content(
64
+ model=use_model,
65
  content=text,
66
  task_type="retrieval_query"
67
  )
68
+ self.limit_manager.log_request(key, use_model, success=True)
69
  logger.info(f"[GEMINI][EMBEDDING][RAW_RESPONSE] {response['embedding'][:10]} ..... {response['embedding'][-10:]}")
70
  return response['embedding']
71
  except Exception as e:
 
73
  msg = str(e)
74
  if "429" in msg or "rate limit" in msg.lower():
75
  retry_delay = 60
76
+ m_retry = re.search(r'retry_delay.*?seconds: (\d+)', msg)
77
+ if m_retry:
78
+ retry_delay = int(m_retry.group(1))
79
+ self.limit_manager.log_request(key, use_model, success=False, retry_delay=retry_delay)
80
  last_error = e
81
  continue
82
  raise last_error or RuntimeError("No available Gemini API key/model")