File size: 2,020 Bytes
67a532b
 
a6a9e0f
67a532b
a6a9e0f
 
 
 
 
 
67a532b
a6a9e0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67a532b
a6a9e0f
67a532b
 
 
 
 
 
 
 
 
 
a6a9e0f
 
 
 
 
 
 
 
 
 
 
67a532b
 
a6a9e0f
 
 
 
 
 
 
 
 
 
 
 
67a532b
 
a6a9e0f
67a532b
a6a9e0f
67a532b
 
 
 
a6a9e0f
67a532b
 
a6a9e0f
67a532b
 
 
a6a9e0f
67a532b
 
 
 
 
 
 
a6a9e0f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import START, StateGraph, MessagesState
from langchain_openai import AzureChatOpenAI

from config import (
    MODEL_ENDPOINT,
    MODEL_KEY,
    MODEL_NAME,
    MODEL_API_VERSION,
)

from tools import (
    wiki_search,
    tavily_search,
    arxiv_search,
    add,
    subtract,
    multiply,
    divide,
    power,
    sqrt,
    modulus,
    scrape_webpage,
    analyze_image,
    is_commutative,
    commutativity_counterexample_pairs,
    commutativity_counterexample_elements,
    find_identity_element,
    find_inverses,
    transcribe_audio,
    execute_source_file,
    interact_tabular,
)

# Define tools
TOOLS = [
    wiki_search,
    tavily_search,
    arxiv_search,
    add,
    subtract,
    multiply,
    divide,
    power,
    sqrt,
    modulus,
    scrape_webpage,
    analyze_image,
    is_commutative,
    commutativity_counterexample_pairs,
    commutativity_counterexample_elements,
    find_identity_element,
    find_inverses,
    transcribe_audio,
    execute_source_file,
    interact_tabular
]


def build_agent() -> StateGraph:
    """
    Build the agent.
    Returns:
        StateGraph: The agent graph.
    """
    llm = AzureChatOpenAI(
        azure_deployment=MODEL_NAME,
        api_version=MODEL_API_VERSION,
        azure_endpoint=MODEL_ENDPOINT,
        api_key=MODEL_KEY,
    )

    chat_w_tools = llm.bind_tools(TOOLS)

    # Assistant node
    def assistant(state: MessagesState):
        """Assistant node"""
        return {"messages": [chat_w_tools.invoke(state["messages"])]}

    # Build graph
    builder = StateGraph(MessagesState)

    # Add nodes
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(TOOLS))

    # Add edges
    builder.add_edge(START, "assistant")
    builder.add_conditional_edges(
        "assistant",
        tools_condition,
    )
    builder.add_edge("tools", "assistant")

    # Compile graph and return it
    return builder.compile()