nikhmr1235 commited on
Commit
132bd32
·
verified ·
1 Parent(s): 7ad989d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -99
app.py CHANGED
@@ -1,17 +1,19 @@
1
  import gradio as gr
2
  from langgraph.graph import StateGraph, END
3
- from langgraph.checkpoint.memory import MemorySaver
4
  import operator
5
  from typing import TypedDict, Annotated, Optional
 
 
6
 
7
  # --- 1. Define your Graph State (Same as before) ---
8
  class GraphState(TypedDict):
9
  greeting_message: str
10
- human_input: Annotated[Optional[str], operator.add] # Make human_input optional for initial state
11
  final_response: str
12
- current_node: str # To track which node was last executed/interrupted at
13
 
14
- # --- 2. Define the Nodes (Slightly modified for Gradio interaction) ---
15
  def greeting_node(state: GraphState) -> GraphState:
16
  greeting = "Hello there! I'm an AI assistant. How can I help you today?"
17
  print("🤖 Greeting Node Executed:")
@@ -19,12 +21,9 @@ def greeting_node(state: GraphState) -> GraphState:
19
  return {"greeting_message": greeting, "current_node": "greeting"}
20
 
21
  def human_input_node(state: GraphState) -> GraphState:
22
- # In a Gradio context, the actual `input()` will come from the UI
23
- # This node primarily serves as the interruption point and where the UI input
24
- # will be injected into the state when resuming.
25
  print("\n✋ Human Input Node Executed (Graph was resumed with input):")
26
- # We won't call input() here directly, Gradio will provide it
27
- return {"current_node": "human_input_interrupt"}
28
 
29
  def human_response_display_node(state: GraphState) -> GraphState:
30
  human_response = state.get("human_input", "No human input received.")
@@ -43,126 +42,85 @@ builder.set_entry_point("greeting")
43
  builder.add_edge("greeting", "human_input_interrupt")
44
  builder.add_edge("human_response_display", END)
45
 
46
- # --- 4. Configure the interruption and checkpointer ---
47
- memory = MemorySaver()
48
- thread_id = "user_conversation_gradio" # A fixed thread ID for this demo
 
49
 
50
- # Compile the graph with checkpointer and interrupt_after for human_input_interrupt
51
- # This means the human_input_node will run, take input, then interrupt.
52
- graph = builder.compile(
53
- checkpointer=memory,
 
 
54
  interrupt_after=["human_input_interrupt"]
55
  )
56
 
57
- # --- 5. Gradio UI Logic ---
58
 
59
- # Global variable to store the current thread state, for simplicity in this demo
60
- # In a real app, you'd manage this for multiple users/sessions
61
- current_thread_state = {}
62
 
63
- def start_graph():
64
- global current_thread_state
 
 
65
 
66
- # Clear any previous state for a fresh start
67
- memory.drop(config={"configurable": {"thread_id": thread_id}})
68
- current_thread_state = {}
69
 
70
- print("\n--- Starting New Graph Execution ---")
71
-
72
- # Run the graph until the first interruption point
73
- # We pass an initial empty state; the greeting_node will populate it
74
  try:
75
- # Stream the graph, it will run 'greeting' then 'human_input_interrupt'
76
- # and interrupt *after* 'human_input_interrupt'
77
- for s in graph.stream({"greeting_message": ""}, {"configurable": {"thread_id": thread_id}}):
78
- # The 's' here represents the output of each node.
79
- # When an interrupt occurs, the last yielded item will be '__interrupt__':()
80
  if "__end__" in s:
81
- # Graph finished immediately, which means no interruption happened or it was very quick
82
  break
83
  elif "__interrupt__" in s:
84
  print(f"Graph interrupted at {s.get('__interrupt__', 'Unknown')}")
85
- break # Break the loop as soon as an interruption is detected
86
  else:
87
- # Process regular node output, if any before interrupt
88
  pass
89
 
90
- # After the stream, get the current state from the checkpointer
91
- current_state_snapshot = graph.get_state({"configurable": {"thread_id": thread_id}})
92
- current_thread_state = current_state_snapshot.values
93
 
94
- # Display the greeting message from the state
95
- output_message = current_thread_state.get("greeting_message", "No greeting yet.")
96
- output_message += "\n" + current_thread_state.get("human_input", "Please provide your input to continue.")
97
-
98
- # Enable the input box and resume button, disable start
99
- return output_message, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False)
100
 
101
- except Exception as e:
102
- return f"An error occurred during graph start: {e}", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True)
103
 
104
- def resume_graph(human_input_from_ui: str):
105
- global current_thread_state
 
106
 
107
- print("\n--- Resuming Graph Execution ---")
108
- print(f"Received human input from UI: {human_input_from_ui}")
 
