UtkarshSatav commited on
Commit
1c5c280
Β·
verified Β·
1 Parent(s): c51dde5

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. .gitignore +1 -1
  2. server/app.py +6 -24
  3. server/gradio_ui.py +241 -0
  4. server/requirements.txt +1 -0
  5. uv.lock +0 -0
.gitignore CHANGED
@@ -10,4 +10,4 @@ venv/
10
  *.sqlite
11
  .env
12
  .DS_Store
13
- uv.lock
 
10
  *.sqlite
11
  .env
12
  .DS_Store
13
+ # uv.lock is tracked for reproducible builds
server/app.py CHANGED
@@ -6,10 +6,8 @@ Endpoints:
6
  - POST /step: Execute an action (SQL query)
7
  - GET /state: Get current environment state
8
  - GET /health: Health check
 
9
  - WS /ws: WebSocket endpoint for persistent sessions
10
-
11
- Usage:
12
- uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
13
  """
14
 
15
  try:
@@ -36,28 +34,12 @@ app = create_app(
36
  )
37
 
38
 
39
- from fastapi.responses import HTMLResponse
40
-
 
41
 
42
- @app.get("/", response_class=HTMLResponse)
43
- def root():
44
- """Root endpoint β€” required by HF Spaces to detect the app is running."""
45
- return """
46
- <html><head><title>SQLEnv - SQL Query Writing Environment</title></head>
47
- <body style="font-family:sans-serif;max-width:800px;margin:40px auto;padding:0 20px">
48
- <h1>SQLEnv</h1>
49
- <p>SQL Query Writing Environment for AI Agents</p>
50
- <h3>API Endpoints</h3>
51
- <ul>
52
- <li><b>POST /reset</b> β€” Reset environment, get first question</li>
53
- <li><b>POST /step</b> β€” Submit SQL query, get graded result</li>
54
- <li><b>GET /state</b> β€” Current episode state</li>
55
- <li><b>GET /health</b> β€” Health check</li>
56
- <li><b>GET /docs</b> β€” Interactive API docs</li>
57
- </ul>
58
- <p>3 tasks: basic_select (easy), join_aggregate (medium), advanced_analytics (hard)</p>
59
- </body></html>
60
- """
61
 
62
 
63
  def main(host: str = "0.0.0.0", port: int = 8000):
 
6
  - POST /step: Execute an action (SQL query)
7
  - GET /state: Get current environment state
8
  - GET /health: Health check
9
+ - GET /web: Interactive Gradio playground
10
  - WS /ws: WebSocket endpoint for persistent sessions
 
 
 
11
  """
12
 
13
  try:
 
34
  )
35
 
36
 
37
+ # Mount the custom Gradio UI
38
+ import gradio as gr
39
+ from server.gradio_ui import create_gradio_app
40
 
