Spaces:
Sleeping
Sleeping
File size: 1,467 Bytes
3c1150c 6f57d05 583f6dd 8ce97f0 38cf703 6874dac 3c1150c 6874dac 38cf703 8ce97f0 6f57d05 8ce97f0 38cf703 3c1150c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.checkpoint.memory import MemorySaver
from .utils.nodes import ToolReturnNode, ExtractUserReferenceNode , ImageCaptionNode , QueryResponseNode
from src.genai.utils.models_loader import llm_gpt
from .utils.state import State
from .utils.utils import ImageCaptioner, ResponseBlockExtractor
from .utils.tools import InfluencerRetrievalTool
import re
from langchain_core.messages import SystemMessage
class OrchestrationAgent:
def __init__(self):
self.memory = MemorySaver()
self.agent = self.orchestration_graph()
self.user_input_history=[]
def orchestration_graph(self):
workflow = StateGraph(State)
workflow.add_node("image_caption", ImageCaptionNode().run)
workflow.add_node("tool_return", ToolReturnNode().run)
workflow.add_node("query_response", QueryResponseNode().run)
workflow.add_node("extract_reference", ExtractUserReferenceNode().run)
workflow.add_edge(START,"image_caption")
workflow.add_edge(START, "tool_return")
workflow.add_edge(START, "extract_reference")
workflow.add_edge('image_caption', "query_response")
workflow.add_edge('tool_return', "query_response")
workflow.add_edge('extract_reference', 'query_response')
workflow.add_edge('query_response', END)
return workflow.compile(checkpointer=self.memory)
|