nikhmr1235 commited on
Commit
94ababa
·
verified ·
1 Parent(s): ea6e91e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -24
app.py CHANGED
@@ -1,16 +1,16 @@
1
  import gradio as gr
2
  from langgraph.graph import StateGraph, END
3
- from langgraph.checkpoint.memory import MemorySaver
4
- from langgraph.checkpoint.sqlite import SqliteSaver
5
  import operator
6
  from typing import TypedDict, Annotated, Optional
7
  import uuid
8
  import os
 
9
 
10
  # --- 1. Define your Graph State ---
11
  class GraphState(TypedDict):
12
  greeting_message: str
13
- human_input: Annotated[Optional[str], operator.add]
14
  final_response: str
15
  current_node: str
16
 
@@ -24,11 +24,11 @@ def greeting_node(state: GraphState) -> GraphState:
24
  def human_input_node(state: GraphState) -> GraphState:
25
  """
26
  Node 2: This node processes the human input that was added to the state
27
- before its execution (during the resume phase via put_state).
28
  """
29
  human_response = state.get("human_input", "No human input found in state when human_input_node ran.")
30
  print(f"\n✋ Human Input Node Executed (Processing input: '{human_response}'):")
31
- # You could add validation or further processing of state['human_input'] here if needed.
32
  return {"current_node": "human_input_processed"}
33
 
34
  def human_response_display_node(state: GraphState) -> GraphState:
@@ -41,27 +41,44 @@ def human_response_display_node(state: GraphState) -> GraphState:
41
  # --- 3. Build the Graph ---
42
  builder = StateGraph(GraphState)
43
  builder.add_node("greeting", greeting_node)
44
- builder.add_node("human_input_interrupt", human_input_node) # Node name remains for clarity
45
  builder.add_node("human_response_display", human_response_display_node)
46
 
47
  builder.set_entry_point("greeting")
48
- builder.add_edge("greeting", "human_input_interrupt") # Still connects normally
49
  builder.add_edge("human_response_display", END)
50
 
51
- # --- Checkpointer and Graph Compilation ---
52
- # Using MemorySaver as requested
53
- #global_memory_saver = MemorySaver()
54
-
55
- SQLITE_DB_PATH = "langgraph_checkpoints.sqlite" # This will be a file in your Space
56
- global_memory_saver = SqliteSaver.from_conn_string(SQLITE_DB_PATH)
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- global_graph = builder.compile(
60
- checkpointer=global_memory_saver,
61
- # CRITICAL CHANGE: Interrupt BEFORE human_input_interrupt
62
- # This means the graph will pause *before* executing the human_input_node
63
- interrupt_before=["human_input_interrupt"]
64
- )
65
 
66
  # --- Gradio UI Logic ---
67
  current_thread_id = gr.State("")
@@ -70,9 +87,10 @@ def start_graph(thread_id_state):
70
  new_thread_id = str(uuid.uuid4())
71
  print(f"\n--- Starting New Graph Execution with thread_id: {new_thread_id} ---")
72
 
 
 
 
73
  try:
74
- # CRITICAL FIX: Pass an empty dictionary {} as input for the first stream call.
75
- # This tells LangGraph to start from the entry point with an initial empty state.
76
  for s in global_graph.stream({}, {"configurable": {"thread_id": new_thread_id}}):
77
  if "__end__" in s:
78
  break
@@ -97,15 +115,18 @@ def resume_graph(human_input_from_ui: str, thread_id_state):
97
  print(f"\n--- Resuming Graph Execution for thread_id: {thread_id_state} ---")
98
  print(f"Human input received from UI: {human_input_from_ui}")
99
 
 
 
 
100
  try:
101
  current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": thread_id_state}})
102
  current_state_values = current_state_snapshot.values
103
 
104
  current_state_values["human_input"] = human_input_from_ui
105
 
 
106
  global_memory_saver.put_state(current_state_values, {"configurable": {"thread_id": thread_id_state}})
107
 
108
- # Resume the graph with None as input after manually updating the state.
109
  for s in global_graph.stream(None, {"configurable": {"thread_id": thread_id_state}}):
110
  if "__end__" in s:
111
  break
@@ -124,8 +145,9 @@ def resume_graph(human_input_from_ui: str, thread_id_state):
124
 
125
  # Gradio Interface setup
126
  with gr.Blocks() as demo:
127
- gr.Markdown("# LangGraph Human-in-the-Loop Demo (MemorySaver)")
128
- gr.Markdown("Graph state will be saved persistently in **memory** for the current session. Click 'Start Conversation' to begin. Input is taken via Gradio UI.")
 
129
 
130
  output_textbox = gr.Textbox(label="AI Assistant Output", lines=5, interactive=False)
131
  human_input_textbox = gr.Textbox(label="Your Input", placeholder="Type your response here...", interactive=False)
 
1
  import gradio as gr
2
  from langgraph.graph import StateGraph, END
3
+ from langgraph.checkpoint.sqlite import SqliteSaver # Import SqliteSaver
 
4
  import operator
5
  from typing import TypedDict, Annotated, Optional
6
  import uuid
7
  import os
