Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
import sqlite3
|
| 3 |
-
import json
|
| 4 |
-
import openai
|
| 5 |
from langgraph.graph import StateGraph, START, END
|
| 6 |
from typing import TypedDict, Optional
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
|
| 10 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 11 |
|
| 12 |
# Define the state for SQL Execution Workflow
|
|
@@ -20,7 +19,11 @@ class SQLExecutionState(TypedDict):
|
|
| 20 |
# Initialize the LangGraph Workflow
|
| 21 |
graph = StateGraph(state_schema=SQLExecutionState)
|
| 22 |
|
| 23 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def query_understanding_agent(state: SQLExecutionState) -> SQLExecutionState:
|
| 25 |
natural_language_query = state["sql_query"]
|
| 26 |
prompt = f"""
|
|
@@ -68,7 +71,7 @@ def query_understanding_agent(state: SQLExecutionState) -> SQLExecutionState:
|
|
| 68 |
|
| 69 |
graph.add_node("Query Understanding", query_understanding_agent)
|
| 70 |
|
| 71 |
-
#
|
| 72 |
def query_validation_agent(state: SQLExecutionState) -> SQLExecutionState:
|
| 73 |
sql_metadata = state.get("structured_metadata", {})
|
| 74 |
if "error" in sql_metadata:
|
|
@@ -81,7 +84,7 @@ def query_validation_agent(state: SQLExecutionState) -> SQLExecutionState:
|
|
| 81 |
|
| 82 |
graph.add_node("Query Validation", query_validation_agent)
|
| 83 |
|
| 84 |
-
#
|
| 85 |
def query_optimization_agent(state: SQLExecutionState) -> SQLExecutionState:
|
| 86 |
sql_metadata = state.get("structured_metadata", {})
|
| 87 |
prompt = f"""
|
|
@@ -103,10 +106,9 @@ def query_optimization_agent(state: SQLExecutionState) -> SQLExecutionState:
|
|
| 103 |
optimized_query = optimized_query.replace("```sql", "").replace("```", "").strip()
|
| 104 |
return {"optimized_sql": optimized_query}
|
| 105 |
|
| 106 |
-
|
| 107 |
graph.add_node("Query Optimization", query_optimization_agent)
|
| 108 |
|
| 109 |
-
#
|
| 110 |
def execution_agent(state: SQLExecutionState) -> SQLExecutionState:
|
| 111 |
query = state.get("optimized_sql", "").strip()
|
| 112 |
if not query:
|
|
@@ -126,28 +128,20 @@ def execution_agent(state: SQLExecutionState) -> SQLExecutionState:
|
|
| 126 |
|
| 127 |
graph.add_node("SQL Execution", execution_agent)
|
| 128 |
|
| 129 |
-
# Define Execution Flow
|
| 130 |
-
graph.add_edge(
|
| 131 |
graph.add_edge("Query Understanding", "Query Validation")
|
| 132 |
graph.add_edge("Query Validation", "Query Optimization")
|
| 133 |
graph.add_edge("Query Optimization", "SQL Execution")
|
| 134 |
-
graph.add_edge("SQL Execution",
|
| 135 |
|
| 136 |
compiled_pipeline = graph.compile()
|
| 137 |
|
| 138 |
-
#
|
| 139 |
def run_multi_agent_query(natural_language_query):
|
| 140 |
result = compiled_pipeline.invoke({"sql_query": natural_language_query})
|
| 141 |
return json.dumps(result.get("execution_result", {}), indent=2)
|
| 142 |
|
| 143 |
-
# Gradio Interface
|
| 144 |
-
iface = gr.Interface(
|
| 145 |
-
fn=run_multi_agent_query,
|
| 146 |
-
inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your natural language SQL query here..."),
|
| 147 |
-
outputs="text",
|
| 148 |
-
title="Multi-Agent SQL Generator",
|
| 149 |
-
description="Enter a natural language query to generate and execute SQL."
|
| 150 |
-
)
|
| 151 |
-
|
| 152 |
if __name__ == "__main__":
|
| 153 |
-
|
|
|
|
|
|
|
|
|
| 1 |
import sqlite3
|
| 2 |
+
import json
|
| 3 |
+
import openai
|
| 4 |
from langgraph.graph import StateGraph, START, END
|
| 5 |
from typing import TypedDict, Optional
|
| 6 |
|
| 7 |
+
# Set your OpenAI API key from environment variable
|
| 8 |
+
import os
|
| 9 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 10 |
|
| 11 |
# Define the state for SQL Execution Workflow
|
|
|
|
| 19 |
# Initialize the LangGraph Workflow
|
| 20 |
graph = StateGraph(state_schema=SQLExecutionState)
|
| 21 |
|
| 22 |
+
# Add dummy nodes for START and END
|
| 23 |
+
graph.add_node(START, lambda state: state)
|
| 24 |
+
graph.add_node(END, lambda state: state)
|
| 25 |
+
|
| 26 |
+
# ------------------ 1. Query Understanding Agent ------------------
|
| 27 |
def query_understanding_agent(state: SQLExecutionState) -> SQLExecutionState:
|
| 28 |
natural_language_query = state["sql_query"]
|
| 29 |
prompt = f"""
|
|
|
|
| 71 |
|
| 72 |
graph.add_node("Query Understanding", query_understanding_agent)
|
| 73 |
|
| 74 |
+
# ------------------ 2. Query Validation Agent ------------------
|
| 75 |
def query_validation_agent(state: SQLExecutionState) -> SQLExecutionState:
|
| 76 |
sql_metadata = state.get("structured_metadata", {})
|
| 77 |
if "error" in sql_metadata:
|
|
|
|
| 84 |
|
| 85 |
graph.add_node("Query Validation", query_validation_agent)
|
| 86 |
|
| 87 |
+
# ------------------ 3. Query Optimization Agent ------------------
|
| 88 |
def query_optimization_agent(state: SQLExecutionState) -> SQLExecutionState:
|
| 89 |
sql_metadata = state.get("structured_metadata", {})
|
| 90 |
prompt = f"""
|
|
|
|
| 106 |
optimized_query = optimized_query.replace("```sql", "").replace("```", "").strip()
|
| 107 |
return {"optimized_sql": optimized_query}
|
| 108 |
|
|
|
|
| 109 |
graph.add_node("Query Optimization", query_optimization_agent)
|
| 110 |
|
| 111 |
+
# ------------------ 4. SQL Execution Agent ------------------
|
| 112 |
def execution_agent(state: SQLExecutionState) -> SQLExecutionState:
|
| 113 |
query = state.get("optimized_sql", "").strip()
|
| 114 |
if not query:
|
|
|
|
| 128 |
|
| 129 |
graph.add_node("SQL Execution", execution_agent)
|
| 130 |
|
| 131 |
+
# ------------------ Define Execution Flow ------------------
|
| 132 |
+
graph.add_edge(START, "Query Understanding")
|
| 133 |
graph.add_edge("Query Understanding", "Query Validation")
|
| 134 |
graph.add_edge("Query Validation", "Query Optimization")
|
| 135 |
graph.add_edge("Query Optimization", "SQL Execution")
|
| 136 |
+
graph.add_edge("SQL Execution", END)
|
| 137 |
|
| 138 |
compiled_pipeline = graph.compile()
|
| 139 |
|
| 140 |
+
# ------------------ Example Execution ------------------
|
| 141 |
def run_multi_agent_query(natural_language_query):
|
| 142 |
result = compiled_pipeline.invoke({"sql_query": natural_language_query})
|
| 143 |
return json.dumps(result.get("execution_result", {}), indent=2)
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
if __name__ == "__main__":
|
| 146 |
+
# For testing outside of Hugging Face Spaces
|
| 147 |
+
print(run_multi_agent_query("Find the email_id of the top 5 customers who spent the most in 2024."))
|