|
|
import mimetypes |
|
|
import pathlib |
|
|
import time |
|
|
|
|
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, RemoveMessage |
|
|
from langchain_openai import ChatOpenAI |
|
|
|
|
|
from config.settings import config |
|
|
from core.messages import attachmentHandler |
|
|
from core.state import State |
|
|
from nodes.chunking_handler import OversizedContentHandler |
|
|
from tools.audio_tool import query_audio |
|
|
from tools.chess_tool import chess_analysis_tool |
|
|
from tools.excel_tool import query_excel_file |
|
|
from tools.math_agent import math_tool |
|
|
from tools.python_executor import execute_python_code |
|
|
from tools.tavily_tools import web_search_tools |
|
|
from utils.prompt_manager import prompt_mgmt |
|
|
|
|
|
agent_tools = [] |
|
|
agent_tools.extend(web_search_tools) |
|
|
agent_tools.append(query_audio) |
|
|
agent_tools.append(query_excel_file) |
|
|
agent_tools.append(execute_python_code) |
|
|
agent_tools.append(math_tool) |
|
|
agent_tools.append(chess_analysis_tool) |
|
|
|
|
|
model = ChatOpenAI(model=config.model_name) |
|
|
model = model.bind_tools(agent_tools, parallel_tool_calls=False) |
|
|
|
|
|
response_processing_model = ChatOpenAI(model=config.response_processing_model_name) |
|
|
|
|
|
|
|
|
|
|
|
def pre_processor(state: State): |
|
|
|
|
|
question = state.get("question", "") |
|
|
if not question: |
|
|
question = state["messages"][0].content |
|
|
|
|
|
file_reference = state.get("file_reference", "") |
|
|
extension = pathlib.Path(file_reference).suffix |
|
|
if extension == ".png": |
|
|
content_bytes = attachmentHandler.fetch_file_from_reference(file_reference) |
|
|
mime_type = mimetypes.guess_type(file_reference)[0] |
|
|
state["messages"][0].content = [{"type": "text", "text": question}, |
|
|
attachmentHandler.get_representation("image", content_bytes, "png", mime_type)] |
|
|
|
|
|
return {"question": question} |
|
|
|
|
|
|
|
|
|
|
|
def assistant(state: State): |
|
|
|
|
|
|
|
|
summary = state.get("summary", "") |
|
|
|
|
|
|
|
|
question = state.get("question", "") |
|
|
if not question: |
|
|
question = state["messages"][0].content[0] |
|
|
|
|
|
attachment = "" |
|
|
file_reference = state.get("file_reference", "") |
|
|
if file_reference: |
|
|
attachment = f" you have access to the file with the following reference {file_reference}" |
|
|
prompt_params = {"summary": summary, "chunked_last_tool_call": state.get("chunked_last_tool_call", False), |
|
|
"attachment": attachment, "question": question} |
|
|
sys_msg = SystemMessage(content=prompt_mgmt.render_template("base_system", prompt_params)) |
|
|
try: |
|
|
response = model.invoke([sys_msg] + state["messages"]) |
|
|
except Exception as e: |
|
|
if "429" in str(e): |
|
|
time.sleep(20) |
|
|
print("Retrying after receiving 429 error") |
|
|
response = model.invoke([sys_msg] + state["messages"]) |
|
|
return {"messages": [response]} |
|
|
raise |
|
|
return {"question": question, "messages": [response]} |
|
|
|
|
|
|
|
|
def response_processing(state: State): |
|
|
question = state.get("question", "") |
|
|
answer = state["messages"][-1] |
|
|
gaia_messages = [HumanMessage(content=question), AIMessage(content=answer.content)] |
|
|
gaia_sys_msg = SystemMessage(content=prompt_mgmt.render_template("final_answer_processor", {})) |
|
|
response = response_processing_model.invoke([gaia_sys_msg] + gaia_messages) |
|
|
|
|
|
return {"messages": [response]} |
|
|
|
|
|
|
|
|
def optimize_memory(state: State): |
|
|
|
|
|
summary = state.get("summary", "") |
|
|
|
|
|
|
|
|
if summary: |
|
|
|
|
|
|
|
|
summary_message = prompt_mgmt.render_template("summarization", {"summary": summary}) |
|
|
|
|
|
else: |
|
|
summary_message = "Create a summary of the conversation above:" |
|
|
|
|
|
|
|
|
messages = state["messages"][:-2] + [HumanMessage(content=summary_message)] |
|
|
try: |
|
|
response = model.invoke(messages) |
|
|
except Exception as e: |
|
|
if "429" in str(e): |
|
|
time.sleep(20) |
|
|
print("Retrying after receiving 429 error") |
|
|
response = model.invoke(messages) |
|
|
else: |
|
|
raise |
|
|
|
|
|
|
|
|
remaining_messages = [RemoveMessage(id=m.id) for m in state["messages"][:-2]] |
|
|
|
|
|
|
|
|
content_handler = OversizedContentHandler() |
|
|
chunked = content_handler.process_oversized_message(state["messages"][-1], state.get("question")) |
|
|
|
|
|
return {"summary": response.content, "messages": remaining_messages, "chunked_last_tool_call": chunked} |
|
|
|