VietCat commited on
Commit
5a29187
·
1 Parent(s): f6866df

adjust prompt

Browse files
Files changed (1) hide show
  1. app/llm.py +239 -287
app/llm.py CHANGED
@@ -1,32 +1,67 @@
 
 
1
  from typing import List, Dict, Any, Optional, Union
2
- import httpx
3
  import json
4
- from loguru import logger
5
- from tenacity import retry, stop_after_attempt, wait_exponential
6
  import os
 
 
 
 
 
7
  from .gemini_client import GeminiClient
8
  from .config import get_settings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- from .utils import timing_decorator_async, timing_decorator_sync, call_endpoint_with_retry
11
 
12
  class LLMClient:
13
  """
14
  Client để tương tác với các mô hình ngôn ngữ lớn (LLM).
15
- Hỗ trợ nhiều provider: OpenAI, HuggingFace, local models, etc.
16
  """
17
-
18
  def __init__(self, provider: str = "openai", **kwargs):
19
- """
20
- Khởi tạo LLMClient.
21
-
22
- Args:
23
- provider (str): Loại provider ("openai", "huggingface", "local", "custom")
24
- **kwargs: Các tham số cấu hình khác
25
- """
26
  self.provider = provider.lower()
27
  self._client = httpx.AsyncClient(timeout=60.0)
28
-
29
- # Cấu hình theo provider
30
  if self.provider == "openai":
31
  self._setup_openai(kwargs)
32
  elif self.provider == "huggingface":
@@ -36,41 +71,39 @@ class LLMClient:
36
  elif self.provider == "custom":
37
  self._setup_custom(kwargs)
38
  elif self.provider == "hfs":
39
- self._setup_HFS(kwargs)
40
  elif self.provider == "gemini":
41
  self._setup_gemini(kwargs)
42
  else:
43
  raise ValueError(f"Unsupported provider: {provider}")
44
-
 
 
45
  def _setup_openai(self, config: Dict[str, Any]):
46
- """Cấu hình cho OpenAI."""
47
  self.api_key = config.get("api_key") or os.getenv("OPENAI_API_KEY") or ""
48
  self.base_url = config.get("base_url", "https://api.openai.com/v1")
49
  self.model = config.get("model", "gpt-3.5-turbo")
50
  self.max_tokens = config.get("max_tokens", 1000)
51
  self.temperature = config.get("temperature", 0.7)
52
-
53
  if not self.api_key:
54
  raise ValueError("OpenAI API key is required")
55
-
56
  def _setup_huggingface(self, config: Dict[str, Any]):
57
- """Cấu hình cho HuggingFace."""
58
  self.api_key = config.get("api_key", "")
59
  self.base_url = config.get("base_url", "https://api-inference.huggingface.co")
60
  self.model = config.get("model", "microsoft/DialoGPT-medium")
61
  self.max_tokens = config.get("max_tokens", 1000)
62
  self.temperature = config.get("temperature", 0.7)
63
-
64
  def _setup_local(self, config: Dict[str, Any]):
65
- """Cấu hình cho local model."""
66
  self.api_key = ""
67
  self.base_url = config.get("base_url", "http://localhost:8000")
68
  self.model = config.get("model", "default")
69
  self.max_tokens = config.get("max_tokens", 1000)
70
  self.temperature = config.get("temperature", 0.7)
71
-
72
  def _setup_custom(self, config: Dict[str, Any]):
73
- """Cấu hình cho custom provider."""
74
  self.api_key = config.get("api_key", "")
75
  self.base_url = config.get("base_url")
76
  self.model = config.get("model", "default")
@@ -79,40 +112,33 @@ class LLMClient:
79
  if not self.base_url:
80
  raise ValueError("Custom provider requires base_url")
81
 
82
- def _setup_HFS(self, config: Dict[str, Any]):
83
- """Cấu hình cho custom provider."""
84
  self.api_key = config.get("api_key", "")
85
  self.base_url = config.get("base_url")
86
  if not self.base_url:
87
- raise ValueError("Custom provider requires base_url")
88
-
89
  def _setup_gemini(self, config: Dict[str, Any]):
90
- """Cấu hình cho Gemini."""
91
- # Sử dụng GeminiClient với RequestLimitManager
92
  self.gemini_client = GeminiClient()
93
  logger.info("[LLM] Initialized GeminiClient with RequestLimitManager")
94
-
 
 
95
  @timing_decorator_async
96
  async def generate_text(
97
- self,
98
- prompt: str,
99
  system_prompt: Optional[str] = None,
100
- **kwargs
101
  ) -> str:
102
  """
103
  Tạo text từ prompt sử dụng LLM.
104
-
105
- Args:
106
- prompt (str): Prompt đầu vào
107
- system_prompt (str, optional): System prompt
108
- **kwargs: Các tham số bổ sung
109
-
110
- Returns:
111
- str: Text được tạo ra
112
  """
113
- logger.info(f"[LLM] generate_text - provider: {self.provider} \n\t prompt: {prompt}")
 
 
114
  try:
115
- result = None
116
  if self.provider == "openai":
117
  result = await self._generate_openai(prompt, system_prompt, **kwargs)
