ChatbotRAG / agent_service.py
minhvtt's picture
Upload 16 files
7caa85c verified
raw
history blame
9.33 kB
"""
Agent Service - Central Brain for Sales & Feedback Agents
Manages LLM conversation loop with tool calling
"""
from typing import Dict, Any, List, Optional
import os
from tools_service import ToolsService
class AgentService:
"""
Manages the conversation loop between User -> LLM -> Tools -> Response
"""
def __init__(
self,
tools_service: ToolsService,
embedding_service,
qdrant_service,
advanced_rag,
hf_token: str
):
self.tools_service = tools_service
self.embedding_service = embedding_service
self.qdrant_service = qdrant_service
self.advanced_rag = advanced_rag
self.hf_token = hf_token
# Load system prompts
self.prompts = self._load_prompts()
def _load_prompts(self) -> Dict[str, str]:
"""Load system prompts from files"""
prompts = {}
prompts_dir = "prompts"
for mode in ["sales_agent", "feedback_agent"]:
filepath = os.path.join(prompts_dir, f"{mode}.txt")
try:
with open(filepath, 'r', encoding='utf-8') as f:
prompts[mode] = f.read()
print(f"✓ Loaded prompt: {mode}")
except Exception as e:
print(f"⚠️ Error loading {mode} prompt: {e}")
prompts[mode] = ""
return prompts
async def chat(
self,
user_message: str,
conversation_history: List[Dict],
mode: str = "sales", # "sales" or "feedback"
user_id: Optional[str] = None,
max_iterations: int = 3
) -> Dict[str, Any]:
"""
Main conversation loop
Args:
user_message: User's input
conversation_history: Previous messages [{"role": "user", "content": ...}, ...]
mode: "sales" or "feedback"
user_id: User ID (for feedback mode to check purchase history)
max_iterations: Maximum tool call iterations to prevent infinite loops
Returns:
{
"message": "Bot response",
"tool_calls": [...], # List of tools called (for debugging)
"mode": mode
}
"""
print(f"\n🤖 Agent Mode: {mode}")
print(f"👤 User Message: {user_message}")
# Select system prompt
system_prompt = self._get_system_prompt(mode)
# Build conversation context
messages = self._build_messages(system_prompt, conversation_history, user_message)
# Agentic loop: LLM may call tools multiple times
tool_calls_made = []
current_response = None
for iteration in range(max_iterations):
print(f"\n🔄 Iteration {iteration + 1}")
# Call LLM
llm_response = await self._call_llm(messages)
print(f"🧠 LLM Response: {llm_response[:200]}...")
# Check if LLM wants to call a tool
tool_result = await self.tools_service.parse_and_execute(llm_response)
if not tool_result:
# No tool call -> This is the final response
current_response = llm_response
break
# Tool was called
tool_calls_made.append(tool_result)
print(f"🔧 Tool Called: {tool_result.get('function')}")
# Add tool result to conversation
messages.append({
"role": "assistant",
"content": llm_response
})
messages.append({
"role": "system",
"content": f"Tool Result:\n{self._format_tool_result(tool_result)}"
})
# If tool returns "run_rag_search", handle it specially
if tool_result.get("result", {}).get("action") == "run_rag_search":
rag_results = await self._execute_rag_search(tool_result["result"]["query"])
messages[-1]["content"] = f"RAG Search Results:\n{rag_results}"
# Clean up response
final_response = current_response or llm_response
final_response = self._clean_response(final_response)
return {
"message": final_response,
"tool_calls": tool_calls_made,
"mode": mode
}
def _get_system_prompt(self, mode: str) -> str:
"""Get system prompt for selected mode"""
prompt_key = f"{mode}_agent" if mode in ["sales", "feedback"] else "sales_agent"
return self.prompts.get(prompt_key, "")
def _build_messages(
self,
system_prompt: str,
history: List[Dict],
user_message: str
) -> List[Dict]:
"""Build messages array for LLM"""
messages = [{"role": "system", "content": system_prompt}]
# Add conversation history
messages.extend(history)
# Add current user message
messages.append({"role": "user", "content": user_message})
return messages
async def _call_llm(self, messages: List[Dict]) -> str:
"""
Call HuggingFace LLM
Uses advanced_rag's chat method
"""
try:
# Build prompt from messages
prompt = self._messages_to_prompt(messages)
# Call HF API via advanced_rag
response = await self.advanced_rag.chat_completion(
user_prompt=prompt,
context="", # Context is already in system prompt
chat_history=[], # History is in messages
token=self.hf_token
)
return response
except Exception as e:
print(f"⚠️ LLM Call Error: {e}")
return "Xin lỗi, tôi đang gặp chút vấn đề kỹ thuật. Bạn thử lại sau nhé!"
def _messages_to_prompt(self, messages: List[Dict]) -> str:
"""Convert messages array to single prompt string"""
prompt_parts = []
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
prompt_parts.append(f"[SYSTEM]\n{content}\n")
elif role == "user":
prompt_parts.append(f"[USER]\n{content}\n")
elif role == "assistant":
prompt_parts.append(f"[ASSISTANT]\n{content}\n")
return "\n".join(prompt_parts)
def _format_tool_result(self, tool_result: Dict) -> str:
"""Format tool result for feeding back to LLM"""
result = tool_result.get("result", {})
if isinstance(result, dict):
# Pretty print key info
formatted = []
for key, value in result.items():
if key not in ["success", "error"]:
formatted.append(f"{key}: {value}")
return "\n".join(formatted)
return str(result)
async def _execute_rag_search(self, query_params: Dict) -> str:
"""
Execute RAG search for event discovery
Called when LLM wants to search_events
"""
query = query_params.get("query", "")
vibe = query_params.get("vibe", "")
# Build search query
search_text = f"{query} {vibe}".strip()
print(f"🔍 RAG Search: {search_text}")
# Use embedding + qdrant
embedding = self.embedding_service.encode_text(search_text)
results = self.qdrant_service.search(
collection_name="events",
query_vector=embedding,
limit=5
)
# Format results
formatted = []
for i, result in enumerate(results, 1):
payload = result.payload or {}
texts = payload.get("texts", [])
text = texts[0] if texts else ""
event_id = payload.get("id_use", "")
formatted.append(f"{i}. {text[:100]}... (ID: {event_id})")
return "\n".join(formatted) if formatted else "Không tìm thấy sự kiện phù hợp."
def _clean_response(self, response: str) -> str:
"""Remove JSON artifacts from final response"""
# Remove JSON blocks
if "```json" in response:
response = response.split("```json")[0]
if "```" in response:
response = response.split("```")[0]
# Remove tool call markers
if "{" in response and "tool_call" in response:
# Find the last natural sentence before JSON
lines = response.split("\n")
cleaned = []
for line in lines:
if "{" in line and "tool_call" in line:
break
cleaned.append(line)
response = "\n".join(cleaned)
return response.strip()