File size: 4,554 Bytes
8c88d5f
 
 
 
 
 
 
 
 
 
 
2b42de2
 
368d3b7
8c88d5f
71a90c7
8c88d5f
 
 
 
 
 
 
2b42de2
8c88d5f
 
 
2b42de2
8c88d5f
 
 
 
 
2b42de2
8c88d5f
71a90c7
 
368d3b7
 
 
8c88d5f
2b42de2
8c88d5f
 
2b42de2
8c88d5f
 
2b42de2
8c88d5f
 
71a90c7
 
 
 
 
 
8c88d5f
 
 
2b42de2
8c88d5f
 
2b42de2
8c88d5f
2b42de2
8c88d5f
 
 
368d3b7
8c88d5f
 
 
 
 
 
 
 
 
2b42de2
 
 
 
368d3b7
2b42de2
 
 
 
 
 
368d3b7
2b42de2
 
8c88d5f
 
2b42de2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c88d5f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import streamlit as st
from typing import TypedDict, Annotated
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, ToolMessage, AIMessage
from langgraph.prebuilt import ToolNode, tools_condition
import os

# Streamlit UI Header
st.title("Checkpoints and Breakpoints")
st.caption("Demonstrating workflow execution with checkpoints and tool invocation.")

# Fetch API Keys
openai_api_key = os.getenv("OPENAI_API_KEY")
tavily_api_key = os.getenv("TAVILY_API_KEY")

if openai_api_key and tavily_api_key:
    os.environ["OPENAI_API_KEY"] = openai_api_key
    os.environ["TAVILY_API_KEY"] = tavily_api_key

    # Define State Class
    class State(TypedDict):
        messages: Annotated[list, add_messages]

    # Initialize LLM and Tools
    llm = ChatOpenAI(model="gpt-4o-mini")
    tool = TavilySearchResults(max_results=2)
    tools = [tool]
    llm_with_tools = llm.bind_tools(tools)

    # Agent Function
    def Agent(state: State):
        print("Agent received state:", state)
        # Force tool invocation
        response = llm_with_tools.invoke(state["messages"])
        print("Agent Response:", response)
        return {"messages": [response]}

    # Memory Checkpoint
    memory = MemorySaver()

    # Graph Definition
    graph = StateGraph(State)
    tool_node = ToolNode(tools=[tool])

    graph.add_node("Agent", Agent)
    graph.add_node("tools", tool_node)

    # Force tools_condition always True for testing purposes
    def always_true(state):
        return True

    graph.add_conditional_edges("Agent", always_true, {"True": "tools"})
    graph.add_edge("tools", "Agent")
    graph.set_entry_point("Agent")

    # Compile Graph
    app = graph.compile(checkpointer=memory, interrupt_before=["tools"])

    # Display Graph Visualization
    st.subheader("Graph Workflow")
    st.image(app.get_graph().draw_mermaid_png(), caption="Graph Visualization", use_container_width=True)

    # Input Section
    st.subheader("Run the Workflow")
    user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh")
    thread_id = st.text_input("Thread ID", "1")

    if st.button("Execute Workflow"):
        thread = {"configurable": {"thread_id": thread_id}}
        input_message = {'messages': HumanMessage(content=user_input)}

        st.write("### Execution Outputs")
        outputs = []
        for event in app.stream(input_message, thread, stream_mode="values"):
            if "messages" in event and event["messages"]:
                latest_message = event["messages"][-1].pretty_print()
                outputs.append(latest_message)
                st.code(latest_message)
        
        if outputs:
            st.subheader("Intermediate Outputs")
            for i, output in enumerate(outputs, 1):
                st.write(f"**Step {i}:**")
                st.code(output)
        else:
            st.warning("No outputs generated. Adjust your input to trigger tools.")

        # Snapshot of Current State
        st.subheader("Current State Snapshot")
        snapshot = app.get_state(thread)
        if snapshot.values["messages"]:
            current_message = snapshot.values["messages"][-1]
            st.code(current_message.pretty_print())

            # Safe Access to Tool Calls
            if hasattr(current_message, "tool_calls") and current_message.tool_calls:
                tool_call_id = current_message.tool_calls[0]["id"]
                manual_response = st.text_area("Manual Tool Response", "Enter response to continue...")
                if st.button("Update State"):
                    new_messages = [
                        ToolMessage(content=manual_response, tool_call_id=tool_call_id),
                        AIMessage(content=manual_response),
                    ]
                    app.update_state(thread, {"messages": new_messages})
                    st.success("State updated successfully!")
                    st.code(app.get_state(thread).values["messages"][-1].pretty_print())
            else:
                st.warning("No tool calls available for manual updates.")
        else:
            st.warning("No state messages available.")
else:
    st.error("API keys are missing! Please set `OPENAI_API_KEY` and `TAVILY_API_KEY` in Hugging Face Spaces Secrets.")