118
  elif self.provider == "huggingface":
@@ -127,24 +153,38 @@ class LLMClient:
127
  result = await self._generate_gemini(prompt, **kwargs)
128
  else:
129
  raise ValueError(f"Unsupported provider: {self.provider}")
130
- logger.info(f"[LLM] generate_text - provider: {self.provider}\n\t result: {result}")
 
131
  return result
132
  except Exception as e:
133
- logger.error(f"[LLM] Error generating text with {self.provider}: {e}")
134
  raise
135
-
136
- async def _generate_openai(self, prompt: str, system_prompt: Optional[str] = None, **kwargs) -> str:
 
 
137
  url = f"{self.base_url}/chat/completions"
138
- payload = {"model": kwargs.get("model", self.model), "messages": [{"role": "system", "content": system_prompt or ""}, {"role": "user", "content": prompt}], "max_tokens": kwargs.get("max_tokens", self.max_tokens), "temperature": kwargs.get("temperature", self.temperature), "stream": False}
139
- headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
 
 
 
 
 
 
 
 
 
 
 
 
140
  response = await call_endpoint_with_retry(self._client, url, payload, headers=headers)
141
  if response is not None:
142
  data = response.json()
143
  return data["choices"][0]["message"]["content"]
144
- else:
145
- logger.error("OpenAI API response is None")
146
- raise RuntimeError("OpenAI API response is None")
147
-
148
  async def _generate_huggingface(self, prompt: str, **kwargs) -> str:
149
  url = f"{self.base_url}/generate"
150
  payload = {"inputs": prompt}
@@ -152,10 +192,9 @@ class LLMClient:
152
  if response is not None:
153
  data = response.json()
154
  return data[0]["generated_text"]
155
- else:
156
- logger.error("HuggingFace API response is None")
157
- raise RuntimeError("HuggingFace API response is None")
158
-
159
  async def _generate_local(self, prompt: str, **kwargs) -> str:
160
  url = f"{self.base_url}/generate"
161
  payload = {"prompt": prompt}
@@ -163,10 +202,9 @@ class LLMClient:
163
  if response is not None:
164
  data = response.json()
165
  return data.get("text", "")
166
- else:
167
- logger.error("Local API response is None")
168
- raise RuntimeError("Local API response is None")
169
-
170
  async def _generate_custom(self, prompt: str, **kwargs) -> str:
171
  url = f"{self.base_url}/custom"
172
  payload = {"prompt": prompt}
@@ -174,130 +212,101 @@ class LLMClient:
174
  if response is not None:
175
  data = response.json()
176
  return data.get("text", "")
177
- else:
178
- logger.error("Custom API response is None")
179
- raise RuntimeError("Custom API response is None")
180
-
181
  async def _generate_hfs(self, prompt: str, **kwargs) -> str:
 
182
  endpoint = f"{self.base_url}/purechat"
183
  payload = {"prompt": prompt}
184
  headers = {}
185
- if self.api_key:
186
  headers["Authorization"] = f"Bearer {self.api_key}"
187
- response = await call_endpoint_with_retry(self._client, endpoint, payload, 3, 500, headers=headers)
188
- logger.info(f"[LLM] generate_text - provider: {self.provider}\n\t response: {response}")
 
 
 
 
189
  try:
190
- import json as _json
191
- logger.info(f"[LLM][RAW_RESPONSE] { _json.dumps(response, ensure_ascii=False, indent=2) }")
 
192
  except Exception:
193
  logger.info(f"[LLM][RAW_RESPONSE] {str(response)}")
 
194
  if response is not None:
195
  data = response.json()
196
- if 'response' in data:
197
- return data['response']
198
- elif 'result' in data:
199
- return data['result']
200
- elif 'data' in data and isinstance(data['data'], list):
201
- return data['data'][0]
202
  return str(data)
203
- else:
204
- logger.error("HFS API response is None")
205
- raise RuntimeError("HFS API response is None")
206
 
207
  async def _generate_gemini(self, prompt: str, **kwargs) -> str:
208
- import asyncio
209
  loop = asyncio.get_event_loop()
210
- return await loop.run_in_executor(None, self.gemini_client.generate_text, prompt)
 
211
 
212
  @timing_decorator_async
213
- async def chat(
214
- self,
215
- messages: List[Dict[str, str]],
216
- **kwargs
217
- ) -> str:
218
- """
219
- Chat với LLM sử dụng conversation history.
220
-
221
- Args:
222
- messages (List[Dict]): List các message với format [{"role": "user", "content": "..."}]
223
- **kwargs: Các tham số bổ sung
224
-
225
- Returns:
226
- str: Response từ LLM
227
- """
228
- logger.info(f"[LLM] chat - messages: {messages}")
229
  if self.provider == "openai":
230
  return await self._chat_openai(messages, **kwargs)
231
- else:
232
- # Với các provider khác, convert messages thành prompt
233
- prompt = self._messages_to_prompt(messages)
234
- return await self.generate_text(prompt, **kwargs)
235
-
236
  async def _chat_openai(self, messages: List[Dict[str, str]], **kwargs) -> str:
