VietCat commited on
Commit
06dc89b
·
1 Parent(s): 7958587

add gemini

Browse files
Files changed (3) hide show
  1. app/config.py +5 -0
  2. app/llm.py +43 -12
  3. app/main.py +9 -2
app/config.py CHANGED
@@ -30,6 +30,11 @@ class Settings(BaseSettings):
30
  # Logging Configuration
31
  log_level: str = os.getenv("LOG_LEVEL", "INFO") or "INFO"
32
 
 
 
 
 
 
33
  class Config:
34
  env_file = ".env"
35
 
 
30
  # Logging Configuration
31
  log_level: str = os.getenv("LOG_LEVEL", "INFO") or "INFO"
32
 
33
+ # Gemini Configuration
34
+ gemini_api_key: str = os.getenv("GEMINI_API_KEY") or ""
35
+ gemini_base_url: str = os.getenv("GEMINI_BASE_URL", "https://generativelanguage.googleapis.com/v1/models/gemini-2.5-flash:generateContent") or ""
36
+ gemini_model: str = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") or ""
37
+
38
  class Config:
39
  env_file = ".env"
40
 
app/llm.py CHANGED
@@ -35,6 +35,8 @@ class LLMClient:
35
  self._setup_custom(kwargs)
36
  elif self.provider == "hfs":
37
  self._setup_HFS(kwargs)
 
 
38
  else:
39
  raise ValueError(f"Unsupported provider: {provider}")
40
 
@@ -82,6 +84,14 @@ class LLMClient:
82
  if not self.base_url:
83
  raise ValueError("Custom provider requires base_url")
84
 
 
 
 
 
 
 
 
 
85
  @timing_decorator_async
