tjhalanigrid commited on
Commit
d20b967
·
1 Parent(s): 0aabca3

Restore proper Python formatting for app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -105
app.py CHANGED
@@ -35,32 +35,42 @@ SAMPLE_QUESTIONS = [q[0] for q in SAMPLES]
35
  # SQL EXPLAINER
36
  # =========================
37
  def explain_sql(sql):
 
38
  explanation = "This SQL query retrieves information from the database."
 
39
  sql_lower = sql.lower()
40
 
41
  if "join" in sql_lower:
42
  explanation += "\n• It combines data from multiple tables using JOIN."
 
43
  if "where" in sql_lower:
44
  explanation += "\n• It filters rows using a WHERE condition."
 
45
  if "group by" in sql_lower:
46
  explanation += "\n• It groups results using GROUP BY."
 
47
  if "order by" in sql_lower:
48
  explanation += "\n• It sorts the results using ORDER BY."
 
49
  if "limit" in sql_lower:
50
  explanation += "\n• It limits the number of returned rows."
51
 
52
  return explanation
53
 
 
54
  # =========================
55
  # CORE FUNCTIONS
56
  # =========================
57
  def run_query(question, db_id):
 
58
  if not question.strip():
59
- return "", None, "⚠️ Please enter a question."
60
 
61
  start_time = time.time()
 
62
  result = engine.ask(question, db_id)
63
  final_sql = result["sql"]
 
64
  end_time = time.time()
65
  latency = round(end_time - start_time, 3)
66
 
@@ -71,22 +81,29 @@ def run_query(question, db_id):
71
  # ZERO ROW HANDLING
72
  if not result["rows"]:
73
  df = pd.DataFrame(columns=result.get("columns", []))
74
- explanation = f"""✅ Query executed successfully
 
 
75
 
76
  Rows returned: 0
 
77
  Note: The query ran perfectly but no matching records exist.
78
 
79
  Execution Time: {latency} sec
80
 
81
- {explain_sql(final_sql)}"""
 
 
82
  return final_sql, df, explanation
83
 
84
  df = pd.DataFrame(result["rows"], columns=result["columns"])
85
  actual_rows = len(result["rows"])
86
 
87
- explanation = f"""✅ Query executed successfully
 
88
 
89
  Rows returned: {actual_rows}
 
90
  Execution Time: {latency} sec
91
 
92
  {explain_sql(final_sql)}
@@ -98,6 +115,7 @@ This shows the model understood:
98
  """
99
 
100
  limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE)
 
101
  if limit_match:
102
  requested_limit = int(limit_match.group(1))
103
  if actual_rows < requested_limit:
@@ -105,60 +123,32 @@ This shows the model understood:
105
 
106
  return final_sql, df, explanation
107
 
 
108
  def load_sample(selected_question):
 
109
  if not selected_question:
110
  return gr.update(), gr.update()
 
111
  db = next((db for q, db in SAMPLES if q == selected_question), "chinook_1")
 
112
  return gr.update(value=selected_question), gr.update(value=db)
113
 
 
114
  def clear_inputs():
115
  return gr.update(value=None), gr.update(value=""), gr.update(value="chinook_1"), "", None, ""
116
 
117
 
118
  # =========================
119
- # BEAUTIFUL UI LAYOUT
120
  # =========================
 
121
 
122
- # 1. Custom Theme definition
123
- custom_theme = gr.themes.Soft(
124
- primary_hue="indigo",
125
- secondary_hue="blue",
126
- neutral_hue="slate",
127
- font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"]
128
- ).set(
129
- body_background_fill="*neutral_50",
130
- block_background_fill="white",
131
- block_border_width="1px",
132
- block_border_color="*neutral_200",
133
- block_shadow="*shadow_drop_sm"
134
- )
135
-
136
- # 2. Custom CSS for animations, highlights, and layout limits
137
- custom_css = """
138
- .gradio-container { max-width: 1300px !important; margin: auto; }
139
- .header-text { text-align: center; margin-bottom: 2rem; }
140
- .header-text h1 { color: #1e293b; font-weight: 800; font-size: 2.5rem; margin-bottom: 0.5rem; }
141
- .header-text p { color: #64748b; font-size: 1.1rem; margin-bottom: 1.5rem; }
142
- .report-btn { display: inline-block; padding: 10px 24px; background-color: #f8fafc; color: #4338ca; border: 1px solid #cbd5e1; border-radius: 8px; font-weight: bold; text-decoration: none; transition: all 0.2s; box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.05); }
143
- .report-btn:hover { background-color: #f1f5f9; border-color: #94a3b8; transform: translateY(-1px); }
144
- .highlight-notice { background-color: #e0e7ff; color: #3730a3; padding: 12px 16px; border-radius: 8px; font-weight: 500; font-size: 0.95rem; border-left: 4px solid #4f46e5; margin-bottom: 1rem; box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.05);}
145
- .execute-btn { background: linear-gradient(90deg, #4f46e5, #3b82f6) !important; color: white !important; border: none !important; font-weight: bold !important; transition: all 0.2s ease-in-out !important; }
146
- .execute-btn:hover { background: linear-gradient(90deg, #4338ca, #2563eb) !important; transform: scale(1.02); box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1); }
147
- """
148
-
149
- # 🚀 FIX 1: Removed theme and css from Blocks()
150
- with gr.Blocks(title="AI Data Analyst") as demo:
151
-
152
- # --- HEADER WITH REPORT LINK ---
153
- gr.HTML("""
154
- <div class="header-text">
155
- <h1>🧠 AI Text-to-SQL Agent</h1>
156
- <p>Powered by RLHF & Execution Rewards. Convert natural language to strictly validated SQLite queries.</p>
157
- <a href="https://tjhalanigrid.github.io/Text2SQL_Project/" target="_blank" class="report-btn">
158
- 📄 View the Full Project Report
159
- </a>
160
- </div>
161
- """)
162
 
