File size: 4,361 Bytes
0171bb1
3973360
6d6ae78
0171bb1
3973360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f87934b
3973360
0171bb1
6d6ae78
 
 
f87934b
6d6ae78
 
 
 
 
 
 
 
 
 
 
3973360
6d6ae78
 
 
 
 
 
 
 
0171bb1
 
 
6d6ae78
 
 
 
3973360
6d6ae78
 
 
 
 
 
 
 
 
 
f87934b
6d6ae78
 
 
3973360
6d6ae78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3973360
6d6ae78
3973360
6d6ae78
 
3973360
6d6ae78
 
 
 
 
 
 
 
 
 
 
3973360
6d6ae78
 
 
3973360
6d6ae78
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import tools_condition
from langgraph.graph import END, StateGraph
from langchain_core.messages import HumanMessage
from src.langgraph.tools.scheduling_tools import (
    create_a_activity,
    search_activities,
    update_a_activiy,
    delete_a_activity,
)
from src.langgraph.langchain.llm import llm
from src.langgraph.config.agent import Agent
from src.langgraph.utils_function.function_graph import (
    create_tool_node_with_fallback,
    human_review_node,
)
from src.langgraph.langchain.prompt import (
    scheduling_prompt,
    CompleteOrRoute,
    create_entry_node,
)
from src.langgraph.state import State
from src.utils.logger import logger
from src.langgraph.tools.plan_itinerary import plan_itinerary


class SchedulingAgent:
    def __init__(self, builder: StateGraph):
        self.builder = builder
        self.scheduling_safe_tools = [search_activities, plan_itinerary]
        self.scheduling_sensitive_tools = [
            create_a_activity,
            update_a_activiy,
            delete_a_activity,
        ]
        self.scheduling_tools = (
            self.scheduling_sensitive_tools + self.scheduling_safe_tools
        )
        self.scheduling_runnable = scheduling_prompt | llm.bind_tools(
            self.scheduling_tools + [CompleteOrRoute]
        )

    def route_scheduling(self, state: State):
        logger.info("Route scheduling")
        if (
            state["messages_history"] is not None
            and state["messages"][0].content == "y"
            and "Do you want to run the following tool(s)?"
            in state["messages_history"][-1].content
        ):
            if state["messages"][-1].tool_calls and state["messages"][-1].tool_calls[0][
                "name"
            ] in [r.name for r in self.scheduling_sensitive_tools]:
                logger.info("Sensitive tools")
                return "scheduling_sensitive_tools"
            logger.info("Safe tools")
            return END

        route = tools_condition(state)
        if route == END:
            return END
        tool_calls = state["messages"][-1].tool_calls
        did_cancel = any(tc["name"] == CompleteOrRoute.__name__ for tc in tool_calls)
        logger.info(f"Did cancel: {did_cancel}")
        if did_cancel:
            return "leave_skill"
        tool_names = [t.name for t in self.scheduling_safe_tools]
        if all(tc["name"] in tool_names for tc in tool_calls):
            logger.info("Scheduling safe tools")
            return "scheduling_safe_tools"
        logger.info("User review")
        return "user_review_scheduling"

    def node(self):
        self.builder.add_node(
            "enter_schedule_activity",
            create_entry_node("Schedule Activity Assistant", "scheduling"),
        )
        self.builder.add_node("scheduling_agent", Agent(self.scheduling_runnable))
        self.builder.add_edge("enter_schedule_activity", "scheduling_agent")
        self.builder.add_node(
            "scheduling_safe_tools",
            create_tool_node_with_fallback(self.scheduling_safe_tools),
        )
        self.builder.add_node(
            "scheduling_sensitive_tools",
            create_tool_node_with_fallback(self.scheduling_sensitive_tools),
        )

        self.builder.add_node("user_review_scheduling", human_review_node)

    def edge(self):
        self.builder.add_edge("enter_schedule_activity", "scheduling_agent")

        self.builder.add_conditional_edges(
            "scheduling_agent",
            self.route_scheduling,
            {
                "scheduling_sensitive_tools": "scheduling_sensitive_tools",
                END: "save_history",
                "leave_skill": "leave_skill",
                "scheduling_safe_tools": "scheduling_safe_tools",
                "user_review_scheduling": "user_review_scheduling",
            },
        )

        self.builder.add_edge("user_review_scheduling", "save_history")
        self.builder.add_edge("scheduling_sensitive_tools", "scheduling_agent")
        self.builder.add_edge("scheduling_safe_tools", "scheduling_agent")

    def __call__(self):
        self.node()
        self.edge()
        return self.builder