File size: 7,753 Bytes
7ad989d
 
94ababa
7ad989d
 
132bd32
f46b2b1
94ababa
7ad989d
f46b2b1
7ad989d
 
94ababa
7ad989d
132bd32
7ad989d
f46b2b1
7ad989d
 
 
 
 
 
 
f46b2b1
8177102
94ababa
f46b2b1
8177102
 
94ababa
8177102
7ad989d
 
8177102
7ad989d
 
 
 
 
f46b2b1
7ad989d
 
94ababa
7ad989d
 
 
94ababa
1c2b585
7ad989d
 
94ababa
 
ea6e91e
94ababa
c7ca645
 
94ababa
c7ca645
94ababa
 
 
c7ca645
94ababa
 
c7ca645
94ababa
 
 
 
 
 
c7ca645
94ababa
 
 
132bd32
f46b2b1
132bd32
7ad989d
132bd32
 
 
7ad989d
94ababa
 
 
7ad989d
cea9d6a
7ad989d
 
 
8177102
132bd32
7ad989d
cea9d6a
7ad989d
132bd32
7ad989d
132bd32
f46b2b1
7ad989d
132bd32
7ad989d
132bd32
 
 
7ad989d
132bd32
 
f46b2b1
7ad989d
94ababa
 
 
7ad989d
8177102
 
 
 
 
c7ca645
 
94ababa
209f479
8177102
209f479
 
 
 
 
7ad989d
 
 
cea9d6a
7ad989d
209f479
132bd32
7ad989d
132bd32
 
7ad989d
 
132bd32
 
7ad989d
209f479
7ad989d
 
94ababa
 
 
7ad989d
 
f46b2b1
132bd32
 
7ad989d
 
 
 
 
 
 
132bd32
 
7ad989d
 
 
 
132bd32
 
7ad989d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import gradio as gr
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver # Import SqliteSaver
import operator
from typing import TypedDict, Annotated, Optional
import uuid
import os
import sqlite3 # Import sqlite3 for direct connection if needed for older versions

# --- 1. Define your Graph State ---
class GraphState(TypedDict):
    greeting_message: str
    human_input: Annotated[Optional[str], operator.add] # operator.add for merging
    final_response: str
    current_node: str

# --- 2. Define the Nodes ---
def greeting_node(state: GraphState) -> GraphState:
    greeting = "Hello there! I'm an AI assistant. How can I help you today?"
    print("🤖 Greeting Node Executed:")
    print(greeting)
    return {"greeting_message": greeting, "current_node": "greeting"}

def human_input_node(state: GraphState) -> GraphState:
    """
    Node 2: This node processes the human input that was added to the state
    via the `put_state` method during the resume phase. It does NOT contain `input()`.
    """
    human_response = state.get("human_input", "No human input found in state when human_input_node ran.")
    print(f"\n✋ Human Input Node Executed (Processing input: '{human_response}'):")
    # This node could perform validation or further processing using human_response.
    return {"current_node": "human_input_processed"}

def human_response_display_node(state: GraphState) -> GraphState:
    human_response = state.get("human_input", "No human input received for final display.")
    final_message = f"You said: '{human_response}'. Thank you for your input!"
    print("\n✅ Human Response Display Node Executed:")
    print(final_message)
    return {"final_response": final_message, "current_node": "human_response_display"}

# --- 3. Build the Graph ---
builder = StateGraph(GraphState)
builder.add_node("greeting", greeting_node)
builder.add_node("human_input_interrupt", human_input_node)
builder.add_node("human_response_display", human_response_display_node)

builder.set_entry_point("greeting")
builder.add_edge("greeting", "human_input_interrupt")
builder.add_edge("human_input_interrupt", "human_response_display")
builder.add_edge("human_response_display", END)

# Define the path for the SQLite database file
SQLITE_DB_PATH = "langgraph_checkpoints.sqlite"

# --- Checkpointer and Graph Compilation ---
global_memory_saver = None # Initialize to None

try:
    # Attempt to connect to SQLite
    # Use check_same_thread=False for Gradio/web apps that might access the DB from different threads
    conn = sqlite3.connect(SQLITE_DB_PATH, check_same_thread=False)
    global_memory_saver = SqliteSaver(conn=conn)
    print(f"SqliteSaver initialized successfully: {type(global_memory_saver)}") # ADD THIS LINE
