DrishtiSharma commited on
Commit
bfb9e67
Β·
verified Β·
1 Parent(s): 03bda09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -6
app.py CHANGED
@@ -37,21 +37,41 @@ if openai_api_key and tavily_api_key:
37
  st.sidebar.write("Agent Response:", response)
38
  return {"messages": [response]}
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # Memory Checkpoint
41
  memory = MemorySaver()
42
 
43
  # Build the Graph
44
  graph = StateGraph(State)
45
- tool_node = ToolNode(tools=[tool])
46
-
47
  graph.add_node("Agent", Agent)
48
- graph.add_node("tools", tool_node)
49
- graph.add_conditional_edges("Agent", tools_condition)
50
- graph.add_edge("tools", "Agent")
 
51
  graph.set_entry_point("Agent")
52
 
53
  # Compile Graph
54
- app = graph.compile(checkpointer=memory, interrupt_before=["tools"])
55
 
56
  # Display Graph Visualization
57
  st.subheader("Graph Workflow")
 
37
  st.sidebar.write("Agent Response:", response)
38
  return {"messages": [response]}
39
 
40
+ # Tool Execution Logic
41
+ def ToolExecutor(state: State):
42
+ last_message = state["messages"][-1]
43
+ if "tool_calls" in last_message.additional_kwargs:
44
+ tool_call = last_message.additional_kwargs["tool_calls"][0]
45
+ tool_name = tool_call["function"]["name"]
46
+ tool_args = tool_call["function"]["arguments"]
47
+
48
+ st.sidebar.write("Tool Call Detected:", tool_name, tool_args)
49
+
50
+ # Execute the tool with provided arguments
51
+ if tool_name == "tavily_search_results_json":
52
+ query = eval(tool_args)["query"] # Convert stringified arguments to dict
53
+ tool_response = tool.invoke({"query": query})
54
+ st.sidebar.write("Tool Response:", tool_response)
55
+
56
+ # Return the tool response in a new ToolMessage
57
+ return {"messages": [ToolMessage(content=str(tool_response), tool_call_id=tool_call["id"])]}
58
+
59
+ return state # Fallback
60
+
61
  # Memory Checkpoint
62
  memory = MemorySaver()
63
 
64
  # Build the Graph
65
  graph = StateGraph(State)
 
 
66
  graph.add_node("Agent", Agent)
67
+ graph.add_node("ToolExecutor", ToolExecutor)
68
+
69
+ graph.add_conditional_edges("Agent", tools_condition, {"True": "ToolExecutor"})
70
+ graph.add_edge("ToolExecutor", "Agent")
71
  graph.set_entry_point("Agent")
72
 
73
  # Compile Graph
74
+ app = graph.compile(checkpointer=memory, interrupt_before=["ToolExecutor"])
75
 
76
  # Display Graph Visualization
77
  st.subheader("Graph Workflow")