File size: 2,814 Bytes
778116a
 
8073bab
a4b0424
8073bab
a4b0424
 
8073bab
778116a
8073bab
0c44617
8073bab
 
 
 
 
 
 
 
 
 
778116a
8073bab
0c44617
8073bab
 
 
 
778116a
 
5813885
8073bab
 
 
 
 
 
 
5813885
8073bab
 
 
 
 
778116a
0f45d0b
778116a
 
 
a4b0424
778116a
 
8073bab
5813885
 
8073bab
778116a
0f45d0b
778116a
 
8073bab
 
778116a
8073bab
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import Optional

from langchain_core.messages import HumanMessage
from langgraph.graph import START, StateGraph, END
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import tools_condition

from core.messages import Attachment
from core.state import State
from nodes.nodes import assistant, optimize_memory, response_processing, pre_processor, agent_tools


class GaiaAgent:
    react_graph: CompiledStateGraph

    def __init__(self):
        # Graph
        builder = StateGraph(State)

        # Define nodes: these do the work
        builder.add_node("pre_processor", pre_processor)
        builder.add_node("assistant", assistant)
        builder.add_node("tools", ToolNode(agent_tools))
        builder.add_node("optimize_memory", optimize_memory)
        builder.add_node("response_processing", response_processing)

        # Define edges: these determine how the control flow moves
        builder.add_edge(START, "pre_processor")
        builder.add_edge("pre_processor", "assistant")

        builder.add_conditional_edges(
            "assistant",
            # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools If the
            # latest message (result) from assistant is a not a tool call -> tools_condition routes to
            # response_processing
            tools_condition, {"tools": "tools", "__end__": "response_processing"}
        )

        builder.add_edge("tools", "optimize_memory")
        builder.add_edge("optimize_memory", "assistant")
        builder.add_edge("response_processing", END)
        self.react_graph = builder.compile()

    def __call__(self, question: str, attachment: Optional[Attachment] = None) -> str:
        initial_state = {"messages": [HumanMessage(content=question)], "question": question}
        if attachment:
            initial_state["file_reference"] = attachment.file_path

        messages = self.react_graph.invoke(initial_state, {"recursion_limit": 30})
        #        for m in messages['messages']:
        #            m.pretty_print()

        answer = messages['messages'][-1].content
        return answer

    def __streamed_call__(self, question: str, attachment: Optional[Attachment] = None) -> str:
        initial_state = {"messages": [HumanMessage(content=question)], "question": question}
        if attachment:
            initial_state["file_reference"] = attachment.file_path

        # Stream the web agent's response
        for s in self.react_graph.stream(initial_state, stream_mode="values"):
            message = s["messages"][-1]
            if isinstance(message, tuple):
                print(message)
            else:
                message.pretty_print()
        return message.content