except Exception as e:
    print(f"Error initializing SqliteSaver connection: {e}")
    # If an error occurs here, global_memory_saver will remain None

if global_memory_saver:
    global_graph = builder.compile(
        checkpointer=global_memory_saver,
        interrupt_before=["human_input_interrupt"]
    )
    print("Graph compiled successfully.") # ADD THIS LINE
else:
    print("SqliteSaver not initialized, graph compilation skipped.")
    global_graph = None # Handle the case where checkpointer couldn't be initialized

# --- Gradio UI Logic ---
current_thread_id = gr.State("")

def start_graph(thread_id_state):
    new_thread_id = str(uuid.uuid4())
    print(f"\n--- Starting New Graph Execution with thread_id: {new_thread_id} ---")
    
    if not global_graph:
        return (f"Error: Graph not initialized. Check server logs.", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)

    try:
        for s in global_graph.stream({}, {"configurable": {"thread_id": new_thread_id}}):
            if "__end__" in s:
                break
            elif "__interrupt__" in s:
                print(f"Graph interrupted BEFORE {s.get('__interrupt__', 'Unknown')} node.")
                break
            else:
                pass 
        
        current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": new_thread_id}})
        
        output_message = current_state_snapshot.values.get("greeting_message", "No greeting yet.")
        output_message += "\n\n" + "Please type your response in the 'Your Input' box and click 'Resume Graph'."

        return (output_message, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False), new_thread_id)

    except Exception as e:
        return (f"An error occurred during graph start: {e}",
                gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)

def resume_graph(human_input_from_ui: str, thread_id_state):
    print(f"\n--- Resuming Graph Execution for thread_id: {thread_id_state} ---")
    print(f"Human input received from UI: {human_input_from_ui}")

    if not global_graph or not global_memory_saver:
        return (f"Error: Graph or checkpointer not initialized. Check server logs.", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)

    try:
        current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": thread_id_state}})
        current_state_values = current_state_snapshot.values
        
        current_state_values["human_input"] = human_input_from_ui

        print(f"Type of global_memory_saver before put_state: {type(global_memory_saver)}") # ADD THIS LINE

        # Use put_state() from the SqliteSaver instance
        #global_memory_saver.put_state(current_state_values, {"configurable": {"thread_id": thread_id_state}})

        # Update the state directly
        config = {"configurable": {"thread_id": thread_id_state}}
        global_graph.update_state(config, current_state_values)

        for s in global_graph.stream(None, config):
            if "__end__" in s:
                break
            else:
                pass 

        final_state_snapshot = global_graph.get_state(config)
        final_state_values = final_state_snapshot.values

        final_message = final_state_values.get("final_response", "Graph finished without final response.")
        return (final_message, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)

    except Exception as e:
        return (f"An error occurred during graph resumption: {e}",
                gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), thread_id_state)


# Gradio Interface setup
with gr.Blocks() as demo:
    gr.Markdown("# LangGraph Human-in-the-Loop Demo (SqliteSaver)")
    gr.Markdown(f"Graph state will be saved persistently in `{SQLITE_DB_PATH}`. Click 'Start Conversation' to begin. Input is taken via Gradio UI.")
    

    output_textbox = gr.Textbox(label="AI Assistant Output", lines=5, interactive=False)
    human_input_textbox = gr.Textbox(label="Your Input", placeholder="Type your response here...", interactive=False)
    
    thread_id_state = gr.State("")

    with gr.Row():
        start_button = gr.Button("Start Conversation")
        resume_button = gr.Button("Resume Graph", interactive=False)

    start_button.click(
        start_graph,
        inputs=[thread_id_state],
        outputs=[output_textbox, human_input_textbox, resume_button, start_button, thread_id_state]
    )

    resume_button.click(
        resume_graph,
        inputs=[human_input_textbox, thread_id_state],
        outputs=[output_textbox, human_input_textbox, resume_button, start_button, thread_id_state]
    )

if __name__ == "__main__":
    demo.launch()