minhvtt commited on
Commit
35d088a
·
verified ·
1 Parent(s): 1f34f53

Update agent_service.py

Browse files
Files changed (1) hide show
  1. agent_service.py +180 -44
agent_service.py CHANGED
@@ -139,9 +139,95 @@ class AgentService:
139
  }
140
 
141
  def _get_system_prompt(self, mode: str) -> str:
142
- """Get system prompt for selected mode"""
143
  prompt_key = f"{mode}_agent" if mode in ["sales", "feedback"] else "sales_agent"
144
- return self.prompts.get(prompt_key, "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  def _build_messages(
147
  self,
@@ -162,22 +248,29 @@ class AgentService:
162
 
163
  async def _call_llm(self, messages: List[Dict]) -> str:
164
  """
165
- Call HuggingFace LLM
166
- Uses advanced_rag's chat method
167
  """
168
  try:
 
 
169
  # Build prompt from messages
170
  prompt = self._messages_to_prompt(messages)
171
 
172
- # Call HF API via advanced_rag
173
- response = await self.advanced_rag.chat_completion(
174
- user_prompt=prompt,
175
- context="", # Context is already in system prompt
176
- chat_history=[], # History is in messages
177
- token=self.hf_token
178
- )
 
 
 
 
 
 
179
 
180
- return response
181
  except Exception as e:
182
  print(f"⚠️ LLM Call Error: {e}")
183
  return "Xin lỗi, tôi đang gặp chút vấn đề kỹ thuật. Bạn thử lại sau nhé!"
@@ -248,51 +341,94 @@ class AgentService:
248
 
249
  def _parse_tool_call(self, llm_response: str) -> Optional[Dict]:
250
  """
251
- Parse LLM response to detect tool calls
252
 
253
  Returns:
254
  {"tool_name": "...", "arguments": {...}} or None
255
  """
256
  import json
 
257
 
258
- # Simple heuristic: Check if response mentions tools
259
- # In a real system, LLM should output structured JSON
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
- # For now, we'll use keyword detection
262
- # TODO: Train LLM to output proper tool call JSON
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
- response_lower = llm_response.lower()
 
 
 
 
265
 
266
- # Check for search intent
267
- if any(keyword in response_lower for keyword in ["tìm kiếm", "search", "tìm event"]):
268
- # Extract query from response
 
 
 
 
269
  return {
270
- "tool_name": "search_events",
271
- "arguments": {"query": llm_response[:100]}
272
  }
273
 
274
- # Check for event details intent
275
- if "get_event_details" in response_lower or "chi tiết sự kiện" in response_lower:
276
- # Try to extract event_id
277
- # Simple extraction - in production use better parsing
278
- return None # Skip for now
 
279
 
280
- # Try to parse JSON if present
281
- try:
282
- if "{" in llm_response and "}" in llm_response:
283
- json_start = llm_response.find("{")
284
- json_end = llm_response.rfind("}") + 1
285
- json_str = llm_response[json_start:json_end]
286
- data = json.loads(json_str)
287
-
288
- # Check if it's a tool call
289
- if "tool_name" in data or "function" in data:
290
- return {
291
- "tool_name": data.get("tool_name") or data.get("function"),
292
- "arguments": data.get("arguments", {})
293
- }
294
- except:
295
- pass
296
 
297
  return None
298
 
 
139
  }
140
 
141
  def _get_system_prompt(self, mode: str) -> str:
142
+ """Get system prompt for selected mode with tools definition"""
143
  prompt_key = f"{mode}_agent" if mode in ["sales", "feedback"] else "sales_agent"
144
+ base_prompt = self.prompts.get(prompt_key, "")
145
+
146
+ # Add tools definition
147
+ tools_definition = self._get_tools_definition()
148
+
149
+ return f"{base_prompt}\n\n{tools_definition}"
150
+
151
+ def _get_tools_definition(self) -> str:
152
+ """Get tools definition in text format for prompt"""
153
+ return """
154
+ # AVAILABLE TOOLS
155
+
156
+ You can call the following tools when needed. To call a tool, output a JSON block like this:
157
+
158
+ ```json
159
+ {
160
+ "tool_call": "tool_name",
161
+ "arguments": {
162
+ "arg1": "value1",
163
+ "arg2": "value2"
164
+ }
165
+ }
166
+ ```
167
+
168
+ ## Tools List:
169
+
170
+ ### 1. search_events
171
+ Search for events matching user criteria.
172
+ Arguments:
173
+ - query (string): Search keywords
174
+ - vibe (string, optional): Mood/vibe (e.g., "chill", "sôi động")
175
+ - time (string, optional): Time period (e.g., "cuối tuần này")
176
+
177
+ Example:
178
+ ```json
179
+ {"tool_call": "search_events", "arguments": {"query": "nhạc rock", "vibe": "sôi động"}}
180
+ ```
181
+
182
+ ### 2. get_event_details
183
+ Get detailed information about a specific event.
184
+ Arguments:
185
+ - event_id (string): Event ID from search results
186
+
187
+ Example:
188
+ ```json
189
+ {"tool_call": "get_event_details", "arguments": {"event_id": "6900ae38eb03f29702c7fd1d"}}
190
+ ```
191
+
192
+ ### 3. get_purchased_events (Feedback mode only)
193
+ Check which events the user has attended.
194
+ Arguments:
195
+ - user_id (string): User ID
196
+
197
+ Example:
198
+ ```json
199
+ {"tool_call": "get_purchased_events", "arguments": {"user_id": "user_123"}}
200
+ ```
201
+
202
+ ### 4. save_feedback
203
+ Save user's feedback/review for an event.
204
+ Arguments:
205
+ - event_id (string): Event ID
206
+ - rating (integer): 1-5 stars
207
+ - comment (string, optional): User's comment
208
+
209
+ Example:
210
+ ```json
211
+ {"tool_call": "save_feedback", "arguments": {"event_id": "abc123", "rating": 5, "comment": "Tuyệt vời!"}}
212
+ ```
213
+
214
+ ### 5. save_lead
215
+ Save customer contact information.
216
+ Arguments:
217
+ - email (string, optional): Email address
218
+ - phone (string, optional): Phone number
219
+ - interest (string, optional): What they're interested in
220
+
221
+ Example:
222
+ ```json
223
+ {"tool_call": "save_lead", "arguments": {"email": "user@example.com", "interest": "Rock show"}}
224
+ ```
225
+
226
+ **IMPORTANT:**
227
+ - Call tools ONLY when you need real-time data
228
+ - After receiving tool results, respond naturally to the user
229
+ - Don't expose raw JSON to users - always format nicely
230
+ """
231
 
