DrishtiSharma commited on
Commit
36fafc5
Β·
verified Β·
1 Parent(s): 3b5459d

Update interim.py

Browse files
Files changed (1) hide show
  1. interim.py +71 -62
interim.py CHANGED
@@ -11,95 +11,104 @@ import os
11
 
12
  # Streamlit UI Header
13
  st.title("Checkpoints and Breakpoints")
14
- st.caption("Demonstrating workflow execution with checkpoints and tool invocation.")
15
 
16
  # Fetch API Keys
17
  openai_api_key = os.getenv("OPENAI_API_KEY")
18
  tavily_api_key = os.getenv("TAVILY_API_KEY")
19
 
20
- if openai_api_key and tavily_api_key:
21
- os.environ["OPENAI_API_KEY"] = openai_api_key
22
- os.environ["TAVILY_API_KEY"] = tavily_api_key
23
-
24
- # Define State Class
25
- class State(TypedDict):
26
- messages: Annotated[list, add_messages]
27
-
28
- # Initialize LLM and Tools
29
- llm = ChatOpenAI(model="gpt-4o-mini")
30
- tool = TavilySearchResults(max_results=2)
31
- llm_with_tools = llm.bind_tools([tool])
32
-
33
- # Agent Function
34
- def Agent(state: State):
35
- st.sidebar.write("Agent received state:", state["messages"])
36
- response = llm_with_tools.invoke(state["messages"])
37
- st.sidebar.write("Agent Response:", response)
38
- return {"messages": [response]}
39
-
40
- # Memory Checkpoint
41
- memory = MemorySaver()
42
-
43
- # Build the Graph
44
- graph = StateGraph(State)
45
- tool_node = ToolNode(tools=[tool])
46
-
47
- graph.add_node("Agent", Agent)
48
- graph.add_node("tools", tool_node)
49
- graph.add_conditional_edges("Agent", tools_condition)
50
- graph.add_edge("tools", "Agent")
51
- graph.set_entry_point("Agent")
52
-
53
- # Compile Graph
54
- app = graph.compile(checkpointer=memory, interrupt_before=["tools"])
55
-
56
- # Display Graph Visualization
57
- st.subheader("Graph Workflow")
58
- st.image(app.get_graph().draw_mermaid_png(), caption="Graph Visualization", use_container_width=True)
59
-
60
- # Input Section
61
- st.subheader("Run the Workflow")
62
- user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh")
63
- thread_id = st.text_input("Thread ID", "1")
64
-
65
- if st.button("Execute Workflow"):
66
- thread = {"configurable": {"thread_id": thread_id}}
67
- input_message = {"messages": [HumanMessage(content=user_input)]}
68
-
69
- st.write("### Execution Outputs")
70
- outputs = []
 
 
 
 
 
 
 
 
71
  for event in app.stream(input_message, thread, stream_mode="values"):
72
  st.code(event["messages"][-1].content)
73
  outputs.append(event["messages"][-1].content)
74
- st.sidebar.write("Intermediate State:", event["messages"])
75
 
 
76
  if outputs:
77
  st.subheader("Intermediate Outputs")
78
  for idx, output in enumerate(outputs, start=1):
79
  st.write(f"**Step {idx}:**")
80
  st.code(output)
81
  else:
82
- st.warning("No outputs generated. Adjust your input to trigger tools.")
83
 
84
- # Display Snapshot of State
85
  st.subheader("Current State Snapshot")
86
  snapshot = app.get_state(thread)
87
  current_message = snapshot.values["messages"][-1]
88
  st.code(current_message.pretty_print())
89
 
90
- # Manual Update Section
91
  if hasattr(current_message, "tool_calls") and current_message.tool_calls:
92
  tool_call_id = current_message.tool_calls[0]["id"]
93
- manual_response = st.text_area("Manual Tool Response", "Enter your response to continue execution...")
94
- if st.button("Update State"):
 
95
  new_messages = [
96
  ToolMessage(content=manual_response, tool_call_id=tool_call_id),
97
  AIMessage(content=manual_response),
98
  ]
99
  app.update_state(thread, {"messages": new_messages})
100
- st.success("State updated successfully!")
101
  st.code(app.get_state(thread).values["messages"][-1].pretty_print())
102
  else:
103
- st.warning("No tool calls available for manual updates.")
104
- else:
105
- st.error("API keys are missing! Please set `OPENAI_API_KEY` and `TAVILY_API_KEY` in Hugging Face Spaces Secrets.")
 
11
 
12
  # Streamlit UI Header
