""" Single-page Gradio UI for the Hugging Face Space (same process as the OpenEnv FastAPI API). Playground uses POST /reset and POST /step via loopback HTTP with X-Session-Id. """ from __future__ import annotations import json import os import uuid from pathlib import Path from typing import Any, Optional, Tuple import httpx COLAB_FIRST_TRAINING = ( "https://colab.research.google.com/drive/1H6SLfCBhHzRJtnymLgevjfyytWUximF5" "#scrollTo=j-9MptXvmPk8" ) COLAB_TRAINING_ROOT = ( "https://colab.research.google.com/drive/1H6SLfCBhHzRJtnymLgevjfyytWUximF5" "#scrollTo=x5YuvatGyyu_" ) HF_SPACE = "https://huggingface.co/spaces/md896/sql-debug-env" HF_SAMPLE_REWARDS = ( "https://huggingface.co/spaces/md896/sql-debug-env/tree/main/" "artifacts/runs/20260426-064318-sample-rewards-32eval" ) HF_EVAL_32 = ( "https://huggingface.co/spaces/md896/sql-debug-env/tree/main/" "artifacts/runs/20260426-060502-final-pass-32eval" ) HF_MODEL = "https://huggingface.co/md896/sql-debug-agent-qwen25-05b-grpo-wandb-continue-v2" GITHUB_REPO = "https://github.com/mdayan8/sql-debug-env.git" WANDB_TRAINING_RUN = "https://wandb.ai/mdayanbag-pesitm/sql-debug-grpo-best-budget/workspace?nw=nwusermdayanbag" GCLOUD_TEXT2SQL_BLOG = "https://cloud.google.com/blog/products/databases/techniques-for-improving-text-to-sql" OURBENCH_PAPER = "https://arxiv.org/abs/2601.18119" PREDEFINED_QUERIES: dict[str, list[tuple[str, str]]] = { "easy_syntax_fix": [ ("Broken baseline: typo table", "SELECT * FROM userss;"), ("Simple lookup", "SELECT id, name FROM users ORDER BY id LIMIT 10;"), ("Potential invalid write", "UPDATE users SET name='test';"), ], "medium_logic_fix": [ ("Broken: missing GROUP BY", "SELECT department, COUNT(*) FROM employees;"), ("Revenue by month", "SELECT strftime('%Y-%m', order_date) AS ym, SUM(amount) FROM orders GROUP BY ym ORDER BY ym;"), ("Top entities", "SELECT customer_id, SUM(total) AS spend FROM invoices GROUP BY customer_id ORDER BY spend DESC LIMIT 5;"), ], "hard_multi_bug": [ ("Broken join alias", "SELECT u.name, o.total FROM users u JOIN orders o ON user.id = o.user_id;"), ("Join + aggregate", "SELECT p.category, AVG(p.price) AS avg_price FROM products p GROUP BY p.category ORDER BY avg_price DESC;"), ("Nested query", "SELECT name FROM customers WHERE id IN (SELECT customer_id FROM orders GROUP BY customer_id HAVING COUNT(*) > 2);"), ], "hard_finance_explosion": [ ("Broken finance calc", "SELECT account_id, SUM(amount) / COUNT(*) AS risk FROM txn GROUP BY account;"), ("PnL-style aggregate", "SELECT symbol, SUM(CASE WHEN side='BUY' THEN -notional ELSE notional END) AS pnl FROM trades GROUP BY symbol ORDER BY pnl DESC;"), ("Daily exposure", "SELECT date(trade_ts) AS d, SUM(abs(notional)) AS exposure FROM trades GROUP BY d ORDER BY d;"), ], } GRADIO_CSS = """ :root { --sde-ink: #0f172a; --sde-muted: #64748b; --sde-line: #e2e8f0; --sde-card: #ffffff; --sde-glow: radial-gradient(120% 140% at 0% 0%, rgba(45, 212, 191, 0.22) 0%, rgba(45, 212, 191, 0) 52%), radial-gradient(120% 140% at 100% 0%, rgba(147, 197, 253, 0.22) 0%, rgba(147, 197, 253, 0) 55%), linear-gradient(132deg, #0f172a 0%, #1e293b 45%, #0f766e 100%); /* HF Space embed is often dark; keep prose readable (theme alone used near-black on black). */ --sde-body-text: #e2e8f0; --sde-heading-text: #f8fafc; } .gradio-container { max-width: 1180px !important; margin-left: auto !important; margin-right: auto !important; color: #e2e8f0 !important; } /* Gradio 5 + HF: markdown / prose defaults can match a dark shell and disappear */ .gradio-container .prose, .gradio-container .prose p, .gradio-container .prose li, .gradio-container .prose td, .gradio-container .prose th { color: var(--sde-body-text) !important; } .gradio-container .prose h1, .gradio-container .prose h2, .gradio-container .prose h3, .gradio-container .prose h4, .gradio-container .prose strong { color: var(--sde-heading-text) !important; } .gradio-container .prose a { color: #7dd3fc !important; } .gradio-container .prose code { color: #fef3c7 !important; background: rgba(15, 23, 42, 0.55) !important; } .sde-hero-wrap { background: var(--sde-glow); color: #f8fafc; border-radius: 20px; padding: 1.75rem 1.5rem 1.5rem; margin-bottom: 1.25rem; border: 1px solid rgba(148, 163, 184, 0.24); box-shadow: 0 18px 40px rgba(15, 23, 42, 0.20), inset 0 1px 0 rgba(255, 255, 255, 0.12); } .sde-hero-wrap .sde-hero-title { margin: 0 0 0.35rem 0; font-size: 1.85rem; font-weight: 800; letter-spacing: -0.03em; line-height: 1.2; color: #f8fafc !important; } .sde-hero-wrap .sde-hero-lede { margin: 0 0 0.5rem 0; color: #e2e8f0 !important; font-size: 0.95rem; line-height: 1.55; } .sde-hero-subnav { margin-bottom: 0.75rem; font-size: 0.88rem; } .sde-hero-wrap .sde-hero-subnav a { color: #a5f3fc !important; font-weight: 600; } .sde-pill-row { display: flex; flex-wrap: wrap; gap: 0.5rem; margin-top: 1rem; } .sde-pill { display: inline-block; padding: 0.35rem 0.75rem; border-radius: 999px; font-size: 0.72rem; font-weight: 700; letter-spacing: 0.06em; text-transform: uppercase; background: rgba(15, 23, 42, 0.26); border: 1px solid rgba(226, 232, 240, 0.34); color: #f8fafc; } .sde-section-title { font-size: 1.05rem; font-weight: 700; color: var(--sde-heading-text) !important; margin: 1.5rem 0 0.75rem 0; letter-spacing: -0.02em; } .sde-muted-caption { color: #94a3b8 !important; font-size: 0.9rem; } .sde-link-row a { color: #7dd3fc !important; font-weight: 600; margin-right: 1rem; } .sde-kpi-grid { display: grid; grid-template-columns: repeat(4, minmax(0, 1fr)); gap: 0.75rem; margin: 0.5rem 0 1rem; } .sde-kpi { background: #ffffff; border: 1px solid #dbe3f0; border-radius: 14px; padding: 0.85rem 0.95rem; box-shadow: 0 10px 24px rgba(15, 23, 42, 0.06); } .sde-kpi .v { font-size: 1.25rem; font-weight: 800; letter-spacing: -0.02em; color: #0f172a; } .sde-kpi .k { margin-top: 0.15rem; font-size: 0.73rem; text-transform: uppercase; letter-spacing: 0.06em; color: #64748b; } .sde-callout { border-left: 4px solid #2563eb; background: #eff6ff; color: #1e3a8a; padding: 0.7rem 0.8rem; border-radius: 8px; margin: 0.5rem 0 0.75rem; font-size: 0.86rem; } @media (max-width: 900px) { .sde-kpi-grid { grid-template-columns: repeat(2, minmax(0, 1fr)); } } """ def _api_base() -> str: return os.environ.get( "INTERNAL_API_BASE", f"http://127.0.0.1:{os.environ.get('PORT', '7860')}", ).rstrip("/") def _blog_url() -> str: return (os.environ.get("BLOG_URL") or "").strip() def _http() -> httpx.Client: return httpx.Client(timeout=120.0) def _img_path(static_dir: Path, *names: str) -> Optional[str]: for n in names: p = static_dir / n if p.is_file(): return str(p.resolve()) return None def _preset_options(task_id: str) -> list[str]: return [name for name, _ in PREDEFINED_QUERIES.get(task_id, [])] def _preset_query(task_id: str, preset_name: str) -> str: for name, query in PREDEFINED_QUERIES.get(task_id, []): if name == preset_name: return query return "" def _safe_reward(value: Any) -> float: try: return float(value) except Exception: return 0.0 def build_blocks(static_dir: Path) -> Any: import gradio as gr wf = _img_path(static_dir, "diagram-end-to-end-workflow.png", "environment-workflow.png") chart_leap = _img_path(static_dir, "chart-performance-leap.png", "hero_performance_leap.png") chart_dual = _img_path(static_dir, "chart-comparison-shift.png", "hero_dual_benchmark.png") chart_spider = _img_path(static_dir, "chart-spider-benchmark.png", "hero_spider_sota.png") proof_combo = _img_path(static_dir, "proof-combo.png") proof_dist = _img_path(static_dir, "proof-distribution-shift.png") final_gallery_paths = [ "training_reward_curve_final.png", "training_diagnostics_dual_axis_final.png", "baseline_vs_trained_by_task_final.png", "task_delta_post_minus_base_final.png", "reward_distribution_shift_red_green_final.png", "presentation_combo_final.png", "benchmark_style_summary_final.png", "checkpoint_leaderboard_step_vs_reward_final.png", "cost_vs_performance_final.png", ] final_gallery: list[tuple[str, str]] = [] for filename in final_gallery_paths: path = _img_path(static_dir, filename) if path: title = filename.replace("_final.png", "").replace("_", " ").title() final_gallery.append((path, title)) blog = _blog_url() blog_md = ( f"### Blog\n[Read the write-up]({blog})" if blog else "### Blog\nAdd a **Space secret** named `BLOG_URL` with your post URL (e.g. Medium, personal site, or Hugging Face blog)." ) task_choices = [ "easy_syntax_fix", "medium_logic_fix", "hard_multi_bug", "hard_finance_explosion", ] def reset_fn( task_id: str, session_id: Optional[str] ) -> Tuple[str, str, str, str]: sid = session_id or str(uuid.uuid4()) try: with _http() as client: r = client.post( f"{_api_base()}/reset", json={"task_id": task_id}, headers={"X-Session-Id": sid}, ) r.raise_for_status() data = r.json() except Exception as e: err = {"error": str(e), "hint": "Is the server listening on PORT?"} return json.dumps(err, indent=2), "", sid, f"Session: `{sid}` · **error**" obs = json.dumps(data, indent=2) q = (data.get("observation") or {}).get("original_query") or "" return obs, q, sid, f"Session: `{sid}`" def submit_fn( query: str, session_id: Optional[str] ) -> Tuple[str, str]: if not session_id: return ( json.dumps({"error": "Click “Reset task” first to create a session."}, indent=2), "", ) payload = {"action": {"action_type": "submit_query", "query": query or ""}} try: with _http() as client: r = client.post( f"{_api_base()}/step", json=payload, headers={"X-Session-Id": session_id}, ) r.raise_for_status() data = r.json() except httpx.HTTPStatusError as e: try: detail = e.response.json() except Exception: detail = e.response.text return json.dumps({"error": str(e), "detail": detail}, indent=2), "" except Exception as e: return json.dumps({"error": str(e)}, indent=2), "" out = json.dumps(data, indent=2) reward = data.get("reward") done = data.get("done") return out, f"**reward** `{reward}` · **done** `{done}`" def run_preset_suite( task_id: str, session_id: Optional[str] ) -> Tuple[str, str, str, str]: sid = session_id or str(uuid.uuid4()) presets = PREDEFINED_QUERIES.get(task_id, []) if not presets: return "No presets for selected task.", "{}", sid, f"Session: `{sid}`" rows: list[str] = [] rewards: list[float] = [] done_count = 0 error_count = 0 with _http() as client: for idx, (name, query) in enumerate(presets, start=1): try: client.post( f"{_api_base()}/reset", json={"task_id": task_id}, headers={"X-Session-Id": sid}, ).raise_for_status() step_resp = client.post( f"{_api_base()}/step", json={"action": {"action_type": "submit_query", "query": query}}, headers={"X-Session-Id": sid}, ) step_resp.raise_for_status() data = step_resp.json() reward = _safe_reward(data.get("reward")) done = bool(data.get("done")) info = data.get("info") or {} label = "pass" if reward >= 0.5 else "check" rewards.append(reward) done_count += int(done) note = "review_rejected" if info.get("review_rejected") else "" rows.append( f"| {idx} | {name} | `{reward:.3f}` | `{done}` | {label} {note} |" ) except Exception as e: error_count += 1 rows.append( f"| {idx} | {name} | `0.000` | `False` | error: {str(e)[:120]} |" ) avg_reward = (sum(rewards) / len(rewards)) if rewards else 0.0 max_reward = max(rewards) if rewards else 0.0 min_reward = min(rewards) if rewards else 0.0 suite_md = ( "#### Preset suite report\n" "| # | Preset | Reward | Done | Note |\n" "|---|---|---:|:---:|---|\n" + "\n".join(rows) + "\n\n" + f"**Summary:** avg reward `{avg_reward:.3f}` · min `{min_reward:.3f}` · max `{max_reward:.3f}` · " f"done count `{done_count}` · errors `{error_count}`" ) suite_json = json.dumps( { "task_id": task_id, "session_id": sid, "n_presets": len(presets), "avg_reward": round(avg_reward, 4), "min_reward": round(min_reward, 4), "max_reward": round(max_reward, 4), "done_count": done_count, "error_count": error_count, }, indent=2, ) return suite_md, suite_json, sid, f"Session: `{sid}`" font = gr.themes.GoogleFont("Plus Jakarta Sans") mono = gr.themes.GoogleFont("JetBrains Mono") theme = gr.themes.Soft( primary_hue="indigo", secondary_hue="slate", neutral_hue="slate", font=(font, "ui-sans-serif", "system-ui"), font_mono=(mono, "ui-monospace", "monospace"), ) with gr.Blocks( title="SQL Debug Environment", analytics_enabled=False, theme=theme, css=GRADIO_CSS, ) as demo: gr.HTML( """
OpenEnv-compliant SQL repair · live SQLite rewards · TRL / GRPO training on this same Space. One page: benchmarks, artifacts, architecture, and a live playground.
Benchmark visuals
') gr.Markdown( "| Metric snapshot | Value |\n" "|---|---|\n" "| Spider chart: Industry baseline | **48.2%** |\n" "| Spider chart: Qwen-7B base | **52.4%** |\n" "| Spider chart: RL agent | **78.5%** |\n" "| Performance leap chart | **0.0% -> 25.0%** (base to RL in that run view) |\n" ) with gr.Row(equal_height=True): if chart_leap: gr.Image(value=chart_leap, label="Performance leap (Spider-style)", type="filepath", scale=1) if chart_dual: gr.Image(value=chart_dual, label="Comparison + reward shift", type="filepath", scale=2) if chart_spider: gr.Image(value=chart_spider, label="Spider-style headline chart", type="filepath", scale=1) gr.Markdown( 'Training run charts (repo static)
' 'Training plots from real runs. Regenerate with `presentation_graphs.py`; commit PNGs under `server/static/`.' ) with gr.Row(): if proof_combo: gr.Image(value=proof_combo, label="Presentation combo", type="filepath", scale=1) if proof_dist: gr.Image(value=proof_dist, label="Reward distribution shift", type="filepath", scale=1) if final_gallery: gr.Markdown( 'Hard-testing proof set (presentation_graphs_out_final)
' "All generated graphs from the final evaluation set." ) gr.Gallery( value=final_gallery, label="Final hard-testing charts", preview=True, columns=3, height="auto", object_fit="contain", ) gr.Markdown('Environment architecture
') if wf: gr.Image(value=wf, label="End-to-end workflow", type="filepath", show_label=True) else: gr.Markdown("*Add `server/static/diagram-end-to-end-workflow.png`*") gr.Markdown( 'OpenEnv HTTP API
' f"`GET /health` · `GET /tasks` · `POST /reset` · `POST /step` · `POST /step_with_review` · `GET /state` · `GET /benchmark` · " f"loopback base `{_api_base()}` (override with **INTERNAL_API_BASE**)." ) gr.Markdown('Live playground
') session = gr.State(None) session_md = gr.Markdown("Session: *click “Reset task”*") with gr.Row(): task = gr.Dropdown( choices=task_choices, value="easy_syntax_fix", label="Task", scale=1, ) btn_reset = gr.Button("Reset task", variant="primary", scale=0, min_width=140) btn_submit = gr.Button("Submit query", variant="secondary", scale=0, min_width=140) btn_run_suite = gr.Button("Run preset suite", variant="secondary", scale=0, min_width=160) preset_name = gr.Dropdown( choices=_preset_options("easy_syntax_fix"), value=_preset_options("easy_syntax_fix")[0], label="Predefined test query", ) btn_load_preset = gr.Button("Load predefined query", variant="secondary") sql = gr.Code(label="Candidate SQL", language="sql", lines=12) result_hint = gr.Markdown("") with gr.Row(): obs_json = gr.Code( language="json", label="Observation (/reset)", lines=12, interactive=False, scale=1, ) step_json = gr.Code( language="json", label="Step (/step)", lines=12, interactive=False, scale=1, ) suite_md = gr.Markdown("") suite_json = gr.Code( label="Preset suite summary", language="json", lines=10, interactive=False, ) btn_reset.click( reset_fn, inputs=[task, session], outputs=[obs_json, sql, session, session_md], ) btn_submit.click( submit_fn, inputs=[sql, session], outputs=[step_json, result_hint], ) task.change( lambda t: gr.Dropdown( choices=_preset_options(t), value=_preset_options(t)[0] if _preset_options(t) else None, ), inputs=[task], outputs=[preset_name], ) btn_load_preset.click( lambda t, p: _preset_query(t, p or ""), inputs=[task, preset_name], outputs=[sql], ) btn_run_suite.click( run_preset_suite, inputs=[task, session], outputs=[suite_md, suite_json, session, session_md], ) gr.Markdown('Blog
') gr.Markdown(blog_md) gr.Markdown( "### Why I picked SQL debugging and why this architecture exists\n" "“The goal is not to generate beautiful SQL text. The goal is to produce SQL fixes that survive execution, repeatedly, under changing runtime conditions.”\n\n" "### The cost of “almost right” SQL\n" "Industry time-use reporting commonly puts **roughly a quarter to a third** of analytics and data-engineering work into fixing queries and pipelines—" "**not** shipping net-new insights, **not** launching features, but **debugging SQL that already looked reasonable** in a notebook or PR.\n\n" "### Benchmarks vs production\n" "On Spider-style leaderboards, headline numbers often sit in the **high 80s to low 90s (%)**. In messy enterprise warehouses—drifting schemas, implicit business rules, " "join explosions, permissioned views—teams routinely describe effective success rates closer to the **10–30%** band unless the system closes the loop with " "**execution-grounded feedback** (run the SQL, read the error or result, attribute reward to what changed).\n\n" "SQL debugging is one of the few tasks where *language quality* and *system quality* diverge sharply: a query can be neat, plausible, and still fail in production. " "This project forces the agent to optimize for **behavior under execution**, not only fluency under prompting." ) gr.HTML( """