Final_Assignment_Template / langgraph_agent.py
ArseniyPerchik's picture
Clean state
45b200f
from globals import *
from tools import *
# ------------------------------------------------------ #
# MODELS
# ------------------------------------------------------ #
# init_chat_llm = ChatOllama(model=model_name)
init_chat_llm = ChatTogether(model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free", api_key=os.getenv("TOGETHER_API_KEY"))
# ------------------------------------------------------ #
# BENDING TO TOOLS
# ------------------------------------------------------ #
# tools = [guest_info_tool, search_tool, weather_info_tool, hub_stats_tool]
# tools = [search_tool]
tools = [
search_tool,
describe_image_tool,
describe_audio_tool,
python_repl_tool,
excel_repl_tool,
youtube_extractor_tool,
wikipedia_tool
]
chat_llm = init_chat_llm.bind_tools(tools)
# ------------------------------------------------------ #
# STATE
# ------------------------------------------------------ #
class AgentState(TypedDict):
# messages: list[AnyMessage, add_messages]
messages: list[AnyMessage]
file_name: str
final_output_is_good: bool
# ------------------------------------------------------ #
# HELP FUNCTIONS
# ------------------------------------------------------ #
def step_print(state: AgentState | None, step_label: str):
if state:
print(f'<<--- [{len(state["messages"])}] Entering ``{step_label}`` Node... --->>')
else:
print(f'<<--- [] Entering ``{step_label}`` Node... --->>')
def messages_print(messages_to_print: List[AnyMessage]):
print('--- Message/s ---')
for m in messages_to_print:
print(f'{m.type} ({m.name}): \n{m.content}')
print(f'<<--- *** --->>')
# ------------------------------------------------------ #
# NODES
# ------------------------------------------------------ #
def preprocessing(state: AgentState):
# state['messages'] = [state['messages'][0]]
step_print(None, 'Preprocessing')
if state['file_name'] != '':
# state['messages'] += f"\nfile_name: {state['file_name']}"
state['messages'][0].content += f"\nfile_name: {state['file_name']}"
messages_print(state['messages'])
return {
"messages": [SystemMessage(content=DEFAULT_SYSTEM_PROMPT)] + state["messages"]
}
def assistant(state: AgentState):
# state["messages"] = [SystemMessage(content=DEFAULT_SYSTEM_PROMPT)] + state["messages"]
step_print(state, 'assistant')
ai_message = chat_llm.invoke(state["messages"])
messages_print([ai_message])
return {
'messages': state["messages"] + [ai_message]
}
base_tool_node = ToolNode(tools)
def wrapped_tool_node(state: AgentState):
step_print(state, 'Tools')
# Call the original ToolNode
result = base_tool_node.invoke(state)
messages_print(result["messages"])
# Append to the messages list instead of replacing it
state["messages"] += result["messages"]
return {"messages": state["messages"]}
def checker_final_answer(state: AgentState):
step_print(state, 'Final Check')
s = state['messages'][-1].content
if "FINAL ANSWER: " not in s:
return {
'messages': state["messages"],
'final_output_is_good': False
}
return {
'final_output_is_good': True
}
# ------------------------------------------------------ #
# CONDITIONAL FUNCTIONS
# ------------------------------------------------------ #
def condition_output(state: AgentState) -> Literal["assistant", "__end__"]:
if state['final_output_is_good']:
return END
return "assistant"
def condition_tools_or_continue(
state: Union[list[AnyMessage], dict[str, Any], BaseModel],
messages_key: str = "messages",
) -> Literal["tools", "checker_final_answer"]:
if isinstance(state, list):
ai_message = state[-1]
elif isinstance(state, dict) and (messages := state.get(messages_key, [])):
ai_message = messages[-1]
elif messages := getattr(state, messages_key, []):
ai_message = messages[-1]
else:
# pass
raise ValueError(f"No messages found in input state to tool_edge: {state}")
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
return "tools"
return "checker_final_answer"
# return "__end__"
# ------------------------------------------------------ #
# BUILDERS
# ------------------------------------------------------ #
def workflow_simple() -> Tuple[StateGraph, str]:
i_builder = StateGraph(AgentState)
# Nodes
i_builder.add_node('preprocessing', preprocessing)
i_builder.add_node('assistant', assistant)
# Edges
i_builder.add_edge(START, 'preprocessing')
i_builder.add_edge('preprocessing', 'assistant')
return i_builder, 'workflow_simple'
def workflow_tools() -> Tuple[StateGraph, str]:
i_builder = StateGraph(AgentState)
# Nodes
i_builder.add_node('preprocessing', preprocessing)
i_builder.add_node('assistant', assistant)
i_builder.add_node('tools', wrapped_tool_node)
i_builder.add_node('checker_final_answer', checker_final_answer)
# Edges
i_builder.add_edge(START, 'preprocessing')
i_builder.add_edge('preprocessing', 'assistant')
i_builder.add_conditional_edges('assistant', condition_tools_or_continue)
i_builder.add_edge('tools', 'assistant')
i_builder.add_conditional_edges('checker_final_answer', condition_output)
return i_builder, 'workflow_tools'
@traceable
def main():
# Laminar.initialize(project_api_key=LAMINAR_API_KEY)
# ------------------------------------------------------ #
# COMPILATION
# ------------------------------------------------------ #
# builder, builder_name = workflow_simple()
builder, builder_name = workflow_tools()
alfred = builder.compile()
# print(alfred.get_graph().draw_ascii())
# print(alfred.get_graph().draw_mermaid())
# with open(f"{builder_name}.png", "wb") as f:
# f.write(alfred.get_graph().draw_mermaid_png())
# ------------------------------------------------------ #
# EXAMPLES
# ------------------------------------------------------ #
# response = alfred.invoke({'messages': "What is an apple?"})
# ---
question = """
If Eliud Kipchoge could maintain his record-making marathon pace indefinitely,
how many thousand hours would it take him to run the distance between the Earth and the Moon its closest approach?
Please use the minimum perigee value on the Wikipedia page for the Moon when carrying out your calculation.
Round your result to the nearest 1000 hours and do not use any comma separators if necessary.
"""
# response = alfred.invoke({'messages': [HumanMessage(content=question.replace('\n', ""))]})
response = alfred.invoke({'messages': [HumanMessage(content="Who is the president of USA in 2025?")]})
print(f"--- OUTPUT --- \n{response['messages'][-1].content}\n--- --- ---")
if __name__ == '__main__':
main()