""" Custom Gradio UI for the SQL Query Writing Environment. Provides an interactive playground where users can: - Select task difficulty - See the database schema - Write and submit SQL queries - View graded results with reward breakdowns - Track progress through questions """ import gradio as gr import os import json from pathlib import Path # We use the environment directly (not HTTP) for the Gradio UI import sys sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from server.sql_env_environment import SQLEnvironment, _load_task from server.database import Database from server.graders import grade_query from models import SQLAction def create_gradio_app() -> gr.Blocks: """Create a custom Gradio Blocks app for the SQL environment.""" # Shared state env_state = {"env": None, "task_name": "basic_select"} def reset_env(task_name): """Reset environment with selected task.""" os.environ["SQL_ENV_TASK"] = task_name env = SQLEnvironment() obs = env.reset() env_state["env"] = env env_state["task_name"] = task_name task = _load_task(task_name) difficulty = task.get("difficulty", "unknown") status = f"**Task:** {task_name} ({difficulty}) | **Question 1/{obs.total_questions}** | **Attempts left:** {obs.steps_remaining}" return ( obs.question, obs.schema_description, "", # clear query input "", # clear result "", # clear feedback "0.0", # reward status, _build_progress_html(0, obs.total_questions, []), ) def submit_query(query, question_text): """Submit a SQL query and get graded results.""" env = env_state.get("env") if env is None: return ( question_text, "Please click 'Start Task' first!", "Environment not initialized", "0.0", "**Error:** Not initialized", "", ) obs = env.step(SQLAction(query=query)) feedback = obs.metadata.get("feedback", "") reward_display = round(obs.reward) # show 0 or 1 # Color the reward if reward_display == 1: reward_html = f'{reward_display}' else: reward_html = f'{reward_display}' if obs.done: rewards = obs.metadata.get("rewards", []) total = obs.metadata.get("total_reward", sum(rewards)) status = f"**Episode Complete!** | **Total Reward:** {round(total)} | **Steps:** {len(rewards)}" next_question = "All questions answered! Click 'Start Task' to try again." progress = _build_progress_html(len(rewards), obs.total_questions, rewards) else: status = f"**Task:** {env_state['task_name']} | **Question {obs.question_index}/{obs.total_questions}** | **Attempts left:** {obs.steps_remaining}" next_question = obs.question # Collect rewards from episode so far rewards = env._rewards progress = _build_progress_html(obs.question_index - 1, obs.total_questions, rewards) result_display = obs.query_result if obs.query_result else "(no output)" if obs.error: result_display = f"ERROR: {obs.error}\n\n{result_display}" return ( next_question, result_display, feedback, reward_html, status, progress, ) def run_ground_truth(task_name): """Run all ground truth queries for demo purposes.""" os.environ["SQL_ENV_TASK"] = task_name env = SQLEnvironment() obs = env.reset() task = _load_task(task_name) results = [] for q in task["questions"]: obs = env.step(SQLAction(query=q["ground_truth_sql"])) results.append(f"**Q{len(results)+1}:** {q['question'][:80]}...\n- SQL: `{q['ground_truth_sql'][:100]}...`\n- Reward: **{round(obs.reward)}**\n") total = sum(env._rewards) results.append(f"\n---\n**Total: {round(total)} / {len(task['questions'])}**") return "\n".join(results) def preview_schema(): """Show the database schema.""" db = Database() db.initialize() schema = db.get_schema_description() db.close() return schema def _build_progress_html(current_q, total_q, rewards): """Build a visual progress bar.""" bars = [] for i in range(total_q): if i < len(rewards): r = rewards[i] if i < len(rewards) else 0 if r >= 0.9: color = "#22c55e" elif r >= 0.5: color = "#eab308" else: color = "#ef4444" bars.append(f'
Q{i+1}: {round(r)}
') elif i == len(rewards): bars.append(f'
Q{i+1} ▶
') else: bars.append(f'
Q{i+1}
') return "
" + "".join(bars) + "
" # Build the Gradio interface with gr.Blocks(title="SQLEnv — SQL Query Writing Environment") as app: gr.Markdown(""" # 🗃️ SQLEnv — SQL Query Writing Environment Write SQL queries to answer natural language questions about an e-commerce database. Get graded with partial-credit scoring — syntax, columns, rows, and exact match. """) with gr.Row(): with gr.Column(scale=1): task_selector = gr.Dropdown( choices=["basic_select", "join_aggregate", "advanced_analytics"], value="basic_select", label="Select Task Difficulty", ) start_btn = gr.Button("🚀 Start Task", variant="primary", size="lg") status_md = gr.Markdown("Click **Start Task** to begin") progress_html = gr.HTML("") reward_html = gr.HTML('') gr.Markdown("---") feedback_box = gr.Textbox(label="Grader Feedback", lines=3, interactive=False) with gr.Column(scale=2): question_box = gr.Textbox( label="Question", lines=2, interactive=False, placeholder="Start a task to see the question...", ) query_input = gr.Textbox( label="Your SQL Query", lines=5, placeholder="SELECT name, age FROM customers WHERE age > 30 ORDER BY age DESC", elem_classes=["query-input"], ) submit_btn = gr.Button("▶ Execute & Grade", variant="primary", size="lg") result_box = gr.Textbox( label="Query Result", lines=10, interactive=False, elem_classes=["result-output"], ) with gr.Accordion("📋 Database Schema", open=False): schema_box = gr.Textbox( label="Schema", lines=20, interactive=False, elem_classes=["result-output"], ) with gr.Accordion("🏆 Run Ground Truth Demo", open=False): gr.Markdown("See how perfect SQL queries score on each task:") with gr.Row(): demo_task = gr.Dropdown( choices=["basic_select", "join_aggregate", "advanced_analytics"], value="basic_select", label="Task", ) demo_btn = gr.Button("Run Demo") demo_output = gr.Markdown("") # Event handlers start_btn.click( fn=reset_env, inputs=[task_selector], outputs=[question_box, schema_box, query_input, result_box, feedback_box, reward_html, status_md, progress_html], ) submit_btn.click( fn=submit_query, inputs=[query_input, question_box], outputs=[question_box, result_box, feedback_box, reward_html, status_md, progress_html], ) # Also submit on Enter (Shift+Enter for newline) query_input.submit( fn=submit_query, inputs=[query_input, question_box], outputs=[question_box, result_box, feedback_box, reward_html, status_md, progress_html], ) demo_btn.click( fn=run_ground_truth, inputs=[demo_task], outputs=[demo_output], ) return app