File size: 25,685 Bytes
c3002ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5de8f8e
 
 
c3002ad
 
 
 
 
b08652c
 
c3002ad
 
 
 
 
 
 
 
b08652c
5d90461
f1b7439
b08652c
c3002ad
 
b08652c
 
 
 
 
f1b7439
c3002ad
 
 
 
 
 
 
 
b08652c
c3002ad
b08652c
 
5d90461
 
 
 
 
 
 
b08652c
5d90461
b08652c
 
 
 
c3002ad
 
b08652c
 
 
 
 
 
 
 
c3002ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1f98bf
b08652c
 
 
a1f98bf
c3002ad
 
 
5cb467d
 
 
 
5de8f8e
 
 
 
 
a9620ef
5cb467d
 
 
 
 
5de8f8e
a9620ef
5cb467d
5de8f8e
a9620ef
 
5de8f8e
c699b6f
5de8f8e
 
 
 
5cb467d
 
5de8f8e
8910a26
5de8f8e
 
 
 
 
 
 
5cb467d
 
 
887c1aa
 
 
 
 
 
8560706
887c1aa
8560706
887c1aa
 
 
 
 
 
8560706
887c1aa
 
 
 
 
 
fcce834
8560706
887c1aa
 
 
 
 
 
 
 
 
fcce834
8560706
887c1aa
 
 
c3002ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
"""
Gradio UI β€” Agent Trajectory Replay Viewer for DataQA.

Designed for judges: zero clicks needed, auto-plays on load.
Tab per task, step slider, prominent metric cards, color-coded dataset.
"""

from __future__ import annotations

import csv
import io

import gradio as gr

from .environment import DataQAEnvironment, parse_issue_key
from .tasks import list_tasks, PlantedIssue
from ..models import DataQAAction


# ── Pre-built agent trajectories (simulates baseline agent) ──

