Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| from typing import Any | |
| import gradio as gr | |
| from fastapi import Body, FastAPI, HTTPException | |
| from env.inprocess_backend import BACKEND | |
| SESSION = BACKEND | |
| def health() -> dict[str, str]: | |
| return {"status": "ok", "env": "DataQualityEnv", "mode": "space-ui"} | |
| def session_status(obs: dict[str, Any] | None) -> str: | |
| if not obs: | |
| return "No active episode. Choose a task and click Reset." | |
| return ( | |
| f"Task {obs.get('task_id')} | phase={obs.get('phase')} | step={obs.get('step')}/{obs.get('max_steps')} | " | |
| f"credits={obs.get('query_credits_remaining')}" | |
| ) | |
| def initial_chat() -> list[dict[str, str]]: | |
| return [] | |
| def format_observation(obs: dict[str, Any] | None) -> str: | |
| return json.dumps(obs or {}, indent=2, default=str) | |
| def format_reward(reward: dict[str, Any] | None) -> str: | |
| return json.dumps(reward or {}, indent=2, default=str) | |
| def task_hint(task_id: int) -> str: | |
| if task_id == 1: | |
| return "Try null-like value checks and duplicate-row grouping on the customers table." | |
| if task_id == 2: | |
| return "Try type parsing, negative values, and date-format checks on orders." | |
| if task_id == 3: | |
| return "Try baseline/current comparisons, new categories, and user population drift." | |
| return "Try orphaned foreign keys, temporal checks, and aggregate consistency." | |
| def heuristic_queries(task_id: int) -> list[str]: | |
| if task_id == 1: | |
| return [ | |
| "SELECT COUNT(*) AS total_rows FROM customers", | |
| "SELECT SUM(CASE WHEN email IS NULL OR lower(trim(cast(email as varchar))) IN ('null','n/a','unknown','-','','0','none') THEN 1 ELSE 0 END) AS email_null_total, SUM(CASE WHEN customer_id IS NULL THEN 1 ELSE 0 END) AS cid_nulls FROM customers", | |
| "SELECT COALESCE(SUM(c-1),0) AS exact_duplicate_rows FROM (SELECT customer_id,email,name,signup_date,country, COUNT(*) c FROM customers GROUP BY 1,2,3,4,5 HAVING COUNT(*)>1) t", | |
| ] | |
| if task_id == 2: | |
| return [ | |
| "SELECT SUM(CASE WHEN quantity < 0 THEN 1 ELSE 0 END) AS neg_qty, SUM(CASE WHEN try_cast(replace(amount,'$','') AS DOUBLE) IS NULL THEN 1 ELSE 0 END) AS bad_amt FROM orders", | |
| "SELECT amount, order_date FROM orders LIMIT 10", | |
| ] | |
| if task_id == 3: | |
| return [ | |
| "SELECT (SELECT AVG(amount) FROM transactions_baseline) AS baseline_mean, (SELECT AVG(amount) FROM transactions_current) AS current_mean", | |
| "SELECT DISTINCT c.category FROM transactions_current c LEFT JOIN (SELECT DISTINCT category FROM transactions_baseline) b ON c.category=b.category WHERE b.category IS NULL ORDER BY c.category", | |
| ] | |
| return [ | |
| "SELECT COUNT(*) AS orphan_count FROM orders o LEFT JOIN customers c ON o.customer_id=c.customer_id WHERE c.customer_id IS NULL", | |
| "SELECT COUNT(*) AS temporal_count FROM orders WHERE try_cast(ship_date AS TIMESTAMP) < try_cast(order_date AS TIMESTAMP)", | |
| "SELECT COUNT(*) AS aggregate_count FROM (SELECT o.order_id, o.order_total, SUM(li.subtotal) AS s FROM orders o JOIN line_items li ON o.order_id=li.order_id GROUP BY o.order_id, o.order_total HAVING abs(o.order_total - SUM(li.subtotal)) > 1e-6) x", | |
| ] | |
| def current_tables(obs: dict[str, Any] | None) -> set[str]: | |
| tables = (obs or {}).get("tables") or {} | |
| return {str(name).lower() for name in tables.keys()} | |
| def referenced_tables(sql_text: str) -> set[str]: | |
| sql = normalize_command(sql_text) | |
| matches = re.finditer(r"\b(?:from|join)\s+([a-zA-Z_][\w\.]*)", sql, flags=re.IGNORECASE) | |
| refs: set[str] = set() | |
| for match in matches: | |
| identifier = match.group(1).split(".")[-1].lower() | |
| if identifier: | |
| refs.add(identifier) | |
| return refs | |
| def validate_query_tables(sql_text: str, obs: dict[str, Any] | None) -> str | None: | |
| allowed = current_tables(obs) | |
| if not allowed: | |
| return None | |
| refs = referenced_tables(sql_text) | |
| if not refs: | |
| return None | |
| unknown = sorted(refs - allowed) | |
| if unknown: | |
| available = ", ".join(sorted(allowed)) | |
| return f"This task only exposes: {available}. Please query one of those tables instead of: {', '.join(unknown)}." | |
| return None | |
| def normalize_command(text: str) -> str: | |
| return (text or "").strip() | |
| def parse_json_fragment(text: str) -> dict[str, Any] | None: | |
| raw = normalize_command(text) | |
| raw = raw.replace("```json", "").replace("```", "").strip() | |
| try: | |
| return json.loads(raw) | |
| except Exception: | |
| match = re.search(r"\{.*\}", raw, re.DOTALL) | |
| if match: | |
| try: | |
| return json.loads(match.group()) | |
| except Exception: | |
| return None | |
| return None | |
| def fallback_report_from_obs(obs: dict[str, Any] | None) -> dict[str, Any]: | |
| task_id = int((obs or {}).get("task_id", 1) or 1) | |
| base = { | |
| "null_issues": {}, | |
| "duplicate_row_count": {"value": 0, "confidence": 0.5}, | |
| "schema_violations": [], | |
| "drifted_columns": [], | |
| "drift_details": {}, | |
| "relational_issues": [], | |
| "recommended_fixes": [ | |
| "Auto-submitted fallback report to avoid max_steps termination", | |
| "Run additional targeted probes in earlier steps for higher confidence", | |
| ], | |
| } | |
| if task_id == 1: | |
| base["schema_violations"] = [ | |
| { | |
| "column": "customers", | |
| "issue_type": "partial_audit", | |
| "example": "auto_submit_guard", | |
| "count": 1, | |
| "confidence": 0.4, | |
| } | |
| ] | |
| return base | |
| def reset_ui(task_id: int, seed: int): | |
| obs = SESSION.reset({"task_id": task_id, "seed": seed}) | |
| chat = initial_chat() | |
| chat.append({"role": "assistant", "content": f"Reset complete for task {task_id}. {task_hint(task_id)}"}) | |
| return chat, format_observation(obs), session_status(obs), format_reward({"value": 0.0, "done": False}), obs | |
| def run_query(sql_text: str, current_obs: dict[str, Any] | None, chat: list[dict[str, str]]): | |
| if current_obs: | |
| step = int(current_obs.get("step", 0) or 0) | |
| max_steps = int(current_obs.get("max_steps", 12) or 12) | |
| if step >= max_steps - 1: | |
| chat = chat + [ | |
| { | |
| "role": "assistant", | |
| "content": "Step budget is almost exhausted. Submit your report now (`submit: {...}`) to avoid `max_steps` termination.", | |
| } | |
| ] | |
| return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs | |
| sql = normalize_command(sql_text) | |
| if not sql: | |
| chat = chat + [{"role": "assistant", "content": "Send a SQL query first."}] | |
| return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs | |
| table_error = validate_query_tables(sql, current_obs) | |
| if table_error: | |
| chat = chat + [{"role": "assistant", "content": table_error}] | |
| return chat, format_observation(current_obs), session_status(current_obs), format_reward({"value": 0.0, "done": False}), current_obs | |
| out = SESSION.step({"action": {"action_type": "query", "sql": sql}}) | |
| obs = out.get("observation") | |
| reward = out.get("reward") | |
| chat = chat + [ | |
| {"role": "user", "content": f"query: {sql}"}, | |
| {"role": "assistant", "content": f"Ran query. reward={reward.get('value', 0.0)}"}, | |
| ] | |
| return chat, format_observation(obs), session_status(obs), format_reward(reward), obs | |
| def submit_report(report_text: str, current_obs: dict[str, Any] | None, chat: list[dict[str, str]]): | |
| report = parse_json_fragment(report_text) | |
| if report is None: | |
| chat = chat + [{"role": "assistant", "content": "I couldn’t parse that as JSON. Paste a valid report object."}] | |
| return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs | |
| out = SESSION.step({"action": {"action_type": "submit_report", "report": report}}) | |
| obs = out.get("observation") | |
| reward = out.get("reward") | |
| chat = chat + [ | |
| {"role": "user", "content": "submit report"}, | |
| {"role": "assistant", "content": f"Submitted report. reward={reward.get('value', 0.0)}"}, | |
| ] | |
| return chat, format_observation(obs), session_status(obs), format_reward(reward), obs | |
| def auto_audit(current_obs: dict[str, Any] | None, chat: list[dict[str, str]]): | |
| if not current_obs: | |
| chat = chat + [{"role": "assistant", "content": "Reset a task before running auto audit."}] | |
| return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs | |
| task_id = int(current_obs.get("task_id", 1) or 1) | |
| queries = heuristic_queries(task_id) | |
| running_chat = chat + [{"role": "assistant", "content": f"Running {len(queries)} diagnostic probes..."}] | |
| obs = current_obs | |
| reward = None | |
| for sql in queries: | |
| table_error = validate_query_tables(sql, obs) | |
| if table_error: | |
| running_chat.append({"role": "assistant", "content": table_error}) | |
| continue | |
| out = SESSION.step({"action": {"action_type": "query", "sql": sql}}) | |
| obs = out.get("observation") | |
| reward = out.get("reward") | |
| running_chat.append({"role": "user", "content": sql}) | |
| running_chat.append({"role": "assistant", "content": f"reward={reward.get('value', 0.0)}"}) | |
| return running_chat, format_observation(obs), session_status(obs), format_reward(reward), obs | |
| def handle_command(user_text: str, current_obs: dict[str, Any] | None, chat: list[dict[str, str]], task_id: int, seed: int): | |
| text = normalize_command(user_text) | |
| if not text: | |
| return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs | |
| lower = text.lower() | |
| if lower in {"help", "?"}: | |
| chat = chat + [{"role": "assistant", "content": "Commands: `reset`, `query: SELECT ...`, `submit: {...json...}`, `auto`, or `state`."}] | |
| return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs | |
| if current_obs and not (lower.startswith("submit") or lower.startswith("reset") or lower == "state"): | |
| step = int(current_obs.get("step", 0) or 0) | |
| max_steps = int(current_obs.get("max_steps", 12) or 12) | |
| if step >= max_steps - 1: | |
| fallback = fallback_report_from_obs(current_obs) | |
| out = SESSION.step({"action": {"action_type": "submit_report", "report": fallback}}) | |
| obs = out.get("observation", current_obs) | |
| reward = out.get("reward", {}) | |
| chat = chat + [ | |
| { | |
| "role": "assistant", | |
| "content": "Step budget exhausted. I auto-submitted a fallback report to prevent `max_steps` zero-output failure.", | |
| } | |
| ] | |
| return chat, format_observation(obs), session_status(obs), format_reward(reward), obs | |
| if lower.startswith("reset"): | |
| return reset_ui(task_id=task_id, seed=seed) | |
| if lower == "state": | |
| chat = chat + [{"role": "assistant", "content": session_status(current_obs)}] | |
| return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs | |
| if lower.startswith("auto"): | |
| return auto_audit(current_obs, chat) | |
| if lower.startswith("submit"): | |
| payload = text.split(":", 1)[1].strip() if ":" in text else text[len("submit"):].strip() | |
| return submit_report(payload, current_obs, chat) | |
| if lower.startswith("query"): | |
| payload = text.split(":", 1)[1].strip() if ":" in text else text[len("query"):].strip() | |
| return run_query(payload, current_obs, chat) | |
| if re.search(r"\bselect\b|\bwith\b", lower): | |
| return run_query(text, current_obs, chat) | |
| chat = chat + [{"role": "assistant", "content": "I can help with `reset`, `query`, `submit`, `auto`, or `state`."}] | |
| return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs | |
| fastapi_app = FastAPI(title="DataQualityEnv Space") | |
| def _health() -> dict[str, str]: | |
| return health() | |
| def _reset(payload: dict = Body(default_factory=dict)) -> dict: | |
| payload = payload or {} | |
| payload.setdefault("task_id", 1) | |
| payload.setdefault("seed", 42) | |
| return SESSION.reset(payload) | |
| def _step(payload: dict = Body(default_factory=dict)) -> dict: | |
| payload = payload or {} | |
| return SESSION.step(payload) | |
| def _state() -> dict: | |
| return SESSION.state() | |
| with gr.Blocks(title="DataQualityEnv") as demo: | |
| gr.Markdown( | |
| "# DataQualityEnv\n" | |
| "A self-contained Hugging Face Space demo. No `ENV_URL`, no localhost dependency, no external API hop for the environment." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| task_id = gr.Dropdown(choices=[1, 2, 3, 4], value=1, label="Task") | |
| seed = gr.Number(value=42, precision=0, label="Seed") | |
| reset_btn = gr.Button("Reset episode", variant="primary") | |
| auto_btn = gr.Button("Auto audit") | |
| gr.Markdown("### Session status") | |
| status_box = gr.Markdown("No active episode. Choose a task and click Reset.") | |
| reward_box = gr.Textbox(label="Last reward", lines=8, interactive=False) | |
| obs_box = gr.Textbox(label="Observation JSON", lines=22, interactive=False) | |
| with gr.Column(scale=2): | |
| chat = gr.Chatbot(label="Chat", height=520) | |
| user_text = gr.Textbox( | |
| label="Command or SQL", | |
| placeholder="Type reset, query: SELECT ..., submit: {...}, auto, or state", | |
| lines=3, | |
| ) | |
| with gr.Row(): | |
| send_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear chat") | |
| current_obs = gr.State(None) | |
| reset_btn.click( | |
| reset_ui, | |
| inputs=[task_id, seed], | |
| outputs=[chat, obs_box, status_box, reward_box, current_obs], | |
| ) | |
| auto_btn.click( | |
| auto_audit, | |
| inputs=[current_obs, chat], | |
| outputs=[chat, obs_box, status_box, reward_box, current_obs], | |
| ) | |
| send_btn.click( | |
| handle_command, | |
| inputs=[user_text, current_obs, chat, task_id, seed], | |
| outputs=[chat, obs_box, status_box, reward_box, current_obs], | |
| ) | |
| user_text.submit( | |
| handle_command, | |
| inputs=[user_text, current_obs, chat, task_id, seed], | |
| outputs=[chat, obs_box, status_box, reward_box, current_obs], | |
| ) | |
| clear_btn.click(lambda: [], inputs=None, outputs=chat) | |
| app = gr.mount_gradio_app(fastapi_app, demo, path="/") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("space_app:app", host="0.0.0.0", port=int(os.environ.get("PORT", "7860"))) |