File size: 15,028 Bytes
84607b3
 
 
 
 
 
 
 
84b0d00
84607b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa25459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84607b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa25459
 
 
 
 
84607b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa25459
 
 
 
84607b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2185c5
84b0d00
 
 
 
f2185c5
 
 
 
84b0d00
 
f2185c5
 
 
 
 
 
 
 
84607b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
from __future__ import annotations

import json
import os
import re
from typing import Any

import gradio as gr
from fastapi import Body, FastAPI, HTTPException

from env.inprocess_backend import BACKEND

SESSION = BACKEND


def health() -> dict[str, str]:
    return {"status": "ok", "env": "DataQualityEnv", "mode": "space-ui"}


def session_status(obs: dict[str, Any] | None) -> str:
    if not obs:
        return "No active episode. Choose a task and click Reset."
    return (
        f"Task {obs.get('task_id')} | phase={obs.get('phase')} | step={obs.get('step')}/{obs.get('max_steps')} | "
        f"credits={obs.get('query_credits_remaining')}"
    )


def initial_chat() -> list[dict[str, str]]:
    return []


def format_observation(obs: dict[str, Any] | None) -> str:
    return json.dumps(obs or {}, indent=2, default=str)


def format_reward(reward: dict[str, Any] | None) -> str:
    return json.dumps(reward or {}, indent=2, default=str)


def task_hint(task_id: int) -> str:
    if task_id == 1:
        return "Try null-like value checks and duplicate-row grouping on the customers table."
    if task_id == 2:
        return "Try type parsing, negative values, and date-format checks on orders."
    if task_id == 3:
        return "Try baseline/current comparisons, new categories, and user population drift."
    return "Try orphaned foreign keys, temporal checks, and aggregate consistency."


def heuristic_queries(task_id: int) -> list[str]:
    if task_id == 1:
        return [
            "SELECT COUNT(*) AS total_rows FROM customers",
            "SELECT SUM(CASE WHEN email IS NULL OR lower(trim(cast(email as varchar))) IN ('null','n/a','unknown','-','','0','none') THEN 1 ELSE 0 END) AS email_null_total, SUM(CASE WHEN customer_id IS NULL THEN 1 ELSE 0 END) AS cid_nulls FROM customers",
            "SELECT COALESCE(SUM(c-1),0) AS exact_duplicate_rows FROM (SELECT customer_id,email,name,signup_date,country, COUNT(*) c FROM customers GROUP BY 1,2,3,4,5 HAVING COUNT(*)>1) t",
        ]
    if task_id == 2:
        return [
            "SELECT SUM(CASE WHEN quantity < 0 THEN 1 ELSE 0 END) AS neg_qty, SUM(CASE WHEN try_cast(replace(amount,'$','') AS DOUBLE) IS NULL THEN 1 ELSE 0 END) AS bad_amt FROM orders",
            "SELECT amount, order_date FROM orders LIMIT 10",
        ]
    if task_id == 3:
        return [
            "SELECT (SELECT AVG(amount) FROM transactions_baseline) AS baseline_mean, (SELECT AVG(amount) FROM transactions_current) AS current_mean",
            "SELECT DISTINCT c.category FROM transactions_current c LEFT JOIN (SELECT DISTINCT category FROM transactions_baseline) b ON c.category=b.category WHERE b.category IS NULL ORDER BY c.category",
        ]
    return [
        "SELECT COUNT(*) AS orphan_count FROM orders o LEFT JOIN customers c ON o.customer_id=c.customer_id WHERE c.customer_id IS NULL",
        "SELECT COUNT(*) AS temporal_count FROM orders WHERE try_cast(ship_date AS TIMESTAMP) < try_cast(order_date AS TIMESTAMP)",
        "SELECT COUNT(*) AS aggregate_count FROM (SELECT o.order_id, o.order_total, SUM(li.subtotal) AS s FROM orders o JOIN line_items li ON o.order_id=li.order_id GROUP BY o.order_id, o.order_total HAVING abs(o.order_total - SUM(li.subtotal)) > 1e-6) x",
    ]


def current_tables(obs: dict[str, Any] | None) -> set[str]:
    tables = (obs or {}).get("tables") or {}
    return {str(name).lower() for name in tables.keys()}


def referenced_tables(sql_text: str) -> set[str]:
    sql = normalize_command(sql_text)
    matches = re.finditer(r"\b(?:from|join)\s+([a-zA-Z_][\w\.]*)", sql, flags=re.IGNORECASE)
    refs: set[str] = set()
    for match in matches:
        identifier = match.group(1).split(".")[-1].lower()
        if identifier:
            refs.add(identifier)
    return refs


def validate_query_tables(sql_text: str, obs: dict[str, Any] | None) -> str | None:
    allowed = current_tables(obs)
    if not allowed:
        return None
    refs = referenced_tables(sql_text)
    if not refs:
        return None
    unknown = sorted(refs - allowed)
    if unknown:
        available = ", ".join(sorted(allowed))
        return f"This task only exposes: {available}. Please query one of those tables instead of: {', '.join(unknown)}."
    return None


