krishanusinha20 commited on
Commit
1e23bc6
·
verified ·
1 Parent(s): a5c6271

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sqlite3
3
+ import json
4
+ import openai # Ensure you have openai==0.27.0 installed
5
+ from langgraph.graph import StateGraph, START, END
6
+ from typing import TypedDict, Optional
7
+
8
+ # --- Set your OpenAI API key securely ---
9
+ # Remove hardcoding and rely on environment variables in your Space.
10
+ openai.api_key = os.getenv("OPENAI_API_KEY")
11
+
12
+ # Define the state for SQL Execution Workflow
13
+ class SQLExecutionState(TypedDict):
14
+ sql_query: str
15
+ structured_metadata: Optional[dict]
16
+ validation_result: Optional[dict]
17
+ optimized_sql: Optional[str]
18
+ execution_result: Optional[dict]
19
+
20
+ # Initialize the LangGraph Workflow
21
+ graph = StateGraph(state_schema=SQLExecutionState)
22
+
23
+ # ------------- 1. Query Understanding Agent -------------
24
+ def query_understanding_agent(state: SQLExecutionState) -> SQLExecutionState:
25
+ natural_language_query = state["sql_query"]
26
+ prompt = f"""
27
+ Convert the following natural language query into **structured SQL metadata** based on the database schema.
28
+ If you cannot generate a query that adheres strictly to the schema, return:
29
+ {{ "error": "Invalid query: Tables or columns do not match schema" }}
30
+
31
+ **Query:** "{natural_language_query}"
32
+
33
+ **Database Schema:**
34
+ - **orders** (order_id, customer_id, order_date, total_amount, status)
35
+ - **order_items** (order_item_id, order_id, product_id, quantity, subtotal)
36
+ - **products** (product_id, name, category, price, stock_quantity)
37
+ - **customers** (customer_id, name, email, phone, address, created_at)
38
+ - **payments** (payment_id, order_id, payment_date, amount, payment_method, status)
39
+
40
+ **Rules:**
41
+ - Use only the provided tables.
42
+ - Ensure correct column names.
43
+ - Return output strictly in JSON format.
44
+ - Group by relevant fields when necessary.
45
+
46
+ **Example Output Format:**
47
+ {json.dumps({
48
+ "operation": "SELECT",
49
+ "columns": ["customer_id", "SUM(total_amount) AS total_spent"],
50
+ "table": "orders",
51
+ "conditions": ["order_date BETWEEN '2024-01-01' AND '2024-12-31'"],
52
+ "group_by": ["customer_id"],
53
+ "order_by": ["total_spent DESC"],
54
+ "limit": 5
55
+ }, indent=4)}
56
+
57
+ **DO NOT return explanations. Only return valid JSON.**
58
+ """
59
+ response = openai.ChatCompletion.create(
60
+ model="gpt-4o-mini",
61
+ messages=[{"role": "user", "content": prompt}]
62
+ )
63
+ try:
64
+ metadata = json.loads(response["choices"][0]["message"]["content"])
65
+ return {"structured_metadata": metadata}
66
+ except json.JSONDecodeError:
67
+ return {"structured_metadata": {"error": "Invalid JSON response from OpenAI"}}
68
+
69
+ graph.add_node("Query Understanding", query_understanding_agent)
70
+
71
+ # ------------- 2. Query Validation Agent -------------
72
+ def query_validation_agent(state: SQLExecutionState) -> SQLExecutionState:
73
+ sql_metadata = state.get("structured_metadata", {})
74
+ if "error" in sql_metadata:
75
+ return {"validation_result": {"error": sql_metadata["error"]}}
76
+ query = sql_metadata.get("operation", "")
77
+ restricted_keywords = ["DROP", "DELETE", "TRUNCATE", "ALTER"]
78
+ if any(keyword in query.upper() for keyword in restricted_keywords):
79
+ return {"validation_result": {"error": "Potentially harmful SQL operation detected!"}}
80
+ return {"validation_result": {"valid": True}}
81
+
82
+ graph.add_node("Query Validation", query_validation_agent)
83
+
84
+ # ------------- 3. Query Optimization Agent -------------
85
+ def query_optimization_agent(state: SQLExecutionState) -> SQLExecutionState:
86
+ sql_metadata = state.get("structured_metadata", {})
87
+ prompt = f"""
88
+ Optimize the following SQL query for performance while ensuring that the output includes only the required columns and necessary joins.
89
+ Do not include any extra columns, unnecessary joins, or records that are not required to answer the query.
90
+
91
+ Here is the original SQL metadata:
92
+ {json.dumps(sql_metadata, indent=4)}
93
+
94
+ Output only the final optimized SQL query in plain text without any markdown formatting or explanations.
95
+ """
96
+ response = openai.ChatCompletion.create(
97
+ model="gpt-4o-mini",
98
+ messages=[{"role": "user", "content": prompt}],
99
+ temperature=0
100
+ )
101
+ optimized_query = response["choices"][0]["message"]["content"].strip()
102
+ if optimized_query.startswith("
103
+ sql"):
104
+ optimized_query = optimized_query.replace("
105
+ sql", "").replace("
106
+ ", "").strip()
107
+ return {"optimized_sql": optimized_query}
108
+
109
+ graph.add_node("Query Optimization", query_optimization_agent)
110
+
111
+ # ------------- 4. SQL Execution Agent -------------
112
+ def execution_agent(state: SQLExecutionState) -> SQLExecutionState:
113
+ query = state.get("optimized_sql", "").strip()
114
+ if not query:
115
+ return {"execution_result": {"error": "No SQL query to execute."}}
116
+ try:
117
+ conn = sqlite3.connect("complex_test_db.sqlite", timeout=20)
118
+ cursor = conn.cursor()
119
+ cursor.execute(query)
120
+ result = cursor.fetchall()
121
+ cursor.close()
122
+ conn.close()
123
+ if not result:
124
+ return {"execution_result": {"error": "Query executed successfully but returned no results."}}
125
+ return {"execution_result": result}
126
+ except sqlite3.Error as e:
127
+ return {"execution_result": {"error": str(e)}}
128
+
129
+ graph.add_node("SQL Execution", execution_agent)
130
+
131
+ # Define Execution Flow
132
+ graph.add_edge("START", "Query Understanding")
133
+ graph.add_edge("Query Understanding", "Query Validation")
134
+ graph.add_edge("Query Validation", "Query Optimization")
135
+ graph.add_edge("Query Optimization", "SQL Execution")
136
+ graph.add_edge("SQL Execution", "END")
137
+
138
+ compiled_pipeline = graph.compile()
139
+
140
+ # Wrap your multi-agent query execution into a callable function
141
+ def run_multi_agent_query(natural_language_query):
142
+ result = compiled_pipeline.invoke({"sql_query": natural_language_query})
143
+ return json.dumps(result.get("execution_result", {}), indent=2)
144
+
145
+ # Gradio Interface
146
+ iface = gr.Interface(
147
+ fn=run_multi_agent_query,
148
+ inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your natural language SQL query here..."),
149
+ outputs="text",
150
+ title="Multi-Agent SQL Generator",
151
+ description="Enter a natural language query to generate and execute SQL."
152
+ )
153
+
154
+ if __name__ == "__main__":
155
+ iface.launch()