minhvtt commited on
Commit
1f34f53
·
verified ·
1 Parent(s): a2f6bc2

Update agent_service.py

Browse files
Files changed (1) hide show
  1. agent_service.py +318 -258
agent_service.py CHANGED
@@ -1,258 +1,318 @@
1
- """
2
- Agent Service - Central Brain for Sales & Feedback Agents
3
- Manages LLM conversation loop with tool calling
4
- """
5
- from typing import Dict, Any, List, Optional
6
- import os
7
- from tools_service import ToolsService
8
-
9
-
10
- class AgentService:
11
- """
12
- Manages the conversation loop between User -> LLM -> Tools -> Response
13
- """
14
-
15
- def __init__(
16
- self,
17
- tools_service: ToolsService,
18
- embedding_service,
19
- qdrant_service,
20
- advanced_rag,
21
- hf_token: str
22
- ):
23
- self.tools_service = tools_service
24
- self.embedding_service = embedding_service
25
- self.qdrant_service = qdrant_service
26
- self.advanced_rag = advanced_rag
27
- self.hf_token = hf_token
28
-
29
- # Load system prompts
30
- self.prompts = self._load_prompts()
31
-
32
- def _load_prompts(self) -> Dict[str, str]:
33
- """Load system prompts from files"""
34
- prompts = {}
35
- prompts_dir = "prompts"
36
-
37
- for mode in ["sales_agent", "feedback_agent"]:
38
- filepath = os.path.join(prompts_dir, f"{mode}.txt")
39
- try:
40
- with open(filepath, 'r', encoding='utf-8') as f:
41
- prompts[mode] = f.read()
42
- print(f"✓ Loaded prompt: {mode}")
43
- except Exception as e:
44
- print(f"⚠️ Error loading {mode} prompt: {e}")
45
- prompts[mode] = ""
46
-
47
- return prompts
48
-
49
- async def chat(
50
- self,
51
- user_message: str,
52
- conversation_history: List[Dict],
53
- mode: str = "sales", # "sales" or "feedback"
54
- user_id: Optional[str] = None,
55
- max_iterations: int = 3
56
- ) -> Dict[str, Any]:
57
- """
58
- Main conversation loop
59
-
60
- Args:
61
- user_message: User's input
62
- conversation_history: Previous messages [{"role": "user", "content": ...}, ...]
63
- mode: "sales" or "feedback"
64
- user_id: User ID (for feedback mode to check purchase history)
65
- max_iterations: Maximum tool call iterations to prevent infinite loops
66
-
67
- Returns:
68
- {
69
- "message": "Bot response",
70
- "tool_calls": [...], # List of tools called (for debugging)
71
- "mode": mode
72
- }
73
- """
74
- print(f"\n🤖 Agent Mode: {mode}")
75
- print(f"👤 User Message: {user_message}")
76
-
77
- # Select system prompt
78
- system_prompt = self._get_system_prompt(mode)
79
-
80
- # Build conversation context
81
- messages = self._build_messages(system_prompt, conversation_history, user_message)
82
-
83
- # Agentic loop: LLM may call tools multiple times
84
- tool_calls_made = []
85
- current_response = None
86
-
87
- for iteration in range(max_iterations):
88
- print(f"\n🔄 Iteration {iteration + 1}")
89
-
90
- # Call LLM
91
- llm_response = await self._call_llm(messages)
92
- print(f"🧠 LLM Response: {llm_response[:200]}...")
93
-
94
- # Check if LLM wants to call a tool
95
- tool_result = await self.tools_service.parse_and_execute(llm_response)
96
-
97
- if not tool_result:
98
- # No tool call -> This is the final response
99
- current_response = llm_response
100
- break
101
-
102
- # Tool was called
103
- tool_calls_made.append(tool_result)
104
- print(f"🔧 Tool Called: {tool_result.get('function')}")
105
-
106
- # Add tool result to conversation
107
- messages.append({
108
- "role": "assistant",
109
- "content": llm_response
110
- })
111
- messages.append({
112
- "role": "system",
113
- "content": f"Tool Result:\n{self._format_tool_result(tool_result)}"
114
- })
115
-
116
- # If tool returns "run_rag_search", handle it specially
117
- if tool_result.get("result", {}).get("action") == "run_rag_search":
118
- rag_results = await self._execute_rag_search(tool_result["result"]["query"])
119
- messages[-1]["content"] = f"RAG Search Results:\n{rag_results}"
120
-
121
- # Clean up response
122
- final_response = current_response or llm_response
123
- final_response = self._clean_response(final_response)
124
-
125
- return {
126
- "message": final_response,
127
- "tool_calls": tool_calls_made,
128
- "mode": mode
129
- }
130
-
131
- def _get_system_prompt(self, mode: str) -> str:
132
- """Get system prompt for selected mode"""
133
- prompt_key = f"{mode}_agent" if mode in ["sales", "feedback"] else "sales_agent"
134
- return self.prompts.get(prompt_key, "")
135
-
136
- def _build_messages(
137
- self,
138
- system_prompt: str,
139
- history: List[Dict],
140
- user_message: str
141
- ) -> List[Dict]:
142
- """Build messages array for LLM"""
143
- messages = [{"role": "system", "content": system_prompt}]
144
-
145
- # Add conversation history
146
- messages.extend(history)
147
-
148
- # Add current user message
149
- messages.append({"role": "user", "content": user_message})
150
-
151
- return messages
152
-
153
- async def _call_llm(self, messages: List[Dict]) -> str:
154
- """
155
- Call HuggingFace LLM
156
- Uses advanced_rag's chat method
157
- """
158
- try:
159
- # Build prompt from messages
160
- prompt = self._messages_to_prompt(messages)
161
-
162
- # Call HF API via advanced_rag
163
- response = await self.advanced_rag.chat_completion(
164
- user_prompt=prompt,
165
- context="", # Context is already in system prompt
166
- chat_history=[], # History is in messages
167
- token=self.hf_token
168
- )
169
-
170
- return response
171
- except Exception as e:
172
- print(f"⚠️ LLM Call Error: {e}")
173
- return "Xin lỗi, tôi đang gặp chút vấn đề kỹ thuật. Bạn thử lại sau nhé!"
174
-
175
- def _messages_to_prompt(self, messages: List[Dict]) -> str:
176
- """Convert messages array to single prompt string"""
177
- prompt_parts = []
178
-
179
- for msg in messages:
180
- role = msg["role"]
181
- content = msg["content"]
182
-
183
- if role == "system":
184
- prompt_parts.append(f"[SYSTEM]\n{content}\n")
185
- elif role == "user":
186
- prompt_parts.append(f"[USER]\n{content}\n")
187
- elif role == "assistant":
188
- prompt_parts.append(f"[ASSISTANT]\n{content}\n")
189
-
190
- return "\n".join(prompt_parts)
191
-
192
- def _format_tool_result(self, tool_result: Dict) -> str:
193
- """Format tool result for feeding back to LLM"""
194
- result = tool_result.get("result", {})
195
-
196
- if isinstance(result, dict):
197
- # Pretty print key info
198
- formatted = []
199
- for key, value in result.items():
200
- if key not in ["success", "error"]:
201
- formatted.append(f"{key}: {value}")
202
- return "\n".join(formatted)
203
-
204
- return str(result)
205
-
206
- async def _execute_rag_search(self, query_params: Dict) -> str:
207
- """
208
- Execute RAG search for event discovery
209
- Called when LLM wants to search_events
210
- """
211
- query = query_params.get("query", "")
212
- vibe = query_params.get("vibe", "")
213
-
214
- # Build search query
215
- search_text = f"{query} {vibe}".strip()
216
-
217
- print(f"🔍 RAG Search: {search_text}")
218
-
219
- # Use embedding + qdrant
220
- embedding = self.embedding_service.encode_text(search_text)
221
- results = self.qdrant_service.search(
222
- collection_name="events",
223
- query_vector=embedding,
224
- limit=5
225
- )
226
-
227
- # Format results
228
- formatted = []
229
- for i, result in enumerate(results, 1):
230
- payload = result.payload or {}
231
- texts = payload.get("texts", [])
232
- text = texts[0] if texts else ""
233
- event_id = payload.get("id_use", "")
234
-
235
- formatted.append(f"{i}. {text[:100]}... (ID: {event_id})")
236
-
237
- return "\n".join(formatted) if formatted else "Không tìm thấy sự kiện phù hợp."
238
-
239
- def _clean_response(self, response: str) -> str:
240
- """Remove JSON artifacts from final response"""
241
- # Remove JSON blocks
242
- if "```json" in response:
243
- response = response.split("```json")[0]
244
- if "```" in response:
245
- response = response.split("```")[0]
246
-
247
- # Remove tool call markers
248
- if "{" in response and "tool_call" in response:
249
- # Find the last natural sentence before JSON
250
- lines = response.split("\n")
251
- cleaned = []
252
- for line in lines:
253
- if "{" in line and "tool_call" in line:
254
- break
255
- cleaned.append(line)
256
- response = "\n".join(cleaned)
257
-
258
- return response.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent Service - Central Brain for Sales & Feedback Agents
3
+ Manages LLM conversation loop with tool calling
4
+ """
5
+ from typing import Dict, Any, List, Optional
6
+ import os
7
+ from tools_service import ToolsService
8
+
9
+
10
+ class AgentService:
11
+ """
12
+ Manages the conversation loop between User -> LLM -> Tools -> Response
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ tools_service: ToolsService,
18
+ embedding_service,
19
+ qdrant_service,
20
+ advanced_rag,
21
+ hf_token: str
22
+ ):
23
+ self.tools_service = tools_service
24
+ self.embedding_service = embedding_service
25
+ self.qdrant_service = qdrant_service
26
+ self.advanced_rag = advanced_rag
27
+ self.hf_token = hf_token
28
+
29
+ # Load system prompts
30
+ self.prompts = self._load_prompts()
31
+
32
+ def _load_prompts(self) -> Dict[str, str]:
33
+ """Load system prompts from files"""
34
+ prompts = {}
35
+ prompts_dir = "prompts"
36
+
37
+ for mode in ["sales_agent", "feedback_agent"]:
38
+ filepath = os.path.join(prompts_dir, f"{mode}.txt")
39
+ try:
40
+ with open(filepath, 'r', encoding='utf-8') as f:
41
+ prompts[mode] = f.read()
42
+ print(f"✓ Loaded prompt: {mode}")
43
+ except Exception as e:
44
+ print(f"⚠️ Error loading {mode} prompt: {e}")
45
+ prompts[mode] = ""
46
+
47
+ return prompts
48
+
49
+ async def chat(
50
+ self,
51
+ user_message: str,
52
+ conversation_history: List[Dict],
53
+ mode: str = "sales", # "sales" or "feedback"
54
+ user_id: Optional[str] = None,
55
+ max_iterations: int = 3
56
+ ) -> Dict[str, Any]:
57
+ """
58
+ Main conversation loop
59
+
60
+ Args:
61
+ user_message: User's input
62
+ conversation_history: Previous messages [{"role": "user", "content": ...}, ...]
63
+ mode: "sales" or "feedback"
64
+ user_id: User ID (for feedback mode to check purchase history)
65
+ max_iterations: Maximum tool call iterations to prevent infinite loops
66
+
67
+ Returns:
68
+ {
69
+ "message": "Bot response",
70
+ "tool_calls": [...], # List of tools called (for debugging)
71
+ "mode": mode
72
+ }
73
+ """
74
+ print(f"\n🤖 Agent Mode: {mode}")
75
+ print(f"👤 User Message: {user_message}")
76
+
77
+ # Select system prompt
78
+ system_prompt = self._get_system_prompt(mode)
79
+
80
+ # Build conversation context
81
+ messages = self._build_messages(system_prompt, conversation_history, user_message)
82
+
83
+ # Agentic loop: LLM may call tools multiple times
84
+ tool_calls_made = []
85
+ current_response = None
86
+
87
+ for iteration in range(max_iterations):
88
+ print(f"\n🔄 Iteration {iteration + 1}")
89
+
90
+ # Call LLM
91
+ llm_response = await self._call_llm(messages)
92
+ print(f"🧠 LLM Response: {llm_response[:200]}...")
93
+
94
+ # Check if LLM wants to call a tool
95
+ tool_call = self._parse_tool_call(llm_response)
96
+
97
+ if not tool_call:
98
+ # No tool call -> This is the final response
99
+ current_response = llm_response
100
+ break
101
+
102
+ # Execute tool
103
+ print(f"🔧 Tool Called: {tool_call['tool_name']}")
104
+ tool_result = await self.tools_service.execute_tool(
105
+ tool_call['tool_name'],
106
+ tool_call['arguments']
107
+ )
108
+
109
+ # Record tool call
110
+ tool_calls_made.append({
111
+ "function": tool_call['tool_name'],
112
+ "arguments": tool_call['arguments'],
113
+ "result": tool_result
114
+ })
115
+
116
+ # Add tool result to conversation
117
+ messages.append({
118
+ "role": "assistant",
119
+ "content": llm_response
120
+ })
121
+ messages.append({
122
+ "role": "system",
123
+ "content": f"Tool Result:\n{self._format_tool_result({'result': tool_result})}"
124
+ })
125
+
126
+ # If tool returns "run_rag_search", handle it specially
127
+ if isinstance(tool_result, dict) and tool_result.get("action") == "run_rag_search":
128
+ rag_results = await self._execute_rag_search(tool_result["query"])
129
+ messages[-1]["content"] = f"RAG Search Results:\n{rag_results}"
130
+
131
+ # Clean up response
132
+ final_response = current_response or llm_response
133
+ final_response = self._clean_response(final_response)
134
+
135
+ return {
136
+ "message": final_response,
137
+ "tool_calls": tool_calls_made,
138
+ "mode": mode
139
+ }
140
+
141
+ def _get_system_prompt(self, mode: str) -> str:
142
+ """Get system prompt for selected mode"""
143
+ prompt_key = f"{mode}_agent" if mode in ["sales", "feedback"] else "sales_agent"
144
+ return self.prompts.get(prompt_key, "")
145
+
146
+ def _build_messages(
147
+ self,
148
+ system_prompt: str,
149
+ history: List[Dict],
150
+ user_message: str
151
+ ) -> List[Dict]:
152
+ """Build messages array for LLM"""
153
+ messages = [{"role": "system", "content": system_prompt}]
154
+
155
+ # Add conversation history
156
+ messages.extend(history)
157
+
158
+ # Add current user message
159
+ messages.append({"role": "user", "content": user_message})
160
+
161
+ return messages
162
+
163
+ async def _call_llm(self, messages: List[Dict]) -> str:
164
+ """
165
+ Call HuggingFace LLM
166
+ Uses advanced_rag's chat method
167
+ """
168
+ try:
169
+ # Build prompt from messages
170
+ prompt = self._messages_to_prompt(messages)
171
+
172
+ # Call HF API via advanced_rag
173
+ response = await self.advanced_rag.chat_completion(
174
+ user_prompt=prompt,
175
+ context="", # Context is already in system prompt
176
+ chat_history=[], # History is in messages
177
+ token=self.hf_token
178
+ )
179
+
180
+ return response
181
+ except Exception as e:
182
+ print(f"⚠️ LLM Call Error: {e}")
183
+ return "Xin lỗi, tôi đang gặp chút vấn đề kỹ thuật. Bạn thử lại sau nhé!"
184
+
185
+ def _messages_to_prompt(self, messages: List[Dict]) -> str:
186
+ """Convert messages array to single prompt string"""
187
+ prompt_parts = []
188
+
189
+ for msg in messages:
190
+ role = msg["role"]
191
+ content = msg["content"]
192
+
193
+ if role == "system":
194
+ prompt_parts.append(f"[SYSTEM]\n{content}\n")
195
+ elif role == "user":
196
+ prompt_parts.append(f"[USER]\n{content}\n")
197
+ elif role == "assistant":
198
+ prompt_parts.append(f"[ASSISTANT]\n{content}\n")
199
+
200
+ return "\n".join(prompt_parts)
201
+
202
+ def _format_tool_result(self, tool_result: Dict) -> str:
203
+ """Format tool result for feeding back to LLM"""
204
+ result = tool_result.get("result", {})
205
+
206
+ if isinstance(result, dict):
207
+ # Pretty print key info
208
+ formatted = []
209
+ for key, value in result.items():
210
+ if key not in ["success", "error"]:
211
+ formatted.append(f"{key}: {value}")
212
+ return "\n".join(formatted)
213
+
214
+ return str(result)
215
+
216
+ async def _execute_rag_search(self, query_params: Dict) -> str:
217
+ """
218
+ Execute RAG search for event discovery
219
+ Called when LLM wants to search_events
220
+ """
221
+ query = query_params.get("query", "")
222
+ vibe = query_params.get("vibe", "")
223
+
224
+ # Build search query
225
+ search_text = f"{query} {vibe}".strip()
226
+
227
+ print(f"🔍 RAG Search: {search_text}")
228
+
229
+ # Use embedding + qdrant
230
+ embedding = self.embedding_service.encode_text(search_text)
231
+ results = self.qdrant_service.search(
232
+ collection_name="events",
233
+ query_vector=embedding,
234
+ limit=5
235
+ )
236
+
237
+ # Format results
238
+ formatted = []
239
+ for i, result in enumerate(results, 1):
240
+ payload = result.payload or {}
241
+ texts = payload.get("texts", [])
242
+ text = texts[0] if texts else ""
243
+ event_id = payload.get("id_use", "")
244
+
245
+ formatted.append(f"{i}. {text[:100]}... (ID: {event_id})")
246
+
247
+ return "\n".join(formatted) if formatted else "Không tìm thấy sự kiện phù hợp."
248
+
249
+ def _parse_tool_call(self, llm_response: str) -> Optional[Dict]:
250
+ """
251
+ Parse LLM response to detect tool calls
252
+
253
+ Returns:
254
+ {"tool_name": "...", "arguments": {...}} or None
255
+ """
256
+ import json
257
+
258
+ # Simple heuristic: Check if response mentions tools
259
+ # In a real system, LLM should output structured JSON
260
+
261
+ # For now, we'll use keyword detection
262
+ # TODO: Train LLM to output proper tool call JSON
263
+
264
+ response_lower = llm_response.lower()
265
+
266
+ # Check for search intent
267
+ if any(keyword in response_lower for keyword in ["tìm kiếm", "search", "tìm event"]):
268
+ # Extract query from response
269
+ return {
270
+ "tool_name": "search_events",
271
+ "arguments": {"query": llm_response[:100]}
272
+ }
273
+
274
+ # Check for event details intent
275
+ if "get_event_details" in response_lower or "chi tiết sự kiện" in response_lower:
276
+ # Try to extract event_id
277
+ # Simple extraction - in production use better parsing
278
+ return None # Skip for now
279
+
280
+ # Try to parse JSON if present
281
+ try:
282
+ if "{" in llm_response and "}" in llm_response:
283
+ json_start = llm_response.find("{")
284
+ json_end = llm_response.rfind("}") + 1
285
+ json_str = llm_response[json_start:json_end]
286
+ data = json.loads(json_str)
287
+
288
+ # Check if it's a tool call
289
+ if "tool_name" in data or "function" in data:
290
+ return {
291
+ "tool_name": data.get("tool_name") or data.get("function"),
292
+ "arguments": data.get("arguments", {})
293
+ }
294
+ except:
295
+ pass
296
+
297
+ return None
298
+
299
+ def _clean_response(self, response: str) -> str:
300
+ """Remove JSON artifacts from final response"""
301
+ # Remove JSON blocks
302
+ if "```json" in response:
303
+ response = response.split("```json")[0]
304
+ if "```" in response:
305
+ response = response.split("```")[0]
306
+
307
+ # Remove tool call markers
308
+ if "{" in response and "tool_call" in response:
309
+ # Find the last natural sentence before JSON
310
+ lines = response.split("\n")
311
+ cleaned = []
312
+ for line in lines:
313
+ if "{" in line and "tool_call" in line:
314
+ break
315
+ cleaned.append(line)
316
+ response = "\n".join(cleaned)
317
+
318
+ return response.strip()