nikhmr1235's picture
Update app.py
1c2b585 verified
import gradio as gr
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver # Import SqliteSaver
import operator
from typing import TypedDict, Annotated, Optional
import uuid
import os
import sqlite3 # Import sqlite3 for direct connection if needed for older versions
# --- 1. Define your Graph State ---
class GraphState(TypedDict):
greeting_message: str
human_input: Annotated[Optional[str], operator.add] # operator.add for merging
final_response: str
current_node: str
# --- 2. Define the Nodes ---
def greeting_node(state: GraphState) -> GraphState:
greeting = "Hello there! I'm an AI assistant. How can I help you today?"
print("🤖 Greeting Node Executed:")
print(greeting)
return {"greeting_message": greeting, "current_node": "greeting"}
def human_input_node(state: GraphState) -> GraphState:
"""
Node 2: This node processes the human input that was added to the state
via the `put_state` method during the resume phase. It does NOT contain `input()`.
"""
human_response = state.get("human_input", "No human input found in state when human_input_node ran.")
print(f"\n✋ Human Input Node Executed (Processing input: '{human_response}'):")
# This node could perform validation or further processing using human_response.
return {"current_node": "human_input_processed"}
def human_response_display_node(state: GraphState) -> GraphState:
human_response = state.get("human_input", "No human input received for final display.")
final_message = f"You said: '{human_response}'. Thank you for your input!"
print("\n✅ Human Response Display Node Executed:")
print(final_message)
return {"final_response": final_message, "current_node": "human_response_display"}
# --- 3. Build the Graph ---
builder = StateGraph(GraphState)
builder.add_node("greeting", greeting_node)
builder.add_node("human_input_interrupt", human_input_node)
builder.add_node("human_response_display", human_response_display_node)
builder.set_entry_point("greeting")
builder.add_edge("greeting", "human_input_interrupt")
builder.add_edge("human_input_interrupt", "human_response_display")
builder.add_edge("human_response_display", END)
# Define the path for the SQLite database file
SQLITE_DB_PATH = "langgraph_checkpoints.sqlite"
# --- Checkpointer and Graph Compilation ---
global_memory_saver = None # Initialize to None
try:
# Attempt to connect to SQLite
# Use check_same_thread=False for Gradio/web apps that might access the DB from different threads
conn = sqlite3.connect(SQLITE_DB_PATH, check_same_thread=False)
global_memory_saver = SqliteSaver(conn=conn)
print(f"SqliteSaver initialized successfully: {type(global_memory_saver)}") # ADD THIS LINE
except Exception as e:
print(f"Error initializing SqliteSaver connection: {e}")
# If an error occurs here, global_memory_saver will remain None
if global_memory_saver:
global_graph = builder.compile(
checkpointer=global_memory_saver,
interrupt_before=["human_input_interrupt"]
)
print("Graph compiled successfully.") # ADD THIS LINE
else:
print("SqliteSaver not initialized, graph compilation skipped.")
global_graph = None # Handle the case where checkpointer couldn't be initialized
# --- Gradio UI Logic ---
current_thread_id = gr.State("")
def start_graph(thread_id_state):
new_thread_id = str(uuid.uuid4())
print(f"\n--- Starting New Graph Execution with thread_id: {new_thread_id} ---")
if not global_graph:
return (f"Error: Graph not initialized. Check server logs.", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
try:
for s in global_graph.stream({}, {"configurable": {"thread_id": new_thread_id}}):
if "__end__" in s:
break
elif "__interrupt__" in s:
print(f"Graph interrupted BEFORE {s.get('__interrupt__', 'Unknown')} node.")
break
else:
pass
current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": new_thread_id}})
output_message = current_state_snapshot.values.get("greeting_message", "No greeting yet.")
output_message += "\n\n" + "Please type your response in the 'Your Input' box and click 'Resume Graph'."
return (output_message, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False), new_thread_id)
except Exception as e:
return (f"An error occurred during graph start: {e}",
gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
def resume_graph(human_input_from_ui: str, thread_id_state):
print(f"\n--- Resuming Graph Execution for thread_id: {thread_id_state} ---")
print(f"Human input received from UI: {human_input_from_ui}")
if not global_graph or not global_memory_saver:
return (f"Error: Graph or checkpointer not initialized. Check server logs.", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
try:
current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": thread_id_state}})
current_state_values = current_state_snapshot.values
current_state_values["human_input"] = human_input_from_ui
print(f"Type of global_memory_saver before put_state: {type(global_memory_saver)}") # ADD THIS LINE
# Use put_state() from the SqliteSaver instance
#global_memory_saver.put_state(current_state_values, {"configurable": {"thread_id": thread_id_state}})
# Update the state directly
config = {"configurable": {"thread_id": thread_id_state}}
global_graph.update_state(config, current_state_values)
for s in global_graph.stream(None, config):
if "__end__" in s:
break
else:
pass
final_state_snapshot = global_graph.get_state(config)
final_state_values = final_state_snapshot.values
final_message = final_state_values.get("final_response", "Graph finished without final response.")
return (final_message, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
except Exception as e:
return (f"An error occurred during graph resumption: {e}",
gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)
# Gradio Interface setup
with gr.Blocks() as demo:
gr.Markdown("# LangGraph Human-in-the-Loop Demo (SqliteSaver)")
gr.Markdown(f"Graph state will be saved persistently in `{SQLITE_DB_PATH}`. Click 'Start Conversation' to begin. Input is taken via Gradio UI.")
output_textbox = gr.Textbox(label="AI Assistant Output", lines=5, interactive=False)
human_input_textbox = gr.Textbox(label="Your Input", placeholder="Type your response here...", interactive=False)
thread_id_state = gr.State("")
with gr.Row():
start_button = gr.Button("Start Conversation")
resume_button = gr.Button("Resume Graph", interactive=False)
start_button.click(
start_graph,
inputs=[thread_id_state],
outputs=[output_textbox, human_input_textbox, resume_button, start_button, thread_id_state]
)
resume_button.click(
resume_graph,
inputs=[human_input_textbox, thread_id_state],
outputs=[output_textbox, human_input_textbox, resume_button, start_button, thread_id_state]
)
if __name__ == "__main__":
demo.launch()