Spaces:
Sleeping
Sleeping
Commit ·
b08652c
1
Parent(s): 5de8f8e
replace ambiguous fixes with deterministic ones across all tasks
Browse filesEasy: misspelled department, extra-digit salary typo (inferrable)
Medium: OCR error (1O→10), misspelled product/status, 3-decimal price
Hard: misspelled model name, truncated sci notation, sign typo
All demo trajectories only propose fixes with logically deducible answers.
Grading now rewards valid fixes (correct type, right range, right format)
even without exact match.
124 tests passing.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- dataqa_env/server/gradio_ui.py +32 -46
- dataqa_env/server/tasks.py +50 -43
- tests/test_environment.py +22 -42
- tests/test_tasks.py +1 -2
dataqa_env/server/gradio_ui.py
CHANGED
|
@@ -28,8 +28,8 @@ AGENT_TRAJECTORIES = {
|
|
| 28 |
"issues": [
|
| 29 |
"row:4,col:name,issue:missing_value",
|
| 30 |
"row:7,col:salary,issue:wrong_type",
|
| 31 |
-
"row:
|
| 32 |
-
"row:
|
| 33 |
"row:3,col:email,issue:format_violation", # FP
|
| 34 |
],
|
| 35 |
"fixes": [],
|
|
@@ -38,21 +38,18 @@ AGENT_TRAJECTORIES = {
|
|
| 38 |
"issues": [
|
| 39 |
"row:4,col:name,issue:missing_value",
|
| 40 |
"row:7,col:salary,issue:wrong_type",
|
| 41 |
-
"row:
|
| 42 |
-
"row:21,col:employee_id,issue:duplicate_row",
|
| 43 |
"row:15,col:email,issue:inconsistent_value",
|
| 44 |
-
"row:18,col:
|
|
|
|
| 45 |
],
|
| 46 |
"fixes": [
|
| 47 |
-
#
|
| 48 |
-
"row:4,col:name,fix:David Kim",
|
| 49 |
-
#
|
| 50 |
-
"row:
|
| 51 |
-
|
| 52 |
-
"row:
|
| 53 |
-
# NOT proposed: row:9 salary (any valid salary 50000-150000 works)
|
| 54 |
-
# NOT proposed: row:18 start_date (any past date works)
|
| 55 |
-
# NOT proposed: row:21 duplicate (remove or reassign — ambiguous)
|
| 56 |
],
|
| 57 |
},
|
| 58 |
],
|
|
@@ -61,11 +58,10 @@ AGENT_TRAJECTORIES = {
|
|
| 61 |
"issues": [
|
| 62 |
"row:5,col:total,issue:inconsistent_value",
|
| 63 |
"row:10,col:category,issue:format_violation",
|
| 64 |
-
"row:
|
| 65 |
-
"row:17,col:quantity,issue:out_of_range",
|
| 66 |
-
"row:19,col:order_id,issue:duplicate_row",
|
| 67 |
"row:12,col:order_date,issue:format_violation",
|
| 68 |
-
"row:
|
|
|
|
| 69 |
],
|
| 70 |
"fixes": [],
|
| 71 |
},
|
|
@@ -73,25 +69,22 @@ AGENT_TRAJECTORIES = {
|
|
| 73 |
"issues": [
|
| 74 |
"row:5,col:total,issue:inconsistent_value",
|
| 75 |
"row:10,col:category,issue:format_violation",
|
| 76 |
-
"row:
|
| 77 |
-
"row:17,col:quantity,issue:out_of_range",
|
| 78 |
-
"row:19,col:order_id,issue:duplicate_row",
|
| 79 |
"row:12,col:order_date,issue:format_violation",
|
| 80 |
-
"row:
|
| 81 |
-
"row:
|
|
|
|
|
|
|
| 82 |
],
|
| 83 |
"fixes": [
|
| 84 |
-
#
|
| 85 |
-
"row:5,col:total,fix:42.00",
|
| 86 |
-
#
|
| 87 |
-
"row:10,col:
|
| 88 |
-
|
| 89 |
-
"row:
|
| 90 |
-
|
| 91 |
-
#
|
| 92 |
-
# NOT proposed: row:19 duplicate order_id (reassign — ambiguous)
|
| 93 |
-
# NOT proposed: row:24 country (could be any valid ISO code)
|
| 94 |
-
# NOT proposed: row:29 future date (any past date works)
|
| 95 |
],
|
| 96 |
},
|
| 97 |
],
|
|
@@ -120,18 +113,11 @@ AGENT_TRAJECTORIES = {
|
|
| 120 |
"row:12,col:test_accuracy,issue:statistical_outlier",
|
| 121 |
],
|
| 122 |
"fixes": [
|
| 123 |
-
#
|
| 124 |
-
"row:9,col:batch_size,fix:256",
|
| 125 |
-
#
|
| 126 |
-
"row:
|
| 127 |
-
|
| 128 |
-
# NOT proposed: row:15 model_name (could be any model)
|
| 129 |
-
# NOT proposed: row:5 val_loss (any val >= train_loss)
|
| 130 |
-
# NOT proposed: row:7 GPU memory (any reasonable value)
|
| 131 |
-
# NOT proposed: row:10 train_size (any value > test_size)
|
| 132 |
-
# NOT proposed: row:11 timestamp (any date after prev)
|
| 133 |
-
# NOT proposed: row:9 training_time (any reasonable hours)
|
| 134 |
-
# NOT proposed: row:12 test_accuracy (any < SOTA)
|
| 135 |
],
|
| 136 |
},
|
| 137 |
],
|
|
|
|
| 28 |
"issues": [
|
| 29 |
"row:4,col:name,issue:missing_value",
|
| 30 |
"row:7,col:salary,issue:wrong_type",
|
| 31 |
+
"row:11,col:department,issue:format_violation",
|
| 32 |
+
"row:15,col:email,issue:inconsistent_value",
|
| 33 |
"row:3,col:email,issue:format_violation", # FP
|
| 34 |
],
|
| 35 |
"fixes": [],
|
|
|
|
| 38 |
"issues": [
|
| 39 |
"row:4,col:name,issue:missing_value",
|
| 40 |
"row:7,col:salary,issue:wrong_type",
|
| 41 |
+
"row:11,col:department,issue:format_violation",
|
|
|
|
| 42 |
"row:15,col:email,issue:inconsistent_value",
|
| 43 |
+
"row:18,col:salary,issue:out_of_range",
|
| 44 |
+
"row:21,col:employee_id,issue:duplicate_row",
|
| 45 |
],
|
| 46 |
"fixes": [
|
| 47 |
+
# All deterministic fixes:
|
| 48 |
+
"row:4,col:name,fix:David Kim", # from email david.kim@
|
| 49 |
+
"row:7,col:salary,fix:75000", # "seventy-five thousand" → 75000
|
| 50 |
+
"row:11,col:department,fix:Engineering", # "Engneering" → "Engineering"
|
| 51 |
+
"row:15,col:email,fix:oscar.rivera@company.com", # from name Oscar Rivera
|
| 52 |
+
"row:18,col:salary,fix:99000", # 990000 → remove extra digit
|
|
|
|
|
|
|
|
|
|
| 53 |
],
|
| 54 |
},
|
| 55 |
],
|
|
|
|
| 58 |
"issues": [
|
| 59 |
"row:5,col:total,issue:inconsistent_value",
|
| 60 |
"row:10,col:category,issue:format_violation",
|
| 61 |
+
"row:10,col:quantity,issue:wrong_type",
|
|
|
|
|
|
|
| 62 |
"row:12,col:order_date,issue:format_violation",
|
| 63 |
+
"row:29,col:product_name,issue:format_violation",
|
| 64 |
+
"row:24,col:status,issue:format_violation",
|
| 65 |
],
|
| 66 |
"fixes": [],
|
| 67 |
},
|
|
|
|
| 69 |
"issues": [
|
| 70 |
"row:5,col:total,issue:inconsistent_value",
|
| 71 |
"row:10,col:category,issue:format_violation",
|
| 72 |
+
"row:10,col:quantity,issue:wrong_type",
|
|
|
|
|
|
|
| 73 |
"row:12,col:order_date,issue:format_violation",
|
| 74 |
+
"row:19,col:order_id,issue:duplicate_row",
|
| 75 |
+
"row:21,col:unit_price,issue:format_violation",
|
| 76 |
+
"row:24,col:status,issue:format_violation",
|
| 77 |
+
"row:29,col:product_name,issue:format_violation",
|
| 78 |
],
|
| 79 |
"fixes": [
|
| 80 |
+
# All deterministic:
|
| 81 |
+
"row:5,col:total,fix:42.00", # qty(1) * price(42.00)
|
| 82 |
+
"row:10,col:category,fix:Sports", # "Fitness" → nearest valid
|
| 83 |
+
"row:10,col:quantity,fix:10", # "1O" (letter O) → "10"
|
| 84 |
+
"row:12,col:order_date,fix:2024-01-26", # DD/MM/YYYY → YYYY-MM-DD
|
| 85 |
+
"row:24,col:status,fix:delivered", # "deliverred" → "delivered"
|
| 86 |
+
"row:29,col:product_name,fix:Wireless Charger", # "Wireles" → "Wireless"
|
| 87 |
+
"row:21,col:unit_price,fix:24.99", # 24.999 → round to 2 decimals
|
|
|
|
|
|
|
|
|
|
| 88 |
],
|
| 89 |
},
|
| 90 |
],
|
|
|
|
| 113 |
"row:12,col:test_accuracy,issue:statistical_outlier",
|
| 114 |
],
|
| 115 |
"fixes": [
|
| 116 |
+
# All deterministic:
|
| 117 |
+
"row:9,col:batch_size,fix:256", # 250 → nearest power of 2
|
| 118 |
+
"row:14,col:training_time_hours,fix:72.0", # -72.0 → remove negative sign
|
| 119 |
+
"row:15,col:model_name,fix:whisper-small", # "whsiper-small" → fix spelling
|
| 120 |
+
"row:13,col:learning_rate,fix:0.000025", # 2.5 → likely 2.5e-5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
],
|
| 122 |
},
|
| 123 |
],
|
dataqa_env/server/tasks.py
CHANGED
|
@@ -144,24 +144,25 @@ def create_task_easy(seed: int = 42) -> Task:
|
|
| 144 |
issues.append(PlantedIssue(row=len(data), col="employee_id", issue_type="duplicate_row",
|
| 145 |
description=f"Exact duplicate of row {dup_source + 1}", difficulty=1.5))
|
| 146 |
|
| 147 |
-
# Issue 4:
|
| 148 |
-
r =
|
| 149 |
-
data[r][
|
| 150 |
-
issues.append(PlantedIssue(row=r + 1, col="
|
| 151 |
-
description="
|
| 152 |
-
|
| 153 |
-
|
|
|
|
| 154 |
r = 14 # Oscar Rivera -> email should be oscar.rivera@company.com
|
| 155 |
data[r][2] = "john.doe@company.com"
|
| 156 |
issues.append(PlantedIssue(row=r + 1, col="email", issue_type="inconsistent_value",
|
| 157 |
description="Email john.doe@company.com doesn't match name Oscar Rivera",
|
| 158 |
difficulty=1.5))
|
| 159 |
|
| 160 |
-
# Issue 6:
|
| 161 |
-
r = 17 # Rosa Diaz
|
| 162 |
-
data[r][
|
| 163 |
-
issues.append(PlantedIssue(row=r + 1, col="
|
| 164 |
-
description="
|
| 165 |
difficulty=1.5))
|
| 166 |
|
| 167 |
corrupted = _rows_to_csv([header] + data)
|
|
@@ -259,17 +260,19 @@ ORD-030,CUST-128,Dumbbells Set,Sports,1,89.00,2024-02-13,US,shipped,89.00"""
|
|
| 259 |
issues.append(PlantedIssue(row=r + 1, col="category", issue_type="format_violation",
|
| 260 |
description="'Fitness' is not in allowed categories", difficulty=1.5))
|
| 261 |
|
| 262 |
-
# Issue 3:
|
| 263 |
-
r =
|
| 264 |
-
data[r][2] = ""
|
| 265 |
-
issues.append(PlantedIssue(row=r + 1, col="product_name", issue_type="
|
| 266 |
-
description="
|
|
|
|
| 267 |
|
| 268 |
-
# Issue 4:
|
| 269 |
-
r =
|
| 270 |
-
data[r][4] = "
|
| 271 |
-
issues.append(PlantedIssue(row=r + 1, col="quantity", issue_type="
|
| 272 |
-
description="
|
|
|
|
| 273 |
|
| 274 |
# Issue 5: Duplicate order_id (requires cross-row comparison)
|
| 275 |
r = 18 # ORD-019
|
|
@@ -283,19 +286,20 @@ ORD-030,CUST-128,Dumbbells Set,Sports,1,89.00,2024-02-13,US,shipped,89.00"""
|
|
| 283 |
issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="format_violation",
|
| 284 |
description="Date format DD/MM/YYYY instead of YYYY-MM-DD", difficulty=1.5))
|
| 285 |
|
| 286 |
-
# Issue 7:
|
| 287 |
r = 23 # ORD-024
|
| 288 |
-
data[r][
|
| 289 |
-
issues.append(PlantedIssue(row=r + 1, col="
|
| 290 |
-
description="'
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
#
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
|
|
|
| 299 |
|
| 300 |
corrupted = _rows_to_csv([header] + data)
|
| 301 |
|
|
@@ -421,23 +425,26 @@ EXP-030,llama2-13b,oasst1,84437,4401,4401,0.00001,2,3,0.78,0.88,0.0,52.0,12.0,20
|
|
| 421 |
description="train_size (500) is smaller than test_size (1821)",
|
| 422 |
difficulty=2.0))
|
| 423 |
|
| 424 |
-
# Issue 6: Negative training time (
|
| 425 |
r = 13 # EXP-014
|
| 426 |
data[r][13] = "-72.0"
|
| 427 |
issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="out_of_range",
|
| 428 |
-
description="Negative training time
|
|
|
|
| 429 |
|
| 430 |
-
# Issue 7: Learning rate
|
| 431 |
r = 12 # EXP-013
|
| 432 |
-
data[r][6] = "2.5" #
|
| 433 |
issues.append(PlantedIssue(row=r + 1, col="learning_rate", issue_type="out_of_range",
|
| 434 |
-
description="Learning rate 2.5 exceeds maximum
|
|
|
|
| 435 |
|
| 436 |
-
# Issue 8:
|
| 437 |
r = 14 # EXP-015
|
| 438 |
-
data[r][1] = "
|
| 439 |
-
issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="
|
| 440 |
-
description="
|
|
|
|
| 441 |
|
| 442 |
# Issue 9: Training time impossibly fast for dataset size and epochs
|
| 443 |
# EXP-004: vit-base on imagenet-1k, 300 epochs, but only 96 hours is plausible.
|
|
|
|
| 144 |
issues.append(PlantedIssue(row=len(data), col="employee_id", issue_type="duplicate_row",
|
| 145 |
description=f"Exact duplicate of row {dup_source + 1}", difficulty=1.5))
|
| 146 |
|
| 147 |
+
# Issue 4: Department is not in allowed set (deterministic: "Engneering" is not valid, closest match = "Engineering")
|
| 148 |
+
r = 10 # Kevin Zhang, department is Engineering
|
| 149 |
+
data[r][3] = "Engneering"
|
| 150 |
+
issues.append(PlantedIssue(row=r + 1, col="department", issue_type="format_violation",
|
| 151 |
+
description="Department 'Engneering' is misspelled — should be 'Engineering'",
|
| 152 |
+
difficulty=1.0))
|
| 153 |
+
|
| 154 |
+
# Issue 5: Email doesn't match name pattern (deterministic fix: derive from name)
|
| 155 |
r = 14 # Oscar Rivera -> email should be oscar.rivera@company.com
|
| 156 |
data[r][2] = "john.doe@company.com"
|
| 157 |
issues.append(PlantedIssue(row=r + 1, col="email", issue_type="inconsistent_value",
|
| 158 |
description="Email john.doe@company.com doesn't match name Oscar Rivera",
|
| 159 |
difficulty=1.5))
|
| 160 |
|
| 161 |
+
# Issue 6: Salary with extra digit — typo (deterministic fix: "950000" → "95000")
|
| 162 |
+
r = 17 # Rosa Diaz, original salary is 99000
|
| 163 |
+
data[r][4] = "990000" # extra zero
|
| 164 |
+
issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="out_of_range",
|
| 165 |
+
description="Salary 990000 exceeds maximum 150000 — likely extra digit typo (should be 99000)",
|
| 166 |
difficulty=1.5))
|
| 167 |
|
| 168 |
corrupted = _rows_to_csv([header] + data)
|
|
|
|
| 260 |
issues.append(PlantedIssue(row=r + 1, col="category", issue_type="format_violation",
|
| 261 |
description="'Fitness' is not in allowed categories", difficulty=1.5))
|
| 262 |
|
| 263 |
+
# Issue 3: Product name misspelling (deterministic fix: "Wireles Charger" → "Wireless Charger")
|
| 264 |
+
r = 28 # ORD-029
|
| 265 |
+
data[r][2] = "Wireles Charger"
|
| 266 |
+
issues.append(PlantedIssue(row=r + 1, col="product_name", issue_type="format_violation",
|
| 267 |
+
description="Product name 'Wireles Charger' is misspelled — should be 'Wireless Charger'",
|
| 268 |
+
difficulty=1.0))
|
| 269 |
|
| 270 |
+
# Issue 4: Quantity is letter O instead of zero — OCR/encoding error (deterministic: "1O" → "10")
|
| 271 |
+
r = 9 # ORD-010
|
| 272 |
+
data[r][4] = "1O" # letter O not digit 0
|
| 273 |
+
issues.append(PlantedIssue(row=r + 1, col="quantity", issue_type="wrong_type",
|
| 274 |
+
description="Quantity '1O' contains letter O instead of digit 0 — should be '10'",
|
| 275 |
+
difficulty=1.5))
|
| 276 |
|
| 277 |
# Issue 5: Duplicate order_id (requires cross-row comparison)
|
| 278 |
r = 18 # ORD-019
|
|
|
|
| 286 |
issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="format_violation",
|
| 287 |
description="Date format DD/MM/YYYY instead of YYYY-MM-DD", difficulty=1.5))
|
| 288 |
|
| 289 |
+
# Issue 7: Status misspelling (deterministic fix: "deliverred" → "delivered")
|
| 290 |
r = 23 # ORD-024
|
| 291 |
+
data[r][8] = "deliverred"
|
| 292 |
+
issues.append(PlantedIssue(row=r + 1, col="status", issue_type="format_violation",
|
| 293 |
+
description="Status 'deliverred' is misspelled — should be 'delivered'",
|
| 294 |
+
difficulty=1.0))
|
| 295 |
+
|
| 296 |
+
# Issue 8: Unit price has 3 decimal places (deterministic fix: "34.999" → "34.99")
|
| 297 |
+
# Rule says: all monetary values must have at most 2 decimal places
|
| 298 |
+
r = 20 # ORD-021
|
| 299 |
+
data[r][5] = "24.999"
|
| 300 |
+
issues.append(PlantedIssue(row=r + 1, col="unit_price", issue_type="format_violation",
|
| 301 |
+
description="Unit price 24.999 has 3 decimal places — rule requires at most 2 (should be 24.99 or 25.00)",
|
| 302 |
+
difficulty=1.5))
|
| 303 |
|
| 304 |
corrupted = _rows_to_csv([header] + data)
|
| 305 |
|
|
|
|
| 425 |
description="train_size (500) is smaller than test_size (1821)",
|
| 426 |
difficulty=2.0))
|
| 427 |
|
| 428 |
+
# Issue 6: Negative training time — sign typo (deterministic: "-72.0" → "72.0")
|
| 429 |
r = 13 # EXP-014
|
| 430 |
data[r][13] = "-72.0"
|
| 431 |
issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="out_of_range",
|
| 432 |
+
description="Negative training time -72.0 — likely sign typo (should be 72.0)",
|
| 433 |
+
difficulty=1.0))
|
| 434 |
|
| 435 |
+
# Issue 7: Learning rate in wrong notation (deterministic: "2.5e1" intended as "2.5e-5" → "0.000025")
|
| 436 |
r = 12 # EXP-013
|
| 437 |
+
data[r][6] = "2.5" # clearly missing the "e-5" part
|
| 438 |
issues.append(PlantedIssue(row=r + 1, col="learning_rate", issue_type="out_of_range",
|
| 439 |
+
description="Learning rate 2.5 exceeds maximum 1.0 — likely truncated scientific notation (e.g. 2.5e-5 → 0.000025)",
|
| 440 |
+
difficulty=1.5))
|
| 441 |
|
| 442 |
+
# Issue 8: Model name misspelling (deterministic: "whsiper-small" → "whisper-small")
|
| 443 |
r = 14 # EXP-015
|
| 444 |
+
data[r][1] = "whsiper-small"
|
| 445 |
+
issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="format_violation",
|
| 446 |
+
description="Model name 'whsiper-small' is misspelled — should be 'whisper-small'",
|
| 447 |
+
difficulty=1.5))
|
| 448 |
|
| 449 |
# Issue 9: Training time impossibly fast for dataset size and epochs
|
| 450 |
# EXP-004: vit-base on imagenet-1k, 300 epochs, but only 96 hours is plausible.
|
tests/test_environment.py
CHANGED
|
@@ -197,12 +197,11 @@ class TestGradeFixes:
|
|
| 197 |
result = grade_fixes(fixes, easy_task)
|
| 198 |
assert result["fixes_correct"] == 1
|
| 199 |
|
| 200 |
-
def
|
| 201 |
-
# Row
|
| 202 |
-
|
| 203 |
-
fixes = [(9, "salary", "73100")]
|
| 204 |
result = grade_fixes(fixes, easy_task)
|
| 205 |
-
assert result["
|
| 206 |
|
| 207 |
def test_wrong_value_for_issue_cell(self, easy_task):
|
| 208 |
# Row 4 name is empty — propose wrong name
|
|
@@ -228,16 +227,16 @@ class TestGradeFixes:
|
|
| 228 |
assert result["fixes_correct"] >= 1
|
| 229 |
|
| 230 |
def test_all_fixes_correct(self, easy_task):
|
| 231 |
-
# Fix
|
| 232 |
fixes = [
|
| 233 |
-
(4, "name", "David Kim"),
|
| 234 |
-
(7, "salary", "75000"),
|
| 235 |
-
(
|
| 236 |
-
(15, "email", "oscar.rivera@company.com"),
|
| 237 |
-
(18, "
|
| 238 |
]
|
| 239 |
result = grade_fixes(fixes, easy_task)
|
| 240 |
-
assert result["fix_score"] > 0.7
|
| 241 |
|
| 242 |
def test_fix_score_bounded(self, easy_task):
|
| 243 |
fixes = [(4, "name", "David Kim"), (99, "x", "bad")]
|
|
@@ -278,43 +277,31 @@ class TestDataQAEnvironment:
|
|
| 278 |
"""Backward compatible: only issues, no fixes."""
|
| 279 |
env.reset(task_id="easy")
|
| 280 |
# Submit all 6 correct issues for easy task
|
|
|
|
|
|
|
| 281 |
action = DataQAAction(
|
| 282 |
-
issues=[
|
| 283 |
-
"row:4,col:name,issue:missing_value",
|
| 284 |
-
"row:7,col:salary,issue:wrong_type",
|
| 285 |
-
"row:21,col:employee_id,issue:duplicate_row",
|
| 286 |
-
"row:9,col:salary,issue:out_of_range",
|
| 287 |
-
"row:15,col:email,issue:inconsistent_value",
|
| 288 |
-
"row:18,col:start_date,issue:out_of_range",
|
| 289 |
-
],
|
| 290 |
task_id="easy",
|
| 291 |
)
|
| 292 |
obs = env.step(action)
|
| 293 |
assert obs.done is True
|
| 294 |
-
assert obs.reward >= 0.999
|
| 295 |
|
| 296 |
def test_step_with_fixes_increases_reward(self, env):
|
| 297 |
"""Submitting correct fixes should produce high combined reward."""
|
| 298 |
env.reset(task_id="easy")
|
| 299 |
-
|
|
|
|
| 300 |
action = DataQAAction(
|
| 301 |
-
issues=[
|
| 302 |
-
"row:4,col:name,issue:missing_value",
|
| 303 |
-
"row:7,col:salary,issue:wrong_type",
|
| 304 |
-
"row:21,col:employee_id,issue:duplicate_row",
|
| 305 |
-
"row:9,col:salary,issue:out_of_range",
|
| 306 |
-
"row:15,col:email,issue:inconsistent_value",
|
| 307 |
-
"row:18,col:start_date,issue:out_of_range",
|
| 308 |
-
],
|
| 309 |
fixes=[
|
| 310 |
"row:4,col:name,fix:David Kim",
|
| 311 |
"row:7,col:salary,fix:75000",
|
| 312 |
-
"row:9,col:
|
| 313 |
],
|
| 314 |
task_id="easy",
|
| 315 |
)
|
| 316 |
obs = env.step(action)
|
| 317 |
-
# Perfect identify + partial fixes -> high combined reward
|
| 318 |
assert obs.metadata["combined_reward"] > 0.7
|
| 319 |
|
| 320 |
def test_step_with_partial_issues(self, env):
|
|
@@ -437,19 +424,12 @@ class TestDataQAEnvironment:
|
|
| 437 |
def test_no_fix_penalty_when_no_fixes_submitted(self, env):
|
| 438 |
"""If agent submits no fixes, reward = identify_score (no penalty)."""
|
| 439 |
env.reset(task_id="easy")
|
|
|
|
|
|
|
| 440 |
action = DataQAAction(
|
| 441 |
-
issues=[
|
| 442 |
-
"row:4,col:name,issue:missing_value",
|
| 443 |
-
"row:7,col:salary,issue:wrong_type",
|
| 444 |
-
"row:21,col:employee_id,issue:duplicate_row",
|
| 445 |
-
"row:9,col:salary,issue:out_of_range",
|
| 446 |
-
"row:15,col:email,issue:inconsistent_value",
|
| 447 |
-
"row:18,col:start_date,issue:out_of_range",
|
| 448 |
-
],
|
| 449 |
task_id="easy",
|
| 450 |
)
|
| 451 |
obs = env.step(action)
|
| 452 |
-
# identify_score should be ~1.0 since all 6 issues found
|
| 453 |
assert obs.reward >= 0.99
|
| 454 |
-
# combined_reward equals identify_score when no fixes
|
| 455 |
assert obs.metadata["combined_reward"] == obs.metadata["identify_score"]
|
|
|
|
| 197 |
result = grade_fixes(fixes, easy_task)
|
| 198 |
assert result["fixes_correct"] == 1
|
| 199 |
|
| 200 |
+
def test_misspelling_fix(self, easy_task):
|
| 201 |
+
# Row 11 has department "Engneering" — fix to "Engineering"
|
| 202 |
+
fixes = [(11, "department", "Engineering")]
|
|
|
|
| 203 |
result = grade_fixes(fixes, easy_task)
|
| 204 |
+
assert result["fixes_correct"] == 1
|
| 205 |
|
| 206 |
def test_wrong_value_for_issue_cell(self, easy_task):
|
| 207 |
# Row 4 name is empty — propose wrong name
|
|
|
|
| 227 |
assert result["fixes_correct"] >= 1
|
| 228 |
|
| 229 |
def test_all_fixes_correct(self, easy_task):
|
| 230 |
+
# Fix deterministic issues with exact values
|
| 231 |
fixes = [
|
| 232 |
+
(4, "name", "David Kim"), # inferred from email
|
| 233 |
+
(7, "salary", "75000"), # type conversion
|
| 234 |
+
(11, "department", "Engineering"), # spelling fix
|
| 235 |
+
(15, "email", "oscar.rivera@company.com"), # pattern match
|
| 236 |
+
(18, "salary", "99000"), # remove extra digit
|
| 237 |
]
|
| 238 |
result = grade_fixes(fixes, easy_task)
|
| 239 |
+
assert result["fix_score"] > 0.7
|
| 240 |
|
| 241 |
def test_fix_score_bounded(self, easy_task):
|
| 242 |
fixes = [(4, "name", "David Kim"), (99, "x", "bad")]
|
|
|
|
| 277 |
"""Backward compatible: only issues, no fixes."""
|
| 278 |
env.reset(task_id="easy")
|
| 279 |
# Submit all 6 correct issues for easy task
|
| 280 |
+
from dataqa_env.server.tasks import get_task
|
| 281 |
+
task = get_task("easy")
|
| 282 |
action = DataQAAction(
|
| 283 |
+
issues=[i.to_key() for i in task.planted_issues],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
task_id="easy",
|
| 285 |
)
|
| 286 |
obs = env.step(action)
|
| 287 |
assert obs.done is True
|
| 288 |
+
assert obs.reward >= 0.999
|
| 289 |
|
| 290 |
def test_step_with_fixes_increases_reward(self, env):
|
| 291 |
"""Submitting correct fixes should produce high combined reward."""
|
| 292 |
env.reset(task_id="easy")
|
| 293 |
+
from dataqa_env.server.tasks import get_task
|
| 294 |
+
task = get_task("easy")
|
| 295 |
action = DataQAAction(
|
| 296 |
+
issues=[i.to_key() for i in task.planted_issues],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
fixes=[
|
| 298 |
"row:4,col:name,fix:David Kim",
|
| 299 |
"row:7,col:salary,fix:75000",
|
| 300 |
+
"row:9,col:department,fix:Engineering",
|
| 301 |
],
|
| 302 |
task_id="easy",
|
| 303 |
)
|
| 304 |
obs = env.step(action)
|
|
|
|
| 305 |
assert obs.metadata["combined_reward"] > 0.7
|
| 306 |
|
| 307 |
def test_step_with_partial_issues(self, env):
|
|
|
|
| 424 |
def test_no_fix_penalty_when_no_fixes_submitted(self, env):
|
| 425 |
"""If agent submits no fixes, reward = identify_score (no penalty)."""
|
| 426 |
env.reset(task_id="easy")
|
| 427 |
+
from dataqa_env.server.tasks import get_task
|
| 428 |
+
task = get_task("easy")
|
| 429 |
action = DataQAAction(
|
| 430 |
+
issues=[i.to_key() for i in task.planted_issues],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
task_id="easy",
|
| 432 |
)
|
| 433 |
obs = env.step(action)
|
|
|
|
| 434 |
assert obs.reward >= 0.99
|
|
|
|
| 435 |
assert obs.metadata["combined_reward"] == obs.metadata["identify_score"]
|
tests/test_tasks.py
CHANGED
|
@@ -95,7 +95,7 @@ class TestTaskMedium:
|
|
| 95 |
types = {i.issue_type for i in task.planted_issues}
|
| 96 |
assert "inconsistent_value" in types
|
| 97 |
assert "format_violation" in types
|
| 98 |
-
assert "
|
| 99 |
|
| 100 |
def test_issue_keys_unique(self, task):
|
| 101 |
keys = [i.to_key() for i in task.planted_issues]
|
|
@@ -123,7 +123,6 @@ class TestTaskHard:
|
|
| 123 |
assert "format_violation" in types
|
| 124 |
assert "statistical_outlier" in types
|
| 125 |
assert "out_of_range" in types
|
| 126 |
-
assert "missing_value" in types
|
| 127 |
|
| 128 |
def test_has_high_difficulty_issues(self, task):
|
| 129 |
hard_issues = [i for i in task.planted_issues if i.difficulty >= 2.5]
|
|
|
|
| 95 |
types = {i.issue_type for i in task.planted_issues}
|
| 96 |
assert "inconsistent_value" in types
|
| 97 |
assert "format_violation" in types
|
| 98 |
+
assert "wrong_type" in types
|
| 99 |
|
| 100 |
def test_issue_keys_unique(self, task):
|
| 101 |
keys = [i.to_key() for i in task.planted_issues]
|
|
|
|
| 123 |
assert "format_violation" in types
|
| 124 |
assert "statistical_outlier" in types
|
| 125 |
assert "out_of_range" in types
|
|
|
|
| 126 |
|
| 127 |
def test_has_high_difficulty_issues(self, task):
|
| 128 |
hard_issues = [i for i in task.planted_issues if i.difficulty >= 2.5]
|