File size: 1,467 Bytes
3c1150c
 
6f57d05
583f6dd
8ce97f0
38cf703
 
6874dac
3c1150c
6874dac
38cf703
 
 
 
 
 
 
 
8ce97f0
 
 
6f57d05
8ce97f0
 
 
 
 
 
 
 
 
 
38cf703
 
 
 
3c1150c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.checkpoint.memory import MemorySaver
from .utils.nodes import ToolReturnNode, ExtractUserReferenceNode , ImageCaptionNode , QueryResponseNode
from src.genai.utils.models_loader import llm_gpt
from .utils.state import State
from .utils.utils import ImageCaptioner, ResponseBlockExtractor
from .utils.tools import InfluencerRetrievalTool

import re
from langchain_core.messages import SystemMessage

class OrchestrationAgent:
    def __init__(self):
        self.memory = MemorySaver()
        self.agent = self.orchestration_graph()
        self.user_input_history=[]

    def orchestration_graph(self):
        workflow = StateGraph(State)
        workflow.add_node("image_caption", ImageCaptionNode().run)
        workflow.add_node("tool_return", ToolReturnNode().run)
        workflow.add_node("query_response", QueryResponseNode().run)
        workflow.add_node("extract_reference", ExtractUserReferenceNode().run)
        
        workflow.add_edge(START,"image_caption")
        workflow.add_edge(START, "tool_return")
        workflow.add_edge(START, "extract_reference")

        workflow.add_edge('image_caption', "query_response")
        workflow.add_edge('tool_return', "query_response")
        workflow.add_edge('extract_reference', 'query_response')
        workflow.add_edge('query_response', END)

        return workflow.compile(checkpointer=self.memory)