File size: 5,099 Bytes
5b20995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import streamlit as st
from typing import TypedDict, Annotated
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, ToolMessage, AIMessage
from langgraph.prebuilt import tools_condition
import os

# Streamlit UI Header
st.title("Checkpoints and Breakpoints")
st.caption("Demonstrating LangGraph workflow execution with interruptions and tool invocation.")

# Fetch API Keys
openai_api_key = os.getenv("OPENAI_API_KEY")
tavily_api_key = os.getenv("TAVILY_API_KEY")

if not openai_api_key or not tavily_api_key:
    st.error("API keys are missing! Set OPENAI_API_KEY and TAVILY_API_KEY in Hugging Face Spaces Secrets.")
    st.stop()

os.environ["OPENAI_API_KEY"] = openai_api_key
os.environ["TAVILY_API_KEY"] = tavily_api_key

# Define State Class
class State(TypedDict):
    messages: Annotated[list, add_messages]

# Initialize LLM and Tools
llm = ChatOpenAI(model="gpt-4o-mini")
tool = TavilySearchResults(max_results=2)
llm_with_tools = llm.bind_tools([tool])

# Agent Node
def Agent(state: State):
    st.sidebar.write("Agent Input State:", state["messages"])
    response = llm_with_tools.invoke(state["messages"])
    st.sidebar.write("Agent Response:", response)
    return {"messages": [response]}

# Tools Execution Node
def ExecuteTools(state: State):
    tool_calls = state["messages"][-1].tool_calls
    responses = []

    if tool_calls:
        for call in tool_calls:
            tool_name = call["name"]
            args = call["args"]
            st.sidebar.write("Tool Call Detected:", tool_name, args)

            if tool_name == "tavily_search_results_json":
                tool_response = tool.invoke({"query": args["query"]})
                st.sidebar.write("Tool Response:", tool_response)
                responses.append(ToolMessage(content=str(tool_response), tool_call_id=call["id"]))
    return {"messages": responses}

# Memory Checkpoint
memory = MemorySaver()

# Build the Graph
graph = StateGraph(State)
graph.add_node("Agent", Agent)
graph.add_node("ExecuteTools", ExecuteTools)

# Add Conditional Edge to Check for Tools
def custom_tools_condition(state: State):
    return "True" if state["messages"][-1].tool_calls else "False"

graph.add_conditional_edges("Agent", custom_tools_condition, {"True": "ExecuteTools", "False": "Agent"})
graph.add_edge("ExecuteTools", "Agent")
graph.set_entry_point("Agent")

# Compile the Graph
app = graph.compile(checkpointer=memory, interrupt_before=["ExecuteTools"])

# Display Graph Visualization
st.subheader("Graph Visualization")
st.image(app.get_graph().draw_mermaid_png(), caption="Workflow Graph", use_container_width=True)

# Run the Workflow
st.subheader("Run the Workflow")
user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh")
thread_id = st.text_input("Thread ID", "1")

if st.button("Execute Workflow"):
    thread = {"configurable": {"thread_id": thread_id}}
    input_message = {"messages": [HumanMessage(content=user_input)]}

    st.write("### Execution Outputs")
    outputs = []

    try:
        # Stream the graph execution
        for event in app.stream(input_message, thread, stream_mode="values"):
            output_message = event["messages"][-1]
            st.code(output_message.content)
            outputs.append(output_message.content)
            st.sidebar.write("Intermediate State:", event["messages"])

        # Display Intermediate Outputs
        if outputs:
            st.subheader("Intermediate Outputs")
            for idx, output in enumerate(outputs, start=1):
                st.write(f"**Step {idx}:**")
                st.code(output)
        else:
            st.warning("No outputs generated. Check the workflow or tool calls.")

        # Snapshot of Current State
        st.subheader("Current State Snapshot")
        snapshot = app.get_state(thread)
        current_message = snapshot.values["messages"][-1]
        st.code(current_message.pretty_print())

        # Manual Update for Interrupted State
        if hasattr(current_message, "tool_calls") and current_message.tool_calls:
            tool_call_id = current_message.tool_calls[0]["id"]
            manual_response = st.text_area("Manual Tool Response", "Enter response to continue execution...")
            if st.button("Update State"):
                new_messages = [
                    ToolMessage(content=manual_response, tool_call_id=tool_call_id),
                    AIMessage(content=manual_response),
                ]
                app.update_state(thread, {"messages": new_messages})
                st.success("State updated successfully!")
                st.code(app.get_state(thread).values["messages"][-1].pretty_print())
        else:
            st.warning("No tool calls detected to update the state.")
    except Exception as e:
        st.error(f"Error during execution: {e}")