File size: 5,099 Bytes
5b20995 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import streamlit as st
from typing import TypedDict, Annotated
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, ToolMessage, AIMessage
from langgraph.prebuilt import tools_condition
import os
# Streamlit UI Header
st.title("Checkpoints and Breakpoints")
st.caption("Demonstrating LangGraph workflow execution with interruptions and tool invocation.")
# Fetch API Keys
openai_api_key = os.getenv("OPENAI_API_KEY")
tavily_api_key = os.getenv("TAVILY_API_KEY")
if not openai_api_key or not tavily_api_key:
st.error("API keys are missing! Set OPENAI_API_KEY and TAVILY_API_KEY in Hugging Face Spaces Secrets.")
st.stop()
os.environ["OPENAI_API_KEY"] = openai_api_key
os.environ["TAVILY_API_KEY"] = tavily_api_key
# Define State Class
class State(TypedDict):
messages: Annotated[list, add_messages]
# Initialize LLM and Tools
llm = ChatOpenAI(model="gpt-4o-mini")
tool = TavilySearchResults(max_results=2)
llm_with_tools = llm.bind_tools([tool])
# Agent Node
def Agent(state: State):
st.sidebar.write("Agent Input State:", state["messages"])
response = llm_with_tools.invoke(state["messages"])
st.sidebar.write("Agent Response:", response)
return {"messages": [response]}
# Tools Execution Node
def ExecuteTools(state: State):
tool_calls = state["messages"][-1].tool_calls
responses = []
if tool_calls:
for call in tool_calls:
tool_name = call["name"]
args = call["args"]
st.sidebar.write("Tool Call Detected:", tool_name, args)
if tool_name == "tavily_search_results_json":
tool_response = tool.invoke({"query": args["query"]})
st.sidebar.write("Tool Response:", tool_response)
responses.append(ToolMessage(content=str(tool_response), tool_call_id=call["id"]))
return {"messages": responses}
# Memory Checkpoint
memory = MemorySaver()
# Build the Graph
graph = StateGraph(State)
graph.add_node("Agent", Agent)
graph.add_node("ExecuteTools", ExecuteTools)
# Add Conditional Edge to Check for Tools
def custom_tools_condition(state: State):
return "True" if state["messages"][-1].tool_calls else "False"
graph.add_conditional_edges("Agent", custom_tools_condition, {"True": "ExecuteTools", "False": "Agent"})
graph.add_edge("ExecuteTools", "Agent")
graph.set_entry_point("Agent")
# Compile the Graph
app = graph.compile(checkpointer=memory, interrupt_before=["ExecuteTools"])
# Display Graph Visualization
st.subheader("Graph Visualization")
st.image(app.get_graph().draw_mermaid_png(), caption="Workflow Graph", use_container_width=True)
# Run the Workflow
st.subheader("Run the Workflow")
user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh")
thread_id = st.text_input("Thread ID", "1")
if st.button("Execute Workflow"):
thread = {"configurable": {"thread_id": thread_id}}
input_message = {"messages": [HumanMessage(content=user_input)]}
st.write("### Execution Outputs")
outputs = []
try:
# Stream the graph execution
for event in app.stream(input_message, thread, stream_mode="values"):
output_message = event["messages"][-1]
st.code(output_message.content)
outputs.append(output_message.content)
st.sidebar.write("Intermediate State:", event["messages"])
# Display Intermediate Outputs
if outputs:
st.subheader("Intermediate Outputs")
for idx, output in enumerate(outputs, start=1):
st.write(f"**Step {idx}:**")
st.code(output)
else:
st.warning("No outputs generated. Check the workflow or tool calls.")
# Snapshot of Current State
st.subheader("Current State Snapshot")
snapshot = app.get_state(thread)
current_message = snapshot.values["messages"][-1]
st.code(current_message.pretty_print())
# Manual Update for Interrupted State
if hasattr(current_message, "tool_calls") and current_message.tool_calls:
tool_call_id = current_message.tool_calls[0]["id"]
manual_response = st.text_area("Manual Tool Response", "Enter response to continue execution...")
if st.button("Update State"):
new_messages = [
ToolMessage(content=manual_response, tool_call_id=tool_call_id),
AIMessage(content=manual_response),
]
app.update_state(thread, {"messages": new_messages})
st.success("State updated successfully!")
st.code(app.get_state(thread).values["messages"][-1].pretty_print())
else:
st.warning("No tool calls detected to update the state.")
except Exception as e:
st.error(f"Error during execution: {e}")
|