DrishtiSharma's picture
Update app.py
bfb9e67 verified
raw
history blame
5.19 kB
import streamlit as st
from typing import TypedDict, Annotated
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, ToolMessage, AIMessage
from langgraph.prebuilt import ToolNode, tools_condition
import os
# Streamlit UI Header
st.title("Checkpoints and Breakpoints")
st.caption("Demonstrating workflow execution with checkpoints and tool invocation.")
# Fetch API Keys
openai_api_key = os.getenv("OPENAI_API_KEY")
tavily_api_key = os.getenv("TAVILY_API_KEY")
if openai_api_key and tavily_api_key:
os.environ["OPENAI_API_KEY"] = openai_api_key
os.environ["TAVILY_API_KEY"] = tavily_api_key
# Define State Class
class State(TypedDict):
messages: Annotated[list, add_messages]
# Initialize LLM and Tools
llm = ChatOpenAI(model="gpt-4o-mini")
tool = TavilySearchResults(max_results=2)
llm_with_tools = llm.bind_tools([tool])
# Agent Function
def Agent(state: State):
st.sidebar.write("Agent received state:", state["messages"])
response = llm_with_tools.invoke(state["messages"])
st.sidebar.write("Agent Response:", response)
return {"messages": [response]}
# Tool Execution Logic
def ToolExecutor(state: State):
last_message = state["messages"][-1]
if "tool_calls" in last_message.additional_kwargs:
tool_call = last_message.additional_kwargs["tool_calls"][0]
tool_name = tool_call["function"]["name"]
tool_args = tool_call["function"]["arguments"]
st.sidebar.write("Tool Call Detected:", tool_name, tool_args)
# Execute the tool with provided arguments
if tool_name == "tavily_search_results_json":
query = eval(tool_args)["query"] # Convert stringified arguments to dict
tool_response = tool.invoke({"query": query})
st.sidebar.write("Tool Response:", tool_response)
# Return the tool response in a new ToolMessage
return {"messages": [ToolMessage(content=str(tool_response), tool_call_id=tool_call["id"])]}
return state # Fallback
# Memory Checkpoint
memory = MemorySaver()
# Build the Graph
graph = StateGraph(State)
graph.add_node("Agent", Agent)
graph.add_node("ToolExecutor", ToolExecutor)
graph.add_conditional_edges("Agent", tools_condition, {"True": "ToolExecutor"})
graph.add_edge("ToolExecutor", "Agent")
graph.set_entry_point("Agent")
# Compile Graph
app = graph.compile(checkpointer=memory, interrupt_before=["ToolExecutor"])
# Display Graph Visualization
st.subheader("Graph Workflow")
st.image(app.get_graph().draw_mermaid_png(), caption="Graph Visualization", use_container_width=True)
# Input Section
st.subheader("Run the Workflow")
user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh")
thread_id = st.text_input("Thread ID", "1")
if st.button("Execute Workflow"):
thread = {"configurable": {"thread_id": thread_id}}
input_message = {"messages": [HumanMessage(content=user_input)]}
st.write("### Execution Outputs")
outputs = []
for event in app.stream(input_message, thread, stream_mode="values"):
st.code(event["messages"][-1].content)
outputs.append(event["messages"][-1].content)
st.sidebar.write("Intermediate State:", event["messages"])
if outputs:
st.subheader("Intermediate Outputs")
for idx, output in enumerate(outputs, start=1):
st.write(f"**Step {idx}:**")
st.code(output)
else:
st.warning("No outputs generated. Adjust your input to trigger tools.")
# Display Snapshot of State
st.subheader("Current State Snapshot")
snapshot = app.get_state(thread)
current_message = snapshot.values["messages"][-1]
st.code(current_message.pretty_print())
# Manual Update Section
if hasattr(current_message, "tool_calls") and current_message.tool_calls:
tool_call_id = current_message.tool_calls[0]["id"]
manual_response = st.text_area("Manual Tool Response", "Enter your response to continue execution...")
if st.button("Update State"):
new_messages = [
ToolMessage(content=manual_response, tool_call_id=tool_call_id),
AIMessage(content=manual_response),
]
app.update_state(thread, {"messages": new_messages})
st.success("State updated successfully!")
st.code(app.get_state(thread).values["messages"][-1].pretty_print())
else:
st.warning("No tool calls available for manual updates.")
else:
st.error("API keys are missing! Please set `OPENAI_API_KEY` and `TAVILY_API_KEY` in Hugging Face Spaces Secrets.")