8
+ import sqlite3 # Import sqlite3 for direct connection if needed for older versions
9
 
10
  # --- 1. Define your Graph State ---
11
  class GraphState(TypedDict):
12
  greeting_message: str
13
+ human_input: Annotated[Optional[str], operator.add] # operator.add for merging
14
  final_response: str
15
  current_node: str
16
 
 
24
  def human_input_node(state: GraphState) -> GraphState:
25
  """
26
  Node 2: This node processes the human input that was added to the state
27
+ via the `put_state` method during the resume phase. It does NOT contain `input()`.
28
  """
29
  human_response = state.get("human_input", "No human input found in state when human_input_node ran.")
30
  print(f"\n✋ Human Input Node Executed (Processing input: '{human_response}'):")
31
+ # This node could perform validation or further processing using human_response.
32
  return {"current_node": "human_input_processed"}
33
 
34
  def human_response_display_node(state: GraphState) -> GraphState:
 
41
  # --- 3. Build the Graph ---
42
  builder = StateGraph(GraphState)
43
  builder.add_node("greeting", greeting_node)
44
+ builder.add_node("human_input_interrupt", human_input_node)
45
  builder.add_node("human_response_display", human_response_display_node)
46
 
47
  builder.set_entry_point("greeting")
48
+ builder.add_edge("greeting", "human_input_interrupt")
49
  builder.add_edge("human_response_display", END)
50
 
51
+ # Define the path for the SQLite database file
52
+ SQLITE_DB_PATH = "langgraph_checkpoints.sqlite"
 
 
 
 
53
 
54
+ # --- Checkpointer and Graph Compilation ---
55
+ # Initialize global_graph and global_memory_saver
56
+ # The recommended way is to establish the connection once and keep the saver instance.
57
+ # For SqliteSaver, if from_conn_string is a context manager, you need to use `with`.
58
+ # A simpler approach that generally works better for global scope in web apps:
59
+ # Establish a direct SQLite connection and pass it to SqliteSaver.
60
+ try:
61
+ # Use check_same_thread=False for Gradio/web apps that might access the DB from different threads
62
+ # or if the connection is managed globally.
63
+ conn = sqlite3.connect(SQLITE_DB_PATH, check_same_thread=False)
64
+ global_memory_saver = SqliteSaver(conn=conn)
65
+ # Ensure tables are created if they don't exist.
66
+ # SqliteSaver usually does this on first use or you can explicitly call a setup if available.
67
+ # In recent versions, just passing the connection is often enough for setup.
68
+ except Exception as e:
69
+ print(f"Error initializing SqliteSaver connection: {e}")
70
+ # Fallback or error handling if DB connection fails
71
+ global_memory_saver = None # Or exit the application
72
+
73
+ if global_memory_saver:
74
+ global_graph = builder.compile(
75
+ checkpointer=global_memory_saver,
76
+ interrupt_before=["human_input_interrupt"]
77
+ )
78
+ else:
79
+ print("SqliteSaver not initialized, graph compilation skipped.")
80
+ global_graph = None # Handle the case where checkpointer couldn't be initialized
81
 
 
 
 
 
 
 
82
 
83
  # --- Gradio UI Logic ---
84
  current_thread_id = gr.State("")
 
87
  new_thread_id = str(uuid.uuid4())
88
  print(f"\n--- Starting New Graph Execution with thread_id: {new_thread_id} ---")
89
 
90
+ if not global_graph:
91
+ return (f"Error: Graph not initialized. Check server logs.", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
92
+
93
  try:
 
 
94
  for s in global_graph.stream({}, {"configurable": {"thread_id": new_thread_id}}):
95
  if "__end__" in s:
96
  break
 
115
  print(f"\n--- Resuming Graph Execution for thread_id: {thread_id_state} ---")
116
  print(f"Human input received from UI: {human_input_from_ui}")
117
 
118
+ if not global_graph or not global_memory_saver:
119
+ return (f"Error: Graph or checkpointer not initialized. Check server logs.", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
120
+
121
  try:
122
  current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": thread_id_state}})
123
  current_state_values = current_state_snapshot.values
124
 
125
  current_state_values["human_input"] = human_input_from_ui
126
 
127
+ # Use put_state() from the SqliteSaver instance
128
  global_memory_saver.put_state(current_state_values, {"configurable": {"thread_id": thread_id_state}})
129
 
 
130
  for s in global_graph.stream(None, {"configurable": {"thread_id": thread_id_state}}):
131
  if "__end__" in s:
132
  break
 
145
 
146
  # Gradio Interface setup
147
  with gr.Blocks() as demo:
148
+ gr.Markdown("# LangGraph Human-in-the-Loop Demo (SqliteSaver)")
149
+ gr.Markdown(f"Graph state will be saved persistently in `{SQLITE_DB_PATH}`. Click 'Start Conversation' to begin. Input is taken via Gradio UI.")
150
+
151
 
152
  output_textbox = gr.Textbox(label="AI Assistant Output", lines=5, interactive=False)
153
  human_input_textbox = gr.Textbox(label="Your Input", placeholder="Type your response here...", interactive=False)