AGENT_TRAJECTORIES = {
    # Demo trajectories: fixes are ONLY proposed where the correct value
    # is logically inferrable (computable, format conversion, or deducible from context).
    # Ambiguous fixes (any valid salary, any past date) are NOT proposed.
    "easy": [
        {
            "issues": [
                "row:4,col:name,issue:missing_value",
                "row:7,col:salary,issue:wrong_type",
                "row:11,col:department,issue:format_violation",
                "row:15,col:email,issue:inconsistent_value",
                "row:3,col:email,issue:format_violation",  # FP
            ],
            "fixes": [],
        },
        {
            "issues": [
                "row:4,col:name,issue:missing_value",
                "row:7,col:salary,issue:wrong_type",
                "row:11,col:department,issue:format_violation",
                "row:15,col:email,issue:inconsistent_value",
                "row:12,col:start_date,issue:format_violation",
                "row:21,col:employee_id,issue:duplicate_row",
            ],
            "fixes": [
                # All deterministic fixes:
                "row:4,col:name,fix:David Kim",                     # from email david.kim@
                "row:7,col:salary,fix:75000",                       # "seventy-five thousand" β†’ 75000
                "row:11,col:department,fix:Engineering",             # "Engneering" β†’ "Engineering"
                "row:15,col:email,fix:oscar.rivera@company.com",    # from name Oscar Rivera
                "row:12,col:start_date,fix:2022-11-03",              # MM-DD-YYYY β†’ YYYY-MM-DD
            ],
        },
    ],
    "medium": [
        {
            "issues": [
                "row:5,col:total,issue:inconsistent_value",
                "row:10,col:category,issue:format_violation",
                "row:10,col:quantity,issue:wrong_type",
                "row:12,col:order_date,issue:format_violation",
                "row:29,col:product_name,issue:format_violation",
                "row:24,col:status,issue:format_violation",
            ],
            "fixes": [],
        },
        {
            "issues": [
                "row:5,col:total,issue:inconsistent_value",
                "row:10,col:category,issue:format_violation",
                "row:10,col:quantity,issue:wrong_type",
                "row:12,col:order_date,issue:format_violation",
                "row:19,col:order_id,issue:duplicate_row",
                "row:21,col:unit_price,issue:format_violation",
                "row:24,col:status,issue:format_violation",
                "row:29,col:product_name,issue:format_violation",
            ],
            "fixes": [
                # All deterministic:
                "row:5,col:total,fix:42.00",             # qty(1) * price(42.00)
                "row:10,col:category,fix:Sports",         # "Fitness" β†’ nearest valid
                "row:10,col:quantity,fix:10",              # "1O" (letter O) β†’ "10"
                "row:12,col:order_date,fix:2024-01-26",   # DD/MM/YYYY β†’ YYYY-MM-DD
                "row:24,col:status,fix:delivered",         # "deliverred" β†’ "delivered"
                "row:29,col:product_name,fix:Wireless Charger",  # "Wireles" β†’ "Wireless"
                "row:21,col:unit_price,fix:24.99",        # 24.999 β†’ round to 2 decimals
            ],
        },
    ],
    "hard": [
        {
            "issues": [
                "row:14,col:training_time_hours,issue:out_of_range",
                "row:13,col:learning_rate,issue:out_of_range",
                "row:15,col:model_name,issue:missing_value",
                "row:9,col:batch_size,issue:format_violation",
                "row:10,col:train_size,issue:inconsistent_value",
            ],
            "fixes": [],
        },
        {
            "issues": [
                "row:14,col:training_time_hours,issue:out_of_range",
                "row:13,col:learning_rate,issue:out_of_range",
                "row:15,col:model_name,issue:missing_value",
                "row:9,col:batch_size,issue:format_violation",
                "row:10,col:train_size,issue:inconsistent_value",
                "row:5,col:val_loss,issue:inconsistent_value",
                "row:7,col:gpu_memory_gb,issue:statistical_outlier",
                "row:11,col:timestamp,issue:inconsistent_value",
                "row:9,col:training_time_hours,issue:statistical_outlier",
                "row:12,col:test_accuracy,issue:statistical_outlier",
            ],
            "fixes": [
                # Only deterministic fixes:
                "row:9,col:batch_size,fix:256",                 # 250 β†’ nearest power of 2
                "row:14,col:training_time_hours,fix:72.0",      # -72.0 β†’ remove negative sign
                "row:15,col:model_name,fix:whisper-small",      # "whsiper-small" β†’ fix spelling
                # NOT proposed: row:13 LR (2.5 is out of range but any valid LR works)
            ],
        },
    ],
    "alignment": [
        {
            "issues": [
                "row:6,col:response,issue:inconsistent_value",
                "row:15,col:response,issue:inconsistent_value",
                "row:28,col:prompt,issue:missing_value",
                "row:20,col:response,issue:inconsistent_value",
                "row:7,col:prompt,issue:duplicate_row",
                "row:25,col:response,issue:missing_value",
                "row:3,col:response,issue:inconsistent_value",
            ],
            "fixes": [],
        },
        {
            "issues": [
                "row:3,col:response,issue:inconsistent_value",
                "row:4,col:response,issue:inconsistent_value",
                "row:6,col:response,issue:inconsistent_value",
                "row:7,col:prompt,issue:duplicate_row",
                "row:8,col:response,issue:inconsistent_value",
                "row:11,col:response,issue:inconsistent_value",
                "row:15,col:response,issue:inconsistent_value",
                "row:23,col:helpfulness,issue:inconsistent_value",
                "row:20,col:response,issue:inconsistent_value",
                "row:25,col:response,issue:missing_value",
                "row:28,col:prompt,issue:missing_value",
                "row:29,col:response,issue:inconsistent_value",
            ],
            "fixes": [
                # Inferrable: Salvator Mundi facts are well-known ($450.3M at Christie's)
                "row:4,col:response,fix:The most expensive painting ever sold at auction is Salvator Mundi by Leonardo da Vinci. It was sold for $450.3 million at Christie's in New York City in 2017.",
                # Inferrable: strip leaked [SYSTEM] prompt prefix
                "row:3,col:response,fix:Kitsch is art or design that is overly sentimental or ornate while camp is a style that is over-the-top and exaggerated often used in satire or irony.",
                # NOT proposed: row:6 wrong scientific name (need taxonomy knowledge)
                # NOT proposed: row:8 harmful advice (need to write safe version)
                # NOT proposed: row:11 self-contradiction (need to rewrite coherently)
                # NOT proposed: row:15 French response (need English translation)
                # NOT proposed: row:29 hallucinated citation (need factual replacement)
            ],
        },
    ],
    "moderation": [
        {
            "issues": [
                "row:16,col:hate,issue:inconsistent_value",
                "row:17,col:harassment,issue:inconsistent_value",
                "row:20,col:violence,issue:inconsistent_value",
                "row:13,col:violence,issue:out_of_range",
                "row:15,col:text,issue:missing_value",
                "row:30,col:text,issue:duplicate_row",
            ],
            "fixes": [],
        },
        {
            "issues": [
                "row:5,col:hate,issue:inconsistent_value",
                "row:13,col:violence,issue:out_of_range",
                "row:15,col:text,issue:missing_value",
                "row:16,col:hate,issue:inconsistent_value",
                "row:17,col:harassment,issue:inconsistent_value",
                "row:20,col:violence,issue:inconsistent_value",
                "row:22,col:self_harm,issue:inconsistent_value",
                "row:24,col:hate,issue:inconsistent_value",
                "row:25,col:violence,issue:inconsistent_value",
                "row:30,col:text,issue:duplicate_row",
            ],
            "fixes": [
                # All deterministic label flips:
                "row:16,col:hate,fix:1",           # unflagged hate β†’ flag it
                "row:17,col:harassment,fix:1",      # unflagged harassment β†’ flag it
                "row:20,col:violence,fix:1",        # unflagged violence β†’ flag it
                "row:22,col:self_harm,fix:1",       # unflagged self-harm β†’ flag it
                "row:5,col:hate,fix:0",             # false positive on idiom β†’ unflag
                "row:24,col:hate,fix:1",            # subset rule: hate_threatening needs hate
                "row:25,col:violence,fix:0",         # chose walk over violence β†’ not violent
                "row:13,col:violence,fix:0",         # out of range 3 β†’ 0
            ],
        },
    ],
}


