Update app.py
Browse files
app.py
CHANGED
|
@@ -37,21 +37,41 @@ if openai_api_key and tavily_api_key:
|
|
| 37 |
st.sidebar.write("Agent Response:", response)
|
| 38 |
return {"messages": [response]}
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
# Memory Checkpoint
|
| 41 |
memory = MemorySaver()
|
| 42 |
|
| 43 |
# Build the Graph
|
| 44 |
graph = StateGraph(State)
|
| 45 |
-
tool_node = ToolNode(tools=[tool])
|
| 46 |
-
|
| 47 |
graph.add_node("Agent", Agent)
|
| 48 |
-
graph.add_node("
|
| 49 |
-
|
| 50 |
-
graph.
|
|
|
|
| 51 |
graph.set_entry_point("Agent")
|
| 52 |
|
| 53 |
# Compile Graph
|
| 54 |
-
app = graph.compile(checkpointer=memory, interrupt_before=["
|
| 55 |
|
| 56 |
# Display Graph Visualization
|
| 57 |
st.subheader("Graph Workflow")
|
|
|
|
| 37 |
st.sidebar.write("Agent Response:", response)
|
| 38 |
return {"messages": [response]}
|
| 39 |
|
| 40 |
+
# Tool Execution Logic
|
| 41 |
+
def ToolExecutor(state: State):
|
| 42 |
+
last_message = state["messages"][-1]
|
| 43 |
+
if "tool_calls" in last_message.additional_kwargs:
|
| 44 |
+
tool_call = last_message.additional_kwargs["tool_calls"][0]
|
| 45 |
+
tool_name = tool_call["function"]["name"]
|
| 46 |
+
tool_args = tool_call["function"]["arguments"]
|
| 47 |
+
|
| 48 |
+
st.sidebar.write("Tool Call Detected:", tool_name, tool_args)
|
| 49 |
+
|
| 50 |
+
# Execute the tool with provided arguments
|
| 51 |
+
if tool_name == "tavily_search_results_json":
|
| 52 |
+
query = eval(tool_args)["query"] # Convert stringified arguments to dict
|
| 53 |
+
tool_response = tool.invoke({"query": query})
|
| 54 |
+
st.sidebar.write("Tool Response:", tool_response)
|
| 55 |
+
|
| 56 |
+
# Return the tool response in a new ToolMessage
|
| 57 |
+
return {"messages": [ToolMessage(content=str(tool_response), tool_call_id=tool_call["id"])]}
|
| 58 |
+
|
| 59 |
+
return state # Fallback
|
| 60 |
+
|
| 61 |
# Memory Checkpoint
|
| 62 |
memory = MemorySaver()
|
| 63 |
|
| 64 |
# Build the Graph
|
| 65 |
graph = StateGraph(State)
|
|
|
|
|
|
|
| 66 |
graph.add_node("Agent", Agent)
|
| 67 |
+
graph.add_node("ToolExecutor", ToolExecutor)
|
| 68 |
+
|
| 69 |
+
graph.add_conditional_edges("Agent", tools_condition, {"True": "ToolExecutor"})
|
| 70 |
+
graph.add_edge("ToolExecutor", "Agent")
|
| 71 |
graph.set_entry_point("Agent")
|
| 72 |
|
| 73 |
# Compile Graph
|
| 74 |
+
app = graph.compile(checkpointer=memory, interrupt_before=["ToolExecutor"])
|
| 75 |
|
| 76 |
# Display Graph Visualization
|
| 77 |
st.subheader("Graph Workflow")
|