Spaces:
Build error
Build error
| import gradio as gr | |
| import sqlite3 | |
| import json | |
| import openai # Ensure you have openai==0.27.0 installed | |
| from langgraph.graph import StateGraph, START, END | |
| from typing import TypedDict, Optional | |
| # --- Set your OpenAI API key securely --- | |
| # Remove hardcoding and rely on environment variables in your Space. | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| # Define the state for SQL Execution Workflow | |
| class SQLExecutionState(TypedDict): | |
| sql_query: str | |
| structured_metadata: Optional[dict] | |
| validation_result: Optional[dict] | |
| optimized_sql: Optional[str] | |
| execution_result: Optional[dict] | |
| # Initialize the LangGraph Workflow | |
| graph = StateGraph(state_schema=SQLExecutionState) | |
| # ------------- 1. Query Understanding Agent ------------- | |
| def query_understanding_agent(state: SQLExecutionState) -> SQLExecutionState: | |
| natural_language_query = state["sql_query"] | |
| prompt = f""" | |
| Convert the following natural language query into **structured SQL metadata** based on the database schema. | |
| If you cannot generate a query that adheres strictly to the schema, return: | |
| {{ "error": "Invalid query: Tables or columns do not match schema" }} | |
| **Query:** "{natural_language_query}" | |
| **Database Schema:** | |
| - **orders** (order_id, customer_id, order_date, total_amount, status) | |
| - **order_items** (order_item_id, order_id, product_id, quantity, subtotal) | |
| - **products** (product_id, name, category, price, stock_quantity) | |
| - **customers** (customer_id, name, email, phone, address, created_at) | |
| - **payments** (payment_id, order_id, payment_date, amount, payment_method, status) | |
| **Rules:** | |
| - Use only the provided tables. | |
| - Ensure correct column names. | |
| - Return output strictly in JSON format. | |
| - Group by relevant fields when necessary. | |
| **Example Output Format:** | |
| {json.dumps({ | |
| "operation": "SELECT", | |
| "columns": ["customer_id", "SUM(total_amount) AS total_spent"], | |
| "table": "orders", | |
| "conditions": ["order_date BETWEEN '2024-01-01' AND '2024-12-31'"], | |
| "group_by": ["customer_id"], | |
| "order_by": ["total_spent DESC"], | |
| "limit": 5 | |
| }, indent=4)} | |
| **DO NOT return explanations. Only return valid JSON.** | |
| """ | |
| response = openai.ChatCompletion.create( | |
| model="gpt-4o-mini", | |
| messages=[{"role": "user", "content": prompt}] | |
| ) | |
| try: | |
| metadata = json.loads(response["choices"][0]["message"]["content"]) | |
| return {"structured_metadata": metadata} | |
| except json.JSONDecodeError: | |
| return {"structured_metadata": {"error": "Invalid JSON response from OpenAI"}} | |
| graph.add_node("Query Understanding", query_understanding_agent) | |
| # ------------- 2. Query Validation Agent ------------- | |
| def query_validation_agent(state: SQLExecutionState) -> SQLExecutionState: | |
| sql_metadata = state.get("structured_metadata", {}) | |
| if "error" in sql_metadata: | |
| return {"validation_result": {"error": sql_metadata["error"]}} | |
| query = sql_metadata.get("operation", "") | |
| restricted_keywords = ["DROP", "DELETE", "TRUNCATE", "ALTER"] | |
| if any(keyword in query.upper() for keyword in restricted_keywords): | |
| return {"validation_result": {"error": "Potentially harmful SQL operation detected!"}} | |
| return {"validation_result": {"valid": True}} | |
| graph.add_node("Query Validation", query_validation_agent) | |
| # ------------- 3. Query Optimization Agent ------------- | |
| def query_optimization_agent(state: SQLExecutionState) -> SQLExecutionState: | |
| sql_metadata = state.get("structured_metadata", {}) | |
| prompt = f""" | |
| Optimize the following SQL query for performance while ensuring that the output includes only the required columns and necessary joins. | |
| Do not include any extra columns, unnecessary joins, or records that are not required to answer the query. | |
| Here is the original SQL metadata: | |
| {json.dumps(sql_metadata, indent=4)} | |
| Output only the final optimized SQL query in plain text without any markdown formatting or explanations. | |
| """ | |
| response = openai.ChatCompletion.create( | |
| model="gpt-4o-mini", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0 | |
| ) | |
| optimized_query = response["choices"][0]["message"]["content"].strip() | |
| if optimized_query.startswith(" | |
| sql"): | |
| optimized_query = optimized_query.replace(" | |
| sql", "").replace(" | |
| ", "").strip() | |
| return {"optimized_sql": optimized_query} | |
| graph.add_node("Query Optimization", query_optimization_agent) | |
| # ------------- 4. SQL Execution Agent ------------- | |
| def execution_agent(state: SQLExecutionState) -> SQLExecutionState: | |
| query = state.get("optimized_sql", "").strip() | |
| if not query: | |
| return {"execution_result": {"error": "No SQL query to execute."}} | |
| try: | |
| conn = sqlite3.connect("complex_test_db.sqlite", timeout=20) | |
| cursor = conn.cursor() | |
| cursor.execute(query) | |
| result = cursor.fetchall() | |
| cursor.close() | |
| conn.close() | |
| if not result: | |
| return {"execution_result": {"error": "Query executed successfully but returned no results."}} | |
| return {"execution_result": result} | |
| except sqlite3.Error as e: | |
| return {"execution_result": {"error": str(e)}} | |
| graph.add_node("SQL Execution", execution_agent) | |
| # Define Execution Flow | |
| graph.add_edge("START", "Query Understanding") | |
| graph.add_edge("Query Understanding", "Query Validation") | |
| graph.add_edge("Query Validation", "Query Optimization") | |
| graph.add_edge("Query Optimization", "SQL Execution") | |
| graph.add_edge("SQL Execution", "END") | |
| compiled_pipeline = graph.compile() | |
| # Wrap your multi-agent query execution into a callable function | |
| def run_multi_agent_query(natural_language_query): | |
| result = compiled_pipeline.invoke({"sql_query": natural_language_query}) | |
| return json.dumps(result.get("execution_result", {}), indent=2) | |
| # Gradio Interface | |
| iface = gr.Interface( | |
| fn=run_multi_agent_query, | |
| inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your natural language SQL query here..."), | |
| outputs="text", | |
| title="Multi-Agent SQL Generator", | |
| description="Enter a natural language query to generate and execute SQL." | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |