File size: 2,938 Bytes
1c9e323
 
 
 
 
 
 
 
 
 
 
 
 
5e5b370
1c9e323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Common service for handling user queries across different agents."""

from mistralai import MessageOutputEvent, FunctionCallEvent, ResponseErrorEvent

import orjson
from app.services.base_agent import BaseAgent
from app.utils.config import tool_to_server_map
from app.utils.logging_config import get_logger
from app.utils.utils_log import get_error_msg
from app.models.schemas import QueryOutputFormat
import weave

logger = get_logger(__name__)
# logger.setLevel("DEBUG")  # TODO DEBUG ONLY


@weave.op()
async def handle_user_query(user_query: str, agent: BaseAgent) -> QueryOutputFormat:
    """
    Handle user query processing for any agent type.

    Args:
        user_query: The user's question or request
        agent: The agent instance to process the query with

    Returns:
        The processed response as a dictionary
    """

    if not user_query:
        return QueryOutputFormat(response={}, status="empty_query", mcp_tools_used=[], query=user_query)

    try:
        out_txt = ""
        mcp_tools_used = []
        errors = []
        result_events = await agent.process_query(user_query)

        async for event in result_events:
            # logger.debug(event)
            if hasattr(event, "data") and event.data:
                match event.data:
                    # Stream agent text responses to the UI
                    case MessageOutputEvent():
                        match event.data.content:
                            case str():
                                # logger.debug(event.data.content)
                                out_txt += event.data.content

                    # Display which MCP server and tool is being executed
                    case FunctionCallEvent():
                        server_tool = f"{tool_to_server_map[event.data.name]} - {event.data.name}"
                        if server_tool not in mcp_tools_used:
                            mcp_tools_used.append(server_tool)
                            logger.debug(server_tool)

                    # Handle and display any errors from the agent
                    case ResponseErrorEvent():
                        logger.debug(event.data)
                        error_msg = f"Error: {event.data.message}"
                        errors.append(error_msg)
                        out_txt += f"\n\n 🔴 {error_msg}\n\n"

        result = QueryOutputFormat(
            response=orjson.loads(out_txt),
            status="success" if not errors else "partial_success",
            mcp_tools_used=mcp_tools_used,
            query=user_query,
            errors=errors,
        )

    except Exception as e:
        result = QueryOutputFormat(
            response={"error": str(e)},
            status="error",
            mcp_tools_used=mcp_tools_used,
            query=user_query,
            errors=[get_error_msg(e)],
        )

    logger.debug(f"result:\n{result.model_dump_json()}")
    return result