41
+ gradio_app = create_gradio_app()
42
+ app = gr.mount_gradio_app(app, gradio_app, path="/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  def main(host: str = "0.0.0.0", port: int = 8000):
server/gradio_ui.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Gradio UI for the SQL Query Writing Environment.
3
+
4
+ Provides an interactive playground where users can:
5
+ - Select task difficulty
6
+ - See the database schema
7
+ - Write and submit SQL queries
8
+ - View graded results with reward breakdowns
9
+ - Track progress through questions
10
+ """
11
+
12
+ import gradio as gr
13
+ import os
14
+ import json
15
+ from pathlib import Path
16
+
17
+ # We use the environment directly (not HTTP) for the Gradio UI
18
+ import sys
19
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
20
+
21
+ from server.sql_env_environment import SQLEnvironment, _load_task
22
+ from server.database import Database
23
+ from server.graders import grade_query
24
+ from models import SQLAction
25
+
26
+
27
+ def create_gradio_app() -> gr.Blocks:
28
+ """Create a custom Gradio Blocks app for the SQL environment."""
29
+
30
+ # Shared state
31
+ env_state = {"env": None, "task_name": "basic_select"}
32
+
33
+ def reset_env(task_name):
34
+ """Reset environment with selected task."""
35
+ os.environ["SQL_ENV_TASK"] = task_name
36
+ env = SQLEnvironment()
37
+ obs = env.reset()
38
+ env_state["env"] = env
39
+ env_state["task_name"] = task_name
40
+
41
+ task = _load_task(task_name)
42
+ difficulty = task.get("difficulty", "unknown")
43
+
44
+ status = f"**Task:** {task_name} ({difficulty}) | **Question 1/{obs.total_questions}** | **Attempts left:** {obs.steps_remaining}"
45
+
46
+ return (
47
+ obs.question,
48
+ obs.schema_description,
49
+ "", # clear query input
50
+ "", # clear result
51
+ "", # clear feedback
52
+ "0.0", # reward
53
+ status,
54
+ _build_progress_html(0, obs.total_questions, []),
55
+ )
56
+
57
+ def submit_query(query, question_text):
58
+ """Submit a SQL query and get graded results."""
59
+ env = env_state.get("env")
60
+ if env is None:
61
+ return (
62
+ question_text,
63
+ "Please click 'Start Task' first!",
64
+ "Environment not initialized",
65
+ "0.0",
66
+ "**Error:** Not initialized",
67
+ "",
68
+ )
69
+
70
+ obs = env.step(SQLAction(query=query))
71
+
72
+ feedback = obs.metadata.get("feedback", "")
73
+ reward_display = f"{obs.reward:.2f}"
74
+
75
+ # Color the reward
76
+ if obs.reward >= 0.9:
77
+ reward_html = f'<span style="color:#22c55e;font-size:2em;font-weight:bold">{reward_display}</span>'
78
+ elif obs.reward >= 0.5:
79
+ reward_html = f'<span style="color:#eab308;font-size:2em;font-weight:bold">{reward_display}</span>'
80
+ else:
81
+ reward_html = f'<span style="color:#ef4444;font-size:2em;font-weight:bold">{reward_display}</span>'
82
+
83
+ if obs.done:
84
+ rewards = obs.metadata.get("rewards", [])
85
+ total = obs.metadata.get("total_reward", sum(rewards))
86
+ status = f"**Episode Complete!** | **Total Reward:** {total:.2f} | **Steps:** {len(rewards)}"
87
+ next_question = "All questions answered! Click 'Start Task' to try again."
88
+ progress = _build_progress_html(len(rewards), obs.total_questions, rewards)
89
+ else:
90
+ status = f"**Task:** {env_state['task_name']} | **Question {obs.question_index}/{obs.total_questions}** | **Attempts left:** {obs.steps_remaining}"
91
+ next_question = obs.question
92
+ # Collect rewards from episode so far
93
+ rewards = env._rewards
94
+ progress = _build_progress_html(obs.question_index - 1, obs.total_questions, rewards)
95
+
96
+ result_display = obs.query_result if obs.query_result else "(no output)"
97
+ if obs.error:
98
+ result_display = f"ERROR: {obs.error}\n\n{result_display}"
99
+
100
+ return (
101
+ next_question,
102
+ result_display,
103
+ feedback,
104
+ reward_html,
105
+ status,
106
+ progress,
107
+ )
108
+
109
+ def run_ground_truth(task_name):
110
+ """Run all ground truth queries for demo purposes."""
111
+ os.environ["SQL_ENV_TASK"] = task_name
112
+ env = SQLEnvironment()
113
+ obs = env.reset()
114
+ task = _load_task(task_name)
115
+
116
+ results = []
117
+ for q in task["questions"]:
118
+ obs = env.step(SQLAction(query=q["ground_truth_sql"]))
119
+ results.append(f"**Q{len(results)+1}:** {q['question'][:80]}...\n- SQL: `{q['ground_truth_sql'][:100]}...`\n- Reward: **{obs.reward:.2f}**\n")
120
+
121
+ total = sum(env._rewards)
122
+ results.append(f"\n---\n**Total: {total:.2f} / {len(task['questions']):.1f}**")
123
+ return "\n".join(results)
124
+
125
+ def preview_schema():
126
+ """Show the database schema."""
127
+ db = Database()
128
+ db.initialize()
129
+ schema = db.get_schema_description()
130
+ db.close()
131
+ return schema
132
+
133
+ def _build_progress_html(current_q, total_q, rewards):
134
+ """Build a visual progress bar."""
135
+ bars = []
136
+ for i in range(total_q):
137
+ if i < len(rewards):
138
+ r = rewards[i] if i < len(rewards) else 0
139
+ if r >= 0.9:
140
+ color = "#22c55e"
141
+ elif r >= 0.5:
142
+ color = "#eab308"
143
+ else:
144
+ color = "#ef4444"
145
+ bars.append(f'<div style="display:inline-block;width:18%;height:30px;background:{color};margin:1%;border-radius:4px;text-align:center;line-height:30px;color:white;font-weight:bold">Q{i+1}: {r:.2f}</div>')
146
+ elif i == len(rewards):
147
+ bars.append(f'<div style="display:inline-block;width:18%;height:30px;background:#3b82f6;margin:1%;border-radius:4px;text-align:center;line-height:30px;color:white;font-weight:bold">Q{i+1} β–Ά</div>')
148
+ else:
149
+ bars.append(f'<div style="display:inline-block;width:18%;height:30px;background:#374151;margin:1%;border-radius:4px;text-align:center;line-height:30px;color:#9ca3af">Q{i+1}</div>')
150
+ return "<div style='margin:10px 0'>" + "".join(bars) + "</div>"
151
+
152
+ # Build the Gradio interface
153
+ with gr.Blocks(title="SQLEnv β€” SQL Query Writing Environment") as app:
154
+
155
+ gr.Markdown("""
156
+ # πŸ—ƒοΈ SQLEnv β€” SQL Query Writing Environment
157
+ Write SQL queries to answer natural language questions about an e-commerce database.
158
+ Get graded with partial-credit scoring β€” syntax, columns, rows, and exact match.
159
+ """)
160
+
161
+ with gr.Row():
162
+ with gr.Column(scale=1):
163
+ task_selector = gr.Dropdown(
164
+ choices=["basic_select", "join_aggregate", "advanced_analytics"],
165
+ value="basic_select",
166
+ label="Select Task Difficulty",
167
+ )
168
+ start_btn = gr.Button("πŸš€ Start Task", variant="primary", size="lg")
169
+ status_md = gr.Markdown("Click **Start Task** to begin")
170
+ progress_html = gr.HTML("")
171
+ reward_html = gr.HTML('<span style="color:#666;font-size:2em">β€”</span>')
172
+ gr.Markdown("---")
173
+ feedback_box = gr.Textbox(label="Grader Feedback", lines=3, interactive=False)
174
+
175
+ with gr.Column(scale=2):
176
+ question_box = gr.Textbox(
177
+ label="Question",
178
+ lines=2,
179
+ interactive=False,
180
+ placeholder="Start a task to see the question...",
181
+ )
182
+ query_input = gr.Textbox(
183
+ label="Your SQL Query",
184
+ lines=5,
185
+ placeholder="SELECT name, age FROM customers WHERE age > 30 ORDER BY age DESC",
186
+ elem_classes=["query-input"],
187
+ )
188
+ submit_btn = gr.Button("β–Ά Execute & Grade", variant="primary", size="lg")
189
+ result_box = gr.Textbox(
190
+ label="Query Result",
191
+ lines=10,
192
+ interactive=False,
193
+ elem_classes=["result-output"],
194
+ )
195
+
196
+ with gr.Accordion("πŸ“‹ Database Schema", open=False):
197
+ schema_box = gr.Textbox(
198
+ label="Schema",
199
+ lines=20,
200
+ interactive=False,
201
+ elem_classes=["result-output"],
202
+ )
203
+
204
+ with gr.Accordion("πŸ† Run Ground Truth Demo", open=False):
205
+ gr.Markdown("See how perfect SQL queries score on each task:")
206
+ with gr.Row():
207
+ demo_task = gr.Dropdown(
208
+ choices=["basic_select", "join_aggregate", "advanced_analytics"],
209
+ value="basic_select",
210
+ label="Task",
211
+ )
212
+ demo_btn = gr.Button("Run Demo")
213
+ demo_output = gr.Markdown("")
214
+
215
+ # Event handlers
216
+ start_btn.click(
217
+ fn=reset_env,
218
+ inputs=[task_selector],
219
+ outputs=[question_box, schema_box, query_input, result_box, feedback_box, reward_html, status_md, progress_html],
220
+ )
221
+
222
+ submit_btn.click(
223
+ fn=submit_query,
224
+ inputs=[query_input, question_box],
225
+ outputs=[question_box, result_box, feedback_box, reward_html, status_md, progress_html],
226
+ )
227
+
228
+ # Also submit on Enter (Shift+Enter for newline)
229
+ query_input.submit(
230
+ fn=submit_query,
231
+ inputs=[query_input, question_box],
232
+ outputs=[question_box, result_box, feedback_box, reward_html, status_md, progress_html],
233
+ )
234
+
235
+ demo_btn.click(
236
+ fn=run_ground_truth,
237
+ inputs=[demo_task],
238
+ outputs=[demo_output],
239
+ )
240
+
241
+ return app
server/requirements.txt CHANGED
@@ -2,3 +2,4 @@ openenv-core[core]>=0.2.0
2
  fastapi>=0.115.0
3
  uvicorn>=0.24.0
4
  openai>=1.0.0
 
 
2
  fastapi>=0.115.0
3
  uvicorn>=0.24.0
4
  openai>=1.0.0
5
+ gradio>=4.0.0
uv.lock ADDED
The diff for this file is too large to render. See raw diff