File size: 4,245 Bytes
20bc95a
 
03bda09
20bc95a
 
 
 
 
 
 
 
 
 
36fafc5
20bc95a
 
 
 
 
36fafc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20bc95a
03bda09
 
20bc95a
36fafc5
20bc95a
 
03bda09
 
20bc95a
 
36fafc5
20bc95a
36fafc5
20bc95a
 
03bda09
 
 
36fafc5
03bda09
 
36fafc5
 
 
03bda09
 
 
 
 
36fafc5
03bda09
20bc95a
36fafc5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st
from typing import TypedDict, Annotated
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, ToolMessage, AIMessage
from langgraph.prebuilt import ToolNode, tools_condition
import os

# Streamlit UI Header
st.title("Checkpoints and Breakpoints")
st.caption("Demonstrating LangGraph workflow execution with interruptions and tool invocation.")

# Fetch API Keys
openai_api_key = os.getenv("OPENAI_API_KEY")
tavily_api_key = os.getenv("TAVILY_API_KEY")

if not openai_api_key or not tavily_api_key:
    st.error("API keys are missing! Set OPENAI_API_KEY and TAVILY_API_KEY in Hugging Face Spaces Secrets.")
    st.stop()

os.environ["OPENAI_API_KEY"] = openai_api_key
os.environ["TAVILY_API_KEY"] = tavily_api_key

# Define State Class
class State(TypedDict):
    messages: Annotated[list, add_messages]

# Initialize LLM and Tools
llm = ChatOpenAI(model="gpt-4o-mini")
tool = TavilySearchResults(max_results=2)
tools = [tool]
llm_with_tools = llm.bind_tools(tools)

# Agent Node
def Agent(state: State):
    st.sidebar.write("Agent received input:", state["messages"])
    response = llm_with_tools.invoke(state["messages"])
    st.sidebar.write("Agent Response:", response)
    return {"messages": [response]}

# Set up Graph
memory = MemorySaver()
graph = StateGraph(State)

# Add nodes
graph.add_node("Agent", Agent)
tool_node = ToolNode(tools=[tool])
graph.add_node("tools", tool_node)

# Add edges
graph.add_conditional_edges("Agent", tools_condition)
graph.add_edge("tools", "Agent")
graph.set_entry_point("Agent")

# Compile with Breakpoint
app = graph.compile(checkpointer=memory, interrupt_before=["tools"])

# Display Graph Visualization
st.subheader("Graph Visualization")
st.image(app.get_graph().draw_mermaid_png(), caption="Workflow Graph", use_container_width=True)

# Input Section
st.subheader("Run the Workflow")
user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh")
thread_id = st.text_input("Thread ID", "1")

if st.button("Execute Workflow"):
    thread = {"configurable": {"thread_id": thread_id}}
    input_message = {"messages": [HumanMessage(content=user_input)]}

    st.write("### Execution Outputs")
    outputs = []

    # Execute the workflow
    try:
        for event in app.stream(input_message, thread, stream_mode="values"):
            st.code(event["messages"][-1].content)
            outputs.append(event["messages"][-1].content)

        # Display Intermediate Outputs
        if outputs:
            st.subheader("Intermediate Outputs")
            for idx, output in enumerate(outputs, start=1):
                st.write(f"**Step {idx}:**")
                st.code(output)
        else:
            st.warning("No outputs generated yet.")

        # Show State Snapshot
        st.subheader("Current State Snapshot")
        snapshot = app.get_state(thread)
        current_message = snapshot.values["messages"][-1]
        st.code(current_message.pretty_print())

        # Handle Tool Calls with Manual Input
        if hasattr(current_message, "tool_calls") and current_message.tool_calls:
            tool_call_id = current_message.tool_calls[0]["id"]
            st.warning("Execution paused before tool execution. Provide manual input to resume.")
            manual_response = st.text_area("Manual Tool Response", "Enter the tool's response here...")
            if st.button("Resume Execution"):
                new_messages = [
                    ToolMessage(content=manual_response, tool_call_id=tool_call_id),
                    AIMessage(content=manual_response),
                ]
                app.update_state(thread, {"messages": new_messages})
                st.success("State updated! Rerun the workflow to continue.")
                st.code(app.get_state(thread).values["messages"][-1].pretty_print())
        else:
            st.info("No tool calls detected at this step.")
    except Exception as e:
        st.error(f"Error during execution: {e}")