from langgraph.graph import StateGraph, START, END from .func import State, trim_history, call_roleplay, call_guiding_agent from langgraph.graph.state import CompiledStateGraph from langgraph.checkpoint.memory import InMemorySaver class RolePlayAgent: def __init__(self): pass @staticmethod def route_to_active_agent(state: State): if state["active_agent"] == "Roleplay Agent": return "Roleplay Agent" elif state["active_agent"] == "Guiding Agent": return "Guiding Agent" def node(self, graph: StateGraph): graph.add_node("trim_history", trim_history) graph.add_node("Roleplay Agent", call_roleplay, destinations=("Guiding Agent",)) graph.add_node( "Guiding Agent", call_guiding_agent, destinations=("Roleplay Agent",) ) return graph def edge(self, graph: StateGraph): graph.add_edge(START, "trim_history") graph.add_conditional_edges( "trim_history", self.route_to_active_agent, { "Roleplay Agent": "Roleplay Agent", "Guiding Agent": "Guiding Agent", }, ) return graph def __call__(self, checkpointer=InMemorySaver()) -> CompiledStateGraph: graph = StateGraph(State) graph: StateGraph = self.node(graph) graph: StateGraph = self.edge(graph) return graph.compile(checkpointer=checkpointer) role_play_agent = RolePlayAgent()