86
  async def generate_text(
87
  self,
@@ -113,6 +123,8 @@ class LLMClient:
113
  result = await self._generate_custom(prompt, **kwargs)
114
  elif self.provider == "hfs":
115
  result = await self._generate_hfs(prompt, **kwargs)
 
 
116
  else:
117
  raise ValueError(f"Unsupported provider: {self.provider}")
118
  logger.info(f"[LLM] generate_text - provider: {self.provider}\n\t result: {result}")
@@ -192,6 +204,34 @@ class LLMClient:
192
  logger.error("HFS API response is None")
193
  raise RuntimeError("HFS API response is None")
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  @timing_decorator_async
196
  async def chat(
197
  self,
@@ -391,23 +431,14 @@ class LLMClient:
391
  """
392
 
393
  prompt = f"""
394
- Phân tích ngữ nghĩa câu sau: \"{text}\"
395
-
396
- Trả lời dưới dạng JSON với 3 trường sau:
397
  {{
398
  "muc_dich": "mục đích của câu hỏi",
399
- "phuong_tien": "loại phương tiện giao thông",
400
  "hanh_vi_vi_pham": "hành vi vi phạm luật giao thông"
401
  }}
402
 
403
- Ví dụ:
404
- "Tôi chạy xe hơi không bật đèn vào ban đêm thì có bị sao không?"
405
- {{
406
- "muc_dich": "Hỏi về hậu quả/hình phạt khi không bật đèn xe hơi ban đêm",
407
- "phuong_tien": "Xe hơi",
408
- "hanh_vi_vi_pham": "Không bật đèn khi lái xe vào ban đêm"
409
- }}
410
-
411
  Câu bạn cần phân tích:
412
  \"{text}\"
413
  """.strip()
 
35
  self._setup_custom(kwargs)
36
  elif self.provider == "hfs":
37
  self._setup_HFS(kwargs)
38
+ elif self.provider == "gemini":
39
+ self._setup_gemini(kwargs)
40
  else:
41
  raise ValueError(f"Unsupported provider: {provider}")
42
 
 
84
  if not self.base_url:
85
  raise ValueError("Custom provider requires base_url")
86
 
87
+ def _setup_gemini(self, config: Dict[str, Any]):
88
+ """Cấu hình cho Gemini."""
89
+ self.api_key = config.get("api_key", "")
90
+ self.base_url = config.get("base_url", "")
91
+ self.model = config.get("model", "")
92
+ self.max_tokens = config.get("max_tokens", 1024)
93
+ self.temperature = config.get("temperature", 0.7)
94
+
95
  @timing_decorator_async
96
  async def generate_text(
97
  self,
 
123
  result = await self._generate_custom(prompt, **kwargs)
124
  elif self.provider == "hfs":
125
  result = await self._generate_hfs(prompt, **kwargs)
126
+ elif self.provider == "gemini":
127
+ result = await self._generate_gemini(prompt, **kwargs)
128
  else:
129
  raise ValueError(f"Unsupported provider: {self.provider}")
130
  logger.info(f"[LLM] generate_text - provider: {self.provider}\n\t result: {result}")
 
204
  logger.error("HFS API response is None")
205
  raise RuntimeError("HFS API response is None")
206
 
207
+ async def _generate_gemini(self, prompt: str, **kwargs) -> str:
208
+ """Gọi Gemini API để sinh text từ prompt."""
209
+ url = self.base_url
210
+ headers = {"Content-Type": "application/json"}
211
+ if self.api_key:
212
+ headers["Authorization"] = f"Bearer {self.api_key}"
213
+ # Gemini API expects {"contents": [{"parts": [{"text": prompt}]}]}
214
+ payload = {"contents": [{"parts": [{"text": prompt}]}]}
215
+ response = await call_endpoint_with_retry(self._client, url, payload, headers=headers)
216
+ if response is not None and hasattr(response, 'text'):
217
+ logger.info(f"[LLM][GEMINI][RAW_RESPONSE] {response.text}")
218
+ else:
219
+ logger.info(f"[LLM][GEMINI][RAW_RESPONSE] {str(response)}")
220
+ if response is not None:
221
+ data = response.json()
222
+ # Log token usage nếu có
223
+ usage = data.get('usage') or data.get('usageMetadata')
224
+ if usage:
225
+ logger.info(f"[LLM][GEMINI][USAGE] {usage}")
226
+ # Gemini trả về: {'candidates': [{'content': {'parts': [{'text': '...'}]}}]}
227
+ try:
228
+ return data['candidates'][0]['content']['parts'][0]['text']
229
+ except Exception:
230
+ return str(data)
231
+ else:
232
+ logger.error("Gemini API response is None")
233
+ raise RuntimeError("Gemini API response is None")
234
+
235
  @timing_decorator_async
236
  async def chat(
237
  self,
 
431
  """
432
 
433
  prompt = f"""
434
+ Bạn là một AI chuyên phân tích ngữ nghĩa câu hỏi về giao thông đường bộ.
435
+ Với mỗi câu đầu vào, hãy trích xuất 3 thông tin sau và trả lời đúng định dạng JSON:
 
436
  {{
437
  "muc_dich": "mục đích của câu hỏi",
438
+ "phuong_tien": "loại phương tiện giao thông (nếu có)",
439
  "hanh_vi_vi_pham": "hành vi vi phạm luật giao thông"
440
  }}
441
 
 
 
 
 
 
 
 
 
442
  Câu bạn cần phân tích:
443
  \"{text}\"
444
  """.strip()
app/main.py CHANGED
@@ -54,9 +54,16 @@ embedding_client = EmbeddingClient()
54
  VEHICLE_KEYWORDS = ["xe máy", "ô tô", "xe đạp", "xe hơi"]
55
 
56
  # Khởi tạo LLM client (ví dụ dùng HFS, bạn có thể đổi provider tuỳ ý)
 
 
 
 
 
57
  llm_client = create_llm_client(
58
- provider="hfs",
59
- base_url="https://vietcat-gemma34b.hf.space"
 
 
60
  )
61
 
62
  logger.info("[STARTUP] Mount health router...")
 
54
  VEHICLE_KEYWORDS = ["xe máy", "ô tô", "xe đạp", "xe hơi"]
55
 
56
  # Khởi tạo LLM client (ví dụ dùng HFS, bạn có thể đổi provider tuỳ ý)
57
+ # llm_client = create_llm_client(
58
+ # provider="hfs",
59
+ # base_url="https://vietcat-gemma34b.hf.space"
60
+ # )
61
+ # Khởi tạo LLM client Gemini
62
  llm_client = create_llm_client(
63
+ provider="gemini",
64
+ api_key=settings.gemini_api_key,
65
+ base_url=settings.gemini_base_url,
66
+ model=settings.gemini_model
67
  )
68
 
69
  logger.info("[STARTUP] Mount health router...")