File size: 6,958 Bytes
45b200f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from globals import *
from tools import *

# ------------------------------------------------------ #
# MODELS
# ------------------------------------------------------ #
# init_chat_llm = ChatOllama(model=model_name)
init_chat_llm = ChatTogether(model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free", api_key=os.getenv("TOGETHER_API_KEY"))

# ------------------------------------------------------ #
# BENDING TO TOOLS
# ------------------------------------------------------ #
# tools = [guest_info_tool, search_tool, weather_info_tool, hub_stats_tool]
# tools = [search_tool]
tools = [
    search_tool,
    describe_image_tool,
    describe_audio_tool,
    python_repl_tool,
    excel_repl_tool,
    youtube_extractor_tool,
    wikipedia_tool
]
chat_llm = init_chat_llm.bind_tools(tools)

# ------------------------------------------------------ #
# STATE
# ------------------------------------------------------ #
class AgentState(TypedDict):
    # messages: list[AnyMessage, add_messages]
    messages: list[AnyMessage]
    file_name: str
    final_output_is_good: bool

# ------------------------------------------------------ #
# HELP FUNCTIONS
# ------------------------------------------------------ #
def step_print(state: AgentState | None, step_label: str):
    if state:
        print(f'<<--- [{len(state["messages"])}] Entering ``{step_label}`` Node... --->>')
    else:
        print(f'<<--- [] Entering ``{step_label}`` Node... --->>')


def messages_print(messages_to_print: List[AnyMessage]):
    print('--- Message/s ---')
    for m in messages_to_print:
        print(f'{m.type} ({m.name}): \n{m.content}')
    print(f'<<--- *** --->>')

# ------------------------------------------------------ #
# NODES
# ------------------------------------------------------ #
def preprocessing(state: AgentState):
    # state['messages'] = [state['messages'][0]]
    step_print(None, 'Preprocessing')
    if state['file_name'] != '':
        # state['messages'] += f"\nfile_name: {state['file_name']}"
        state['messages'][0].content += f"\nfile_name: {state['file_name']}"
    messages_print(state['messages'])
    return {
        "messages": [SystemMessage(content=DEFAULT_SYSTEM_PROMPT)] + state["messages"]
    }


def assistant(state: AgentState):
    # state["messages"] = [SystemMessage(content=DEFAULT_SYSTEM_PROMPT)] + state["messages"]
    step_print(state, 'assistant')
    ai_message = chat_llm.invoke(state["messages"])
    messages_print([ai_message])
    return {
        'messages': state["messages"] + [ai_message]
    }


base_tool_node = ToolNode(tools)
def wrapped_tool_node(state: AgentState):
    step_print(state, 'Tools')
    # Call the original ToolNode
    result = base_tool_node.invoke(state)
    messages_print(result["messages"])
    # Append to the messages list instead of replacing it
    state["messages"] += result["messages"]
    return {"messages": state["messages"]}


def checker_final_answer(state: AgentState):
    step_print(state, 'Final Check')
    s = state['messages'][-1].content
    if "FINAL ANSWER: " not in s:
        return {
        'messages': state["messages"],
        'final_output_is_good': False
    }
    return {
        'final_output_is_good': True
    }

# ------------------------------------------------------ #
# CONDITIONAL FUNCTIONS
# ------------------------------------------------------ #
def condition_output(state: AgentState) -> Literal["assistant", "__end__"]:
    if state['final_output_is_good']:
        return END
    return "assistant"


def condition_tools_or_continue(
    state: Union[list[AnyMessage], dict[str, Any], BaseModel],
    messages_key: str = "messages",
) -> Literal["tools", "checker_final_answer"]:

    if isinstance(state, list):
        ai_message = state[-1]
    elif isinstance(state, dict) and (messages := state.get(messages_key, [])):
        ai_message = messages[-1]
    elif messages := getattr(state, messages_key, []):
        ai_message = messages[-1]
    else:
        # pass
        raise ValueError(f"No messages found in input state to tool_edge: {state}")
    if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
        return "tools"
    return "checker_final_answer"
    # return "__end__"


# ------------------------------------------------------ #
# BUILDERS
# ------------------------------------------------------ #
def workflow_simple() -> Tuple[StateGraph, str]:
    i_builder = StateGraph(AgentState)
    # Nodes
    i_builder.add_node('preprocessing', preprocessing)
    i_builder.add_node('assistant', assistant)

    # Edges
    i_builder.add_edge(START, 'preprocessing')
    i_builder.add_edge('preprocessing', 'assistant')
    return i_builder, 'workflow_simple'


def workflow_tools() -> Tuple[StateGraph, str]:
    i_builder = StateGraph(AgentState)

    # Nodes
    i_builder.add_node('preprocessing', preprocessing)
    i_builder.add_node('assistant', assistant)
    i_builder.add_node('tools', wrapped_tool_node)
    i_builder.add_node('checker_final_answer', checker_final_answer)

    # Edges
    i_builder.add_edge(START, 'preprocessing')
    i_builder.add_edge('preprocessing', 'assistant')
    i_builder.add_conditional_edges('assistant', condition_tools_or_continue)
    i_builder.add_edge('tools', 'assistant')
    i_builder.add_conditional_edges('checker_final_answer', condition_output)
    return i_builder, 'workflow_tools'


@traceable
def main():
    # Laminar.initialize(project_api_key=LAMINAR_API_KEY)
    # ------------------------------------------------------ #
    # COMPILATION
    # ------------------------------------------------------ #
    # builder, builder_name = workflow_simple()
    builder, builder_name = workflow_tools()

    alfred = builder.compile()
    # print(alfred.get_graph().draw_ascii())
    # print(alfred.get_graph().draw_mermaid())
    # with open(f"{builder_name}.png", "wb") as f:
    #     f.write(alfred.get_graph().draw_mermaid_png())

    # ------------------------------------------------------ #
    # EXAMPLES
    # ------------------------------------------------------ #
    # response = alfred.invoke({'messages': "What is an apple?"})
    # ---
    question = """
        If Eliud Kipchoge could maintain his record-making marathon pace indefinitely,
        how many thousand hours would it take him to run the distance between the Earth and the Moon its closest approach?
        Please use the minimum perigee value on the Wikipedia page for the Moon when carrying out your calculation.
        Round your result to the nearest 1000 hours and do not use any comma separators if necessary.
        """
    # response = alfred.invoke({'messages': [HumanMessage(content=question.replace('\n', ""))]})
    response = alfred.invoke({'messages': [HumanMessage(content="Who is the president of USA in 2025?")]})

    print(f"--- OUTPUT --- \n{response['messages'][-1].content}\n--- --- ---")


if __name__ == '__main__':
    main()