def normalize_command(text: str) -> str:
    return (text or "").strip()


def parse_json_fragment(text: str) -> dict[str, Any] | None:
    raw = normalize_command(text)
    raw = raw.replace("```json", "").replace("```", "").strip()
    try:
        return json.loads(raw)
    except Exception:
        match = re.search(r"\{.*\}", raw, re.DOTALL)
        if match:
            try:
                return json.loads(match.group())
            except Exception:
                return None
    return None


def fallback_report_from_obs(obs: dict[str, Any] | None) -> dict[str, Any]:
    task_id = int((obs or {}).get("task_id", 1) or 1)
    base = {
        "null_issues": {},
        "duplicate_row_count": {"value": 0, "confidence": 0.5},
        "schema_violations": [],
        "drifted_columns": [],
        "drift_details": {},
        "relational_issues": [],
        "recommended_fixes": [
            "Auto-submitted fallback report to avoid max_steps termination",
            "Run additional targeted probes in earlier steps for higher confidence",
        ],
    }
    if task_id == 1:
        base["schema_violations"] = [
            {
                "column": "customers",
                "issue_type": "partial_audit",
                "example": "auto_submit_guard",
                "count": 1,
                "confidence": 0.4,
            }
        ]
    return base


def reset_ui(task_id: int, seed: int):
    obs = SESSION.reset({"task_id": task_id, "seed": seed})
    chat = initial_chat()
    chat.append({"role": "assistant", "content": f"Reset complete for task {task_id}. {task_hint(task_id)}"})
    return chat, format_observation(obs), session_status(obs), format_reward({"value": 0.0, "done": False}), obs


def run_query(sql_text: str, current_obs: dict[str, Any] | None, chat: list[dict[str, str]]):
    if current_obs:
        step = int(current_obs.get("step", 0) or 0)
        max_steps = int(current_obs.get("max_steps", 12) or 12)
        if step >= max_steps - 1:
            chat = chat + [
                {
                    "role": "assistant",
                    "content": "Step budget is almost exhausted. Submit your report now (`submit: {...}`) to avoid `max_steps` termination.",
                }
            ]
            return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs

    sql = normalize_command(sql_text)
    if not sql:
        chat = chat + [{"role": "assistant", "content": "Send a SQL query first."}]
        return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs

    table_error = validate_query_tables(sql, current_obs)
    if table_error:
        chat = chat + [{"role": "assistant", "content": table_error}]
        return chat, format_observation(current_obs), session_status(current_obs), format_reward({"value": 0.0, "done": False}), current_obs

    out = SESSION.step({"action": {"action_type": "query", "sql": sql}})
    obs = out.get("observation")
    reward = out.get("reward")
    chat = chat + [
        {"role": "user", "content": f"query: {sql}"},
        {"role": "assistant", "content": f"Ran query. reward={reward.get('value', 0.0)}"},
    ]
    return chat, format_observation(obs), session_status(obs), format_reward(reward), obs


def submit_report(report_text: str, current_obs: dict[str, Any] | None, chat: list[dict[str, str]]):
    report = parse_json_fragment(report_text)
    if report is None:
        chat = chat + [{"role": "assistant", "content": "I couldn’t parse that as JSON. Paste a valid report object."}]
        return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs

    out = SESSION.step({"action": {"action_type": "submit_report", "report": report}})
    obs = out.get("observation")
    reward = out.get("reward")
    chat = chat + [
        {"role": "user", "content": "submit report"},
        {"role": "assistant", "content": f"Submitted report. reward={reward.get('value', 0.0)}"},
    ]
    return chat, format_observation(obs), session_status(obs), format_reward(reward), obs


def auto_audit(current_obs: dict[str, Any] | None, chat: list[dict[str, str]]):
    if not current_obs:
        chat = chat + [{"role": "assistant", "content": "Reset a task before running auto audit."}]
        return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs

    task_id = int(current_obs.get("task_id", 1) or 1)
    queries = heuristic_queries(task_id)
    running_chat = chat + [{"role": "assistant", "content": f"Running {len(queries)} diagnostic probes..."}]
    obs = current_obs
    reward = None
    for sql in queries:
        table_error = validate_query_tables(sql, obs)
        if table_error:
            running_chat.append({"role": "assistant", "content": table_error})
            continue
        out = SESSION.step({"action": {"action_type": "query", "sql": sql}})
        obs = out.get("observation")
        reward = out.get("reward")
        running_chat.append({"role": "user", "content": sql})
        running_chat.append({"role": "assistant", "content": f"reward={reward.get('value', 0.0)}"})
    return running_chat, format_observation(obs), session_status(obs), format_reward(reward), obs


