import sys import os sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) from agents.sql_agent.states import SQLAgentState from langgraph.graph import StateGraph, START, END from agents.sql_agent.nodes import ( get_db_info, generate_sql, execute_sql, generate_answer, detect_off_topic, choose_visualization, format_data_for_visualization, render_visualization, finalize_output ) def build_graph(visualize: bool = True) -> StateGraph: graph = StateGraph(SQLAgentState) # Add nodes graph.add_node("detect_off_topic", detect_off_topic) graph.add_node("generate_sql", generate_sql) graph.add_node("get_db_info", get_db_info) graph.add_node("execute_sql", execute_sql) graph.add_node("generate_answer", generate_answer) graph.add_node("choose_visualization", choose_visualization) graph.add_node("format_data_for_visualization", format_data_for_visualization) graph.add_node("render_visualization", render_visualization) graph.add_node("finalize_output", finalize_output) # Add edges graph.add_edge(START, "detect_off_topic") graph.add_conditional_edges( "detect_off_topic", lambda state: state['error'], path_map={ True: "generate_answer", # True: "get_db_info", False: "get_db_info" } ) graph.add_edge("get_db_info", "generate_sql") graph.add_edge("generate_sql", "execute_sql") graph.add_edge("execute_sql", "choose_visualization") graph.add_edge("choose_visualization", "format_data_for_visualization") graph.add_edge("format_data_for_visualization", "render_visualization") graph.add_edge("render_visualization", "generate_answer") graph.add_edge("generate_answer", "finalize_output") graph.add_edge("finalize_output", END) # graph.add_edge("execute_sql", "generate_answer") # graph.add_edge("generate_answer", "choose_visualization") # graph.add_edge("choose_visualization", END) if visualize: # TODO: Implement visualization pass return graph def visualize_graph(graph) -> None: graph.visualize() if __name__ == "__main__": state = { "question": "top 3 sản phẩm có giá thấp nhất", "db_info": { "tables": [], "columns": {}, "schema": "" }, "sql_query": "", "sql_result": None, "error": None, "step": None, "answer": None, "plot_path": None, "response_md": None, "visualization": None, "visualization_reason": None, "formatted_data_for_visualization": None, "visualization_output": None, "off_topic": None } graph = build_graph().compile() # visualize_graph(graph) result = graph.invoke(state) # print(result) answer = result['answer'] print(answer) for step in graph.stream( state, stream_mode="updates" ): print("-" * 80) # print(step['step']) print(step)