add gemini
Browse files- app/config.py +5 -0
- app/llm.py +43 -12
- app/main.py +9 -2
app/config.py
CHANGED
|
@@ -30,6 +30,11 @@ class Settings(BaseSettings):
|
|
| 30 |
# Logging Configuration
|
| 31 |
log_level: str = os.getenv("LOG_LEVEL", "INFO") or "INFO"
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
class Config:
|
| 34 |
env_file = ".env"
|
| 35 |
|
|
|
|
| 30 |
# Logging Configuration
|
| 31 |
log_level: str = os.getenv("LOG_LEVEL", "INFO") or "INFO"
|
| 32 |
|
| 33 |
+
# Gemini Configuration
|
| 34 |
+
gemini_api_key: str = os.getenv("GEMINI_API_KEY") or ""
|
| 35 |
+
gemini_base_url: str = os.getenv("GEMINI_BASE_URL", "https://generativelanguage.googleapis.com/v1/models/gemini-2.5-flash:generateContent") or ""
|
| 36 |
+
gemini_model: str = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") or ""
|
| 37 |
+
|
| 38 |
class Config:
|
| 39 |
env_file = ".env"
|
| 40 |
|
app/llm.py
CHANGED
|
@@ -35,6 +35,8 @@ class LLMClient:
|
|
| 35 |
self._setup_custom(kwargs)
|
| 36 |
elif self.provider == "hfs":
|
| 37 |
self._setup_HFS(kwargs)
|
|
|
|
|
|
|
| 38 |
else:
|
| 39 |
raise ValueError(f"Unsupported provider: {provider}")
|
| 40 |
|
|
@@ -82,6 +84,14 @@ class LLMClient:
|
|
| 82 |
if not self.base_url:
|
| 83 |
raise ValueError("Custom provider requires base_url")
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
@timing_decorator_async
|
| 86 |
async def generate_text(
|
| 87 |
self,
|
|
@@ -113,6 +123,8 @@ class LLMClient:
|
|
| 113 |
result = await self._generate_custom(prompt, **kwargs)
|
| 114 |
elif self.provider == "hfs":
|
| 115 |
result = await self._generate_hfs(prompt, **kwargs)
|
|
|
|
|
|
|
| 116 |
else:
|
| 117 |
raise ValueError(f"Unsupported provider: {self.provider}")
|
| 118 |
logger.info(f"[LLM] generate_text - provider: {self.provider}\n\t result: {result}")
|
|
@@ -192,6 +204,34 @@ class LLMClient:
|
|
| 192 |
logger.error("HFS API response is None")
|
| 193 |
raise RuntimeError("HFS API response is None")
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
@timing_decorator_async
|
| 196 |
async def chat(
|
| 197 |
self,
|
|
@@ -391,23 +431,14 @@ class LLMClient:
|
|
| 391 |
"""
|
| 392 |
|
| 393 |
prompt = f"""
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
Trả lời dưới dạng JSON với 3 trường sau:
|
| 397 |
{{
|
| 398 |
"muc_dich": "mục đích của câu hỏi",
|
| 399 |
-
"phuong_tien": "loại phương tiện giao thông",
|
| 400 |
"hanh_vi_vi_pham": "hành vi vi phạm luật giao thông"
|
| 401 |
}}
|
| 402 |
|
| 403 |
-
Ví dụ:
|
| 404 |
-
"Tôi chạy xe hơi không bật đèn vào ban đêm thì có bị sao không?"
|
| 405 |
-
{{
|
| 406 |
-
"muc_dich": "Hỏi về hậu quả/hình phạt khi không bật đèn xe hơi ban đêm",
|
| 407 |
-
"phuong_tien": "Xe hơi",
|
| 408 |
-
"hanh_vi_vi_pham": "Không bật đèn khi lái xe vào ban đêm"
|
| 409 |
-
}}
|
| 410 |
-
|
| 411 |
Câu bạn cần phân tích:
|
| 412 |
\"{text}\"
|
| 413 |
""".strip()
|
|
|
|
| 35 |
self._setup_custom(kwargs)
|
| 36 |
elif self.provider == "hfs":
|
| 37 |
self._setup_HFS(kwargs)
|
| 38 |
+
elif self.provider == "gemini":
|
| 39 |
+
self._setup_gemini(kwargs)
|
| 40 |
else:
|
| 41 |
raise ValueError(f"Unsupported provider: {provider}")
|
| 42 |
|
|
|
|
| 84 |
if not self.base_url:
|
| 85 |
raise ValueError("Custom provider requires base_url")
|
| 86 |
|
| 87 |
+
def _setup_gemini(self, config: Dict[str, Any]):
|
| 88 |
+
"""Cấu hình cho Gemini."""
|
| 89 |
+
self.api_key = config.get("api_key", "")
|
| 90 |
+
self.base_url = config.get("base_url", "")
|
| 91 |
+
self.model = config.get("model", "")
|
| 92 |
+
self.max_tokens = config.get("max_tokens", 1024)
|
| 93 |
+
self.temperature = config.get("temperature", 0.7)
|
| 94 |
+
|
| 95 |
@timing_decorator_async
|
| 96 |
async def generate_text(
|
| 97 |
self,
|
|
|
|
| 123 |
result = await self._generate_custom(prompt, **kwargs)
|
| 124 |
elif self.provider == "hfs":
|
| 125 |
result = await self._generate_hfs(prompt, **kwargs)
|
| 126 |
+
elif self.provider == "gemini":
|
| 127 |
+
result = await self._generate_gemini(prompt, **kwargs)
|
| 128 |
else:
|
| 129 |
raise ValueError(f"Unsupported provider: {self.provider}")
|
| 130 |
logger.info(f"[LLM] generate_text - provider: {self.provider}\n\t result: {result}")
|
|
|
|
| 204 |
logger.error("HFS API response is None")
|
| 205 |
raise RuntimeError("HFS API response is None")
|
| 206 |
|
| 207 |
+
async def _generate_gemini(self, prompt: str, **kwargs) -> str:
|
| 208 |
+
"""Gọi Gemini API để sinh text từ prompt."""
|
| 209 |
+
url = self.base_url
|
| 210 |
+
headers = {"Content-Type": "application/json"}
|
| 211 |
+
if self.api_key:
|
| 212 |
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
| 213 |
+
# Gemini API expects {"contents": [{"parts": [{"text": prompt}]}]}
|
| 214 |
+
payload = {"contents": [{"parts": [{"text": prompt}]}]}
|
| 215 |
+
response = await call_endpoint_with_retry(self._client, url, payload, headers=headers)
|
| 216 |
+
if response is not None and hasattr(response, 'text'):
|
| 217 |
+
logger.info(f"[LLM][GEMINI][RAW_RESPONSE] {response.text}")
|
| 218 |
+
else:
|
| 219 |
+
logger.info(f"[LLM][GEMINI][RAW_RESPONSE] {str(response)}")
|
| 220 |
+
if response is not None:
|
| 221 |
+
data = response.json()
|
| 222 |
+
# Log token usage nếu có
|
| 223 |
+
usage = data.get('usage') or data.get('usageMetadata')
|
| 224 |
+
if usage:
|
| 225 |
+
logger.info(f"[LLM][GEMINI][USAGE] {usage}")
|
| 226 |
+
# Gemini trả về: {'candidates': [{'content': {'parts': [{'text': '...'}]}}]}
|
| 227 |
+
try:
|
| 228 |
+
return data['candidates'][0]['content']['parts'][0]['text']
|
| 229 |
+
except Exception:
|
| 230 |
+
return str(data)
|
| 231 |
+
else:
|
| 232 |
+
logger.error("Gemini API response is None")
|
| 233 |
+
raise RuntimeError("Gemini API response is None")
|
| 234 |
+
|
| 235 |
@timing_decorator_async
|
| 236 |
async def chat(
|
| 237 |
self,
|
|
|
|
| 431 |
"""
|
| 432 |
|
| 433 |
prompt = f"""
|
| 434 |
+
Bạn là một AI chuyên phân tích ngữ nghĩa câu hỏi về giao thông đường bộ.
|
| 435 |
+
Với mỗi câu đầu vào, hãy trích xuất 3 thông tin sau và trả lời đúng định dạng JSON:
|
|
|
|
| 436 |
{{
|
| 437 |
"muc_dich": "mục đích của câu hỏi",
|
| 438 |
+
"phuong_tien": "loại phương tiện giao thông (nếu có)",
|
| 439 |
"hanh_vi_vi_pham": "hành vi vi phạm luật giao thông"
|
| 440 |
}}
|
| 441 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
Câu bạn cần phân tích:
|
| 443 |
\"{text}\"
|
| 444 |
""".strip()
|
app/main.py
CHANGED
|
@@ -54,9 +54,16 @@ embedding_client = EmbeddingClient()
|
|
| 54 |
VEHICLE_KEYWORDS = ["xe máy", "ô tô", "xe đạp", "xe hơi"]
|
| 55 |
|
| 56 |
# Khởi tạo LLM client (ví dụ dùng HFS, bạn có thể đổi provider tuỳ ý)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
llm_client = create_llm_client(
|
| 58 |
-
provider="
|
| 59 |
-
|
|
|
|
|
|
|
| 60 |
)
|
| 61 |
|
| 62 |
logger.info("[STARTUP] Mount health router...")
|
|
|
|
| 54 |
VEHICLE_KEYWORDS = ["xe máy", "ô tô", "xe đạp", "xe hơi"]
|
| 55 |
|
| 56 |
# Khởi tạo LLM client (ví dụ dùng HFS, bạn có thể đổi provider tuỳ ý)
|
| 57 |
+
# llm_client = create_llm_client(
|
| 58 |
+
# provider="hfs",
|
| 59 |
+
# base_url="https://vietcat-gemma34b.hf.space"
|
| 60 |
+
# )
|
| 61 |
+
# Khởi tạo LLM client Gemini
|
| 62 |
llm_client = create_llm_client(
|
| 63 |
+
provider="gemini",
|
| 64 |
+
api_key=settings.gemini_api_key,
|
| 65 |
+
base_url=settings.gemini_base_url,
|
| 66 |
+
model=settings.gemini_model
|
| 67 |
)
|
| 68 |
|
| 69 |
logger.info("[STARTUP] Mount health router...")
|