Spaces:
Sleeping
Sleeping
| from langgraph.graph import StateGraph, START, END | |
| from .func import ( | |
| State, | |
| trim_history, | |
| execute_tool, | |
| generate_answer_rag, | |
| ) | |
| from langgraph.graph.state import CompiledStateGraph | |
| from langgraph.checkpoint.memory import InMemorySaver | |
| class RAGAgentTemplate: | |
| def __init__(self): | |
| self.builder = StateGraph(State) | |
| def should_continue(state: State): | |
| messages = state["messages"] | |
| last_message = messages[-1] | |
| if last_message.tool_calls: | |
| return "execute_tool" | |
| return END | |
| def node(self): | |
| self.builder.add_node("trim_history", trim_history) | |
| self.builder.add_node("generate_answer_rag", generate_answer_rag) | |
| self.builder.add_node("execute_tool", execute_tool) | |
| def edge(self): | |
| self.builder.add_edge(START, "trim_history") | |
| self.builder.add_edge("trim_history", "generate_answer_rag") | |
| self.builder.add_conditional_edges( | |
| "generate_answer_rag", | |
| self.should_continue, | |
| { | |
| END: END, | |
| "execute_tool": "execute_tool", | |
| }, | |
| ) | |
| self.builder.add_edge("execute_tool", "generate_answer_rag") | |
| self.builder.add_edge("generate_answer_rag", END) | |
| def __call__(self) -> CompiledStateGraph: | |
| self.node() | |
| self.edge() | |
| return self.builder.compile(checkpointer=InMemorySaver()) | |
| rag_agent_template_agent = RAGAgentTemplate()() | |