# ── HTML rendering ──

def _metric_card(label: str, value: str, color: str = "#333") -> str:
    return (
        f'<div style="text-align:center;padding:12px 16px;background:#f8f9fa;'
        f'border-radius:8px;min-width:100px;">'
        f'<div style="font-size:11px;color:#666;text-transform:uppercase;letter-spacing:1px;">{label}</div>'
        f'<div style="font-size:28px;font-weight:700;color:{color};margin-top:2px;">{value}</div>'
        f'</div>'
    )


def _csv_to_html(
    csv_text: str,
    planted: list[PlantedIssue],
    correct: set[tuple[int, str]],
    fp: set[tuple[int, str]],
    missed: set[tuple[int, str]],
    fixed: dict[tuple[int, str], str],
    fix_values: dict[tuple[int, str], str] | None = None,
) -> str:
    """Render CSV as HTML with color-coded cells and inline fix proposals."""
    fix_values = fix_values or {}
    desc_map = {(i.row, i.col): i for i in planted}
    reader = csv.reader(io.StringIO(csv_text.strip()))
    rows = list(reader)
    if not rows:
        return ""

    header = rows[0]
    header_lower = [h.strip().lower() for h in header]
    data = rows[1:]

    t = ['<table style="border-collapse:collapse;width:100%;font-size:12px;font-family:\'SF Mono\',monospace;">']
    t.append('<tr>')
    t.append('<th style="border:1px solid #dee2e6;padding:6px 8px;background:#343a40;color:#fff;font-size:11px;">Row</th>')
    for h in header:
        t.append(f'<th style="border:1px solid #dee2e6;padding:6px 8px;background:#343a40;color:#fff;font-size:11px;">{h}</th>')
    t.append('</tr>')

    for i, row in enumerate(data):
        rn = i + 1
        bg = "#fff" if i % 2 == 0 else "#f8f9fa"
        t.append(f'<tr style="background:{bg};">')
        t.append(f'<td style="border:1px solid #dee2e6;padding:4px 8px;color:#adb5bd;text-align:center;font-size:11px;">{rn}</td>')
        for j, val in enumerate(row):
            col = header_lower[j] if j < len(header_lower) else ""
            ck = (rn, col)
            s = "border:1px solid #dee2e6;padding:4px 8px;"
            tip = ""
            badge = ""

            issue = desc_map.get(ck)

            if ck in correct:
                s += "background:#d4edda;"
                tip = f"FOUND: {issue.description}" if issue else ""
                badge = '<span style="font-size:9px;background:#28a745;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">TP</span>'
            elif ck in fp:
                s += "background:#f8d7da;"
                badge = '<span style="font-size:9px;background:#dc3545;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">FP</span>'
            elif ck in missed:
                s += "background:#fff3cd;"
                tip = f"MISSED: {issue.description}" if issue else ""
                badge = '<span style="font-size:9px;background:#856404;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">MISS</span>'

            fx = fixed.get(ck)
            proposed = fix_values.get(ck)
            if fx == "correct":
                s += "box-shadow:inset 0 0 0 2px #28a745;"
                badge += '<span style="font-size:9px;background:#28a745;color:#fff;padding:1px 4px;border-radius:3px;margin-left:2px;">FIX</span>'
            elif fx == "partial":
                s += "box-shadow:inset 0 0 0 2px #ffc107;"
                badge += '<span style="font-size:9px;background:#ffc107;color:#333;padding:1px 4px;border-radius:3px;margin-left:2px;">~FIX</span>'

            dv = val if val.strip() else '<em style="color:#dc3545;font-style:italic;">empty</em>'

            # Show proposed fix value below the corrupted value
            fix_line = ""
            if proposed is not None:
                fix_color = "#28a745" if fx == "correct" else ("#b8860b" if fx == "partial" else "#dc3545")
                fix_line = (
                    f'<div style="font-size:10px;color:{fix_color};margin-top:2px;'
                    f'border-top:1px dashed {fix_color};padding-top:2px;">'
                    f'\u2192 {proposed}</div>'
                )

            t.append(f'<td style="{s}" title="{tip}">{dv}{badge}{fix_line}</td>')
        t.append('</tr>')
    t.append('</table>')
    return "".join(t)