109
 
110
  try:
111
- # Resume the graph with the human input from Gradio
112
- # The graph will pick up from where it left off (after human_input_interrupt)
113
- # The human_input_node itself already captured the input, so this input is just
114
- # confirming/triggering the next step, but if using interrupt_before, this is where
115
- # you'd pass the actual input. For interrupt_after, the input is already in state.
116
-
117
- # However, for consistency and to ensure flow, we re-run and pass the input
118
- # which will effectively update the 'human_input' again (harmless due to operator.add)
119
- # or simply trigger the next steps if human_input_interrupt already ran.
120
-
121
- # We need to explicitly re-provide the human_input from the UI
122
- # in case the human_input_node itself wasn't designed to wait for input *after* its execution.
123
- # Given our current human_input_node structure (it calls input()),
124
- # we'll let it call input(), and this resume just triggers the graph forward.
125
-
126
- # If human_input_node was only a marker, you'd pass:
127
- # for s in graph.stream({"human_input": human_input_from_ui}, {"configurable": {"thread_id": thread_id}}):
128
-
129
- # But since human_input_node has the input() call, let's just resume
130
- # to process the *next* step assuming human_input_node *already* updated the state
131
- # correctly during the initial interrupted run.
132
- # If the input was taken inside the node that interrupts *after*, the state is already there.
133
- # So we just trigger continuation.
134
-
135
- # However, a cleaner way for interrupt_after is often:
136
- # Node takes input -> Node updates state -> Interrupt -> Resume (no new input needed)
137
- # If input was taken *outside* the node (interrupt_before), then resume provides it.
138
-
139
- # Given the previous successful run with interrupt_after, the human_input_node
140
- # already captured the input. We just need to drive the graph forward.
141
- # LangGraph automatically loads the state from the checkpointer based on thread_id
142
- for s in graph.stream(None, {"configurable": {"thread_id": thread_id}}): # Pass None as input to just continue
143
  if "__end__" in s:
144
  break
145
  else:
146
- pass # Process intermediate outputs if needed
147
 
148
- # Get the final state
149
- final_state_snapshot = graph.get_state({"configurable": {"thread_id": thread_id}})
150
- current_thread_state = final_state_snapshot.values
151
 
152
- final_message = current_thread_state.get("final_response", "Graph finished.")
153
- # Disable input and resume, enable start for a new conversation
154
- return final_message, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True)
155
 
156
  except Exception as e:
157
- return f"An error occurred during graph resumption: {e}", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True)
 
158
 
159
  # Gradio Interface setup
160
  with gr.Blocks() as demo:
161
- gr.Markdown("# LangGraph Human Approval Demo")
162
- gr.Markdown("Click 'Start Conversation' to begin. The graph will interrupt for your input.")
163
 
164
  output_textbox = gr.Textbox(label="AI Assistant Output", lines=5, interactive=False)
165
- human_input_textbox = gr.Textbox(label="Your Input", placeholder="Type your response here...", interactive=False)
 
 
166
 
167
  with gr.Row():
168
  start_button = gr.Button("Start Conversation")
@@ -170,14 +128,14 @@ with gr.Blocks() as demo:
170
 
171
  start_button.click(
172
  start_graph,
173
- inputs=[],
174
- outputs=[output_textbox, human_input_textbox, resume_button, start_button]
175
  )
176
 
177
  resume_button.click(
178
  resume_graph,
179
- inputs=[human_input_textbox],
180
- outputs=[output_textbox, human_input_textbox, resume_button, start_button]
181
  )
182
 
183
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  from langgraph.graph import StateGraph, END
3
+ from langgraph.checkpoint.sqlite import SqliteSaver # Changed import
4
  import operator
5
  from typing import TypedDict, Annotated, Optional
6
+ import uuid
7
+ import os # Import os for path handling
8
 
9
  # --- 1. Define your Graph State (Same as before) ---
10
  class GraphState(TypedDict):
11
  greeting_message: str
12
+ human_input: Annotated[Optional[str], operator.add]
13
  final_response: str
14
+ current_node: str
15
 
16
+ # --- 2. Define the Nodes (Same as before) ---
17
  def greeting_node(state: GraphState) -> GraphState:
18
  greeting = "Hello there! I'm an AI assistant. How can I help you today?"
19
  print("🤖 Greeting Node Executed:")
 
21
  return {"greeting_message": greeting, "current_node": "greeting"}
22
 
23
  def human_input_node(state: GraphState) -> GraphState:
 
 
 
24
  print("\n✋ Human Input Node Executed (Graph was resumed with input):")
25
+ user_response = input("Please enter your response (e.g., 'I want to know about LangChain'): ")
26
+ return {"human_input": user_response, "current_node": "human_input_interrupt"}
27
 
