subashpoudel's picture
Converted code to OOP
ef9fa4b
from langgraph.graph import StateGraph, START, END
from .utils.state import State
from .utils.nodes import Retriever , ImageCaptioner , StoryGenerator, BrainstromTopicGenerator
from langgraph.checkpoint.memory import MemorySaver
memory = MemorySaver()
class BrainstormingAgent:
def __init__(self):
self.memory = MemorySaver()
def brainstorming_graph(self):
builder = StateGraph(State)
builder.add_node("caption_image",ImageCaptioner().run)
builder.add_node("retrieve",Retriever().run)
builder.add_node("generate_story",StoryGenerator().run)
builder.add_node("generate_brainstroming",BrainstromTopicGenerator().run)
builder.add_edge(START, "caption_image")
builder.add_edge("caption_image", "retrieve")
builder.add_edge("retrieve", "generate_story")
builder.add_edge("generate_story", "generate_brainstroming")
builder.add_edge("generate_brainstroming", END)
return builder.compile(checkpointer=self.memory)