Spaces:
Build error
Build error
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import sqlite3
|
| 3 |
+
import json
|
| 4 |
+
import openai # Ensure you have openai==0.27.0 installed
|
| 5 |
+
from langgraph.graph import StateGraph, START, END
|
| 6 |
+
from typing import TypedDict, Optional
|
| 7 |
+
|
| 8 |
+
# --- Set your OpenAI API key securely ---
|
| 9 |
+
# Remove hardcoding and rely on environment variables in your Space.
|
| 10 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 11 |
+
|
| 12 |
+
# Define the state for SQL Execution Workflow
|
| 13 |
+
class SQLExecutionState(TypedDict):
|
| 14 |
+
sql_query: str
|
| 15 |
+
structured_metadata: Optional[dict]
|
| 16 |
+
validation_result: Optional[dict]
|
| 17 |
+
optimized_sql: Optional[str]
|
| 18 |
+
execution_result: Optional[dict]
|
| 19 |
+
|
| 20 |
+
# Initialize the LangGraph Workflow
|
| 21 |
+
graph = StateGraph(state_schema=SQLExecutionState)
|
| 22 |
+
|
| 23 |
+
# ------------- 1. Query Understanding Agent -------------
|
| 24 |
+
def query_understanding_agent(state: SQLExecutionState) -> SQLExecutionState:
|
| 25 |
+
natural_language_query = state["sql_query"]
|
| 26 |
+
prompt = f"""
|
| 27 |
+
Convert the following natural language query into **structured SQL metadata** based on the database schema.
|
| 28 |
+
If you cannot generate a query that adheres strictly to the schema, return:
|
| 29 |
+
{{ "error": "Invalid query: Tables or columns do not match schema" }}
|
| 30 |
+
|
| 31 |
+
**Query:** "{natural_language_query}"
|
| 32 |
+
|
| 33 |
+
**Database Schema:**
|
| 34 |
+
- **orders** (order_id, customer_id, order_date, total_amount, status)
|
| 35 |
+
- **order_items** (order_item_id, order_id, product_id, quantity, subtotal)
|
| 36 |
+
- **products** (product_id, name, category, price, stock_quantity)
|
| 37 |
+
- **customers** (customer_id, name, email, phone, address, created_at)
|
| 38 |
+
- **payments** (payment_id, order_id, payment_date, amount, payment_method, status)
|
| 39 |
+
|
| 40 |
+
**Rules:**
|
| 41 |
+
- Use only the provided tables.
|
| 42 |
+
- Ensure correct column names.
|
| 43 |
+
- Return output strictly in JSON format.
|
| 44 |
+
- Group by relevant fields when necessary.
|
| 45 |
+
|
| 46 |
+
**Example Output Format:**
|
| 47 |
+
{json.dumps({
|
| 48 |
+
"operation": "SELECT",
|
| 49 |
+
"columns": ["customer_id", "SUM(total_amount) AS total_spent"],
|
| 50 |
+
"table": "orders",
|
| 51 |
+
"conditions": ["order_date BETWEEN '2024-01-01' AND '2024-12-31'"],
|
| 52 |
+
"group_by": ["customer_id"],
|
| 53 |
+
"order_by": ["total_spent DESC"],
|
| 54 |
+
"limit": 5
|
| 55 |
+
}, indent=4)}
|
| 56 |
+
|
| 57 |
+
**DO NOT return explanations. Only return valid JSON.**
|
| 58 |
+
"""
|
| 59 |
+
response = openai.ChatCompletion.create(
|
| 60 |
+
model="gpt-4o-mini",
|
| 61 |
+
messages=[{"role": "user", "content": prompt}]
|
| 62 |
+
)
|
| 63 |
+
try:
|
| 64 |
+
metadata = json.loads(response["choices"][0]["message"]["content"])
|
| 65 |
+
return {"structured_metadata": metadata}
|
| 66 |
+
except json.JSONDecodeError:
|
| 67 |
+
return {"structured_metadata": {"error": "Invalid JSON response from OpenAI"}}
|
| 68 |
+
|
| 69 |
+
graph.add_node("Query Understanding", query_understanding_agent)
|
| 70 |
+
|
| 71 |
+
# ------------- 2. Query Validation Agent -------------
|
| 72 |
+
def query_validation_agent(state: SQLExecutionState) -> SQLExecutionState:
|
| 73 |
+
sql_metadata = state.get("structured_metadata", {})
|
| 74 |
+
if "error" in sql_metadata:
|
| 75 |
+
return {"validation_result": {"error": sql_metadata["error"]}}
|
| 76 |
+
query = sql_metadata.get("operation", "")
|
| 77 |
+
restricted_keywords = ["DROP", "DELETE", "TRUNCATE", "ALTER"]
|
| 78 |
+
if any(keyword in query.upper() for keyword in restricted_keywords):
|
| 79 |
+
return {"validation_result": {"error": "Potentially harmful SQL operation detected!"}}
|
| 80 |
+
return {"validation_result": {"valid": True}}
|
| 81 |
+
|
| 82 |
+
graph.add_node("Query Validation", query_validation_agent)
|
| 83 |
+
|
| 84 |
+
# ------------- 3. Query Optimization Agent -------------
|
| 85 |
+
def query_optimization_agent(state: SQLExecutionState) -> SQLExecutionState:
|
| 86 |
+
sql_metadata = state.get("structured_metadata", {})
|
| 87 |
+
prompt = f"""
|
| 88 |
+
Optimize the following SQL query for performance while ensuring that the output includes only the required columns and necessary joins.
|
| 89 |
+
Do not include any extra columns, unnecessary joins, or records that are not required to answer the query.
|
| 90 |
+
|
| 91 |
+
Here is the original SQL metadata:
|
| 92 |
+
{json.dumps(sql_metadata, indent=4)}
|
| 93 |
+
|
| 94 |
+
Output only the final optimized SQL query in plain text without any markdown formatting or explanations.
|
| 95 |
+
"""
|
| 96 |
+
response = openai.ChatCompletion.create(
|
| 97 |
+
model="gpt-4o-mini",
|
| 98 |
+
messages=[{"role": "user", "content": prompt}],
|
| 99 |
+
temperature=0
|
| 100 |
+
)
|
| 101 |
+
optimized_query = response["choices"][0]["message"]["content"].strip()
|
| 102 |
+
if optimized_query.startswith("
|
| 103 |
+
sql"):
|
| 104 |
+
optimized_query = optimized_query.replace("
|
| 105 |
+
sql", "").replace("
|
| 106 |
+
", "").strip()
|
| 107 |
+
return {"optimized_sql": optimized_query}
|
| 108 |
+
|
| 109 |
+
graph.add_node("Query Optimization", query_optimization_agent)
|
| 110 |
+
|
| 111 |
+
# ------------- 4. SQL Execution Agent -------------
|
| 112 |
+
def execution_agent(state: SQLExecutionState) -> SQLExecutionState:
|
| 113 |
+
query = state.get("optimized_sql", "").strip()
|
| 114 |
+
if not query:
|
| 115 |
+
return {"execution_result": {"error": "No SQL query to execute."}}
|
| 116 |
+
try:
|
| 117 |
+
conn = sqlite3.connect("complex_test_db.sqlite", timeout=20)
|
| 118 |
+
cursor = conn.cursor()
|
| 119 |
+
cursor.execute(query)
|
| 120 |
+
result = cursor.fetchall()
|
| 121 |
+
cursor.close()
|
| 122 |
+
conn.close()
|
| 123 |
+
if not result:
|
| 124 |
+
return {"execution_result": {"error": "Query executed successfully but returned no results."}}
|
| 125 |
+
return {"execution_result": result}
|
| 126 |
+
except sqlite3.Error as e:
|
| 127 |
+
return {"execution_result": {"error": str(e)}}
|
| 128 |
+
|
| 129 |
+
graph.add_node("SQL Execution", execution_agent)
|
| 130 |
+
|
| 131 |
+
# Define Execution Flow
|
| 132 |
+
graph.add_edge("START", "Query Understanding")
|
| 133 |
+
graph.add_edge("Query Understanding", "Query Validation")
|
| 134 |
+
graph.add_edge("Query Validation", "Query Optimization")
|
| 135 |
+
graph.add_edge("Query Optimization", "SQL Execution")
|
| 136 |
+
graph.add_edge("SQL Execution", "END")
|
| 137 |
+
|
| 138 |
+
compiled_pipeline = graph.compile()
|
| 139 |
+
|
| 140 |
+
# Wrap your multi-agent query execution into a callable function
|
| 141 |
+
def run_multi_agent_query(natural_language_query):
|
| 142 |
+
result = compiled_pipeline.invoke({"sql_query": natural_language_query})
|
| 143 |
+
return json.dumps(result.get("execution_result", {}), indent=2)
|
| 144 |
+
|
| 145 |
+
# Gradio Interface
|
| 146 |
+
iface = gr.Interface(
|
| 147 |
+
fn=run_multi_agent_query,
|
| 148 |
+
inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your natural language SQL query here..."),
|
| 149 |
+
outputs="text",
|
| 150 |
+
title="Multi-Agent SQL Generator",
|
| 151 |
+
description="Enter a natural language query to generate and execute SQL."
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
iface.launch()
|