28
  def human_response_display_node(state: GraphState) -> GraphState:
29
  human_response = state.get("human_input", "No human input received.")
 
42
  builder.add_edge("greeting", "human_input_interrupt")
43
  builder.add_edge("human_response_display", END)
44
 
45
+ # --- Checkpointer and Graph Compilation ---
46
+ # Define a persistent SQLite database file
47
+ # It's good practice to place it in a dedicated directory or the project root.
48
+ SQLITE_DB_PATH = "langgraph_checkpoints.sqlite" # This file will be created/used
49
 
50
+ # Initialize SqliteSaver with the file path
51
+ global_memory_saver = SqliteSaver.from_conn_string(SQLITE_DB_PATH)
52
+
53
+ # Global graph instance compiled once
54
+ global_graph = builder.compile(
55
+ checkpointer=global_memory_saver,
56
  interrupt_after=["human_input_interrupt"]
57
  )
58
 
59
+ # --- Gradio UI Logic (Mostly the same, but now truly persistent) ---
60
 
61
+ current_thread_id = gr.State("")
 
 
62
 
63
+ def start_graph(thread_id_state):
64
+ # Generate a new unique thread ID for each new conversation
65
+ new_thread_id = str(uuid.uuid4())
66
+ print(f"\n--- Starting New Graph Execution with thread_id: {new_thread_id} ---")
67
 
68
+ # You might want to explicitly drop previous thread state if it exists for this ID
69
+ # if global_memory_saver.get_state({"configurable": {"thread_id": new_thread_id}}):
70
+ # global_memory_saver.drop(config={"configurable": {"thread_id": new_thread_id}})
71
 
 
 
 
 
72
  try:
73
+ for s in global_graph.stream({"greeting_message": ""}, {"configurable": {"thread_id": new_thread_id}}):
 
 
 
 
74
  if "__end__" in s:
 
75
  break
76
  elif "__interrupt__" in s:
77
  print(f"Graph interrupted at {s.get('__interrupt__', 'Unknown')}")
78
+ break
79
  else:
 
80
  pass
81
 
82
+ current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": new_thread_id}})
 
 
83
 
84
+ output_message = current_state_snapshot.values.get("greeting_message", "No greeting yet.")
85
+ output_message += "\n" + "Please provide your input to continue."
 
 
 
 
86
 
87
+ return (output_message, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False), new_thread_id)
 
88
 
89
+ except Exception as e:
90
+ return (f"An error occurred during graph start: {e}",
91
+ gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
92
 
93
+ def resume_graph(human_input_from_ui: str, thread_id_state):
94
+ print(f"\n--- Resuming Graph Execution for thread_id: {thread_id_state} ---")
95
+ print(f"Received human input from UI (will be ignored if input taken in node): {human_input_from_ui}")
96
 
97
  try:
98
+ # Pass None as input as the human_input_node already handled taking input
99
+ for s in global_graph.stream(None, {"configurable": {"thread_id": thread_id_state}}):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  if "__end__" in s:
101
  break
102
  else:
103
+ pass
104
 
105
+ final_state_snapshot = global_graph.get_state({"configurable": {"thread_id": thread_id_state}})
106
+ final_state_values = final_state_snapshot.values
 
107
 
108
+ final_message = final_state_values.get("final_response", "Graph finished without final response.")
109
+ return (final_message, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
 
110
 
111
  except Exception as e:
112
+ return (f"An error occurred during graph resumption: {e}",
113
+ gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
114
 
115
  # Gradio Interface setup
116
  with gr.Blocks() as demo:
117
+ gr.Markdown("# LangGraph Human Approval Demo (Persistent)")
118
+ gr.Markdown(f"Graph state will be saved persistently in `{SQLITE_DB_PATH}`. Click 'Start Conversation' to begin.")
119
 
120
  output_textbox = gr.Textbox(label="AI Assistant Output", lines=5, interactive=False)
121
+ human_input_textbox = gr.Textbox(label="Your Input (Please enter in console as well)", placeholder="Type your response here...", interactive=True)
122
+
123
+ thread_id_state = gr.State("")
124
 
125
  with gr.Row():
126
  start_button = gr.Button("Start Conversation")
 
128
 
129
  start_button.click(
130
  start_graph,
131
+ inputs=[thread_id_state],
132
+ outputs=[output_textbox, human_input_textbox, resume_button, start_button, thread_id_state]
133
  )
134
 
135
  resume_button.click(
136
  resume_graph,
137
+ inputs=[human_input_textbox, thread_id_state],
138
+ outputs=[output_textbox, human_input_textbox, resume_button, start_button, thread_id_state]
139
  )
140
 
141
  if __name__ == "__main__":