def handle_command(user_text: str, current_obs: dict[str, Any] | None, chat: list[dict[str, str]], task_id: int, seed: int):
    text = normalize_command(user_text)
    if not text:
        return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs

    lower = text.lower()
    if lower in {"help", "?"}:
        chat = chat + [{"role": "assistant", "content": "Commands: `reset`, `query: SELECT ...`, `submit: {...json...}`, `auto`, or `state`."}]
        return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs

    if current_obs and not (lower.startswith("submit") or lower.startswith("reset") or lower == "state"):
        step = int(current_obs.get("step", 0) or 0)
        max_steps = int(current_obs.get("max_steps", 12) or 12)
        if step >= max_steps - 1:
            fallback = fallback_report_from_obs(current_obs)
            out = SESSION.step({"action": {"action_type": "submit_report", "report": fallback}})
            obs = out.get("observation", current_obs)
            reward = out.get("reward", {})
            chat = chat + [
                {
                    "role": "assistant",
                    "content": "Step budget exhausted. I auto-submitted a fallback report to prevent `max_steps` zero-output failure.",
                }
            ]
            return chat, format_observation(obs), session_status(obs), format_reward(reward), obs

    if lower.startswith("reset"):
        return reset_ui(task_id=task_id, seed=seed)

    if lower == "state":
        chat = chat + [{"role": "assistant", "content": session_status(current_obs)}]
        return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs

    if lower.startswith("auto"):
        return auto_audit(current_obs, chat)

    if lower.startswith("submit"):
        payload = text.split(":", 1)[1].strip() if ":" in text else text[len("submit"):].strip()
        return submit_report(payload, current_obs, chat)

    if lower.startswith("query"):
        payload = text.split(":", 1)[1].strip() if ":" in text else text[len("query"):].strip()
        return run_query(payload, current_obs, chat)

    if re.search(r"\bselect\b|\bwith\b", lower):
        return run_query(text, current_obs, chat)

    chat = chat + [{"role": "assistant", "content": "I can help with `reset`, `query`, `submit`, `auto`, or `state`."}]
    return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs


fastapi_app = FastAPI(title="DataQualityEnv Space")


@fastapi_app.get("/health")
def _health() -> dict[str, str]:
    return health()


@fastapi_app.post("/reset")
def _reset(payload: dict = Body(default_factory=dict)) -> dict:
    payload = payload or {}
    payload.setdefault("task_id", 1)
    payload.setdefault("seed", 42)
    return SESSION.reset(payload)


@fastapi_app.post("/step")
def _step(payload: dict = Body(default_factory=dict)) -> dict:
    payload = payload or {}
    return SESSION.step(payload)


@fastapi_app.get("/state")
def _state() -> dict:
    return SESSION.state()


with gr.Blocks(title="DataQualityEnv") as demo:
    gr.Markdown(
        "# DataQualityEnv\n"
        "A self-contained Hugging Face Space demo. No `ENV_URL`, no localhost dependency, no external API hop for the environment."
    )
    with gr.Row():
        with gr.Column(scale=1):
            task_id = gr.Dropdown(choices=[1, 2, 3, 4], value=1, label="Task")
            seed = gr.Number(value=42, precision=0, label="Seed")
            reset_btn = gr.Button("Reset episode", variant="primary")
            auto_btn = gr.Button("Auto audit")
            gr.Markdown("### Session status")
            status_box = gr.Markdown("No active episode. Choose a task and click Reset.")
            reward_box = gr.Textbox(label="Last reward", lines=8, interactive=False)
            obs_box = gr.Textbox(label="Observation JSON", lines=22, interactive=False)
        with gr.Column(scale=2):
            chat = gr.Chatbot(label="Chat", height=520)
            user_text = gr.Textbox(
                label="Command or SQL",
                placeholder="Type reset, query: SELECT ..., submit: {...}, auto, or state",
                lines=3,
            )
            with gr.Row():
                send_btn = gr.Button("Send", variant="primary")
                clear_btn = gr.Button("Clear chat")

    current_obs = gr.State(None)

    reset_btn.click(
        reset_ui,
        inputs=[task_id, seed],
        outputs=[chat, obs_box, status_box, reward_box, current_obs],
    )
    auto_btn.click(
        auto_audit,
        inputs=[current_obs, chat],
        outputs=[chat, obs_box, status_box, reward_box, current_obs],
    )
    send_btn.click(
        handle_command,
        inputs=[user_text, current_obs, chat, task_id, seed],
        outputs=[chat, obs_box, status_box, reward_box, current_obs],
    )
    user_text.submit(
        handle_command,
        inputs=[user_text, current_obs, chat, task_id, seed],
        outputs=[chat, obs_box, status_box, reward_box, current_obs],
    )
    clear_btn.click(lambda: [], inputs=None, outputs=chat)


app = gr.mount_gradio_app(fastapi_app, demo, path="/")


if __name__ == "__main__":
    import uvicorn

    uvicorn.run("space_app:app", host="0.0.0.0", port=int(os.environ.get("PORT", "7860")))