232
  def _build_messages(
233
  self,
 
248
 
249
  async def _call_llm(self, messages: List[Dict]) -> str:
250
  """
251
+ Call HuggingFace LLM directly using InferenceClient
 
252
  """
253
  try:
254
+ from huggingface_hub import AsyncInferenceClient
255
+
256
  # Build prompt from messages
257
  prompt = self._messages_to_prompt(messages)
258
 
259
+ # Create async client
260
+ client = AsyncInferenceClient(token=self.hf_token)
261
+
262
+ # Call HF API
263
+ response_text = ""
264
+ async for token in await client.text_generation(
265
+ prompt=prompt,
266
+ model="openai/gpt-oss-20b",
267
+ max_new_tokens=512,
268
+ temperature=0.7,
269
+ stream=True
270
+ ):
271
+ response_text += token
272
 
273
+ return response_text
274
  except Exception as e:
275
  print(f"⚠️ LLM Call Error: {e}")
276
  return "Xin lỗi, tôi đang gặp chút vấn đề kỹ thuật. Bạn thử lại sau nhé!"
 
341
 
342
  def _parse_tool_call(self, llm_response: str) -> Optional[Dict]:
343
  """
344
+ Parse LLM response to detect tool calls using structured JSON
345
 
346
  Returns:
347
  {"tool_name": "...", "arguments": {...}} or None
348
  """
349
  import json
350
+ import re
351
 
352
+ # Method 1: Look for JSON code block
353
+ json_match = re.search(r'```json\s*(\{.*?\})\s*```', llm_response, re.DOTALL)
354
+ if json_match:
355
+ try:
356
+ data = json.loads(json_match.group(1))
357
+ return self._extract_tool_from_json(data)
358
+ except json.JSONDecodeError:
359
+ pass
360
+
361
+ # Method 2: Look for inline JSON object
362
+ # Find all potential JSON objects
363
+ json_objects = re.findall(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', llm_response)
364
+ for json_str in json_objects:
365
+ try:
366
+ data = json.loads(json_str)
367
+ tool_call = self._extract_tool_from_json(data)
368
+ if tool_call:
369
+ return tool_call
370
+ except json.JSONDecodeError:
371
+ continue
372
 
373
+ # Method 3: Nested JSON (for complex structures)
374
+ try:
375
+ # Find outermost curly braces
376
+ if '{' in llm_response and '}' in llm_response:
377
+ start = llm_response.find('{')
378
+ # Find matching closing brace
379
+ count = 0
380
+ for i, char in enumerate(llm_response[start:], start):
381
+ if char == '{':
382
+ count += 1
383
+ elif char == '}':
384
+ count -= 1
385
+ if count == 0:
386
+ json_str = llm_response[start:i+1]
387
+ data = json.loads(json_str)
388
+ return self._extract_tool_from_json(data)
389
+ except (json.JSONDecodeError, ValueError):
390
+ pass
391
 
392
+ return None
393
+
394
+ def _extract_tool_from_json(self, data: dict) -> Optional[Dict]:
395
+ """
396
+ Extract tool call information from parsed JSON
397
 
398
+ Supports multiple formats:
399
+ - {"tool_call": "search_events", "arguments": {...}}
400
+ - {"function": "search_events", "parameters": {...}}
401
+ - {"name": "search_events", "args": {...}}
402
+ """
403
+ # Format 1: tool_call + arguments
404
+ if "tool_call" in data and isinstance(data["tool_call"], str):
405
  return {
406
+ "tool_name": data["tool_call"],
407
+ "arguments": data.get("arguments", {})
408
  }
409
 
410
+ # Format 2: function + parameters
411
+ if "function" in data:
412
+ return {
413
+ "tool_name": data["function"],
414
+ "arguments": data.get("parameters", data.get("arguments", {}))
415
+ }
416
 
417
+ # Format 3: name + args
418
+ if "name" in data:
419
+ return {
420
+ "tool_name": data["name"],
421
+ "arguments": data.get("args", data.get("arguments", {}))
422
+ }
423
+
424
+ # Format 4: Direct tool name as key
425
+ valid_tools = ["search_events", "get_event_details", "get_purchased_events", "save_feedback", "save_lead"]
426
+ for tool in valid_tools:
427
+ if tool in data:
428
+ return {
429
+ "tool_name": tool,
430
+ "arguments": data[tool] if isinstance(data[tool], dict) else {}
431
+ }
 
432
 
433
  return None
434