Spaces:
Runtime error
Runtime error
File size: 2,938 Bytes
1c9e323 5e5b370 1c9e323 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | """Common service for handling user queries across different agents."""
from mistralai import MessageOutputEvent, FunctionCallEvent, ResponseErrorEvent
import orjson
from app.services.base_agent import BaseAgent
from app.utils.config import tool_to_server_map
from app.utils.logging_config import get_logger
from app.utils.utils_log import get_error_msg
from app.models.schemas import QueryOutputFormat
import weave
logger = get_logger(__name__)
# logger.setLevel("DEBUG") # TODO DEBUG ONLY
@weave.op()
async def handle_user_query(user_query: str, agent: BaseAgent) -> QueryOutputFormat:
"""
Handle user query processing for any agent type.
Args:
user_query: The user's question or request
agent: The agent instance to process the query with
Returns:
The processed response as a dictionary
"""
if not user_query:
return QueryOutputFormat(response={}, status="empty_query", mcp_tools_used=[], query=user_query)
try:
out_txt = ""
mcp_tools_used = []
errors = []
result_events = await agent.process_query(user_query)
async for event in result_events:
# logger.debug(event)
if hasattr(event, "data") and event.data:
match event.data:
# Stream agent text responses to the UI
case MessageOutputEvent():
match event.data.content:
case str():
# logger.debug(event.data.content)
out_txt += event.data.content
# Display which MCP server and tool is being executed
case FunctionCallEvent():
server_tool = f"{tool_to_server_map[event.data.name]} - {event.data.name}"
if server_tool not in mcp_tools_used:
mcp_tools_used.append(server_tool)
logger.debug(server_tool)
# Handle and display any errors from the agent
case ResponseErrorEvent():
logger.debug(event.data)
error_msg = f"Error: {event.data.message}"
errors.append(error_msg)
out_txt += f"\n\n 🔴 {error_msg}\n\n"
result = QueryOutputFormat(
response=orjson.loads(out_txt),
status="success" if not errors else "partial_success",
mcp_tools_used=mcp_tools_used,
query=user_query,
errors=errors,
)
except Exception as e:
result = QueryOutputFormat(
response={"error": str(e)},
status="error",
mcp_tools_used=mcp_tools_used,
query=user_query,
errors=[get_error_msg(e)],
)
logger.debug(f"result:\n{result.model_dump_json()}")
return result
|