data-quality-env / space_app.py
Hemanth Kunta
Fix invalid rewards and Space query guards
aa25459
from __future__ import annotations
import json
import os
import re
from typing import Any
import gradio as gr
from fastapi import Body, FastAPI, HTTPException
from env.inprocess_backend import BACKEND
SESSION = BACKEND
def health() -> dict[str, str]:
return {"status": "ok", "env": "DataQualityEnv", "mode": "space-ui"}
def session_status(obs: dict[str, Any] | None) -> str:
if not obs:
return "No active episode. Choose a task and click Reset."
return (
f"Task {obs.get('task_id')} | phase={obs.get('phase')} | step={obs.get('step')}/{obs.get('max_steps')} | "
f"credits={obs.get('query_credits_remaining')}"
)
def initial_chat() -> list[dict[str, str]]:
return []
def format_observation(obs: dict[str, Any] | None) -> str:
return json.dumps(obs or {}, indent=2, default=str)
def format_reward(reward: dict[str, Any] | None) -> str:
return json.dumps(reward or {}, indent=2, default=str)
def task_hint(task_id: int) -> str:
if task_id == 1:
return "Try null-like value checks and duplicate-row grouping on the customers table."
if task_id == 2:
return "Try type parsing, negative values, and date-format checks on orders."
if task_id == 3:
return "Try baseline/current comparisons, new categories, and user population drift."
return "Try orphaned foreign keys, temporal checks, and aggregate consistency."
def heuristic_queries(task_id: int) -> list[str]:
if task_id == 1:
return [
"SELECT COUNT(*) AS total_rows FROM customers",
"SELECT SUM(CASE WHEN email IS NULL OR lower(trim(cast(email as varchar))) IN ('null','n/a','unknown','-','','0','none') THEN 1 ELSE 0 END) AS email_null_total, SUM(CASE WHEN customer_id IS NULL THEN 1 ELSE 0 END) AS cid_nulls FROM customers",
"SELECT COALESCE(SUM(c-1),0) AS exact_duplicate_rows FROM (SELECT customer_id,email,name,signup_date,country, COUNT(*) c FROM customers GROUP BY 1,2,3,4,5 HAVING COUNT(*)>1) t",
]
if task_id == 2:
return [
"SELECT SUM(CASE WHEN quantity < 0 THEN 1 ELSE 0 END) AS neg_qty, SUM(CASE WHEN try_cast(replace(amount,'$','') AS DOUBLE) IS NULL THEN 1 ELSE 0 END) AS bad_amt FROM orders",
"SELECT amount, order_date FROM orders LIMIT 10",
]
if task_id == 3:
return [
"SELECT (SELECT AVG(amount) FROM transactions_baseline) AS baseline_mean, (SELECT AVG(amount) FROM transactions_current) AS current_mean",
"SELECT DISTINCT c.category FROM transactions_current c LEFT JOIN (SELECT DISTINCT category FROM transactions_baseline) b ON c.category=b.category WHERE b.category IS NULL ORDER BY c.category",
]
return [
"SELECT COUNT(*) AS orphan_count FROM orders o LEFT JOIN customers c ON o.customer_id=c.customer_id WHERE c.customer_id IS NULL",
"SELECT COUNT(*) AS temporal_count FROM orders WHERE try_cast(ship_date AS TIMESTAMP) < try_cast(order_date AS TIMESTAMP)",
"SELECT COUNT(*) AS aggregate_count FROM (SELECT o.order_id, o.order_total, SUM(li.subtotal) AS s FROM orders o JOIN line_items li ON o.order_id=li.order_id GROUP BY o.order_id, o.order_total HAVING abs(o.order_total - SUM(li.subtotal)) > 1e-6) x",
]
def current_tables(obs: dict[str, Any] | None) -> set[str]:
tables = (obs or {}).get("tables") or {}
return {str(name).lower() for name in tables.keys()}
def referenced_tables(sql_text: str) -> set[str]:
sql = normalize_command(sql_text)
matches = re.finditer(r"\b(?:from|join)\s+([a-zA-Z_][\w\.]*)", sql, flags=re.IGNORECASE)
refs: set[str] = set()
for match in matches:
identifier = match.group(1).split(".")[-1].lower()
if identifier:
refs.add(identifier)
return refs
def validate_query_tables(sql_text: str, obs: dict[str, Any] | None) -> str | None:
allowed = current_tables(obs)
if not allowed:
return None
refs = referenced_tables(sql_text)
if not refs:
return None
unknown = sorted(refs - allowed)
if unknown:
available = ", ".join(sorted(allowed))
return f"This task only exposes: {available}. Please query one of those tables instead of: {', '.join(unknown)}."
return None
def normalize_command(text: str) -> str:
return (text or "").strip()
def parse_json_fragment(text: str) -> dict[str, Any] | None:
raw = normalize_command(text)
raw = raw.replace("```json", "").replace("```", "").strip()
try:
return json.loads(raw)
except Exception:
match = re.search(r"\{.*\}", raw, re.DOTALL)
if match:
try:
return json.loads(match.group())
except Exception:
return None
return None
def fallback_report_from_obs(obs: dict[str, Any] | None) -> dict[str, Any]:
task_id = int((obs or {}).get("task_id", 1) or 1)
base = {
"null_issues": {},
"duplicate_row_count": {"value": 0, "confidence": 0.5},
"schema_violations": [],
"drifted_columns": [],
"drift_details": {},
"relational_issues": [],
"recommended_fixes": [
"Auto-submitted fallback report to avoid max_steps termination",
"Run additional targeted probes in earlier steps for higher confidence",
],
}
if task_id == 1:
base["schema_violations"] = [
{
"column": "customers",
"issue_type": "partial_audit",
"example": "auto_submit_guard",
"count": 1,
"confidence": 0.4,
}
]
return base
def reset_ui(task_id: int, seed: int):
obs = SESSION.reset({"task_id": task_id, "seed": seed})
chat = initial_chat()
chat.append({"role": "assistant", "content": f"Reset complete for task {task_id}. {task_hint(task_id)}"})
return chat, format_observation(obs), session_status(obs), format_reward({"value": 0.0, "done": False}), obs
def run_query(sql_text: str, current_obs: dict[str, Any] | None, chat: list[dict[str, str]]):
if current_obs:
step = int(current_obs.get("step", 0) or 0)
max_steps = int(current_obs.get("max_steps", 12) or 12)
if step >= max_steps - 1:
chat = chat + [
{
"role": "assistant",
"content": "Step budget is almost exhausted. Submit your report now (`submit: {...}`) to avoid `max_steps` termination.",
}
]
return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs
sql = normalize_command(sql_text)
if not sql:
chat = chat + [{"role": "assistant", "content": "Send a SQL query first."}]
return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs
table_error = validate_query_tables(sql, current_obs)
if table_error:
chat = chat + [{"role": "assistant", "content": table_error}]
return chat, format_observation(current_obs), session_status(current_obs), format_reward({"value": 0.0, "done": False}), current_obs
out = SESSION.step({"action": {"action_type": "query", "sql": sql}})
obs = out.get("observation")
reward = out.get("reward")
chat = chat + [
{"role": "user", "content": f"query: {sql}"},
{"role": "assistant", "content": f"Ran query. reward={reward.get('value', 0.0)}"},
]
return chat, format_observation(obs), session_status(obs), format_reward(reward), obs
def submit_report(report_text: str, current_obs: dict[str, Any] | None, chat: list[dict[str, str]]):
report = parse_json_fragment(report_text)
if report is None:
chat = chat + [{"role": "assistant", "content": "I couldn’t parse that as JSON. Paste a valid report object."}]
return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs
out = SESSION.step({"action": {"action_type": "submit_report", "report": report}})
obs = out.get("observation")
reward = out.get("reward")
chat = chat + [
{"role": "user", "content": "submit report"},
{"role": "assistant", "content": f"Submitted report. reward={reward.get('value', 0.0)}"},
]
return chat, format_observation(obs), session_status(obs), format_reward(reward), obs
def auto_audit(current_obs: dict[str, Any] | None, chat: list[dict[str, str]]):
if not current_obs:
chat = chat + [{"role": "assistant", "content": "Reset a task before running auto audit."}]
return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs
task_id = int(current_obs.get("task_id", 1) or 1)
queries = heuristic_queries(task_id)
running_chat = chat + [{"role": "assistant", "content": f"Running {len(queries)} diagnostic probes..."}]
obs = current_obs
reward = None
for sql in queries:
table_error = validate_query_tables(sql, obs)
if table_error:
running_chat.append({"role": "assistant", "content": table_error})
continue
out = SESSION.step({"action": {"action_type": "query", "sql": sql}})
obs = out.get("observation")
reward = out.get("reward")
running_chat.append({"role": "user", "content": sql})
running_chat.append({"role": "assistant", "content": f"reward={reward.get('value', 0.0)}"})
return running_chat, format_observation(obs), session_status(obs), format_reward(reward), obs
def handle_command(user_text: str, current_obs: dict[str, Any] | None, chat: list[dict[str, str]], task_id: int, seed: int):
text = normalize_command(user_text)
if not text:
return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs
lower = text.lower()
if lower in {"help", "?"}:
chat = chat + [{"role": "assistant", "content": "Commands: `reset`, `query: SELECT ...`, `submit: {...json...}`, `auto`, or `state`."}]
return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs
if current_obs and not (lower.startswith("submit") or lower.startswith("reset") or lower == "state"):
step = int(current_obs.get("step", 0) or 0)
max_steps = int(current_obs.get("max_steps", 12) or 12)
if step >= max_steps - 1:
fallback = fallback_report_from_obs(current_obs)
out = SESSION.step({"action": {"action_type": "submit_report", "report": fallback}})
obs = out.get("observation", current_obs)
reward = out.get("reward", {})
chat = chat + [
{
"role": "assistant",
"content": "Step budget exhausted. I auto-submitted a fallback report to prevent `max_steps` zero-output failure.",
}
]
return chat, format_observation(obs), session_status(obs), format_reward(reward), obs
if lower.startswith("reset"):
return reset_ui(task_id=task_id, seed=seed)
if lower == "state":
chat = chat + [{"role": "assistant", "content": session_status(current_obs)}]
return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs
if lower.startswith("auto"):
return auto_audit(current_obs, chat)
if lower.startswith("submit"):
payload = text.split(":", 1)[1].strip() if ":" in text else text[len("submit"):].strip()
return submit_report(payload, current_obs, chat)
if lower.startswith("query"):
payload = text.split(":", 1)[1].strip() if ":" in text else text[len("query"):].strip()
return run_query(payload, current_obs, chat)
if re.search(r"\bselect\b|\bwith\b", lower):
return run_query(text, current_obs, chat)
chat = chat + [{"role": "assistant", "content": "I can help with `reset`, `query`, `submit`, `auto`, or `state`."}]
return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs
fastapi_app = FastAPI(title="DataQualityEnv Space")
@fastapi_app.get("/health")
def _health() -> dict[str, str]:
return health()
@fastapi_app.post("/reset")
def _reset(payload: dict = Body(default_factory=dict)) -> dict:
payload = payload or {}
payload.setdefault("task_id", 1)
payload.setdefault("seed", 42)
return SESSION.reset(payload)
@fastapi_app.post("/step")
def _step(payload: dict = Body(default_factory=dict)) -> dict:
payload = payload or {}
return SESSION.step(payload)
@fastapi_app.get("/state")
def _state() -> dict:
return SESSION.state()
with gr.Blocks(title="DataQualityEnv") as demo:
gr.Markdown(
"# DataQualityEnv\n"
"A self-contained Hugging Face Space demo. No `ENV_URL`, no localhost dependency, no external API hop for the environment."
)
with gr.Row():
with gr.Column(scale=1):
task_id = gr.Dropdown(choices=[1, 2, 3, 4], value=1, label="Task")
seed = gr.Number(value=42, precision=0, label="Seed")
reset_btn = gr.Button("Reset episode", variant="primary")
auto_btn = gr.Button("Auto audit")
gr.Markdown("### Session status")
status_box = gr.Markdown("No active episode. Choose a task and click Reset.")
reward_box = gr.Textbox(label="Last reward", lines=8, interactive=False)
obs_box = gr.Textbox(label="Observation JSON", lines=22, interactive=False)
with gr.Column(scale=2):
chat = gr.Chatbot(label="Chat", height=520)
user_text = gr.Textbox(
label="Command or SQL",
placeholder="Type reset, query: SELECT ..., submit: {...}, auto, or state",
lines=3,
)
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear chat")
current_obs = gr.State(None)
reset_btn.click(
reset_ui,
inputs=[task_id, seed],
outputs=[chat, obs_box, status_box, reward_box, current_obs],
)
auto_btn.click(
auto_audit,
inputs=[current_obs, chat],
outputs=[chat, obs_box, status_box, reward_box, current_obs],
)
send_btn.click(
handle_command,
inputs=[user_text, current_obs, chat, task_id, seed],
outputs=[chat, obs_box, status_box, reward_box, current_obs],
)
user_text.submit(
handle_command,
inputs=[user_text, current_obs, chat, task_id, seed],
outputs=[chat, obs_box, status_box, reward_box, current_obs],
)
clear_btn.click(lambda: [], inputs=None, outputs=chat)
app = gr.mount_gradio_app(fastapi_app, demo, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run("space_app:app", host="0.0.0.0", port=int(os.environ.get("PORT", "7860")))