237
- """Chat với OpenAI API."""
238
  payload = {
239
  "model": kwargs.get("model", self.model),
240
  "messages": messages,
241
  "max_tokens": kwargs.get("max_tokens", self.max_tokens),
242
  "temperature": kwargs.get("temperature", self.temperature),
243
- "stream": False
244
  }
245
-
246
  headers = {
247
  "Authorization": f"Bearer {self.api_key}",
248
- "Content-Type": "application/json"
249
  }
250
-
251
  response = await self._client.post(
252
- f"{self.base_url}/chat/completions",
253
- headers=headers,
254
- json=payload
255
  )
256
  response.raise_for_status()
257
-
258
  data = response.json()
259
  return data["choices"][0]["message"]["content"]
260
-
261
  def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
262
- """Convert conversation messages thành prompt string."""
263
- prompt = ""
264
  for msg in messages:
265
  role = msg.get("role", "user")
266
  content = msg.get("content", "")
267
-
268
  if role == "system":
269
- prompt += f"System: {content}\n\n"
270
  elif role == "user":
271
- prompt += f"User: {content}\n"
272
  elif role == "assistant":
273
- prompt += f"Assistant: {content}\n"
274
-
275
- prompt += "Assistant: "
276
- return prompt
277
-
 
278
  @timing_decorator_async
279
  async def classify_text(
280
- self,
281
- text: str,
282
- categories: List[str],
283
- **kwargs
284
  ) -> Dict[str, Any]:
285
- """
286
- Phân loại text vào các categories.
287
-
288
- Args:
289
- text (str): Text cần phân loại
290
- categories (List[str]): List các categories
291
- **kwargs: Các tham số bổ sung
292
-
293
- Returns:
294
- Dict: Kết quả phân loại
295
- """
296
  prompt = f"""
297
  Phân loại text sau vào một trong các categories: {', '.join(categories)}
298
-
299
  Text: {text}
300
-
301
  Trả về kết quả theo format JSON:
302
  {{
303
  "category": "tên_category",
@@ -305,55 +314,31 @@ class LLMClient:
305
  "reasoning": "lý do phân loại"
306
  }}
307
  """
308
-
309
  response = await self.generate_text(prompt, **kwargs)
310
-
311
- try:
312
- # Tìm JSON trong response
313
- import re
314
- json_match = re.search(r'\{.*\}', response, re.DOTALL)
315
- if json_match:
316
- result = json.loads(json_match.group())
317
- return result
318
- else:
319
- return {
320
- "category": "unknown",
321
- "confidence": 0.0,
322
- "reasoning": "Không thể parse JSON response"
323
- }
324
- except json.JSONDecodeError:
325
- return {
326
- "category": "unknown",
327
- "confidence": 0.0,
328
- "reasoning": f"JSON parse error: {response}"
329
- }
330
-
331
  @timing_decorator_async
332
  async def extract_entities(
333
- self,
334
- text: str,
335
- entity_types: Optional[List[str]] = None,
336
- **kwargs
337
  ) -> List[Dict[str, Any]]:
338
- """
339
- Trích xuất entities từ text.
340
-
341
- Args:
342
- text (str): Text cần trích xuất
343
- entity_types (List[str]): Các loại entity cần tìm
344
- **kwargs: Các tham số bổ sung
345
-
346
- Returns:
347
- List[Dict]: List các entities được tìm thấy
348
- """
349
  if entity_types is None:
350
  entity_types = ["PERSON", "ORGANIZATION", "LOCATION", "MONEY", "DATE"]
351
-
352
  prompt = f"""
353
  Trích xuất các entities từ text sau. Tìm các entities thuộc types: {', '.join(entity_types)}
354
-
355
  Text: {text}
356
-
357
  Trả về kết quả theo format JSON:
358
  [
359
  {{
@@ -364,158 +349,125 @@ class LLMClient:
364
  }}
365
  ]
366
  """
367
-
368
  response = await self.generate_text(prompt, **kwargs)
369
-
370
  try:
371
- import re
372
- # Log toàn bộ response dưới dạng text dễ đọc
373
- try:
374
- import json as _json
375
- logger.info(f"[LLM][RAW_RESPONSE] { _json.dumps(response, ensure_ascii=False, indent=2) }")
376
- except Exception:
377
- logger.info(f"[LLM][RAW_RESPONSE] {str(response)}")
378
- # Ưu tiên parse object JSON nếu có
379
- json_match_obj = re.search(r'\{[\s\S]*?\}', response)
380
- json_match_list = re.search(r'\[[\s\S]*?\]', response)
381
- if json_match_list:
382
- entities = json.loads(json_match_list.group())
383
- return entities
384
- elif json_match_obj:
385
- entity = json.loads(json_match_obj.group())
386
- return [entity]
387
- else:
388
- return []
389
- except json.JSONDecodeError:
390
- logger.error(f"Error parsing entities JSON: {response}")
391
  return []
392
 
393
  @timing_decorator_async
394
  async def analyze(
395
- self,
396
- text: str,
397
- conversation_context: str,
398
- **kwargs
399
  ) -> List[Dict[str, Any]]:
