"""Common service for handling user queries across different agents.""" from mistralai import MessageOutputEvent, FunctionCallEvent, ResponseErrorEvent import orjson from app.services.base_agent import BaseAgent from app.utils.config import tool_to_server_map from app.utils.logging_config import get_logger from app.utils.utils_log import get_error_msg from app.models.schemas import QueryOutputFormat import weave logger = get_logger(__name__) # logger.setLevel("DEBUG") # TODO DEBUG ONLY @weave.op() async def handle_user_query(user_query: str, agent: BaseAgent) -> QueryOutputFormat: """ Handle user query processing for any agent type. Args: user_query: The user's question or request agent: The agent instance to process the query with Returns: The processed response as a dictionary """ if not user_query: return QueryOutputFormat(response={}, status="empty_query", mcp_tools_used=[], query=user_query) try: out_txt = "" mcp_tools_used = [] errors = [] result_events = await agent.process_query(user_query) async for event in result_events: # logger.debug(event) if hasattr(event, "data") and event.data: match event.data: # Stream agent text responses to the UI case MessageOutputEvent(): match event.data.content: case str(): # logger.debug(event.data.content) out_txt += event.data.content # Display which MCP server and tool is being executed case FunctionCallEvent(): server_tool = f"{tool_to_server_map[event.data.name]} - {event.data.name}" if server_tool not in mcp_tools_used: mcp_tools_used.append(server_tool) logger.debug(server_tool) # Handle and display any errors from the agent case ResponseErrorEvent(): logger.debug(event.data) error_msg = f"Error: {event.data.message}" errors.append(error_msg) out_txt += f"\n\n 🔴 {error_msg}\n\n" result = QueryOutputFormat( response=orjson.loads(out_txt), status="success" if not errors else "partial_success", mcp_tools_used=mcp_tools_used, query=user_query, errors=errors, ) except Exception as e: result = QueryOutputFormat( response={"error": str(e)}, status="error", mcp_tools_used=mcp_tools_used, query=user_query, errors=[get_error_msg(e)], ) logger.debug(f"result:\n{result.model_dump_json()}") return result