LEGEND_HTML = (
    '<div style="display:flex;gap:12px;flex-wrap:wrap;margin-top:10px;font-size:11px;">'
    '<span style="background:#d4edda;padding:2px 8px;border-radius:4px;">Found (TP)</span>'
    '<span style="background:#f8d7da;padding:2px 8px;border-radius:4px;">False Positive</span>'
    '<span style="background:#fff3cd;padding:2px 8px;border-radius:4px;">Missed</span>'
    '<span style="box-shadow:inset 0 0 0 2px #28a745;padding:2px 8px;border-radius:4px;">Fix Correct</span>'
    '<span style="box-shadow:inset 0 0 0 2px #ffc107;padding:2px 8px;border-radius:4px;">Fix Partial</span>'
    '</div>'
)


# ── Core replay logic ──

def _replay_task(task_id: str) -> list[dict]:
    """Run the agent trajectory and collect per-step data."""
    env = DataQAEnvironment()
    obs = env.reset(task_id=task_id)
    task = env._current_task
    planted_keys = {i.to_key() for i in task.planted_issues}
    steps_data = []

    # Step 0: initial state
    steps_data.append({
        "label": "Initial β€” corrupted dataset",
        "html": _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {}),
        "metrics": {"reward": 0.0, "tp": 0, "fp": 0, "fn": len(task.planted_issues),
                    "identify": 0.0, "fix": 0.0, "fixes_correct": 0},
        "feedback": f"Task: {task.name}\nIssues to find: {obs.num_issues_hint}\n\n{task.description}",
    })

    trajectory = AGENT_TRAJECTORIES.get(task_id, [])
    for i, step_data in enumerate(trajectory):
        action = DataQAAction(
            issues=step_data["issues"],
            fixes=step_data.get("fixes", []),
            task_id=task_id,
        )
        obs = env.step(action)

        reported_keys = set()
        for iss in step_data["issues"]:
            key = parse_issue_key(iss)
            if key:
                reported_keys.add(key)

        tp_keys = reported_keys & planted_keys
        fp_keys = reported_keys - planted_keys
        fn_keys = planted_keys - reported_keys

        correct = {_kc(k) for k in tp_keys}
        fp = {_kc(k) for k in fp_keys}
        missed = {_kc(k) for k in fn_keys} if obs.done else set()

        fixed: dict[tuple[int, str], str] = {}
        for d in obs.metadata.get("fix_details", []):
            c = (d["row"], d["col"])
            fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong")

        # Extract proposed fix values from the raw fix strings
        fix_values: dict[tuple[int, str], str] = {}
        from .environment import parse_fix
        for raw_fix in step_data.get("fixes", []):
            parsed = parse_fix(raw_fix)
            if parsed:
                row, col, val = parsed
                fix_values[(row, col)] = val

        html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp, missed, fixed, fix_values)

        has_fixes = bool(step_data.get("fixes"))
        if has_fixes:
            label = f"Step {i+1} β€” identify + fix"
        else:
            label = f"Step {i+1} β€” identify only"

        steps_data.append({
            "label": label,
            "html": html,
            "metrics": {
                "reward": obs.reward,
                "tp": obs.metadata["tp"],
                "fp": obs.metadata["fp"],
                "fn": obs.metadata["fn"],
                "identify": obs.metadata["identify_score"],
                "fix": obs.metadata["fix_score"],
                "fixes_correct": obs.metadata["fixes_correct"],
            },
            "feedback": obs.feedback,
        })

    return steps_data