400
  """
401
- Trích xuất entities từ text.
402
-
403
- Args:
404
- text (str): Text cần trích xuất
405
- **kwargs: Các tham số bổ sung
406
-
407
- Returns:
408
- List[Dict]: List các entities được tìm thấy
409
  """
410
-
411
  prompt = f"""
412
- Bạn là một chuyên gia phân tích ngôn ngữ tự nhiên (NLP) chuyên xử lý các câu hỏi về luật giao thông Việt Nam. Nhiệm vụ của bạn là đọc kỹ **lịch sử trò chuyện** và **câu hỏi mới nhất** của người dùng để trích xuất thông tin vào một cấu trúc JSON duy nhất. Chỉ trả về đối tượng JSON, không thêm bất kỳ giải thích nào.
413
- Định dạng JSON bắt buộc:
414
 
415
- {{
416
- "muc_dich": "...",
417
- "phuong_tien": "...",
418
- "hanh_vi": "...",
419
- "cau_hoi": "..."
420
- }}
421
 
422
- Hướng dẫn chi tiết cho từng trường:
 
 
 
 
 
423
 
424
- **muc_dich**: Phải một trong các giá trị sau: "hỏi về mức phạt", "hỏi về quy tắc giao thông", "hỏi về báo hiệu đường bộ", "hỏi về quy trình xử lý vi phạm giao thông", "thông tin cá nhân của AI", "khác". **Phải dựa vào câu hỏi mới nhất để xác định.**
425
- **phuong_tien**: Tên phương tiện được đề cập trong câu hỏi mới hoặc trong lịch sử gần nhất. Nếu không có, để chuỗi rỗng "".
426
- **hanh_vi**: Tên gọi pháp lý của hành vi. **Sử dụng lịch sử trò chuyện để xác định hành vi nếu câu hỏi mới không đề cập đến.** Nếu không có hành vi cụ thể, để chuỗi rỗng "".
427
- **cau_hoi**: Diễn đạt lại câu hỏi mới nhất của người dùng thành một câu hỏi hoàn chỉnh, kết hợp ngữ cảnh từ lịch sử nếu cần, sử dụng đúng thuật ngữ pháp lý.
428
 
429
- DỤ MẪU:
 
 
 
 
 
 
430
 
431
- Câu hỏi đầu vào: vượt đèn đỏ phạt nhiêu?"
432
- Kết quả JSON mong muốn:
433
- {{
434
- "muc_dich": "hỏi về mức phạt",
435
- "phuong_tien": tô",
436
- "hanh_vi": "Không chấp hành hiệu lệnh của đèn tín hiệu giao thông",
437
- "cau_hoi": "Mức xử phạt cho hành vi ô không chấp hành hiệu lệnh của đèn tín hiệu giao thông bao nhiêu?"
438
- }}
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
- Bây giờ, hãy phân tích lịch sử câu hỏi sau và chỉ trả về đối tượng JSON.
441
- Lịch sử trò chuyện:
442
- \"{conversation_context}\"
 
 
 
443
 
444
- Câu hỏi:
445
- \"{text}\"
446
- """.strip()
447
-
448
  response = await self.generate_text(prompt, **kwargs)
449
-
450
- logger.info(f"[LLM][RAW] Kết quả trả về từ generate_text: {response}")
451
 
452
  try:
453
- import re
454
- # Log toàn bộ response dưới dạng text dễ đọc
455
- try:
456
- import json as _json
457
- logger.info(f"[LLM][RAW_RESPONSE] { _json.dumps(response, ensure_ascii=False, indent=2) }")
458
- except Exception:
459
- logger.info(f"[LLM][RAW_RESPONSE] {str(response)}")
460
- # Ưu tiên parse object JSON nếu có
461
- json_match_obj = re.search(r'\{[\s\S]*?\}', response)
462
- json_match_list = re.search(r'\[[\s\S]*?\]', response)
463
- if json_match_list:
464
- entities = json.loads(json_match_list.group())
465
- return entities
466
- elif json_match_obj:
467
- entity = json.loads(json_match_obj.group())
468
- return [entity]
469
- else:
470
- return []
471
- except json.JSONDecodeError:
472
- logger.error(f"Error parsing entities JSON: {response}")
473
  return []
474
-
 
 
 
475
  async def close(self):
476
  """Đóng client connection."""
477
- await self._client.aclose()
 
478
 
479
 
480
  # Factory function để tạo LLMClient dễ dàng
481
  def create_llm_client(provider: str = "openai", **kwargs) -> LLMClient:
482
- """
483
- Factory function để tạo LLMClient.
484
-
485
- Args:
486
- provider (str): Loại provider
487
- **kwargs: Các tham số cấu hình
488
-
489
- Returns:
490
- LLMClient: Instance của LLMClient
491
- """
492
  return LLMClient(provider, **kwargs)
493
 
494
 
495
  # Ví dụ sử dụng
496
  if __name__ == "__main__":
497
- import asyncio
498
-
499
  async def test_llm():
500
- # Test với OpenAI
501
  settings = get_settings()
502
  llm_client = create_llm_client(
503
  provider=settings.llm_provider,
504
  model=settings.llm_model,
505
  # ... các config khác nếu cần ...
506
  )
