Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,30 +1,39 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from langgraph.graph import StateGraph, END
|
| 3 |
-
from langgraph.checkpoint.sqlite import SqliteSaver
|
| 4 |
from langgraph.checkpoint.memory import MemorySaver
|
| 5 |
import operator
|
| 6 |
from typing import TypedDict, Annotated, Optional
|
| 7 |
import uuid
|
| 8 |
-
import os
|
| 9 |
|
| 10 |
-
# --- 1. Define your Graph State
|
| 11 |
class GraphState(TypedDict):
|
| 12 |
greeting_message: str
|
| 13 |
human_input: Annotated[Optional[str], operator.add]
|
| 14 |
final_response: str
|
| 15 |
current_node: str
|
| 16 |
|
| 17 |
-
# --- 2. Define the Nodes
|
| 18 |
def greeting_node(state: GraphState) -> GraphState:
|
| 19 |
greeting = "Hello there! I'm an AI assistant. How can I help you today?"
|
| 20 |
print("🤖 Greeting Node Executed:")
|
| 21 |
print(greeting)
|
| 22 |
return {"greeting_message": greeting, "current_node": "greeting"}
|
| 23 |
|
|
|
|
| 24 |
def human_input_node(state: GraphState) -> GraphState:
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def human_response_display_node(state: GraphState) -> GraphState:
|
| 30 |
human_response = state.get("human_input", "No human input received.")
|
|
@@ -33,7 +42,7 @@ def human_response_display_node(state: GraphState) -> GraphState:
|
|
| 33 |
print(final_message)
|
| 34 |
return {"final_response": final_message, "current_node": "human_response_display"}
|
| 35 |
|
| 36 |
-
# --- 3. Build the Graph
|
| 37 |
builder = StateGraph(GraphState)
|
| 38 |
builder.add_node("greeting", greeting_node)
|
| 39 |
builder.add_node("human_input_interrupt", human_input_node)
|
|
@@ -44,37 +53,35 @@ builder.add_edge("greeting", "human_input_interrupt")
|
|
| 44 |
builder.add_edge("human_response_display", END)
|
| 45 |
|
| 46 |
# --- Checkpointer and Graph Compilation ---
|
| 47 |
-
#
|
| 48 |
-
# It's good practice to place it in a dedicated directory or the project root.
|
| 49 |
-
#SQLITE_DB_PATH = "langgraph_checkpoints.sqlite" # This file will be created/used
|
| 50 |
-
|
| 51 |
-
# Initialize SqliteSaver with the file path
|
| 52 |
#global_memory_saver = SqliteSaver.from_conn_string(SQLITE_DB_PATH)
|
| 53 |
-
|
| 54 |
global_memory_saver = MemorySaver()
|
| 55 |
|
| 56 |
-
# Global graph instance compiled once
|
| 57 |
global_graph = builder.compile(
|
| 58 |
checkpointer=global_memory_saver,
|
|
|
|
|
|
|
| 59 |
interrupt_after=["human_input_interrupt"]
|
| 60 |
)
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
# --- Gradio UI Logic (Mostly the same, but now truly persistent) ---
|
| 65 |
-
|
| 66 |
current_thread_id = gr.State("")
|
| 67 |
|
| 68 |
def start_graph(thread_id_state):
|
| 69 |
-
# Generate a new unique thread ID for each new conversation
|
| 70 |
new_thread_id = str(uuid.uuid4())
|
| 71 |
print(f"\n--- Starting New Graph Execution with thread_id: {new_thread_id} ---")
|
| 72 |
|
| 73 |
# You might want to explicitly drop previous thread state if it exists for this ID
|
| 74 |
-
# if
|
|
|
|
|
|
|
| 75 |
# global_memory_saver.drop(config={"configurable": {"thread_id": new_thread_id}})
|
|
|
|
|
|
|
| 76 |
|
| 77 |
try:
|
|
|
|
|
|
|
| 78 |
for s in global_graph.stream({"greeting_message": ""}, {"configurable": {"thread_id": new_thread_id}}):
|
| 79 |
if "__end__" in s:
|
| 80 |
break
|
|
@@ -87,8 +94,9 @@ def start_graph(thread_id_state):
|
|
| 87 |
current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": new_thread_id}})
|
| 88 |
|
| 89 |
output_message = current_state_snapshot.values.get("greeting_message", "No greeting yet.")
|
| 90 |
-
output_message += "\n" + "Please
|
| 91 |
|
|
|
|
| 92 |
return (output_message, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False), new_thread_id)
|
| 93 |
|
| 94 |
except Exception as e:
|
|
@@ -97,11 +105,12 @@ def start_graph(thread_id_state):
|
|
| 97 |
|
| 98 |
def resume_graph(human_input_from_ui: str, thread_id_state):
|
| 99 |
print(f"\n--- Resuming Graph Execution for thread_id: {thread_id_state} ---")
|
| 100 |
-
print(f"
|
| 101 |
|
| 102 |
try:
|
| 103 |
-
# Pass
|
| 104 |
-
|
|
|
|
| 105 |
if "__end__" in s:
|
| 106 |
break
|
| 107 |
else:
|
|
@@ -111,6 +120,7 @@ def resume_graph(human_input_from_ui: str, thread_id_state):
|
|
| 111 |
final_state_values = final_state_snapshot.values
|
| 112 |
|
| 113 |
final_message = final_state_values.get("final_response", "Graph finished without final response.")
|
|
|
|
| 114 |
return (final_message, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
|
| 115 |
|
| 116 |
except Exception as e:
|
|
@@ -119,13 +129,12 @@ def resume_graph(human_input_from_ui: str, thread_id_state):
|
|
| 119 |
|
| 120 |
# Gradio Interface setup
|
| 121 |
with gr.Blocks() as demo:
|
| 122 |
-
gr.Markdown("# LangGraph Human Approval Demo (Persistent)")
|
| 123 |
-
|
| 124 |
-
gr.Markdown(f"Graph state will be saved persistently IN-MEMORY. Click 'Start Conversation' to begin.")
|
| 125 |
-
|
| 126 |
|
| 127 |
output_textbox = gr.Textbox(label="AI Assistant Output", lines=5, interactive=False)
|
| 128 |
-
|
|
|
|
| 129 |
|
| 130 |
thread_id_state = gr.State("")
|
| 131 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from langgraph.graph import StateGraph, END
|
| 3 |
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
| 4 |
from langgraph.checkpoint.memory import MemorySaver
|
| 5 |
import operator
|
| 6 |
from typing import TypedDict, Annotated, Optional
|
| 7 |
import uuid
|
| 8 |
+
import os
|
| 9 |
|
| 10 |
+
# --- 1. Define your Graph State ---
|
| 11 |
class GraphState(TypedDict):
|
| 12 |
greeting_message: str
|
| 13 |
human_input: Annotated[Optional[str], operator.add]
|
| 14 |
final_response: str
|
| 15 |
current_node: str
|
| 16 |
|
| 17 |
+
# --- 2. Define the Nodes ---
|
| 18 |
def greeting_node(state: GraphState) -> GraphState:
|
| 19 |
greeting = "Hello there! I'm an AI assistant. How can I help you today?"
|
| 20 |
print("🤖 Greeting Node Executed:")
|
| 21 |
print(greeting)
|
| 22 |
return {"greeting_message": greeting, "current_node": "greeting"}
|
| 23 |
|
| 24 |
+
# IMPORTANT CHANGE HERE: Remove input() from this node!
|
| 25 |
def human_input_node(state: GraphState) -> GraphState:
|
| 26 |
+
"""
|
| 27 |
+
Node 2: This node now simply acts as the interruption point.
|
| 28 |
+
The human input will be provided by the Gradio UI *after* this node runs and interrupts.
|
| 29 |
+
"""
|
| 30 |
+
print("\n✋ Human Input Node Executed (This is where graph will interrupt AFTER running this node):")
|
| 31 |
+
# This node returns its current_node state. The actual human_input will be added
|
| 32 |
+
# to the state during the RESUME phase from the Gradio UI.
|
| 33 |
+
# Since we are using 'interrupt_after' and human_input_node doesn't take input directly,
|
| 34 |
+
# the 'human_input' will only appear in the state after 'resume_graph' sends it.
|
| 35 |
+
return {"current_node": "human_input_interrupt"}
|
| 36 |
+
|
| 37 |
|
| 38 |
def human_response_display_node(state: GraphState) -> GraphState:
|
| 39 |
human_response = state.get("human_input", "No human input received.")
|
|
|
|
| 42 |
print(final_message)
|
| 43 |
return {"final_response": final_message, "current_node": "human_response_display"}
|
| 44 |
|
| 45 |
+
# --- 3. Build the Graph ---
|
| 46 |
builder = StateGraph(GraphState)
|
| 47 |
builder.add_node("greeting", greeting_node)
|
| 48 |
builder.add_node("human_input_interrupt", human_input_node)
|
|
|
|
| 53 |
builder.add_edge("human_response_display", END)
|
| 54 |
|
| 55 |
# --- Checkpointer and Graph Compilation ---
|
| 56 |
+
#SQLITE_DB_PATH = "langgraph_checkpoints.sqlite"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
#global_memory_saver = SqliteSaver.from_conn_string(SQLITE_DB_PATH)
|
|
|
|
| 58 |
global_memory_saver = MemorySaver()
|
| 59 |
|
|
|
|
| 60 |
global_graph = builder.compile(
|
| 61 |
checkpointer=global_memory_saver,
|
| 62 |
+
# Keep interrupt_after for human_input_interrupt.
|
| 63 |
+
# This means human_input_node runs, then graph interrupts.
|
| 64 |
interrupt_after=["human_input_interrupt"]
|
| 65 |
)
|
| 66 |
|
| 67 |
+
# --- Gradio UI Logic ---
|
|
|
|
|
|
|
|
|
|
| 68 |
current_thread_id = gr.State("")
|
| 69 |
|
| 70 |
def start_graph(thread_id_state):
|
|
|
|
| 71 |
new_thread_id = str(uuid.uuid4())
|
| 72 |
print(f"\n--- Starting New Graph Execution with thread_id: {new_thread_id} ---")
|
| 73 |
|
| 74 |
# You might want to explicitly drop previous thread state if it exists for this ID
|
| 75 |
+
# (Uncomment if you want to explicitly clear old data for a potentially re-used UUID,
|
| 76 |
+
# though with uuid4 it's highly unlikely)
|
| 77 |
+
# try:
|
| 78 |
# global_memory_saver.drop(config={"configurable": {"thread_id": new_thread_id}})
|
| 79 |
+
# except Exception as e:
|
| 80 |
+
# print(f"Could not drop previous state for {new_thread_id}: {e}")
|
| 81 |
|
| 82 |
try:
|
| 83 |
+
# Pass an empty input for the first run.
|
| 84 |
+
# The graph will run greeting_node, then human_input_node, then interrupt.
|
| 85 |
for s in global_graph.stream({"greeting_message": ""}, {"configurable": {"thread_id": new_thread_id}}):
|
| 86 |
if "__end__" in s:
|
| 87 |
break
|
|
|
|
| 94 |
current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": new_thread_id}})
|
| 95 |
|
| 96 |
output_message = current_state_snapshot.values.get("greeting_message", "No greeting yet.")
|
| 97 |
+
output_message += "\n\n" + "Please type your response in the 'Your Input' box and click 'Resume Graph'."
|
| 98 |
|
| 99 |
+
# Enable human input textbox and resume button, disable start button
|
| 100 |
return (output_message, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False), new_thread_id)
|
| 101 |
|
| 102 |
except Exception as e:
|
|
|
|
| 105 |
|
| 106 |
def resume_graph(human_input_from_ui: str, thread_id_state):
|
| 107 |
print(f"\n--- Resuming Graph Execution for thread_id: {thread_id_state} ---")
|
| 108 |
+
print(f"Human input received from UI: {human_input_from_ui}")
|
| 109 |
|
| 110 |
try:
|
| 111 |
+
# Pass the human_input_from_ui directly to the graph stream.
|
| 112 |
+
# This input will be merged into the graph's state when it resumes.
|
| 113 |
+
for s in global_graph.stream({"human_input": human_input_from_ui}, {"configurable": {"thread_id": thread_id_state}}):
|
| 114 |
if "__end__" in s:
|
| 115 |
break
|
| 116 |
else:
|
|
|
|
| 120 |
final_state_values = final_state_snapshot.values
|
| 121 |
|
| 122 |
final_message = final_state_values.get("final_response", "Graph finished without final response.")
|
| 123 |
+
# Disable input, resume button, enable start button for new conversation
|
| 124 |
return (final_message, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
|
| 125 |
|
| 126 |
except Exception as e:
|
|
|
|
| 129 |
|
| 130 |
# Gradio Interface setup
|
| 131 |
with gr.Blocks() as demo:
|
| 132 |
+
gr.Markdown("# LangGraph Human Approval Demo (Persistent on Hugging Face)")
|
| 133 |
+
gr.Markdown(f"Graph state will be saved persistently in MEMORY.")
|
|
|
|
|
|
|
| 134 |
|
| 135 |
output_textbox = gr.Textbox(label="AI Assistant Output", lines=5, interactive=False)
|
| 136 |
+
# The human input now comes ONLY from this textbox
|
| 137 |
+
human_input_textbox = gr.Textbox(label="Your Input", placeholder="Type your response here...", interactive=False)
|
| 138 |
|
| 139 |
thread_id_state = gr.State("")
|
| 140 |
|