krishanusinha20 commited on
Commit
cd844b4
·
verified ·
1 Parent(s): f7c8431

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -24
app.py CHANGED
@@ -1,12 +1,11 @@
1
- import gradio as gr
2
  import sqlite3
3
- import json,os
4
- import openai # Ensure you have openai==0.27.0 installed
5
  from langgraph.graph import StateGraph, START, END
6
  from typing import TypedDict, Optional
7
 
8
- # --- Set your OpenAI API key securely ---
9
- # Remove hardcoding and rely on environment variables in your Space.
10
  openai.api_key = os.getenv("OPENAI_API_KEY")
11
 
12
  # Define the state for SQL Execution Workflow
@@ -20,7 +19,11 @@ class SQLExecutionState(TypedDict):
20
  # Initialize the LangGraph Workflow
21
  graph = StateGraph(state_schema=SQLExecutionState)
22
 
23
- # ------------- 1. Query Understanding Agent -------------
 
 
 
 
24
  def query_understanding_agent(state: SQLExecutionState) -> SQLExecutionState:
25
  natural_language_query = state["sql_query"]
26
  prompt = f"""
@@ -68,7 +71,7 @@ def query_understanding_agent(state: SQLExecutionState) -> SQLExecutionState:
68
 
69
  graph.add_node("Query Understanding", query_understanding_agent)
70
 
71
- # ------------- 2. Query Validation Agent -------------
72
  def query_validation_agent(state: SQLExecutionState) -> SQLExecutionState:
73
  sql_metadata = state.get("structured_metadata", {})
74
  if "error" in sql_metadata:
@@ -81,7 +84,7 @@ def query_validation_agent(state: SQLExecutionState) -> SQLExecutionState:
81
 
82
  graph.add_node("Query Validation", query_validation_agent)
83
 
84
- # ------------- 3. Query Optimization Agent -------------
85
  def query_optimization_agent(state: SQLExecutionState) -> SQLExecutionState:
86
  sql_metadata = state.get("structured_metadata", {})
87
  prompt = f"""
@@ -103,10 +106,9 @@ def query_optimization_agent(state: SQLExecutionState) -> SQLExecutionState:
103
  optimized_query = optimized_query.replace("```sql", "").replace("```", "").strip()
104
  return {"optimized_sql": optimized_query}
105
 
106
-
107
  graph.add_node("Query Optimization", query_optimization_agent)
108
 
109
- # ------------- 4. SQL Execution Agent -------------
110
  def execution_agent(state: SQLExecutionState) -> SQLExecutionState:
111
  query = state.get("optimized_sql", "").strip()
112
  if not query:
@@ -126,28 +128,20 @@ def execution_agent(state: SQLExecutionState) -> SQLExecutionState:
126
 
127
  graph.add_node("SQL Execution", execution_agent)
128
 
129
- # Define Execution Flow
130
- graph.add_edge("START", "Query Understanding")
131
  graph.add_edge("Query Understanding", "Query Validation")
132
  graph.add_edge("Query Validation", "Query Optimization")
133
  graph.add_edge("Query Optimization", "SQL Execution")
134
- graph.add_edge("SQL Execution", "END")
135
 
136
  compiled_pipeline = graph.compile()
137
 
138
- # Wrap your multi-agent query execution into a callable function
139
  def run_multi_agent_query(natural_language_query):
140
  result = compiled_pipeline.invoke({"sql_query": natural_language_query})
141
  return json.dumps(result.get("execution_result", {}), indent=2)
142
 
143
- # Gradio Interface
144
- iface = gr.Interface(
145
- fn=run_multi_agent_query,
146
- inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your natural language SQL query here..."),
147
- outputs="text",
148
- title="Multi-Agent SQL Generator",
149
- description="Enter a natural language query to generate and execute SQL."
150
- )
151
-
152
  if __name__ == "__main__":
153
- iface.launch()
 
 
 
1
  import sqlite3
2
+ import json
3
+ import openai
4
  from langgraph.graph import StateGraph, START, END
5
  from typing import TypedDict, Optional
6
 
7
+ # Set your OpenAI API key from environment variable
8
+ import os
9
  openai.api_key = os.getenv("OPENAI_API_KEY")
10
 
11
  # Define the state for SQL Execution Workflow
 
19
  # Initialize the LangGraph Workflow
20
  graph = StateGraph(state_schema=SQLExecutionState)
21
 
22
+ # Add dummy nodes for START and END
23
+ graph.add_node(START, lambda state: state)
24
+ graph.add_node(END, lambda state: state)
25
+
26
+ # ------------------ 1. Query Understanding Agent ------------------
27
  def query_understanding_agent(state: SQLExecutionState) -> SQLExecutionState:
28
  natural_language_query = state["sql_query"]
29
  prompt = f"""
 
71
 
72
  graph.add_node("Query Understanding", query_understanding_agent)
73
 
74
+ # ------------------ 2. Query Validation Agent ------------------
75
  def query_validation_agent(state: SQLExecutionState) -> SQLExecutionState:
76
  sql_metadata = state.get("structured_metadata", {})
77
  if "error" in sql_metadata:
 
84
 
85
  graph.add_node("Query Validation", query_validation_agent)
86
 
87
+ # ------------------ 3. Query Optimization Agent ------------------
88
  def query_optimization_agent(state: SQLExecutionState) -> SQLExecutionState:
89
  sql_metadata = state.get("structured_metadata", {})
90
  prompt = f"""
 
106
  optimized_query = optimized_query.replace("```sql", "").replace("```", "").strip()
107
  return {"optimized_sql": optimized_query}
108
 
 
109
  graph.add_node("Query Optimization", query_optimization_agent)
110
 
111
+ # ------------------ 4. SQL Execution Agent ------------------
112
  def execution_agent(state: SQLExecutionState) -> SQLExecutionState:
113
  query = state.get("optimized_sql", "").strip()
114
  if not query:
 
128
 
129
  graph.add_node("SQL Execution", execution_agent)
130
 
131
+ # ------------------ Define Execution Flow ------------------
132
+ graph.add_edge(START, "Query Understanding")
133
  graph.add_edge("Query Understanding", "Query Validation")
134
  graph.add_edge("Query Validation", "Query Optimization")
135
  graph.add_edge("Query Optimization", "SQL Execution")
136
+ graph.add_edge("SQL Execution", END)
137
 
138
  compiled_pipeline = graph.compile()
139
 
140
+ # ------------------ Example Execution ------------------
141
  def run_multi_agent_query(natural_language_query):
142
  result = compiled_pipeline.invoke({"sql_query": natural_language_query})
143
  return json.dumps(result.get("execution_result", {}), indent=2)
144
 
 
 
 
 
 
 
 
 
 
145
  if __name__ == "__main__":
146
+ # For testing outside of Hugging Face Spaces
147
+ print(run_multi_agent_query("Find the email_id of the top 5 customers who spent the most in 2024."))