507
-
508
  # Generate text
509
  response = await llm_client.generate_text("Xin chào, bạn có khỏe không?")
510
  print(f"Response: {response}")
511
-
512
  # Chat
513
  messages = [
514
  {"role": "user", "content": "Bạn có thể giúp tôi không?"}
515
  ]
516
  chat_response = await llm_client.chat(messages)
517
  print(f"Chat response: {chat_response}")
518
-
519
  await llm_client.close()
520
-
521
- asyncio.run(test_llm())
 
1
+ from __future__ import annotations
2
+
3
  from typing import List, Dict, Any, Optional, Union
 
4
  import json
5
+ import re
 
6
  import os
7
+ import asyncio
8
+
9
+ import httpx
10
+ from loguru import logger
11
+
12
  from .gemini_client import GeminiClient
13
  from .config import get_settings
14
+ from .utils import (
15
+ timing_decorator_async,
16
+ timing_decorator_sync, # kept for compatibility even if unused here
17
+ call_endpoint_with_retry,
18
+ )
19
+
20
+
21
+ def _safe_truncate(s: str, n: int = 1000) -> str:
22
+ """Truncate long strings for logging purposes."""
23
+ if not isinstance(s, str):
24
+ s = str(s)
25
+ return s if len(s) <= n else s[:n] + "... [truncated]"
26
+
27
+
28
+ def _parse_json_from_text(text: str) -> Optional[Union[List[Dict[str, Any]], Dict[str, Any]]]:
29
+ """Best-effort JSON extractor from LLM free-form responses.
30
+
31
+ Strategy:
32
+ 1) Try json.loads() on the whole string first.
33
+ 2) Fallback to regex to find the first JSON list/object snippet.
34
+ """
35
+ if not text:
36
+ return None
37
+
38
+ # 1) try direct load
39
+ try:
40
+ return json.loads(text)
41
+ except Exception:
42
+ pass
43
+
44
+ # 2) find first JSON array or object
45
+ match = re.search(r"(\[[\s\S]+?\]|\{[\s\S]+?\})", text)
46
+ if match:
47
+ try:
48
+ return json.loads(match.group(1))
49
+ except Exception:
50
+ return None
51
+ return None
52
 
 
53
 
54
  class LLMClient:
55
  """
56
  Client để tương tác với các mô hình ngôn ngữ lớn (LLM).
57
+ Hỗ trợ nhiều provider: OpenAI, HuggingFace, local models, custom, HFS, Gemini.
58
  """
59
+
60
  def __init__(self, provider: str = "openai", **kwargs):
 
 
 
 
 
 
 
61
  self.provider = provider.lower()
62
  self._client = httpx.AsyncClient(timeout=60.0)
63
+
64
+ # Dispatch provider setup
65
  if self.provider == "openai":
66
  self._setup_openai(kwargs)
67
  elif self.provider == "huggingface":
 
71
  elif self.provider == "custom":
72
  self._setup_custom(kwargs)
73
  elif self.provider == "hfs":
74
+ self._setup_hfs(kwargs)
75
  elif self.provider == "gemini":
76
  self._setup_gemini(kwargs)
77
  else:
78
  raise ValueError(f"Unsupported provider: {provider}")
79
+
80
+ # ---------- Provider setups ---------- #
81
+
82
  def _setup_openai(self, config: Dict[str, Any]):
 
83
  self.api_key = config.get("api_key") or os.getenv("OPENAI_API_KEY") or ""
84
  self.base_url = config.get("base_url", "https://api.openai.com/v1")
85
  self.model = config.get("model", "gpt-3.5-turbo")
86
  self.max_tokens = config.get("max_tokens", 1000)
87
  self.temperature = config.get("temperature", 0.7)
88
+
89
  if not self.api_key:
90
  raise ValueError("OpenAI API key is required")
91
+
92
  def _setup_huggingface(self, config: Dict[str, Any]):
 
93
  self.api_key = config.get("api_key", "")
94
  self.base_url = config.get("base_url", "https://api-inference.huggingface.co")
95
  self.model = config.get("model", "microsoft/DialoGPT-medium")
96
  self.max_tokens = config.get("max_tokens", 1000)
97
  self.temperature = config.get("temperature", 0.7)
98
+
99
  def _setup_local(self, config: Dict[str, Any]):
 
100
  self.api_key = ""
101
  self.base_url = config.get("base_url", "http://localhost:8000")
102
  self.model = config.get("model", "default")
103
  self.max_tokens = config.get("max_tokens", 1000)
104
  self.temperature = config.get("temperature", 0.7)
105
+
106
  def _setup_custom(self, config: Dict[str, Any]):
 
107
  self.api_key = config.get("api_key", "")
108
  self.base_url = config.get("base_url")
109
  self.model = config.get("model", "default")
 
112
  if not self.base_url:
113
  raise ValueError("Custom provider requires base_url")
114
 
115
+ def _setup_hfs(self, config: Dict[str, Any]):
 
116
  self.api_key = config.get("api_key", "")
117
  self.base_url = config.get("base_url")
118
  if not self.base_url:
119
+ raise ValueError("HFS provider requires base_url")
120
+
121
  def _setup_gemini(self, config: Dict[str, Any]):
