Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) | |
| from agents.sql_agent.states import SQLAgentState | |
| from langgraph.graph import StateGraph, START, END | |
| from agents.sql_agent.nodes import ( | |
| get_db_info, | |
| generate_sql, | |
| execute_sql, | |
| generate_answer, | |
| detect_off_topic, | |
| choose_visualization, | |
| format_data_for_visualization, | |
| render_visualization, | |
| finalize_output | |
| ) | |
| def build_graph(visualize: bool = True) -> StateGraph: | |
| graph = StateGraph(SQLAgentState) | |
| # Add nodes | |
| graph.add_node("detect_off_topic", detect_off_topic) | |
| graph.add_node("generate_sql", generate_sql) | |
| graph.add_node("get_db_info", get_db_info) | |
| graph.add_node("execute_sql", execute_sql) | |
| graph.add_node("generate_answer", generate_answer) | |
| graph.add_node("choose_visualization", choose_visualization) | |
| graph.add_node("format_data_for_visualization", format_data_for_visualization) | |
| graph.add_node("render_visualization", render_visualization) | |
| graph.add_node("finalize_output", finalize_output) | |
| # Add edges | |
| graph.add_edge(START, "detect_off_topic") | |
| graph.add_conditional_edges( | |
| "detect_off_topic", | |
| lambda state: state['error'], | |
| path_map={ | |
| True: "generate_answer", | |
| # True: "get_db_info", | |
| False: "get_db_info" | |
| } | |
| ) | |
| graph.add_edge("get_db_info", "generate_sql") | |
| graph.add_edge("generate_sql", "execute_sql") | |
| graph.add_edge("execute_sql", "choose_visualization") | |
| graph.add_edge("choose_visualization", "format_data_for_visualization") | |
| graph.add_edge("format_data_for_visualization", "render_visualization") | |
| graph.add_edge("render_visualization", "generate_answer") | |
| graph.add_edge("generate_answer", "finalize_output") | |
| graph.add_edge("finalize_output", END) | |
| # graph.add_edge("execute_sql", "generate_answer") | |
| # graph.add_edge("generate_answer", "choose_visualization") | |
| # graph.add_edge("choose_visualization", END) | |
| if visualize: | |
| # TODO: Implement visualization | |
| pass | |
| return graph | |
| def visualize_graph(graph) -> None: | |
| graph.visualize() | |
| if __name__ == "__main__": | |
| state = { | |
| "question": "top 3 sản phẩm có giá thấp nhất", | |
| "db_info": { | |
| "tables": [], | |
| "columns": {}, | |
| "schema": "" | |
| }, | |
| "sql_query": "", | |
| "sql_result": None, | |
| "error": None, | |
| "step": None, | |
| "answer": None, | |
| "plot_path": None, | |
| "response_md": None, | |
| "visualization": None, | |
| "visualization_reason": None, | |
| "formatted_data_for_visualization": None, | |
| "visualization_output": None, | |
| "off_topic": None | |
| } | |
| graph = build_graph().compile() | |
| # visualize_graph(graph) | |
| result = graph.invoke(state) | |
| # print(result) | |
| answer = result['answer'] | |
| print(answer) | |
| for step in graph.stream( | |
| state, stream_mode="updates" | |
| ): | |
| print("-" * 80) | |
| # print(step['step']) | |
| print(step) | |