Spaces:
Sleeping
Sleeping
File size: 7,753 Bytes
7ad989d 94ababa 7ad989d 132bd32 f46b2b1 94ababa 7ad989d f46b2b1 7ad989d 94ababa 7ad989d 132bd32 7ad989d f46b2b1 7ad989d f46b2b1 8177102 94ababa f46b2b1 8177102 94ababa 8177102 7ad989d 8177102 7ad989d f46b2b1 7ad989d 94ababa 7ad989d 94ababa 1c2b585 7ad989d 94ababa ea6e91e 94ababa c7ca645 94ababa c7ca645 94ababa c7ca645 94ababa c7ca645 94ababa c7ca645 94ababa 132bd32 f46b2b1 132bd32 7ad989d 132bd32 7ad989d 94ababa 7ad989d cea9d6a 7ad989d 8177102 132bd32 7ad989d cea9d6a 7ad989d 132bd32 7ad989d 132bd32 f46b2b1 7ad989d 132bd32 7ad989d 132bd32 7ad989d 132bd32 f46b2b1 7ad989d 94ababa 7ad989d 8177102 c7ca645 94ababa 209f479 8177102 209f479 7ad989d cea9d6a 7ad989d 209f479 132bd32 7ad989d 132bd32 7ad989d 132bd32 7ad989d 209f479 7ad989d 94ababa 7ad989d f46b2b1 132bd32 7ad989d 132bd32 7ad989d 132bd32 7ad989d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import gradio as gr
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver # Import SqliteSaver
import operator
from typing import TypedDict, Annotated, Optional
import uuid
import os
import sqlite3 # Import sqlite3 for direct connection if needed for older versions
# --- 1. Define your Graph State ---
class GraphState(TypedDict):
greeting_message: str
human_input: Annotated[Optional[str], operator.add] # operator.add for merging
final_response: str
current_node: str
# --- 2. Define the Nodes ---
def greeting_node(state: GraphState) -> GraphState:
greeting = "Hello there! I'm an AI assistant. How can I help you today?"
print("🤖 Greeting Node Executed:")
print(greeting)
return {"greeting_message": greeting, "current_node": "greeting"}
def human_input_node(state: GraphState) -> GraphState:
"""
Node 2: This node processes the human input that was added to the state
via the `put_state` method during the resume phase. It does NOT contain `input()`.
"""
human_response = state.get("human_input", "No human input found in state when human_input_node ran.")
print(f"\n✋ Human Input Node Executed (Processing input: '{human_response}'):")
# This node could perform validation or further processing using human_response.
return {"current_node": "human_input_processed"}
def human_response_display_node(state: GraphState) -> GraphState:
human_response = state.get("human_input", "No human input received for final display.")
final_message = f"You said: '{human_response}'. Thank you for your input!"
print("\n✅ Human Response Display Node Executed:")
print(final_message)
return {"final_response": final_message, "current_node": "human_response_display"}
# --- 3. Build the Graph ---
builder = StateGraph(GraphState)
builder.add_node("greeting", greeting_node)
builder.add_node("human_input_interrupt", human_input_node)
builder.add_node("human_response_display", human_response_display_node)
builder.set_entry_point("greeting")
builder.add_edge("greeting", "human_input_interrupt")
builder.add_edge("human_input_interrupt", "human_response_display")
builder.add_edge("human_response_display", END)
# Define the path for the SQLite database file
SQLITE_DB_PATH = "langgraph_checkpoints.sqlite"
# --- Checkpointer and Graph Compilation ---
global_memory_saver = None # Initialize to None
try:
# Attempt to connect to SQLite
# Use check_same_thread=False for Gradio/web apps that might access the DB from different threads
conn = sqlite3.connect(SQLITE_DB_PATH, check_same_thread=False)
global_memory_saver = SqliteSaver(conn=conn)
print(f"SqliteSaver initialized successfully: {type(global_memory_saver)}") # ADD THIS LINE
except Exception as e:
print(f"Error initializing SqliteSaver connection: {e}")
# If an error occurs here, global_memory_saver will remain None
if global_memory_saver:
global_graph = builder.compile(
checkpointer=global_memory_saver,
interrupt_before=["human_input_interrupt"]
)
print("Graph compiled successfully.") # ADD THIS LINE
else:
print("SqliteSaver not initialized, graph compilation skipped.")
global_graph = None # Handle the case where checkpointer couldn't be initialized
# --- Gradio UI Logic ---
current_thread_id = gr.State("")
def start_graph(thread_id_state):
new_thread_id = str(uuid.uuid4())
print(f"\n--- Starting New Graph Execution with thread_id: {new_thread_id} ---")
if not global_graph:
return (f"Error: Graph not initialized. Check server logs.", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
try:
for s in global_graph.stream({}, {"configurable": {"thread_id": new_thread_id}}):
if "__end__" in s:
break
elif "__interrupt__" in s:
print(f"Graph interrupted BEFORE {s.get('__interrupt__', 'Unknown')} node.")
break
else:
pass
current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": new_thread_id}})
output_message = current_state_snapshot.values.get("greeting_message", "No greeting yet.")
output_message += "\n\n" + "Please type your response in the 'Your Input' box and click 'Resume Graph'."
return (output_message, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False), new_thread_id)
except Exception as e:
return (f"An error occurred during graph start: {e}",
gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
def resume_graph(human_input_from_ui: str, thread_id_state):
print(f"\n--- Resuming Graph Execution for thread_id: {thread_id_state} ---")
print(f"Human input received from UI: {human_input_from_ui}")
if not global_graph or not global_memory_saver:
return (f"Error: Graph or checkpointer not initialized. Check server logs.", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
try:
current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": thread_id_state}})
current_state_values = current_state_snapshot.values
current_state_values["human_input"] = human_input_from_ui
print(f"Type of global_memory_saver before put_state: {type(global_memory_saver)}") # ADD THIS LINE
# Use put_state() from the SqliteSaver instance
#global_memory_saver.put_state(current_state_values, {"configurable": {"thread_id": thread_id_state}})
# Update the state directly
config = {"configurable": {"thread_id": thread_id_state}}
global_graph.update_state(config, current_state_values)
for s in global_graph.stream(None, config):
if "__end__" in s:
break
else:
pass
final_state_snapshot = global_graph.get_state(config)
final_state_values = final_state_snapshot.values
final_message = final_state_values.get("final_response", "Graph finished without final response.")
return (final_message, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
except Exception as e:
return (f"An error occurred during graph resumption: {e}",
gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
# Gradio Interface setup
with gr.Blocks() as demo:
gr.Markdown("# LangGraph Human-in-the-Loop Demo (SqliteSaver)")
gr.Markdown(f"Graph state will be saved persistently in `{SQLITE_DB_PATH}`. Click 'Start Conversation' to begin. Input is taken via Gradio UI.")
output_textbox = gr.Textbox(label="AI Assistant Output", lines=5, interactive=False)
human_input_textbox = gr.Textbox(label="Your Input", placeholder="Type your response here...", interactive=False)
thread_id_state = gr.State("")
with gr.Row():
start_button = gr.Button("Start Conversation")
resume_button = gr.Button("Resume Graph", interactive=False)
start_button.click(
start_graph,
inputs=[thread_id_state],
outputs=[output_textbox, human_input_textbox, resume_button, start_button, thread_id_state]
)
resume_button.click(
resume_graph,
inputs=[human_input_textbox, thread_id_state],
outputs=[output_textbox, human_input_textbox, resume_button, start_button, thread_id_state]
)
if __name__ == "__main__":
demo.launch() |