File size: 1,500 Bytes
6cbca40
c0827a3
6cbca40
c0827a3
6cbca40
 
c0827a3
6cbca40
c0827a3
6cbca40
 
7f15e1c
c0827a3
 
 
 
6cbca40
c0827a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cbca40
 
c0827a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from langgraph.graph import StateGraph, START, END
from .func import State, trim_history, call_roleplay, call_guiding_agent
from langgraph.graph.state import CompiledStateGraph
from langgraph.checkpoint.memory import InMemorySaver


class RolePlayAgent:
    def __init__(self):
        pass

    @staticmethod
    def route_to_active_agent(state: State):
        if state["active_agent"] == "Roleplay Agent":
            return "Roleplay Agent"
        elif state["active_agent"] == "Guiding Agent":
            return "Guiding Agent"

    def node(self, graph: StateGraph):
        graph.add_node("trim_history", trim_history)
        graph.add_node("Roleplay Agent", call_roleplay, destinations=("Guiding Agent",))
        graph.add_node(
            "Guiding Agent", call_guiding_agent, destinations=("Roleplay Agent",)
        )
        return graph

    def edge(self, graph: StateGraph):
        graph.add_edge(START, "trim_history")
        graph.add_conditional_edges(
            "trim_history",
            self.route_to_active_agent,
            {
                "Roleplay Agent": "Roleplay Agent",
                "Guiding Agent": "Guiding Agent",
            },
        )
        return graph

    def __call__(self, checkpointer=InMemorySaver()) -> CompiledStateGraph:
        graph = StateGraph(State)
        graph: StateGraph = self.node(graph)
        graph: StateGraph = self.edge(graph)
        return graph.compile(checkpointer=checkpointer)


role_play_agent = RolePlayAgent()