def _kc(key: str) -> tuple[int, str]:
    parts = key.split(",")
    return (int(parts[0].split(":")[1]), parts[1].split(":")[1])


# ── Gradio app ──

def build_gradio_ui():
    # Pre-compute all replays at startup
    all_replays: dict[str, list[dict]] = {}
    for tid in list_tasks():
        all_replays[tid] = _replay_task(tid)

    def show_step(task_id: str, step_idx: int):
        replay = all_replays.get(task_id, [])
        step_idx = int(step_idx)
        if step_idx >= len(replay):
            step_idx = len(replay) - 1
        sd = replay[step_idx]
        m = sd["metrics"]

        # Reward color
        r = m["reward"]
        rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545")

        cards = (
            '<div style="display:flex;gap:10px;flex-wrap:wrap;margin-bottom:12px;">'
            + _metric_card("Reward", f"{r:.2f}", rc)
            + _metric_card("Found", str(m["tp"]), "#28a745")
            + _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745")
            + _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745")
            + _metric_card("Identify", f"{m['identify']:.2f}", "#333")
            + _metric_card("Fix", f"{m['fix']:.2f}", "#333")
            + '</div>'
        )

        full_html = (
            f'<div style="font-size:14px;font-weight:600;margin-bottom:8px;color:#495057;">'
            f'{sd["label"]}</div>'
            + cards + sd["html"] + LEGEND_HTML
        )

        return full_html, sd["feedback"]

    def on_task_change(task_id):
        replay = all_replays.get(task_id, [])
        max_step = len(replay) - 1
        html, fb = show_step(task_id, 0)
        return (
            gr.update(maximum=max_step, value=0),
            html,
            fb,
        )

    def on_step_change(task_id, step_idx):
        html, fb = show_step(task_id, step_idx)
        return html, fb

    # ── Live agent runner (connects to the env server) ──

    live_env = DataQAEnvironment()
    live_state: dict = {"obs": None, "task_id": "easy", "steps": []}

    def live_reset(task_id):
        obs = live_env.reset(task_id=task_id)
        task = live_env._current_task
        live_state["obs"] = obs
        live_state["task_id"] = task_id
        live_state["steps"] = []
        html = _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {})
        info = f"**{task.name}** β€” {obs.num_issues_hint} issues to find, {obs.max_steps} steps max"
        return html, info, "", "0.000"

    def live_step(issues_text, fixes_text):
        if live_state["obs"] is None:
            return "Reset first.", "", "", ""
        obs = live_state["obs"]
        task = live_env._current_task
        planted_keys = {i.to_key() for i in task.planted_issues}

        issues = [l.strip() for l in issues_text.strip().split("\n") if l.strip()]
        fixes = [l.strip() for l in fixes_text.strip().split("\n") if l.strip()] if fixes_text.strip() else []

        action = DataQAAction(issues=issues, fixes=fixes, task_id=live_state["task_id"])
        obs = live_env.step(action)
        live_state["obs"] = obs

        reported_keys = set()
        for iss in issues:
            key = parse_issue_key(iss)
            if key:
                reported_keys.add(key)

        tp_keys = reported_keys & planted_keys
        fp_keys = reported_keys - planted_keys
        fn_keys = planted_keys - reported_keys

        correct = {_kc(k) for k in tp_keys}
        fp_set = {_kc(k) for k in fp_keys}
        missed = {_kc(k) for k in fn_keys} if obs.done else set()

        fixed: dict[tuple[int, str], str] = {}
        for d in obs.metadata.get("fix_details", []):
            c = (d["row"], d["col"])
            fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong")

        from .environment import parse_fix
        fix_values: dict[tuple[int, str], str] = {}
        for raw in fixes:
            parsed = parse_fix(raw)
            if parsed:
                fix_values[(parsed[0], parsed[1])] = parsed[2]

        html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp_set, missed, fixed, fix_values)

        m = obs.metadata
        r = obs.reward
        rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545")
        cards = (
            '<div style="display:flex;gap:10px;flex-wrap:wrap;margin-bottom:12px;">'
            + _metric_card("Reward", f"{r:.2f}", rc)
            + _metric_card("Found", str(m["tp"]), "#28a745")
            + _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745")
            + _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745")
            + '</div>'
        )
        full_html = cards + html + LEGEND_HTML
        return full_html, obs.feedback, f"{r:.3f}", ""

    # ── Build the UI ──

    with gr.Blocks(title="DataQA Environment") as demo:
        gr.Markdown(
            "# DataQA β€” Data Quality Assurance Environment\n"
            "Two-phase RL environment: **Identify** data quality issues, then **Fix** them."
        )

        with gr.Tabs():
            # ── Tab 1: Demo replay ──
            with gr.Tab("Demo (Baseline Agent)"):
                gr.Markdown(
                    "*Replay of the baseline Qwen-72B agent. "
                    "Use the slider to step through the agent's trajectory.*"
                )
                with gr.Row():
                    task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1)
                    step_slider = gr.Slider(minimum=0, maximum=2, step=1, value=0, label="Step", scale=3)

                viz_html = gr.HTML()
                feedback_box = gr.Textbox(label="Agent Feedback", lines=10, interactive=False)

                task_dd.change(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box])
                step_slider.change(on_step_change, inputs=[task_dd, step_slider], outputs=[viz_html, feedback_box])
                demo.load(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box])

            # ── Tab 2: Try your own agent ──
            with gr.Tab("Try Your Own Agent"):
                gr.Markdown(
                    "*Submit your own issues and fixes to see how the environment scores them. "
                    "This is the same environment the baseline agent talks to.*"
                )
                with gr.Row():
                    live_task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1)
                    live_reset_btn = gr.Button("Reset", variant="primary", scale=1)

                with gr.Row():
                    live_info = gr.Markdown()
                    live_reward = gr.Textbox(label="Reward", interactive=False, scale=1)

                live_viz = gr.HTML()

                with gr.Row():
                    live_issues = gr.Textbox(
                        label="Issues (one per line)",
                        placeholder="row:4,col:name,issue:missing_value\nrow:7,col:salary,issue:wrong_type",
                        lines=5,
                    )
                    live_fixes = gr.Textbox(
                        label="Fixes (one per line, optional)",
                        placeholder="row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000",
                        lines=5,
                    )

                live_step_btn = gr.Button("Submit Step", variant="primary")
                live_feedback = gr.Textbox(label="Feedback", lines=10, interactive=False)

                live_reset_btn.click(
                    live_reset, inputs=[live_task_dd],
                    outputs=[live_viz, live_info, live_feedback, live_reward],
                )
                live_step_btn.click(
                    live_step, inputs=[live_issues, live_fixes],
                    outputs=[live_viz, live_feedback, live_reward, live_issues],
                )

    return demo


if __name__ == "__main__":
    demo = build_gradio_ui()
    demo.launch()