Spaces:
Sleeping
Sleeping
| from langgraph.graph import StateGraph, START, END | |
| from typing import TypedDict, Optional | |
| from agents.table_selection import table_selection_agent | |
| from agents.data_retrieval import sample_data_retrieval_agent | |
| from agents.sql_generation import sql_generation_agent | |
| from agents.validation import query_validation_and_optimization | |
| from agents.execution import execution_agent | |
| from utils.bigquery_utils import init_bigquery_connection | |
| # Define the state schema | |
| class SQLExecutionState(TypedDict): | |
| sql_query: str # Natural language query | |
| client: Optional[object] # BigQuery client | |
| relevant_tables: Optional[list] # Tables identified as relevant | |
| sample_data: Optional[dict] # Sample data from relevant tables | |
| generated_sql: Optional[str] # The actual SQL query (not JSON) | |
| validation_result: Optional[dict] | |
| optimized_sql: Optional[str] | |
| execution_result: Optional[dict] | |
| def initialize_client(state: SQLExecutionState) -> SQLExecutionState: | |
| """Initialize the BigQuery client and add it to the state.""" | |
| client = init_bigquery_connection() | |
| return {"client": client} | |
| def create_workflow(): | |
| """Create and return the workflow graph.""" | |
| # Initialize the LangGraph Workflow | |
| graph = StateGraph(state_schema=SQLExecutionState) | |
| # Add nodes | |
| graph.add_node("Initialize Client", initialize_client) | |
| graph.add_node("Table Selection", table_selection_agent) | |
| graph.add_node("Sample Data Retrieval", sample_data_retrieval_agent) | |
| graph.add_node("SQL Generation", sql_generation_agent) | |
| graph.add_node("Query Validation & Optimization", query_validation_and_optimization) | |
| graph.add_node("SQL Execution", execution_agent) | |
| # Define execution flow | |
| graph.add_edge(START, "Initialize Client") | |
| graph.add_edge("Initialize Client", "Table Selection") | |
| graph.add_edge("Table Selection", "Sample Data Retrieval") | |
| graph.add_edge("Sample Data Retrieval", "SQL Generation") | |
| graph.add_edge("SQL Generation", "Query Validation & Optimization") | |
| graph.add_edge("Query Validation & Optimization", "SQL Execution") | |
| graph.add_edge("SQL Execution", END) | |
| # Compile the graph | |
| return graph.compile() |