File size: 4,654 Bytes
778116a
 
 
 
 
8073bab
 
43199e3
778116a
8073bab
a605490
778116a
7d310bb
778116a
2e38934
778116a
a4b0424
8073bab
 
a605490
 
 
 
 
 
 
43199e3
 
a605490
8073bab
43199e3
 
8073bab
5813885
778116a
5813885
 
 
 
 
778116a
 
0f45d0b
778116a
 
 
 
 
5813885
 
 
8073bab
 
 
 
 
 
 
 
 
778116a
 
 
 
 
 
 
2e38934
2e8bb22
8073bab
 
 
 
b4f9800
2e8bb22
8073bab
 
 
 
 
 
 
 
 
5813885
8073bab
 
 
 
 
 
 
 
 
 
 
 
 
 
5813885
8073bab
 
 
 
 
2e8bb22
fc1b83d
 
 
 
 
 
 
 
 
8073bab
 
2e8bb22
 
a605490
2e8bb22
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import mimetypes
import pathlib
import time

from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, RemoveMessage
from langchain_openai import ChatOpenAI

from config.settings import config
from core.messages import attachmentHandler
from core.state import State
from nodes.chunking_handler import OversizedContentHandler
from tools.audio_tool import query_audio
from tools.chess_tool import chess_analysis_tool
from tools.excel_tool import query_excel_file
from tools.math_agent import math_tool
from tools.python_executor import execute_python_code
from tools.tavily_tools import web_search_tools
from utils.prompt_manager import prompt_mgmt

agent_tools = []
agent_tools.extend(web_search_tools)
agent_tools.append(query_audio)
agent_tools.append(query_excel_file)
agent_tools.append(execute_python_code)
agent_tools.append(math_tool)
agent_tools.append(chess_analysis_tool)

model = ChatOpenAI(model=config.model_name)
model = model.bind_tools(agent_tools, parallel_tool_calls=False)

response_processing_model = ChatOpenAI(model=config.response_processing_model_name)


# Node
def pre_processor(state: State):
    # Get original question if it exists
    question = state.get("question", "")
    if not question:
        question = state["messages"][0].content

    file_reference = state.get("file_reference", "")
    extension = pathlib.Path(file_reference).suffix
    if extension == ".png":
        content_bytes = attachmentHandler.fetch_file_from_reference(file_reference)
        mime_type = mimetypes.guess_type(file_reference)[0]
        state["messages"][0].content = [{"type": "text", "text": question},
                                        attachmentHandler.get_representation("image", content_bytes, "png", mime_type)]

    return {"question": question}


# Node
def assistant(state: State):
    # set up the question
    # Get summary if it exists
    summary = state.get("summary", "")

    # Get original question if it exists
    question = state.get("question", "")
    if not question:
        question = state["messages"][0].content[0]

    attachment = ""
    file_reference = state.get("file_reference", "")
    if file_reference:
        attachment = f" you have access to the file with the following reference {file_reference}"
    prompt_params = {"summary": summary, "chunked_last_tool_call": state.get("chunked_last_tool_call", False),
                     "attachment": attachment, "question": question}
    sys_msg = SystemMessage(content=prompt_mgmt.render_template("base_system", prompt_params))
    try:
        response = model.invoke([sys_msg] + state["messages"])
    except Exception as e:
        if "429" in str(e):
            time.sleep(20)
            print("Retrying after receiving 429 error")
            response = model.invoke([sys_msg] + state["messages"])
            return {"messages": [response]}
        raise
    return {"question": question, "messages": [response]}


def response_processing(state: State):
    question = state.get("question", "")
    answer = state["messages"][-1]
    gaia_messages = [HumanMessage(content=question), AIMessage(content=answer.content)]
    gaia_sys_msg = SystemMessage(content=prompt_mgmt.render_template("final_answer_processor", {}))
    response = response_processing_model.invoke([gaia_sys_msg] + gaia_messages)

    return {"messages": [response]}


def optimize_memory(state: State):
    # First, we get any existing summary
    summary = state.get("summary", "")

    # Create our summarization prompt
    if summary:

        # A summary already exists
        summary_message = prompt_mgmt.render_template("summarization", {"summary": summary})

    else:
        summary_message = "Create a summary of the conversation above:"

    # Add prompt to our history
    messages = state["messages"][:-2] + [HumanMessage(content=summary_message)]
    try:
        response = model.invoke(messages)
    except Exception as e:
        if "429" in str(e):
            time.sleep(20)
            print("Retrying after receiving 429 error")
            response = model.invoke(messages)
        else:
            raise

    # Delete all but the 2 most recent messages and the first one
    remaining_messages = [RemoveMessage(id=m.id) for m in state["messages"][:-2]]

    # If the last message returned from a tool is oversized, chunk it and replace it with the relevant chunks
    content_handler = OversizedContentHandler()
    chunked = content_handler.process_oversized_message(state["messages"][-1], state.get("question"))

    return {"summary": response.content, "messages": remaining_messages, "chunked_last_tool_call": chunked}