import streamlit as st from typing import TypedDict, Annotated from langgraph.graph import StateGraph from langgraph.checkpoint.memory import MemorySaver from langgraph.graph.message import add_messages from langchain_openai import ChatOpenAI from langchain_community.tools.tavily_search import TavilySearchResults from langchain_core.messages import HumanMessage, ToolMessage, AIMessage from langgraph.prebuilt import ToolNode, tools_condition import os # Streamlit UI Header st.title("Checkpoints and Breakpoints") st.caption("Demonstrating workflow execution with checkpoints and tool invocation.") # Fetch API Keys openai_api_key = os.getenv("OPENAI_API_KEY") tavily_api_key = os.getenv("TAVILY_API_KEY") if openai_api_key and tavily_api_key: os.environ["OPENAI_API_KEY"] = openai_api_key os.environ["TAVILY_API_KEY"] = tavily_api_key # Define State Class class State(TypedDict): messages: Annotated[list, add_messages] # Initialize LLM and Tools llm = ChatOpenAI(model="gpt-4o-mini") tool = TavilySearchResults(max_results=2) llm_with_tools = llm.bind_tools([tool]) # Agent Function def Agent(state: State): st.sidebar.write("Agent received state:", state["messages"]) response = llm_with_tools.invoke(state["messages"]) st.sidebar.write("Agent Response:", response) return {"messages": [response]} # Tool Execution Logic def ToolExecutor(state: State): last_message = state["messages"][-1] if "tool_calls" in last_message.additional_kwargs: tool_call = last_message.additional_kwargs["tool_calls"][0] tool_name = tool_call["function"]["name"] tool_args = tool_call["function"]["arguments"] st.sidebar.write("Tool Call Detected:", tool_name, tool_args) # Execute the tool with provided arguments if tool_name == "tavily_search_results_json": query = eval(tool_args)["query"] # Convert stringified arguments to dict tool_response = tool.invoke({"query": query}) st.sidebar.write("Tool Response:", tool_response) # Return the tool response in a new ToolMessage return {"messages": [ToolMessage(content=str(tool_response), tool_call_id=tool_call["id"])]} return state # Fallback # Memory Checkpoint memory = MemorySaver() # Build the Graph graph = StateGraph(State) graph.add_node("Agent", Agent) graph.add_node("ToolExecutor", ToolExecutor) graph.add_conditional_edges("Agent", tools_condition, {"True": "ToolExecutor"}) graph.add_edge("ToolExecutor", "Agent") graph.set_entry_point("Agent") # Compile Graph app = graph.compile(checkpointer=memory, interrupt_before=["ToolExecutor"]) # Display Graph Visualization st.subheader("Graph Workflow") st.image(app.get_graph().draw_mermaid_png(), caption="Graph Visualization", use_container_width=True) # Input Section st.subheader("Run the Workflow") user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh") thread_id = st.text_input("Thread ID", "1") if st.button("Execute Workflow"): thread = {"configurable": {"thread_id": thread_id}} input_message = {"messages": [HumanMessage(content=user_input)]} st.write("### Execution Outputs") outputs = [] for event in app.stream(input_message, thread, stream_mode="values"): st.code(event["messages"][-1].content) outputs.append(event["messages"][-1].content) st.sidebar.write("Intermediate State:", event["messages"]) if outputs: st.subheader("Intermediate Outputs") for idx, output in enumerate(outputs, start=1): st.write(f"**Step {idx}:**") st.code(output) else: st.warning("No outputs generated. Adjust your input to trigger tools.") # Display Snapshot of State st.subheader("Current State Snapshot") snapshot = app.get_state(thread) current_message = snapshot.values["messages"][-1] st.code(current_message.pretty_print()) # Manual Update Section if hasattr(current_message, "tool_calls") and current_message.tool_calls: tool_call_id = current_message.tool_calls[0]["id"] manual_response = st.text_area("Manual Tool Response", "Enter your response to continue execution...") if st.button("Update State"): new_messages = [ ToolMessage(content=manual_response, tool_call_id=tool_call_id), AIMessage(content=manual_response), ] app.update_state(thread, {"messages": new_messages}) st.success("State updated successfully!") st.code(app.get_state(thread).values["messages"][-1].pretty_print()) else: st.warning("No tool calls available for manual updates.") else: st.error("API keys are missing! Please set `OPENAI_API_KEY` and `TAVILY_API_KEY` in Hugging Face Spaces Secrets.")