from langgraph.graph import StateGraph, MessagesState, START, END from langgraph.checkpoint.memory import MemorySaver from .utils.nodes import ToolReturnNode, ExtractUserReferenceNode , ImageCaptionNode , QueryResponseNode from src.genai.utils.models_loader import llm_gpt from .utils.state import State from .utils.utils import ImageCaptioner, ResponseBlockExtractor from .utils.tools import InfluencerRetrievalTool import re from langchain_core.messages import SystemMessage class OrchestrationAgent: def __init__(self): self.memory = MemorySaver() self.agent = self.orchestration_graph() self.user_input_history=[] def orchestration_graph(self): workflow = StateGraph(State) workflow.add_node("image_caption", ImageCaptionNode().run) workflow.add_node("tool_return", ToolReturnNode().run) workflow.add_node("query_response", QueryResponseNode().run) workflow.add_node("extract_reference", ExtractUserReferenceNode().run) workflow.add_edge(START,"image_caption") workflow.add_edge(START, "tool_return") workflow.add_edge(START, "extract_reference") workflow.add_edge('image_caption', "query_response") workflow.add_edge('tool_return', "query_response") workflow.add_edge('extract_reference', 'query_response') workflow.add_edge('query_response', END) return workflow.compile(checkpointer=self.memory)