""" Custom Gradio UI for the SQL Query Writing Environment. Provides an interactive playground where users can: - Select task difficulty - See the database schema - Write and submit SQL queries - View graded results with reward breakdowns - Track progress through questions """ import gradio as gr import os import json from pathlib import Path # We use the environment directly (not HTTP) for the Gradio UI import sys sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from server.sql_env_environment import SQLEnvironment, _load_task from server.database import Database from server.graders import grade_query from models import SQLAction def create_gradio_app() -> gr.Blocks: """Create a custom Gradio Blocks app for the SQL environment.""" # Shared state env_state = {"env": None, "task_name": "basic_select"} def reset_env(task_name): """Reset environment with selected task.""" os.environ["SQL_ENV_TASK"] = task_name env = SQLEnvironment() obs = env.reset() env_state["env"] = env env_state["task_name"] = task_name task = _load_task(task_name) difficulty = task.get("difficulty", "unknown") status = f"**Task:** {task_name} ({difficulty}) | **Question 1/{obs.total_questions}** | **Attempts left:** {obs.steps_remaining}" return ( obs.question, obs.schema_description, "", # clear query input "", # clear result "", # clear feedback "0.0", # reward status, _build_progress_html(0, obs.total_questions, []), ) def submit_query(query, question_text): """Submit a SQL query and get graded results.""" env = env_state.get("env") if env is None: return ( question_text, "Please click 'Start Task' first!", "Environment not initialized", "0.0", "**Error:** Not initialized", "", ) obs = env.step(SQLAction(query=query)) feedback = obs.metadata.get("feedback", "") reward_display = round(obs.reward) # show 0 or 1 # Color the reward if reward_display == 1: reward_html = f'{reward_display}' else: reward_html = f'{reward_display}' if obs.done: rewards = obs.metadata.get("rewards", []) total = obs.metadata.get("total_reward", sum(rewards)) status = f"**Episode Complete!** | **Total Reward:** {round(total)} | **Steps:** {len(rewards)}" next_question = "All questions answered! Click 'Start Task' to try again." progress = _build_progress_html(len(rewards), obs.total_questions, rewards) else: status = f"**Task:** {env_state['task_name']} | **Question {obs.question_index}/{obs.total_questions}** | **Attempts left:** {obs.steps_remaining}" next_question = obs.question # Collect rewards from episode so far rewards = env._rewards progress = _build_progress_html(obs.question_index - 1, obs.total_questions, rewards) result_display = obs.query_result if obs.query_result else "(no output)" if obs.error: result_display = f"ERROR: {obs.error}\n\n{result_display}" return ( next_question, result_display, feedback, reward_html, status, progress, ) def run_ground_truth(task_name): """Run all ground truth queries for demo purposes.""" os.environ["SQL_ENV_TASK"] = task_name env = SQLEnvironment() obs = env.reset() task = _load_task(task_name) results = [] for q in task["questions"]: obs = env.step(SQLAction(query=q["ground_truth_sql"])) results.append(f"**Q{len(results)+1}:** {q['question'][:80]}...\n- SQL: `{q['ground_truth_sql'][:100]}...`\n- Reward: **{round(obs.reward)}**\n") total = sum(env._rewards) results.append(f"\n---\n**Total: {round(total)} / {len(task['questions'])}**") return "\n".join(results) def preview_schema(): """Show the database schema.""" db = Database() db.initialize() schema = db.get_schema_description() db.close() return schema def _build_progress_html(current_q, total_q, rewards): """Build a visual progress bar.""" bars = [] for i in range(total_q): if i < len(rewards): r = rewards[i] if i < len(rewards) else 0 if r >= 0.9: color = "#22c55e" elif r >= 0.5: color = "#eab308" else: color = "#ef4444" bars.append(f'