122
+ # Sử dụng GeminiClient với RequestLimitManager (theo thiết kế của bạn)
 
123
  self.gemini_client = GeminiClient()
124
  logger.info("[LLM] Initialized GeminiClient with RequestLimitManager")
125
+
126
+ # ---------- Core APIs ---------- #
127
+
128
  @timing_decorator_async
129
  async def generate_text(
130
+ self,
131
+ prompt: str,
132
  system_prompt: Optional[str] = None,
133
+ **kwargs,
134
  ) -> str:
135
  """
136
  Tạo text từ prompt sử dụng LLM.
 
 
 
 
 
 
 
 
137
  """
138
+ logger.info(
139
+ f"[LLM] generate_text - provider: {self.provider}\n\t prompt: {_safe_truncate(prompt, 1200)}"
140
+ )
141
  try:
 
142
  if self.provider == "openai":
143
  result = await self._generate_openai(prompt, system_prompt, **kwargs)
144
  elif self.provider == "huggingface":
 
153
  result = await self._generate_gemini(prompt, **kwargs)
154
  else:
155
  raise ValueError(f"Unsupported provider: {self.provider}")
156
+
157
+ logger.info(f"[LLM] generate_text - provider: {self.provider}\n\t result: {_safe_truncate(result, 1200)}")
158
  return result
159
  except Exception as e:
160
+ logger.exception(f"[LLM] Error generating text with {self.provider}: {e}")
161
  raise
162
+
163
+ async def _generate_openai(
164
+ self, prompt: str, system_prompt: Optional[str] = None, **kwargs
165
+ ) -> str:
166
  url = f"{self.base_url}/chat/completions"
167
+ payload = {
168
+ "model": kwargs.get("model", self.model),
169
+ "messages": [
170
+ {"role": "system", "content": system_prompt or ""},
171
+ {"role": "user", "content": prompt},
172
+ ],
173
+ "max_tokens": kwargs.get("max_tokens", self.max_tokens),
174
+ "temperature": kwargs.get("temperature", self.temperature),
175
+ "stream": False,
176
+ }
177
+ headers = {
178
+ "Authorization": f"Bearer {self.api_key}",
179
+ "Content-Type": "application/json",
180
+ }
181
  response = await call_endpoint_with_retry(self._client, url, payload, headers=headers)
182
  if response is not None:
183
  data = response.json()
184
  return data["choices"][0]["message"]["content"]
185
+ logger.error("OpenAI API response is None")
186
+ raise RuntimeError("OpenAI API response is None")
187
+
 
188
  async def _generate_huggingface(self, prompt: str, **kwargs) -> str:
189
  url = f"{self.base_url}/generate"
190
  payload = {"inputs": prompt}
 
192
  if response is not None:
193
  data = response.json()
194
  return data[0]["generated_text"]
195
+ logger.error("HuggingFace API response is None")
196
+ raise RuntimeError("HuggingFace API response is None")
197
+
 
198
  async def _generate_local(self, prompt: str, **kwargs) -> str:
199
  url = f"{self.base_url}/generate"
200
  payload = {"prompt": prompt}
 
202
  if response is not None:
203
  data = response.json()
204
  return data.get("text", "")
205
+ logger.error("Local API response is None")
206
+ raise RuntimeError("Local API response is None")
207
+
 
208
  async def _generate_custom(self, prompt: str, **kwargs) -> str:
209
  url = f"{self.base_url}/custom"
210
  payload = {"prompt": prompt}
 
212
  if response is not None:
213
  data = response.json()
214
  return data.get("text", "")
215
+ logger.error("Custom API response is None")
216
+ raise RuntimeError("Custom API response is None")
217
+
 
218
  async def _generate_hfs(self, prompt: str, **kwargs) -> str:
219
+ # Giữ nguyên chữ ký call_endpoint_with_retry như bạn đã dùng
220
  endpoint = f"{self.base_url}/purechat"
221
  payload = {"prompt": prompt}
222
  headers = {}
223
+ if hasattr(self, "api_key") and self.api_key:
224
  headers["Authorization"] = f"Bearer {self.api_key}"
225
+ response = await call_endpoint_with_retry(
226
+ self._client, endpoint, payload, 3, 500, headers=headers
227
+ )
228
+ logger.info(
229
+ f"[LLM] generate_text - provider: {self.provider}\n\t response: {_safe_truncate(str(response), 1200)}"
230
+ )
231
  try:
232
+ logger.info(
233
+ f"[LLM][RAW_RESPONSE] {json.dumps(response, ensure_ascii=False, indent=2) if hasattr(response, 'json') else str(response)}"
234
+ )
235
  except Exception:
236
  logger.info(f"[LLM][RAW_RESPONSE] {str(response)}")
237
+
238
  if response is not None:
239
  data = response.json()
240
+ if "response" in data:
241
+ return data["response"]
242
+ if "result" in data:
243
+ return data["result"]
244
+ if "data" in data and isinstance(data["data"], list) and data["data"]:
245
+ return data["data"][0]
246
  return str(data)
247
+ logger.error("HFS API response is None")
248
+ raise RuntimeError("HFS API response is None")
 
