Spaces:
Sleeping
Sleeping
File size: 1,500 Bytes
6cbca40 c0827a3 6cbca40 c0827a3 6cbca40 c0827a3 6cbca40 c0827a3 6cbca40 7f15e1c c0827a3 6cbca40 c0827a3 6cbca40 c0827a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from langgraph.graph import StateGraph, START, END
from .func import State, trim_history, call_roleplay, call_guiding_agent
from langgraph.graph.state import CompiledStateGraph
from langgraph.checkpoint.memory import InMemorySaver
class RolePlayAgent:
def __init__(self):
pass
@staticmethod
def route_to_active_agent(state: State):
if state["active_agent"] == "Roleplay Agent":
return "Roleplay Agent"
elif state["active_agent"] == "Guiding Agent":
return "Guiding Agent"
def node(self, graph: StateGraph):
graph.add_node("trim_history", trim_history)
graph.add_node("Roleplay Agent", call_roleplay, destinations=("Guiding Agent",))
graph.add_node(
"Guiding Agent", call_guiding_agent, destinations=("Roleplay Agent",)
)
return graph
def edge(self, graph: StateGraph):
graph.add_edge(START, "trim_history")
graph.add_conditional_edges(
"trim_history",
self.route_to_active_agent,
{
"Roleplay Agent": "Roleplay Agent",
"Guiding Agent": "Guiding Agent",
},
)
return graph
def __call__(self, checkpointer=InMemorySaver()) -> CompiledStateGraph:
graph = StateGraph(State)
graph: StateGraph = self.node(graph)
graph: StateGraph = self.edge(graph)
return graph.compile(checkpointer=checkpointer)
role_play_agent = RolePlayAgent()
|