import streamlit as st from typing import TypedDict, Annotated from langgraph.graph import StateGraph from langgraph.checkpoint.memory import MemorySaver from langgraph.graph.message import add_messages from langchain_openai import ChatOpenAI from langchain_community.tools.tavily_search import TavilySearchResults from langchain_core.messages import HumanMessage, ToolMessage, AIMessage from langgraph.prebuilt import ToolNode, tools_condition import os # Streamlit UI Header st.title("Checkpoints and Breakpoints") st.caption("Demonstrating workflow execution with checkpoints and tool invocation.") # Fetch API Keys openai_api_key = os.getenv("OPENAI_API_KEY") tavily_api_key = os.getenv("TAVILY_API_KEY") if openai_api_key and tavily_api_key: os.environ["OPENAI_API_KEY"] = openai_api_key os.environ["TAVILY_API_KEY"] = tavily_api_key # Define State Class class State(TypedDict): messages: Annotated[list, add_messages] # Initialize LLM and Tools llm = ChatOpenAI(model="gpt-4o-mini") tool = TavilySearchResults(max_results=2) llm_with_tools = llm.bind_tools([tool]) # Agent Function def Agent(state: State): st.sidebar.write("Agent received state:", state["messages"]) response = llm_with_tools.invoke(state["messages"]) st.sidebar.write("Agent Response:", response) return {"messages": [response]} # Memory Checkpoint memory = MemorySaver() # Build the Graph graph = StateGraph(State) tool_node = ToolNode(tools=[tool]) graph.add_node("Agent", Agent) graph.add_node("tools", tool_node) graph.add_conditional_edges("Agent", tools_condition) graph.add_edge("tools", "Agent") graph.set_entry_point("Agent") # Compile Graph app = graph.compile(checkpointer=memory, interrupt_before=["tools"]) # Display Graph Visualization st.subheader("Graph Workflow") st.image(app.get_graph().draw_mermaid_png(), caption="Graph Visualization", use_container_width=True) # Input Section st.subheader("Run the Workflow") user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh") thread_id = st.text_input("Thread ID", "1") if st.button("Execute Workflow"): thread = {"configurable": {"thread_id": thread_id}} input_message = {"messages": [HumanMessage(content=user_input)]} st.write("### Execution Outputs") outputs = [] for event in app.stream(input_message, thread, stream_mode="values"): st.code(event["messages"][-1].content) outputs.append(event["messages"][-1].content) st.sidebar.write("Intermediate State:", event["messages"]) if outputs: st.subheader("Intermediate Outputs") for idx, output in enumerate(outputs, start=1): st.write(f"**Step {idx}:**") st.code(output) else: st.warning("No outputs generated. Adjust your input to trigger tools.") # Display Snapshot of State st.subheader("Current State Snapshot") snapshot = app.get_state(thread) current_message = snapshot.values["messages"][-1] st.code(current_message.pretty_print()) # Manual Update Section if hasattr(current_message, "tool_calls") and current_message.tool_calls: tool_call_id = current_message.tool_calls[0]["id"] manual_response = st.text_area("Manual Tool Response", "Enter your response to continue execution...") if st.button("Update State"): new_messages = [ ToolMessage(content=manual_response, tool_call_id=tool_call_id), AIMessage(content=manual_response), ] app.update_state(thread, {"messages": new_messages}) st.success("State updated successfully!") st.code(app.get_state(thread).values["messages"][-1].pretty_print()) else: st.warning("No tool calls available for manual updates.") else: st.error("API keys are missing! Please set `OPENAI_API_KEY` and `TAVILY_API_KEY` in Hugging Face Spaces Secrets.")