249
 
250
  async def _generate_gemini(self, prompt: str, **kwargs) -> str:
 
251
  loop = asyncio.get_event_loop()
252
+ # Đảm bảo kwargs được truyền nếu GeminiClient hỗ trợ
253
+ return await loop.run_in_executor(None, lambda: self.gemini_client.generate_text(prompt, **kwargs))
254
 
255
  @timing_decorator_async
256
+ async def chat(self, messages: List[Dict[str, str]], **kwargs) -> str:
257
+ """Chat với LLM sử dụng conversation history."""
258
+ logger.info(f"[LLM] chat - provider: {self.provider} - messages: {messages}")
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  if self.provider == "openai":
260
  return await self._chat_openai(messages, **kwargs)
261
+ # Convert messages -> prompt cho các provider khác
262
+ prompt = self._messages_to_prompt(messages)
263
+ return await self.generate_text(prompt, **kwargs)
264
+
 
265
  async def _chat_openai(self, messages: List[Dict[str, str]], **kwargs) -> str:
 
266
  payload = {
267
  "model": kwargs.get("model", self.model),
268
  "messages": messages,
269
  "max_tokens": kwargs.get("max_tokens", self.max_tokens),
270
  "temperature": kwargs.get("temperature", self.temperature),
271
+ "stream": False,
272
  }
 
273
  headers = {
274
  "Authorization": f"Bearer {self.api_key}",
275
+ "Content-Type": "application/json",
276
  }
 
277
  response = await self._client.post(
278
+ f"{self.base_url}/chat/completions", headers=headers, json=payload
 
 
279
  )
280
  response.raise_for_status()
 
281
  data = response.json()
282
  return data["choices"][0]["message"]["content"]
283
+
284
  def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
285
+ prompt_lines = []
 
286
  for msg in messages:
287
  role = msg.get("role", "user")
288
  content = msg.get("content", "")
 
289
  if role == "system":
290
+ prompt_lines.append(f"System: {content}\n")
291
  elif role == "user":
292
+ prompt_lines.append(f"User: {content}\n")
293
  elif role == "assistant":
294
+ prompt_lines.append(f"Assistant: {content}\n")
295
+ prompt_lines.append("Assistant: ")
296
+ return "".join(prompt_lines)
297
+
298
+ # ---------- Utility tasks ---------- #
299
+
300
  @timing_decorator_async
301
  async def classify_text(
302
+ self, text: str, categories: List[str], **kwargs
 
 
 
303
  ) -> Dict[str, Any]:
304
+ """Phân loại text vào các categories."""
 
 
 
 
 
 
 
 
 
 
305
  prompt = f"""
306
  Phân loại text sau vào một trong các categories: {', '.join(categories)}
307
+
308
  Text: {text}
309
+
310
  Trả về kết quả theo format JSON:
311
  {{
312
  "category": "tên_category",
 
314
  "reasoning": "lý do phân loại"
315
  }}
316
  """
 
317
  response = await self.generate_text(prompt, **kwargs)
318
+
319
+ result = _parse_json_from_text(response or "")
320
+ if isinstance(result, dict):
321
+ return result
322
+ # fallback default
323
+ return {
324
+ "category": "unknown",
325
+ "confidence": 0.0,
326
+ "reasoning": f"Cannot parse JSON from response: {_safe_truncate(response, 500)}",
327
+ }
328
+
 
 
 
 
 
 
 
 
 
 
329
  @timing_decorator_async
330
  async def extract_entities(
331
+ self, text: str, entity_types: Optional[List[str]] = None, **kwargs
 
 
 
332
  ) -> List[Dict[str, Any]]:
333
+ """Trích xuất entities từ text."""
 
 
 
 
 
 
 
 
 
 
334
  if entity_types is None:
335
  entity_types = ["PERSON", "ORGANIZATION", "LOCATION", "MONEY", "DATE"]
336
+
337
  prompt = f"""
338
  Trích xuất các entities từ text sau. Tìm các entities thuộc types: {', '.join(entity_types)}
339
+
340
  Text: {text}
341
+
342
  Trả về kết quả theo format JSON:
343
  [
344
  {{
 
349
  }}
350
  ]
351
  """
 
352
  response = await self.generate_text(prompt, **kwargs)
353
+
354
  try:
355
+ logger.info(
356
+ f"[LLM][RAW_RESPONSE][extract_entities] {_safe_truncate(response, 2000)}"
357
+ )
358
+ parsed = _parse_json_from_text(response or "")
359
+ if isinstance(parsed, list):
360
+ return parsed
361
+ if isinstance(parsed, dict):
362
+ return [parsed]
363
+ return []
364
+ except Exception as e:
365
+ logger.error(f"Error parsing entities JSON: {e} | Raw: {response}")
 
 
 
 
 
 
 
 
 
366
  return []
367
 
368
  @timing_decorator_async
369
  async def analyze(
370
+ self, text: str, conversation_context: str, **kwargs
 
 
 
371
  ) -> List[Dict[str, Any]]:
372
  """
373
+ Phân tích câu hỏi về luật giao thông Việt Nam và chuẩn hóa thành JSON.
 
 
 
 
 
 
 
374
  """
 
