subashpoudel's picture
next commit
9f72bcf
raw
history blame
1.6 kB
from langgraph.graph import StateGraph, START, END , MessagesState
from .utils.state import State
from .utils.nodes import RetrieverNode, IdeatorNode , CriticNode , ValidatorNode , RoutingAfterValidation, JudgeNode1 , JudgeNode2, Aggregrator
from langgraph.checkpoint.memory import MemorySaver
class IdeationAgent:
def __init__(self):
self.memory = MemorySaver()
def ideation_graph(self):
graph_builder= StateGraph(State)
graph_builder.add_node("retriever", RetrieverNode().run)
graph_builder.add_node("ideator", IdeatorNode().run)
graph_builder.add_node("critic", CriticNode().run)
graph_builder.add_node("judge1", JudgeNode1().run)
graph_builder.add_node("judge2", JudgeNode2().run)
graph_builder.add_node("aggregrator", Aggregrator().run)
graph_builder.add_node("validator", ValidatorNode().run)
graph_builder.add_edge(START, "retriever")
graph_builder.add_edge("retriever", "ideator")
graph_builder.add_edge("ideator", "critic")
graph_builder.add_edge("critic", "judge1")
graph_builder.add_edge("critic", "judge2")
graph_builder.add_edge("judge1", "aggregrator")
graph_builder.add_edge("judge2", "aggregrator")
graph_builder.add_edge("aggregrator", "validator")
graph_builder.add_edge("validator", END)
# Use conditional routing from validator
graph_builder.add_conditional_edges("validator", RoutingAfterValidation().route,{False:'critic',True:END})
return graph_builder.compile(checkpointer=self.memory)