163
  DBS = sorted([
164
  "flight_1", "student_assessment", "store_1", "bike_1", "book_2", "chinook_1",
@@ -171,64 +161,53 @@ with gr.Blocks(title="AI Data Analyst") as demo:
171
  ])
172
 
173
  with gr.Row():
174
-
175
- # --- LEFT COLUMN (Inputs) ---
176
- with gr.Column(scale=4):
177
- with gr.Group(): # Wraps inputs in a nice white card
178
- gr.Markdown("### ⚙️ Query Configuration")
179
-
180
- # Highlight Notice for Auto-Select
181
- gr.HTML("""
182
- <div class="highlight-notice">
183
- ✨ <strong>Pro Tip:</strong> Select a sample question below and the correct database will be <strong>automatically selected</strong> for you!
184
- </div>
185
- """)
186
-
187
- sample_dropdown = gr.Dropdown(
188
- choices=SAMPLE_QUESTIONS,
189
- label="💡 Quick Select a Sample Question",
190
- show_label=True
191
- )
192
-
193
- db_id = gr.Dropdown(
194
- choices=DBS,
195
- value="chinook_1",
196
- label="🗄️ Target Database",
197
- interactive=True
198
- )
199
-
200
- question = gr.Textbox(
201
- label="💬 Ask a Question",
202
- placeholder="Type your own question or select a sample above...",
203
- lines=4
204
- )
205
-
206
- with gr.Row():
207
- clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="lg")
208
- run_btn = gr.Button("🚀 Generate & Execute SQL", elem_classes="execute-btn", size="lg")
209
-
210
- # --- RIGHT COLUMN (Outputs) ---
211
- with gr.Column(scale=7):
212
- with gr.Group(): # Wraps outputs in a nice white card
213
- gr.Markdown("### 📊 Execution Results")
214
-
215
- final_sql = gr.Code(language="sql", label="Generated SQL Script", interactive=False)
216
-
217
- # Using a Tabbed layout makes the right side cleaner
218
- with gr.Tabs():
219
- with gr.TabItem("📋 Data Table"):
220
- # 🚀 FIX 2: Removed the unsupported 'height' parameter
221
- result_table = gr.Dataframe(
222
- label="Query Result Table",
223
- interactive=False,
224
- wrap=True
225
- )
226
- with gr.TabItem("🔍 Execution Analysis"):
227
- explanation = gr.Textbox(label="Agent Details", lines=10, interactive=False)
228
-
229
- # =========================
230
- # EVENT LISTENERS
231
- # =========================
232
  sample_dropdown.change(
233
  fn=load_sample,
234
  inputs=[sample_dropdown],
@@ -248,5 +227,4 @@ with gr.Blocks(title="AI Data Analyst") as demo:
248
  )
249
 
250
  if __name__ == "__main__":
251
- # 🚀 FIX 3: Moved theme and css into launch()
252
- demo.launch(theme=custom_theme, css=custom_css)
 
35
  # SQL EXPLAINER
36
  # =========================
37
  def explain_sql(sql):
38
+
39
  explanation = "This SQL query retrieves information from the database."
40
+
41
  sql_lower = sql.lower()
42
 
43
  if "join" in sql_lower:
44
  explanation += "\n• It combines data from multiple tables using JOIN."
45
+
46
  if "where" in sql_lower:
47
  explanation += "\n• It filters rows using a WHERE condition."