375
  prompt = f"""
376
+ Bạn là một chuyên gia phân tích ngôn ngữ tự nhiên (NLP) chuyên xử lý các câu hỏi về luật giao thông Việt Nam. Nhiệm vụ của bạn là đọc kỹ **lịch sử trò chuyện** và **câu hỏi mới nhất** của người dùng để trích xuất thông tin vào một cấu trúc JSON duy nhất. Chỉ trả về đối tượng JSON, không thêm bất kỳ giải thích nào.
 
377
 
378
+ Định dạng JSON bắt buộc:
 
 
 
 
 
379
 
380
+ {{
381
+ "muc_dich": "...",
382
+ "phuong_tien": "...",
383
+ "hanh_vi": "...",
384
+ "cau_hoi": "..."
385
+ }}
386
 
387
+ Hướng dẫn chi tiết cho từng trường:
 
 
 
388
 
389
+ **muc_dich**: Phải là một trong các giá trị sau:
390
+ - "hỏi về mức phạt"
391
+ - "hỏi về quy tắc giao thông"
392
+ - "hỏi về báo hiệu đường bộ"
393
+ - "hỏi về quy trình xử lý vi phạm giao thông"
394
+ - "thông tin cá nhân của AI"
395
+ - "khác"
396
 
397
+ **Phải dựa vào câu hỏi mới nhất để xác định.**
398
+
399
+ **phuong_tien**: Tên phương tiện được đề cập trong câu hỏi mới hoặc trong lịch sử gần nhất. Nếu không có, để chuỗi rỗng "".
400
+
401
+ **hanh_vi**: cụm từ hoặc từ khóa ngắn gọn và phù hợp nhất để **tìm kiếm nội dung liên quan đến câu hỏi**. Có thể là tên hành vi vi phạm, thuật ngữ pháp lý, hoặc khái niệm về quy tắc/báo hiệu/vi phạm. Nếu không có thông tin rõ ràng, để chuỗi rỗng "".
402
+
403
+ **cau_hoi**: Diễn đạt lại câu hỏi mới nhất của người dùng thành một câu hỏi hoàn chỉnh, kết hợp ngữ cảnh từ lịch sử nếu cần, sử dụng đúng thuật ngữ pháp lý.
404
+
405
+ VÍ DỤ MẪU:
406
+
407
+ Câu hỏi đầu vào: "ô tô vượt đèn đỏ phạt nhiêu?"
408
+ Kết quả JSON mong muốn:
409
+ {{
410
+ "muc_dich": "hỏi về mức phạt",
411
+ "phuong_tien": "Ô tô",
412
+ "hanh_vi": "Không chấp hành hiệu lệnh của đèn tín hiệu giao thông",
413
+ "cau_hoi": "Mức xử phạt cho hành vi ô tô không chấp hành hiệu lệnh của đèn tín hiệu giao thông là bao nhiêu?"
414
+ }}
415
+
416
+ Bây giờ, hãy phân tích lịch sử và câu hỏi sau và chỉ trả về đối tượng JSON.
417
 
418
+ Lịch sử trò chuyện:
419
+ "{conversation_context}"
420
+
421
+ Câu hỏi:
422
+ "{text}"
423
+ """.strip()
424
 
 
 
 
 
425
  response = await self.generate_text(prompt, **kwargs)
426
+ logger.info(f"[LLM][RAW][analyze] Kết quả trả về từ generate_text: {_safe_truncate(response, 2000)}")
 
427
 
428
  try:
429
+ parsed = _parse_json_from_text(response or "")
430
+ if isinstance(parsed, list):
431
+ return parsed
432
+ if isinstance(parsed, dict):
433
+ return [parsed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  return []
435
+ except Exception as e:
436
+ logger.error(f"Error parsing analyze JSON: {e} | Raw: {response}")
437
+ return []
438
+
439
  async def close(self):
440
  """Đóng client connection."""
441
+ if hasattr(self, "_client") and self._client:
442
+ await self._client.aclose()
443
 
444
 
445
  # Factory function để tạo LLMClient dễ dàng
446
  def create_llm_client(provider: str = "openai", **kwargs) -> LLMClient:
 
 
 
 
 
 
 
 
 
 
447
  return LLMClient(provider, **kwargs)
448
 
449
 
450
  # Ví dụ sử dụng
451
  if __name__ == "__main__":
 
 
452
  async def test_llm():
 
453
  settings = get_settings()
454
  llm_client = create_llm_client(
455
  provider=settings.llm_provider,
456
  model=settings.llm_model,
457
  # ... các config khác nếu cần ...
458
  )
459
+
460
  # Generate text
461
  response = await llm_client.generate_text("Xin chào, bạn có khỏe không?")
462
  print(f"Response: {response}")
463
+
464
  # Chat
465
  messages = [
466
  {"role": "user", "content": "Bạn có thể giúp tôi không?"}
467
  ]
468
  chat_response = await llm_client.chat(messages)
469
  print(f"Chat response: {chat_response}")
470
+
471
  await llm_client.close()
472
+
473
+ asyncio.run(test_llm())