File size: 3,130 Bytes
5ec1ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from src.components.internetSearchAgent import InternetSearchAgent
from src.components.synthesizerAgent import SynthesizerAgent
from src.components.reasoningAgent import ReasoningAgent
from src.components.sqlAgent import PostgreSQLAgent
from langgraph.graph import START, END, StateGraph
from utils.exceptions import CustomException
from src.components.ragAgent import RAGAgent
from utils.logger import logger
from typing import TypedDict

class AgentState(TypedDict):
    internetResults: str
    reasoningResults: str
    sqlResults: str
    ragResults: str
    query: str
    finalAnswer: str

class Workflow:
    def __init__(self) -> None:
        self.internetSearchAgentObj = InternetSearchAgent()
        self.reasoningAgentObj = ReasoningAgent()
        self.ragAgentObj = RAGAgent()
        self.sqlAgentObj = PostgreSQLAgent()
        self.synthesizerAgentObj = SynthesizerAgent()

    def _internetSearchAgent(self, state: AgentState) -> dict:
        return {"internetResults": self.internetSearchAgentObj.query(query=state.get("query"))}

    def _reasoningAgent(self, state: AgentState) -> dict:
        return {"reasoningResults": self.reasoningAgentObj.query(query=state.get("query"))}

    def _ragAgent(self, state: AgentState) -> dict:
        return {"ragResults": self.ragAgentObj.query(query=state.get("query"))}

    def _sqlAgent(self, state: AgentState) -> dict:
        return {"sqlResults": self.sqlAgentObj.query(query=state.get("query"))}

    def _synthesizerAgent(self, state: AgentState) -> dict:
        return {"finalAnswer": self.synthesizerAgentObj.query({"query": state["query"], "reasoningOutput": state["reasoningResults"], "webOutput": state["internetResults"], "ragOutput": state["ragResults"], "sqlOutput": state["sqlResults"]})}

    def createWorkflow(self) -> None:
        try:
            logger.info("INITIALIZING LANGGRAPH WORKFLOW")
            graph = StateGraph(AgentState)
            graph.add_node("internetSearchAgent", self._internetSearchAgent, )
            graph.add_node("reasoningAgent", self._reasoningAgent,)
            graph.add_node("ragAgent", self._ragAgent)
            graph.add_node("sqlAgent", self._sqlAgent)
            graph.add_node("synthesizerAgent", self._synthesizerAgent, defer = True)
            graph.add_edge(START, "internetSearchAgent")
            graph.add_edge(START, "reasoningAgent")
            graph.add_edge(START, "ragAgent")
            graph.add_edge(START, "sqlAgent")
            graph.add_edge("internetSearchAgent", "synthesizerAgent")
            graph.add_edge("reasoningAgent", "synthesizerAgent")
            graph.add_edge("ragAgent", "synthesizerAgent")
            graph.add_edge("sqlAgent", "synthesizerAgent")
            graph.add_edge("synthesizerAgent", END)
            self.graph = graph.compile()
            return
        except Exception as e:
            exception = CustomException(e)
            logger.error(exception)
            raise exception

    def run(self, query: str) -> str:
        return self.graph.invoke({"query": query})["finalAnswer"]
    
workflow = Workflow()
workflow.createWorkflow()