48
+
49
  if "group by" in sql_lower:
50
  explanation += "\n• It groups results using GROUP BY."
51
+
52
  if "order by" in sql_lower:
53
  explanation += "\n• It sorts the results using ORDER BY."
54
+
55
  if "limit" in sql_lower:
56
  explanation += "\n• It limits the number of returned rows."
57
 
58
  return explanation
59
 
60
+
61
  # =========================
62
  # CORE FUNCTIONS
63
  # =========================
64
  def run_query(question, db_id):
65
+
66
  if not question.strip():
67
+ return "", None, " Please enter a question."
68
 
69
  start_time = time.time()
70
+
71
  result = engine.ask(question, db_id)
72
  final_sql = result["sql"]
73
+
74
  end_time = time.time()
75
  latency = round(end_time - start_time, 3)
76
 
 
81
  # ZERO ROW HANDLING
82
  if not result["rows"]:
83
  df = pd.DataFrame(columns=result.get("columns", []))
84
+
85
+ explanation = f"""
86
+ ✅ Query executed successfully
87
 
88
  Rows returned: 0
89
+
90
  Note: The query ran perfectly but no matching records exist.
91
 
92
  Execution Time: {latency} sec
93
 
94
+ {explain_sql(final_sql)}
95
+ """
96
+
97
  return final_sql, df, explanation
98
 
99
  df = pd.DataFrame(result["rows"], columns=result["columns"])
100
  actual_rows = len(result["rows"])
101
 
102
+ explanation = f"""
103
+ ✅ Query executed successfully
104
 
105
  Rows returned: {actual_rows}
106
+
107
  Execution Time: {latency} sec
108
 
109
  {explain_sql(final_sql)}
 
115
  """
116
 
117
  limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE)
118
+
119
  if limit_match:
120
  requested_limit = int(limit_match.group(1))
121
  if actual_rows < requested_limit:
 
123
 
124
  return final_sql, df, explanation
125
 
126
+
127
  def load_sample(selected_question):
128
+
129
  if not selected_question:
130
  return gr.update(), gr.update()
131
+
132
  db = next((db for q, db in SAMPLES if q == selected_question), "chinook_1")
133
+
134
  return gr.update(value=selected_question), gr.update(value=db)
135
 
136
+
137
  def clear_inputs():
138
  return gr.update(value=None), gr.update(value=""), gr.update(value="chinook_1"), "", None, ""
139
 
140
 
141
  # =========================
142
+ # UI LAYOUT
143
  # =========================
144
+ with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL RLHF") as demo:
145
 
146
+ gr.Markdown(
147
+ """
148
+ # Text-to-SQL using RLHF + Execution Reward
149
+ Convert Natural Language to SQL, strictly validated and safely executed on local SQLite databases.
150
+ """
151
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  DBS = sorted([
154
  "flight_1", "student_assessment", "store_1", "bike_1", "book_2", "chinook_1",
 
161
  ])
162
 
163
  with gr.Row():
164
+
165
+ with gr.Column(scale=1):
166
+
167
+ gr.Markdown("### 1. Configuration & Input")
168
+
169
+ sample_dropdown = gr.Dropdown(
170
+ choices=SAMPLE_QUESTIONS,
171
+ label=" Quick Select a Sample Question",
172
+ info="Picking a question will automatically select the right database!"
173
+ )
174
+
175
+ gr.Markdown("---")
176
+
177
+ db_id = gr.Dropdown(
178
+ choices=DBS,
179
+ value="chinook_1",
180
+ label="Select Database",
181
+ interactive=True
182
+ )
183
+
184
+ question = gr.Textbox(
185
+ label="Ask a Question",
186
+ placeholder="Type your own question or select a sample above...",
187
+ lines=3
188
+ )
189
+
190
+ with gr.Row():
191
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
192
+ run_btn = gr.Button(" Generate & Run SQL", variant="primary")
193
+
194
+ with gr.Column(scale=2):
195
+
196
+ gr.Markdown("### 2. Execution Results")
197
+
198
+ final_sql = gr.Code(language="sql", label="Final Executed SQL")
199
+
200
+ result_table = gr.Dataframe(
201
+ label="Query Result Table",
202
+ interactive=False,
203
+ wrap=True
204
+ )
205
+
206
+ explanation = gr.Textbox(
207
+ label="AI Explanation + Execution Details",
208
+ lines=8
209
+ )
210
+
 
 
 
 
 
 
 
 
 
 
 
211
  sample_dropdown.change(
212
  fn=load_sample,
213
  inputs=[sample_dropdown],
 
227
  )
228
 
229
  if __name__ == "__main__":
230
+ demo.launch()