13
  st.title("Checkpoints and Breakpoints")
14
+ st.caption("Demonstrating LangGraph workflow execution with interruptions and tool invocation.")
15
 
16
  # Fetch API Keys
17
  openai_api_key = os.getenv("OPENAI_API_KEY")
18
  tavily_api_key = os.getenv("TAVILY_API_KEY")
19
 
20
+ if not openai_api_key or not tavily_api_key:
21
+ st.error("API keys are missing! Set OPENAI_API_KEY and TAVILY_API_KEY in Hugging Face Spaces Secrets.")
22
+ st.stop()
23
+
24
+ os.environ["OPENAI_API_KEY"] = openai_api_key
25
+ os.environ["TAVILY_API_KEY"] = tavily_api_key
26
+
27
+ # Define State Class
28
+ class State(TypedDict):
29
+ messages: Annotated[list, add_messages]
30
+
31
+ # Initialize LLM and Tools
32
+ llm = ChatOpenAI(model="gpt-4o-mini")
33
+ tool = TavilySearchResults(max_results=2)
34
+ tools = [tool]
35
+ llm_with_tools = llm.bind_tools(tools)
36
+
37
+ # Agent Node
38
+ def Agent(state: State):
39
+ st.sidebar.write("Agent received input:", state["messages"])
40
+ response = llm_with_tools.invoke(state["messages"])
41
+ st.sidebar.write("Agent Response:", response)
42
+ return {"messages": [response]}
43
+
44
+ # Set up Graph
45
+ memory = MemorySaver()
46
+ graph = StateGraph(State)
47
+
48
+ # Add nodes
49
+ graph.add_node("Agent", Agent)
50
+ tool_node = ToolNode(tools=[tool])
51
+ graph.add_node("tools", tool_node)
52
+
53
+ # Add edges
54
+ graph.add_conditional_edges("Agent", tools_condition)
55
+ graph.add_edge("tools", "Agent")
56
+ graph.set_entry_point("Agent")
57
+
58
+ # Compile with Breakpoint
59
+ app = graph.compile(checkpointer=memory, interrupt_before=["tools"])
60
+
61
+ # Display Graph Visualization
62
+ st.subheader("Graph Visualization")
63
+ st.image(app.get_graph().draw_mermaid_png(), caption="Workflow Graph", use_container_width=True)
64
+
65
+ # Input Section
66
+ st.subheader("Run the Workflow")
67
+ user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh")
68
+ thread_id = st.text_input("Thread ID", "1")
69
+
70
+ if st.button("Execute Workflow"):
71
+ thread = {"configurable": {"thread_id": thread_id}}
72
+ input_message = {"messages": [HumanMessage(content=user_input)]}
73
+
74
+ st.write("### Execution Outputs")
75
+ outputs = []
76
+
77
+ # Execute the workflow
78
+ try:
79
  for event in app.stream(input_message, thread, stream_mode="values"):
80
  st.code(event["messages"][-1].content)
81
  outputs.append(event["messages"][-1].content)
 
82
 
83
+ # Display Intermediate Outputs
84
  if outputs:
85
  st.subheader("Intermediate Outputs")
86
  for idx, output in enumerate(outputs, start=1):
87
  st.write(f"**Step {idx}:**")
88
  st.code(output)
89
  else:
90
+ st.warning("No outputs generated yet.")
91
 
92
+ # Show State Snapshot
93
  st.subheader("Current State Snapshot")
94
  snapshot = app.get_state(thread)
95
  current_message = snapshot.values["messages"][-1]
96
  st.code(current_message.pretty_print())
97
 
98
+ # Handle Tool Calls with Manual Input
99
  if hasattr(current_message, "tool_calls") and current_message.tool_calls:
100
  tool_call_id = current_message.tool_calls[0]["id"]
101
+ st.warning("Execution paused before tool execution. Provide manual input to resume.")
102
+ manual_response = st.text_area("Manual Tool Response", "Enter the tool's response here...")
103
+ if st.button("Resume Execution"):
104
  new_messages = [
105
  ToolMessage(content=manual_response, tool_call_id=tool_call_id),
106
  AIMessage(content=manual_response),
107
  ]
108
  app.update_state(thread, {"messages": new_messages})
109
+ st.success("State updated! Rerun the workflow to continue.")
110
  st.code(app.get_state(thread).values["messages"][-1].pretty_print())
111
  else:
112
+ st.info("No tool calls detected at this step.")
113
+ except Exception as e:
114
+ st.error(f"Error during execution: {e}")