File size: 3,096 Bytes
472e1d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9016439
 
472e1d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from agents.sql_agent.states import SQLAgentState
from langgraph.graph import StateGraph, START, END
from agents.sql_agent.nodes import (
    get_db_info, 
    generate_sql, 
    execute_sql, 
    generate_answer, 
    detect_off_topic, 
    choose_visualization, 
    format_data_for_visualization,
    render_visualization,
    finalize_output
)

def build_graph(visualize: bool = True) -> StateGraph:
    graph = StateGraph(SQLAgentState)

    # Add nodes
    graph.add_node("detect_off_topic", detect_off_topic)
    graph.add_node("generate_sql", generate_sql)
    graph.add_node("get_db_info", get_db_info)
    graph.add_node("execute_sql", execute_sql)
    graph.add_node("generate_answer", generate_answer)
    graph.add_node("choose_visualization", choose_visualization)
    graph.add_node("format_data_for_visualization", format_data_for_visualization)
    graph.add_node("render_visualization", render_visualization)
    graph.add_node("finalize_output", finalize_output)


    # Add edges
    graph.add_edge(START, "detect_off_topic")

    graph.add_conditional_edges(
        "detect_off_topic",
        lambda state: state['error'], 
        path_map={
            True: "generate_answer",  
            # True: "get_db_info",
            False: "get_db_info"
        }
    )

    graph.add_edge("get_db_info", "generate_sql")
    graph.add_edge("generate_sql", "execute_sql")
    graph.add_edge("execute_sql", "choose_visualization")
    graph.add_edge("choose_visualization", "format_data_for_visualization")
    graph.add_edge("format_data_for_visualization", "render_visualization")
    graph.add_edge("render_visualization", "generate_answer")
    graph.add_edge("generate_answer", "finalize_output")
    graph.add_edge("finalize_output", END)
    # graph.add_edge("execute_sql", "generate_answer")
    # graph.add_edge("generate_answer", "choose_visualization")
    # graph.add_edge("choose_visualization", END)

    if visualize:
        # TODO: Implement visualization
        pass
    return graph

def visualize_graph(graph) -> None:
    graph.visualize()

if __name__ == "__main__":
    state = {
        "question": "top 3 sản phẩm có giá thấp nhất",
        "db_info": {
            "tables": [],
            "columns": {},
            "schema": ""
        },
        "sql_query": "",
        "sql_result": None,
        "error": None,
        "step": None,
        "answer": None,
        "plot_path": None,
        "response_md": None,
        "visualization": None,
        "visualization_reason": None,
        "formatted_data_for_visualization": None,
        "visualization_output": None,
        "off_topic": None
    }

    graph = build_graph().compile()
    # visualize_graph(graph)

    result = graph.invoke(state)
    # print(result)

    answer = result['answer']
    print(answer)

    for step in graph.stream(
        state, stream_mode="updates"
    ):
        print("-" * 80)
        # print(step['step'])
        print(step)