File size: 4,051 Bytes
b4395cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2021b6e
b4395cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
from dotenv import load_dotenv
from langgraph.prebuilt import ToolNode
from typing import TypedDict, Annotated, Literal
from langchain.chat_models import init_chat_model
from langgraph.graph import add_messages, StateGraph, START, END
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage

from tools import (
    default_file_reader,
    image_reader,
    excel_column_reader,
    excel_find_column_values_sum,
    wiki_search,
    archive_search,
    get_ioc_code,
    check_commutativity,
    audio_to_text,
    video_to_text
)

load_dotenv()
os.environ['CURL_CA_BUNDLE'] = ''

class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


def start_agent(question: str, question_filepath: str):
    chat = init_chat_model("claude-3-5-sonnet-20241022", model_provider="anthropic", temperature=0)

    tools = [wiki_search, archive_search, get_ioc_code, check_commutativity, video_to_text]

    if question_filepath:
        #if a file is provided, then add file tools
        tools = tools + [default_file_reader, image_reader, excel_column_reader, excel_find_column_values_sum, audio_to_text]

    chat_with_tools = chat.bind_tools(tools)

    try:
        with open("system_prompt.txt", 'r') as sp_file:
            system_prompt = sp_file.read()
    except FileNotFoundError:
        print("Error: unable to open system_prompt.txt")
        return None

    if question_filepath:
        messages = [
            SystemMessage(system_prompt),
            HumanMessage(content=f"{question} File located at: {question_filepath}")
        ]
    else:
        messages = [
            SystemMessage(system_prompt),
            HumanMessage(content=f"{question}")
        ]

    def assistant(state: AgentState):
        return {
            **state,
            "messages": [chat_with_tools.invoke(state["messages"])],
        }

    def validate_answer_format(state: AgentState):
        try:
            with open("final_answer_validation_prompt.txt", 'r') as favp_file:
                final_answer_validation_prompt = favp_file.read()
        except FileNotFoundError:
            print(f"Error: unable to open final_answer_validation_prompt.txt")
            return None

        state["messages"].append(
            HumanMessage(content=f"Verify your FINAL ANSWER again so it meet user question requirements: {question}")
        )

        state["messages"].append(
            HumanMessage(content=f"Verify your FINAL ANSWER again so it meets these requirements: {final_answer_validation_prompt}. "
                                 f"Do not use any tool here, just validate format of the final answer.")
        )

        return {
            **state,
            "messages": [chat_with_tools.invoke(state["messages"])],
        }

    def custom_tool_condition(state: AgentState, messages_key: str = "messages") -> Literal["tools", "validate"]:
            if isinstance(state, list):
                ai_message = state[-1]
            elif isinstance(state, dict) and (messages := state.get(messages_key, [])):
                ai_message = messages[-1]
            elif messages := getattr(state, messages_key, []):
                ai_message = messages[-1]
            else:
                raise ValueError(f"No messages found in input state to tool_edge: {state}")
            if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
                return "tools"
            return "validate"

    initial_state = AgentState(
        messages=messages,
    )

    builder = StateGraph(AgentState)

    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))
    builder.add_node("validate", validate_answer_format)

    builder.add_edge(START, "assistant")
    builder.add_conditional_edges("assistant", custom_tool_condition)

    builder.add_edge("tools", "assistant")
    builder.add_edge("validate", END)
    agent = builder.compile()
    response = agent.invoke(initial_state)

    return response['messages'][-1].content