Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.sqlite import SqliteSaver # Import SqliteSaver | |
| import operator | |
| from typing import TypedDict, Annotated, Optional | |
| import uuid | |
| import os | |
| import sqlite3 # Import sqlite3 for direct connection if needed for older versions | |
| # --- 1. Define your Graph State --- | |
| class GraphState(TypedDict): | |
| greeting_message: str | |
| human_input: Annotated[Optional[str], operator.add] # operator.add for merging | |
| final_response: str | |
| current_node: str | |
| # --- 2. Define the Nodes --- | |
| def greeting_node(state: GraphState) -> GraphState: | |
| greeting = "Hello there! I'm an AI assistant. How can I help you today?" | |
| print("🤖 Greeting Node Executed:") | |
| print(greeting) | |
| return {"greeting_message": greeting, "current_node": "greeting"} | |
| def human_input_node(state: GraphState) -> GraphState: | |
| """ | |
| Node 2: This node processes the human input that was added to the state | |
| via the `put_state` method during the resume phase. It does NOT contain `input()`. | |
| """ | |
| human_response = state.get("human_input", "No human input found in state when human_input_node ran.") | |
| print(f"\n✋ Human Input Node Executed (Processing input: '{human_response}'):") | |
| # This node could perform validation or further processing using human_response. | |
| return {"current_node": "human_input_processed"} | |
| def human_response_display_node(state: GraphState) -> GraphState: | |
| human_response = state.get("human_input", "No human input received for final display.") | |
| final_message = f"You said: '{human_response}'. Thank you for your input!" | |
| print("\n✅ Human Response Display Node Executed:") | |
| print(final_message) | |
| return {"final_response": final_message, "current_node": "human_response_display"} | |
| # --- 3. Build the Graph --- | |
| builder = StateGraph(GraphState) | |
| builder.add_node("greeting", greeting_node) | |
| builder.add_node("human_input_interrupt", human_input_node) | |
| builder.add_node("human_response_display", human_response_display_node) | |
| builder.set_entry_point("greeting") | |
| builder.add_edge("greeting", "human_input_interrupt") | |
| builder.add_edge("human_input_interrupt", "human_response_display") | |
| builder.add_edge("human_response_display", END) | |
| # Define the path for the SQLite database file | |
| SQLITE_DB_PATH = "langgraph_checkpoints.sqlite" | |
| # --- Checkpointer and Graph Compilation --- | |
| global_memory_saver = None # Initialize to None | |
| try: | |
| # Attempt to connect to SQLite | |
| # Use check_same_thread=False for Gradio/web apps that might access the DB from different threads | |
| conn = sqlite3.connect(SQLITE_DB_PATH, check_same_thread=False) | |
| global_memory_saver = SqliteSaver(conn=conn) | |
| print(f"SqliteSaver initialized successfully: {type(global_memory_saver)}") # ADD THIS LINE | |
| except Exception as e: | |
| print(f"Error initializing SqliteSaver connection: {e}") | |
| # If an error occurs here, global_memory_saver will remain None | |
| if global_memory_saver: | |
| global_graph = builder.compile( | |
| checkpointer=global_memory_saver, | |
| interrupt_before=["human_input_interrupt"] | |
| ) | |
| print("Graph compiled successfully.") # ADD THIS LINE | |
| else: | |
| print("SqliteSaver not initialized, graph compilation skipped.") | |
| global_graph = None # Handle the case where checkpointer couldn't be initialized | |
| # --- Gradio UI Logic --- | |
| current_thread_id = gr.State("") | |
| def start_graph(thread_id_state): | |
| new_thread_id = str(uuid.uuid4()) | |
| print(f"\n--- Starting New Graph Execution with thread_id: {new_thread_id} ---") | |
| if not global_graph: | |
| return (f"Error: Graph not initialized. Check server logs.", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state) | |
| try: | |
| for s in global_graph.stream({}, {"configurable": {"thread_id": new_thread_id}}): | |
| if "__end__" in s: | |
| break | |
| elif "__interrupt__" in s: | |
| print(f"Graph interrupted BEFORE {s.get('__interrupt__', 'Unknown')} node.") | |
| break | |
| else: | |
| pass | |
| current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": new_thread_id}}) | |
| output_message = current_state_snapshot.values.get("greeting_message", "No greeting yet.") | |
| output_message += "\n\n" + "Please type your response in the 'Your Input' box and click 'Resume Graph'." | |
| return (output_message, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False), new_thread_id) | |
| except Exception as e: | |
| return (f"An error occurred during graph start: {e}", | |
| gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state) | |
| def resume_graph(human_input_from_ui: str, thread_id_state): | |
| print(f"\n--- Resuming Graph Execution for thread_id: {thread_id_state} ---") | |
| print(f"Human input received from UI: {human_input_from_ui}") | |
| if not global_graph or not global_memory_saver: | |
| return (f"Error: Graph or checkpointer not initialized. Check server logs.", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state) | |
| try: | |
| current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": thread_id_state}}) | |
| current_state_values = current_state_snapshot.values | |
| current_state_values["human_input"] = human_input_from_ui | |
| print(f"Type of global_memory_saver before put_state: {type(global_memory_saver)}") # ADD THIS LINE | |
| # Use put_state() from the SqliteSaver instance | |
| #global_memory_saver.put_state(current_state_values, {"configurable": {"thread_id": thread_id_state}}) | |
| # Update the state directly | |
| config = {"configurable": {"thread_id": thread_id_state}} | |
| global_graph.update_state(config, current_state_values) | |
| for s in global_graph.stream(None, config): | |
| if "__end__" in s: | |
| break | |
| else: | |
| pass | |
| final_state_snapshot = global_graph.get_state(config) | |
| final_state_values = final_state_snapshot.values | |
| final_message = final_state_values.get("final_response", "Graph finished without final response.") | |
| return (final_message, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state) | |
| except Exception as e: | |
| return (f"An error occurred during graph resumption: {e}", | |
| gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state) | |
| # Gradio Interface setup | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# LangGraph Human-in-the-Loop Demo (SqliteSaver)") | |
| gr.Markdown(f"Graph state will be saved persistently in `{SQLITE_DB_PATH}`. Click 'Start Conversation' to begin. Input is taken via Gradio UI.") | |
| output_textbox = gr.Textbox(label="AI Assistant Output", lines=5, interactive=False) | |
| human_input_textbox = gr.Textbox(label="Your Input", placeholder="Type your response here...", interactive=False) | |
| thread_id_state = gr.State("") | |
| with gr.Row(): | |
| start_button = gr.Button("Start Conversation") | |
| resume_button = gr.Button("Resume Graph", interactive=False) | |
| start_button.click( | |
| start_graph, | |
| inputs=[thread_id_state], | |
| outputs=[output_textbox, human_input_textbox, resume_button, start_button, thread_id_state] | |
| ) | |
| resume_button.click( | |
| resume_graph, | |
| inputs=[human_input_textbox, thread_id_state], | |
| outputs=[output_textbox, human_input_textbox, resume_button, start_button, thread_id_state] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |