Spaces:
Runtime error
Runtime error
| """Common service for handling user queries across different agents.""" | |
| from mistralai import MessageOutputEvent, FunctionCallEvent, ResponseErrorEvent | |
| import orjson | |
| from app.services.base_agent import BaseAgent | |
| from app.utils.config import tool_to_server_map | |
| from app.utils.logging_config import get_logger | |
| from app.utils.utils_log import get_error_msg | |
| from app.models.schemas import QueryOutputFormat | |
| import weave | |
| logger = get_logger(__name__) | |
| # logger.setLevel("DEBUG") # TODO DEBUG ONLY | |
| async def handle_user_query(user_query: str, agent: BaseAgent) -> QueryOutputFormat: | |
| """ | |
| Handle user query processing for any agent type. | |
| Args: | |
| user_query: The user's question or request | |
| agent: The agent instance to process the query with | |
| Returns: | |
| The processed response as a dictionary | |
| """ | |
| if not user_query: | |
| return QueryOutputFormat(response={}, status="empty_query", mcp_tools_used=[], query=user_query) | |
| try: | |
| out_txt = "" | |
| mcp_tools_used = [] | |
| errors = [] | |
| result_events = await agent.process_query(user_query) | |
| async for event in result_events: | |
| # logger.debug(event) | |
| if hasattr(event, "data") and event.data: | |
| match event.data: | |
| # Stream agent text responses to the UI | |
| case MessageOutputEvent(): | |
| match event.data.content: | |
| case str(): | |
| # logger.debug(event.data.content) | |
| out_txt += event.data.content | |
| # Display which MCP server and tool is being executed | |
| case FunctionCallEvent(): | |
| server_tool = f"{tool_to_server_map[event.data.name]} - {event.data.name}" | |
| if server_tool not in mcp_tools_used: | |
| mcp_tools_used.append(server_tool) | |
| logger.debug(server_tool) | |
| # Handle and display any errors from the agent | |
| case ResponseErrorEvent(): | |
| logger.debug(event.data) | |
| error_msg = f"Error: {event.data.message}" | |
| errors.append(error_msg) | |
| out_txt += f"\n\n 🔴 {error_msg}\n\n" | |
| result = QueryOutputFormat( | |
| response=orjson.loads(out_txt), | |
| status="success" if not errors else "partial_success", | |
| mcp_tools_used=mcp_tools_used, | |
| query=user_query, | |
| errors=errors, | |
| ) | |
| except Exception as e: | |
| result = QueryOutputFormat( | |
| response={"error": str(e)}, | |
| status="error", | |
| mcp_tools_used=mcp_tools_used, | |
| query=user_query, | |
| errors=[get_error_msg(e)], | |
| ) | |
| logger.debug(f"result:\n{result.model_dump_json()}") | |
| return result | |