File size: 5,185 Bytes
8c88d5f
 
6b1ac2e
8c88d5f
 
 
 
 
 
 
 
2b42de2
 
368d3b7
8c88d5f
71a90c7
8c88d5f
 
 
 
 
 
 
2b42de2
8c88d5f
 
 
2b42de2
8c88d5f
 
6b1ac2e
8c88d5f
2b42de2
8c88d5f
6b1ac2e
368d3b7
6b1ac2e
368d3b7
8c88d5f
bfb9e67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b42de2
8c88d5f
 
6b1ac2e
8c88d5f
 
bfb9e67
 
 
 
8c88d5f
 
2b42de2
bfb9e67
8c88d5f
2b42de2
8c88d5f
2b42de2
8c88d5f
 
 
368d3b7
8c88d5f
 
 
 
6b1ac2e
8c88d5f
 
 
 
6b1ac2e
 
 
65c0aec
2b42de2
 
6b1ac2e
 
2b42de2
 
368d3b7
2b42de2
6b1ac2e
8c88d5f
 
6b1ac2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b42de2
6b1ac2e
8c88d5f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import streamlit as st
from typing import TypedDict, Annotated
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, ToolMessage, AIMessage
from langgraph.prebuilt import ToolNode, tools_condition
import os

# Streamlit UI Header
st.title("Checkpoints and Breakpoints")
st.caption("Demonstrating workflow execution with checkpoints and tool invocation.")

# Fetch API Keys
openai_api_key = os.getenv("OPENAI_API_KEY")
tavily_api_key = os.getenv("TAVILY_API_KEY")

if openai_api_key and tavily_api_key:
    os.environ["OPENAI_API_KEY"] = openai_api_key
    os.environ["TAVILY_API_KEY"] = tavily_api_key

    # Define State Class
    class State(TypedDict):
        messages: Annotated[list, add_messages]

    # Initialize LLM and Tools
    llm = ChatOpenAI(model="gpt-4o-mini")
    tool = TavilySearchResults(max_results=2)
    llm_with_tools = llm.bind_tools([tool])

    # Agent Function
    def Agent(state: State):
        st.sidebar.write("Agent received state:", state["messages"])
        response = llm_with_tools.invoke(state["messages"])
        st.sidebar.write("Agent Response:", response)
        return {"messages": [response]}

    # Tool Execution Logic
    def ToolExecutor(state: State):
        last_message = state["messages"][-1]
        if "tool_calls" in last_message.additional_kwargs:
            tool_call = last_message.additional_kwargs["tool_calls"][0]
            tool_name = tool_call["function"]["name"]
            tool_args = tool_call["function"]["arguments"]

            st.sidebar.write("Tool Call Detected:", tool_name, tool_args)

            # Execute the tool with provided arguments
            if tool_name == "tavily_search_results_json":
                query = eval(tool_args)["query"]  # Convert stringified arguments to dict
                tool_response = tool.invoke({"query": query})
                st.sidebar.write("Tool Response:", tool_response)

                # Return the tool response in a new ToolMessage
                return {"messages": [ToolMessage(content=str(tool_response), tool_call_id=tool_call["id"])]}

        return state  # Fallback

    # Memory Checkpoint
    memory = MemorySaver()

    # Build the Graph
    graph = StateGraph(State)
    graph.add_node("Agent", Agent)
    graph.add_node("ToolExecutor", ToolExecutor)

    graph.add_conditional_edges("Agent", tools_condition, {"True": "ToolExecutor"})
    graph.add_edge("ToolExecutor", "Agent")
    graph.set_entry_point("Agent")

    # Compile Graph
    app = graph.compile(checkpointer=memory, interrupt_before=["ToolExecutor"])

    # Display Graph Visualization
    st.subheader("Graph Workflow")
    st.image(app.get_graph().draw_mermaid_png(), caption="Graph Visualization", use_container_width=True)

    # Input Section
    st.subheader("Run the Workflow")
    user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh")
    thread_id = st.text_input("Thread ID", "1")

    if st.button("Execute Workflow"):
        thread = {"configurable": {"thread_id": thread_id}}
        input_message = {"messages": [HumanMessage(content=user_input)]}

        st.write("### Execution Outputs")
        outputs = []
        for event in app.stream(input_message, thread, stream_mode="values"):
            st.code(event["messages"][-1].content)
            outputs.append(event["messages"][-1].content)
            st.sidebar.write("Intermediate State:", event["messages"])

        if outputs:
            st.subheader("Intermediate Outputs")
            for idx, output in enumerate(outputs, start=1):
                st.write(f"**Step {idx}:**")
                st.code(output)
        else:
            st.warning("No outputs generated. Adjust your input to trigger tools.")

        # Display Snapshot of State
        st.subheader("Current State Snapshot")
        snapshot = app.get_state(thread)
        current_message = snapshot.values["messages"][-1]
        st.code(current_message.pretty_print())

        # Manual Update Section
        if hasattr(current_message, "tool_calls") and current_message.tool_calls:
            tool_call_id = current_message.tool_calls[0]["id"]
            manual_response = st.text_area("Manual Tool Response", "Enter your response to continue execution...")
            if st.button("Update State"):
                new_messages = [
                    ToolMessage(content=manual_response, tool_call_id=tool_call_id),
                    AIMessage(content=manual_response),
                ]
                app.update_state(thread, {"messages": new_messages})
                st.success("State updated successfully!")
                st.code(app.get_state(thread).values["messages"][-1].pretty_print())
        else:
            st.warning("No tool calls available for manual updates.")
else:
    st.error("API keys are missing! Please set `OPENAI_API_KEY` and `TAVILY_API_KEY` in Hugging Face Spaces Secrets.")