Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from langgraph.graph import StateGraph, END
|
| 3 |
-
from langgraph.checkpoint.
|
| 4 |
-
from langgraph.checkpoint.sqlite import SqliteSaver
|
| 5 |
import operator
|
| 6 |
from typing import TypedDict, Annotated, Optional
|
| 7 |
import uuid
|
| 8 |
import os
|
|
|
|
| 9 |
|
| 10 |
# --- 1. Define your Graph State ---
|
| 11 |
class GraphState(TypedDict):
|
| 12 |
greeting_message: str
|
| 13 |
-
human_input: Annotated[Optional[str], operator.add]
|
| 14 |
final_response: str
|
| 15 |
current_node: str
|
| 16 |
|
|
@@ -24,11 +24,11 @@ def greeting_node(state: GraphState) -> GraphState:
|
|
| 24 |
def human_input_node(state: GraphState) -> GraphState:
|
| 25 |
"""
|
| 26 |
Node 2: This node processes the human input that was added to the state
|
| 27 |
-
|
| 28 |
"""
|
| 29 |
human_response = state.get("human_input", "No human input found in state when human_input_node ran.")
|
| 30 |
print(f"\n✋ Human Input Node Executed (Processing input: '{human_response}'):")
|
| 31 |
-
#
|
| 32 |
return {"current_node": "human_input_processed"}
|
| 33 |
|
| 34 |
def human_response_display_node(state: GraphState) -> GraphState:
|
|
@@ -41,27 +41,44 @@ def human_response_display_node(state: GraphState) -> GraphState:
|
|
| 41 |
# --- 3. Build the Graph ---
|
| 42 |
builder = StateGraph(GraphState)
|
| 43 |
builder.add_node("greeting", greeting_node)
|
| 44 |
-
builder.add_node("human_input_interrupt", human_input_node)
|
| 45 |
builder.add_node("human_response_display", human_response_display_node)
|
| 46 |
|
| 47 |
builder.set_entry_point("greeting")
|
| 48 |
-
builder.add_edge("greeting", "human_input_interrupt")
|
| 49 |
builder.add_edge("human_response_display", END)
|
| 50 |
|
| 51 |
-
#
|
| 52 |
-
|
| 53 |
-
#global_memory_saver = MemorySaver()
|
| 54 |
-
|
| 55 |
-
SQLITE_DB_PATH = "langgraph_checkpoints.sqlite" # This will be a file in your Space
|
| 56 |
-
global_memory_saver = SqliteSaver.from_conn_string(SQLITE_DB_PATH)
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
global_graph = builder.compile(
|
| 60 |
-
checkpointer=global_memory_saver,
|
| 61 |
-
# CRITICAL CHANGE: Interrupt BEFORE human_input_interrupt
|
| 62 |
-
# This means the graph will pause *before* executing the human_input_node
|
| 63 |
-
interrupt_before=["human_input_interrupt"]
|
| 64 |
-
)
|
| 65 |
|
| 66 |
# --- Gradio UI Logic ---
|
| 67 |
current_thread_id = gr.State("")
|
|
@@ -70,9 +87,10 @@ def start_graph(thread_id_state):
|
|
| 70 |
new_thread_id = str(uuid.uuid4())
|
| 71 |
print(f"\n--- Starting New Graph Execution with thread_id: {new_thread_id} ---")
|
| 72 |
|
|
|
|
|
|
|
|
|
|
| 73 |
try:
|
| 74 |
-
# CRITICAL FIX: Pass an empty dictionary {} as input for the first stream call.
|
| 75 |
-
# This tells LangGraph to start from the entry point with an initial empty state.
|
| 76 |
for s in global_graph.stream({}, {"configurable": {"thread_id": new_thread_id}}):
|
| 77 |
if "__end__" in s:
|
| 78 |
break
|
|
@@ -97,15 +115,18 @@ def resume_graph(human_input_from_ui: str, thread_id_state):
|
|
| 97 |
print(f"\n--- Resuming Graph Execution for thread_id: {thread_id_state} ---")
|
| 98 |
print(f"Human input received from UI: {human_input_from_ui}")
|
| 99 |
|
|
|
|
|
|
|
|
|
|
| 100 |
try:
|
| 101 |
current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": thread_id_state}})
|
| 102 |
current_state_values = current_state_snapshot.values
|
| 103 |
|
| 104 |
current_state_values["human_input"] = human_input_from_ui
|
| 105 |
|
|
|
|
| 106 |
global_memory_saver.put_state(current_state_values, {"configurable": {"thread_id": thread_id_state}})
|
| 107 |
|
| 108 |
-
# Resume the graph with None as input after manually updating the state.
|
| 109 |
for s in global_graph.stream(None, {"configurable": {"thread_id": thread_id_state}}):
|
| 110 |
if "__end__" in s:
|
| 111 |
break
|
|
@@ -124,8 +145,9 @@ def resume_graph(human_input_from_ui: str, thread_id_state):
|
|
| 124 |
|
| 125 |
# Gradio Interface setup
|
| 126 |
with gr.Blocks() as demo:
|
| 127 |
-
gr.Markdown("# LangGraph Human-in-the-Loop Demo (
|
| 128 |
-
gr.Markdown("Graph state will be saved persistently in
|
|
|
|
| 129 |
|
| 130 |
output_textbox = gr.Textbox(label="AI Assistant Output", lines=5, interactive=False)
|
| 131 |
human_input_textbox = gr.Textbox(label="Your Input", placeholder="Type your response here...", interactive=False)
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from langgraph.graph import StateGraph, END
|
| 3 |
+
from langgraph.checkpoint.sqlite import SqliteSaver # Import SqliteSaver
|
|
|
|
| 4 |
import operator
|
| 5 |
from typing import TypedDict, Annotated, Optional
|
| 6 |
import uuid
|
| 7 |
import os
|
| 8 |
+
import sqlite3 # Import sqlite3 for direct connection if needed for older versions
|
| 9 |
|
| 10 |
# --- 1. Define your Graph State ---
|
| 11 |
class GraphState(TypedDict):
|
| 12 |
greeting_message: str
|
| 13 |
+
human_input: Annotated[Optional[str], operator.add] # operator.add for merging
|
| 14 |
final_response: str
|
| 15 |
current_node: str
|
| 16 |
|
|
|
|
| 24 |
def human_input_node(state: GraphState) -> GraphState:
|
| 25 |
"""
|
| 26 |
Node 2: This node processes the human input that was added to the state
|
| 27 |
+
via the `put_state` method during the resume phase. It does NOT contain `input()`.
|
| 28 |
"""
|
| 29 |
human_response = state.get("human_input", "No human input found in state when human_input_node ran.")
|
| 30 |
print(f"\n✋ Human Input Node Executed (Processing input: '{human_response}'):")
|
| 31 |
+
# This node could perform validation or further processing using human_response.
|
| 32 |
return {"current_node": "human_input_processed"}
|
| 33 |
|
| 34 |
def human_response_display_node(state: GraphState) -> GraphState:
|
|
|
|
| 41 |
# --- 3. Build the Graph ---
|
| 42 |
builder = StateGraph(GraphState)
|
| 43 |
builder.add_node("greeting", greeting_node)
|
| 44 |
+
builder.add_node("human_input_interrupt", human_input_node)
|
| 45 |
builder.add_node("human_response_display", human_response_display_node)
|
| 46 |
|
| 47 |
builder.set_entry_point("greeting")
|
| 48 |
+
builder.add_edge("greeting", "human_input_interrupt")
|
| 49 |
builder.add_edge("human_response_display", END)
|
| 50 |
|
| 51 |
+
# Define the path for the SQLite database file
|
| 52 |
+
SQLITE_DB_PATH = "langgraph_checkpoints.sqlite"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
+
# --- Checkpointer and Graph Compilation ---
|
| 55 |
+
# Initialize global_graph and global_memory_saver
|
| 56 |
+
# The recommended way is to establish the connection once and keep the saver instance.
|
| 57 |
+
# For SqliteSaver, if from_conn_string is a context manager, you need to use `with`.
|
| 58 |
+
# A simpler approach that generally works better for global scope in web apps:
|
| 59 |
+
# Establish a direct SQLite connection and pass it to SqliteSaver.
|
| 60 |
+
try:
|
| 61 |
+
# Use check_same_thread=False for Gradio/web apps that might access the DB from different threads
|
| 62 |
+
# or if the connection is managed globally.
|
| 63 |
+
conn = sqlite3.connect(SQLITE_DB_PATH, check_same_thread=False)
|
| 64 |
+
global_memory_saver = SqliteSaver(conn=conn)
|
| 65 |
+
# Ensure tables are created if they don't exist.
|
| 66 |
+
# SqliteSaver usually does this on first use or you can explicitly call a setup if available.
|
| 67 |
+
# In recent versions, just passing the connection is often enough for setup.
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"Error initializing SqliteSaver connection: {e}")
|
| 70 |
+
# Fallback or error handling if DB connection fails
|
| 71 |
+
global_memory_saver = None # Or exit the application
|
| 72 |
+
|
| 73 |
+
if global_memory_saver:
|
| 74 |
+
global_graph = builder.compile(
|
| 75 |
+
checkpointer=global_memory_saver,
|
| 76 |
+
interrupt_before=["human_input_interrupt"]
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
print("SqliteSaver not initialized, graph compilation skipped.")
|
| 80 |
+
global_graph = None # Handle the case where checkpointer couldn't be initialized
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
# --- Gradio UI Logic ---
|
| 84 |
current_thread_id = gr.State("")
|
|
|
|
| 87 |
new_thread_id = str(uuid.uuid4())
|
| 88 |
print(f"\n--- Starting New Graph Execution with thread_id: {new_thread_id} ---")
|
| 89 |
|
| 90 |
+
if not global_graph:
|
| 91 |
+
return (f"Error: Graph not initialized. Check server logs.", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
|
| 92 |
+
|
| 93 |
try:
|
|
|
|
|
|
|
| 94 |
for s in global_graph.stream({}, {"configurable": {"thread_id": new_thread_id}}):
|
| 95 |
if "__end__" in s:
|
| 96 |
break
|
|
|
|
| 115 |
print(f"\n--- Resuming Graph Execution for thread_id: {thread_id_state} ---")
|
| 116 |
print(f"Human input received from UI: {human_input_from_ui}")
|
| 117 |
|
| 118 |
+
if not global_graph or not global_memory_saver:
|
| 119 |
+
return (f"Error: Graph or checkpointer not initialized. Check server logs.", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
|
| 120 |
+
|
| 121 |
try:
|
| 122 |
current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": thread_id_state}})
|
| 123 |
current_state_values = current_state_snapshot.values
|
| 124 |
|
| 125 |
current_state_values["human_input"] = human_input_from_ui
|
| 126 |
|
| 127 |
+
# Use put_state() from the SqliteSaver instance
|
| 128 |
global_memory_saver.put_state(current_state_values, {"configurable": {"thread_id": thread_id_state}})
|
| 129 |
|
|
|
|
| 130 |
for s in global_graph.stream(None, {"configurable": {"thread_id": thread_id_state}}):
|
| 131 |
if "__end__" in s:
|
| 132 |
break
|
|
|
|
| 145 |
|
| 146 |
# Gradio Interface setup
|
| 147 |
with gr.Blocks() as demo:
|
| 148 |
+
gr.Markdown("# LangGraph Human-in-the-Loop Demo (SqliteSaver)")
|
| 149 |
+
gr.Markdown(f"Graph state will be saved persistently in `{SQLITE_DB_PATH}`. Click 'Start Conversation' to begin. Input is taken via Gradio UI.")
|
| 150 |
+
|
| 151 |
|
| 152 |
output_textbox = gr.Textbox(label="AI Assistant Output", lines=5, interactive=False)
|
| 153 |
human_input_textbox = gr.Textbox(label="Your Input", placeholder="Type your response here...", interactive=False)
|