avanigupta commited on
Commit
5d90461
·
1 Parent(s): 081eb22

expand datasets to include harder real-world scenarios

Browse files
README.md CHANGED
@@ -52,9 +52,9 @@ This creates a rich multi-step decision problem where agents must explore datase
52
 
53
  | Task | Issues | Difficulty | Domain | Description |
54
  |------|--------|-----------|--------|-------------|
55
- | `easy` | 4 | Beginner | HR/Employee data | Nulls, wrong types, duplicates, out-of-range values |
56
- | `medium` | 6 | Intermediate | E-commerce orders | Format violations, inconsistent computed fields, duplicate keys |
57
- | `hard` | 10 | Advanced | ML experiment metadata | Data leakage signals, unreasonable GPU memory, impossibly fast training, SOTA-exceeding accuracy, timestamp ordering, whitespace-only fields |
58
 
59
  **Difficulty progression**: Easy issues are individually obvious (empty fields, text in numeric columns). Medium issues require cross-column reasoning (total != qty * price) and set membership checks. Hard issues require ML domain knowledge (val_loss < train_loss = data leakage) and multi-row temporal reasoning.
60
 
 
52
 
53
  | Task | Issues | Difficulty | Domain | Description |
54
  |------|--------|-----------|--------|-------------|
55
+ | `easy` | 6 | Beginner | HR/Employee data (21 rows) | Nulls, wrong types, duplicates, out-of-range, email-name mismatch, future dates |
56
+ | `medium` | 8 | Intermediate | E-commerce orders (31 rows) | Inconsistent totals, invalid categories, duplicate keys, wrong date formats, invalid country codes, future-date deliveries |
57
+ | `hard` | 10 | Advanced | ML experiment metadata (31 rows) | Data leakage signals, unreasonable GPU memory, impossibly fast training, SOTA-exceeding accuracy, timestamp ordering, whitespace-only fields |
58
 
59
  **Difficulty progression**: Easy issues are individually obvious (empty fields, text in numeric columns). Medium issues require cross-column reasoning (total != qty * price) and set membership checks. Hard issues require ML domain knowledge (val_loss < train_loss = data leakage) and multi-row temporal reasoning.
60
 
dataqa_env/server/gradio_ui.py CHANGED
@@ -26,6 +26,7 @@ AGENT_TRAJECTORIES = {
26
  "row:4,col:name,issue:missing_value",
27
  "row:7,col:salary,issue:wrong_type",
28
  "row:9,col:salary,issue:out_of_range",
 
29
  "row:3,col:email,issue:format_violation", # FP
30
  ],
31
  "fixes": [],
@@ -35,12 +36,16 @@ AGENT_TRAJECTORIES = {
35
  "row:4,col:name,issue:missing_value",
36
  "row:7,col:salary,issue:wrong_type",
37
  "row:9,col:salary,issue:out_of_range",
38
- "row:11,col:employee_id,issue:duplicate_row",
 
 
39
  ],
40
  "fixes": [
41
  "row:4,col:name,fix:David Kim",
42
  "row:7,col:salary,fix:75000",
43
  "row:9,col:salary,fix:73000",
 
 
44
  ],
45
  },
46
  ],
@@ -53,12 +58,28 @@ AGENT_TRAJECTORIES = {
53
  "row:17,col:quantity,issue:out_of_range",
54
  "row:19,col:order_id,issue:duplicate_row",
55
  "row:12,col:order_date,issue:format_violation",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  ],
57
  "fixes": [
58
  "row:5,col:total,fix:42.00",
59
  "row:10,col:category,fix:Sports",
60
  "row:12,col:order_date,fix:2024-01-26",
61
  "row:14,col:product_name,fix:LED Strip Lights",
 
 
62
  ],
63
  },
64
  ],
 
26
  "row:4,col:name,issue:missing_value",
27
  "row:7,col:salary,issue:wrong_type",
28
  "row:9,col:salary,issue:out_of_range",
29
+ "row:18,col:start_date,issue:out_of_range",
30
  "row:3,col:email,issue:format_violation", # FP
31
  ],
32
  "fixes": [],
 
36
  "row:4,col:name,issue:missing_value",
37
  "row:7,col:salary,issue:wrong_type",
38
  "row:9,col:salary,issue:out_of_range",
39
+ "row:21,col:employee_id,issue:duplicate_row",
40
+ "row:15,col:email,issue:inconsistent_value",
41
+ "row:18,col:start_date,issue:out_of_range",
42
  ],
43
  "fixes": [
44
  "row:4,col:name,fix:David Kim",
45
  "row:7,col:salary,fix:75000",
46
  "row:9,col:salary,fix:73000",
47
+ "row:15,col:email,fix:oscar.rivera@company.com",
48
+ "row:18,col:start_date,fix:2022-01-19",
49
  ],
50
  },
51
  ],
 
58
  "row:17,col:quantity,issue:out_of_range",
59
  "row:19,col:order_id,issue:duplicate_row",
60
  "row:12,col:order_date,issue:format_violation",
61
+ "row:24,col:shipping_country,issue:format_violation",
62
+ ],
63
+ "fixes": [],
64
+ },
65
+ {
66
+ "issues": [
67
+ "row:5,col:total,issue:inconsistent_value",
68
+ "row:10,col:category,issue:format_violation",
69
+ "row:14,col:product_name,issue:missing_value",
70
+ "row:17,col:quantity,issue:out_of_range",
71
+ "row:19,col:order_id,issue:duplicate_row",
72
+ "row:12,col:order_date,issue:format_violation",
73
+ "row:24,col:shipping_country,issue:format_violation",
74
+ "row:29,col:order_date,issue:inconsistent_value",
75
  ],
76
  "fixes": [
77
  "row:5,col:total,fix:42.00",
78
  "row:10,col:category,fix:Sports",
79
  "row:12,col:order_date,fix:2024-01-26",
80
  "row:14,col:product_name,fix:LED Strip Lights",
81
+ "row:24,col:shipping_country,fix:US",
82
+ "row:29,col:order_date,fix:2024-02-12",
83
  ],
84
  },
85
  ],
dataqa_env/server/tasks.py CHANGED
@@ -150,6 +150,20 @@ def create_task_easy(seed: int = 42) -> Task:
150
  issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="out_of_range",
151
  description="Salary 5000 is below minimum 50000", difficulty=1.0))
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  corrupted = _rows_to_csv([header] + data)
154
 
155
  return Task(
@@ -269,6 +283,20 @@ ORD-030,CUST-128,Dumbbells Set,Sports,1,89.00,2024-02-13,US,shipped,89.00"""
269
  issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="format_violation",
270
  description="Date format DD/MM/YYYY instead of YYYY-MM-DD", difficulty=1.5))
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  corrupted = _rows_to_csv([header] + data)
273
 
274
  return Task(
 
150
  issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="out_of_range",
151
  description="Salary 5000 is below minimum 50000", difficulty=1.0))
152
 
153
+ # Issue 5: Email doesn't match name pattern (moderate — cross-column check)
154
+ r = 14 # Oscar Rivera -> email should be oscar.rivera@company.com
155
+ data[r][2] = "john.doe@company.com"
156
+ issues.append(PlantedIssue(row=r + 1, col="email", issue_type="inconsistent_value",
157
+ description="Email john.doe@company.com doesn't match name Oscar Rivera",
158
+ difficulty=1.5))
159
+
160
+ # Issue 6: Future start date (requires knowing current date context)
161
+ r = 17 # Rosa Diaz
162
+ data[r][5] = "2027-06-15"
163
+ issues.append(PlantedIssue(row=r + 1, col="start_date", issue_type="out_of_range",
164
+ description="Start date 2027-06-15 is in the future (beyond 2025-12-31)",
165
+ difficulty=1.5))
166
+
167
  corrupted = _rows_to_csv([header] + data)
168
 
169
  return Task(
 
283
  issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="format_violation",
284
  description="Date format DD/MM/YYYY instead of YYYY-MM-DD", difficulty=1.5))
285
 
286
+ # Issue 7: Invalid country code (requires ISO knowledge)
287
+ r = 23 # ORD-024
288
+ data[r][7] = "XX" # not a valid ISO country code
289
+ issues.append(PlantedIssue(row=r + 1, col="shipping_country", issue_type="format_violation",
290
+ description="'XX' is not a valid ISO 2-letter country code", difficulty=1.5))
291
+
292
+ # Issue 8: Status-date inconsistency — order from Feb 13 still "processing" is suspicious
293
+ # but more importantly: delivered order with a future date
294
+ r = 28 # ORD-029
295
+ data[r][6] = "2025-12-25" # future date but status is "delivered"
296
+ issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="inconsistent_value",
297
+ description="Order date 2025-12-25 is in the future but status is 'delivered'",
298
+ difficulty=2.0))
299
+
300
  corrupted = _rows_to_csv([header] + data)
301
 
302
  return Task(
tests/test_environment.py CHANGED
@@ -228,16 +228,16 @@ class TestGradeFixes:
228
  assert result["fixes_correct"] >= 1
229
 
230
  def test_all_fixes_correct(self, easy_task):
231
- # Fix all 4 issues with exact values
232
  fixes = [
233
  (4, "name", "David Kim"),
234
  (7, "salary", "75000"),
235
  (9, "salary", "73000"),
236
- # Row 11 is duplicate — clean value for employee_id is "Bob Martinez" row
237
- # The duplicate is of row 2 (Bob Martinez), so the clean row 11 doesn't exist
238
  ]
239
  result = grade_fixes(fixes, easy_task)
240
- assert result["fix_score"] > 0.5 # at least 3/4 issues fixed
241
 
242
  def test_fix_score_bounded(self, easy_task):
243
  fixes = [(4, "name", "David Kim"), (99, "x", "bad")]
@@ -260,7 +260,7 @@ class TestDataQAEnvironment:
260
  assert obs.schema_description
261
  assert obs.validation_rules
262
  assert obs.task_description
263
- assert obs.num_issues_hint == 4
264
  assert obs.max_steps == 3
265
  assert obs.done is False
266
  assert obs.reward == 0.0
@@ -268,7 +268,7 @@ class TestDataQAEnvironment:
268
 
269
  def test_reset_medium(self, env):
270
  obs = env.reset(task_id="medium")
271
- assert obs.num_issues_hint == 6
272
 
273
  def test_reset_hard(self, env):
274
  obs = env.reset(task_id="hard")
@@ -277,12 +277,15 @@ class TestDataQAEnvironment:
277
  def test_step_identify_only(self, env):
278
  """Backward compatible: only issues, no fixes."""
279
  env.reset(task_id="easy")
 
280
  action = DataQAAction(
281
  issues=[
282
  "row:4,col:name,issue:missing_value",
283
  "row:7,col:salary,issue:wrong_type",
284
- "row:11,col:employee_id,issue:duplicate_row",
285
  "row:9,col:salary,issue:out_of_range",
 
 
286
  ],
287
  task_id="easy",
288
  )
@@ -291,30 +294,17 @@ class TestDataQAEnvironment:
291
  assert obs.reward >= 0.999 # identify-only uses identify_score directly
292
 
293
  def test_step_with_fixes_increases_reward(self, env):
294
- """Submitting correct fixes should increase reward beyond identify-only."""
295
  env.reset(task_id="easy")
296
- # Step 1: identify only
297
- action1 = DataQAAction(
298
- issues=[
299
- "row:4,col:name,issue:missing_value",
300
- "row:7,col:salary,issue:wrong_type",
301
- "row:11,col:employee_id,issue:duplicate_row",
302
- "row:9,col:salary,issue:out_of_range",
303
- ],
304
- task_id="easy",
305
- )
306
- obs1 = env.step(action1)
307
- score_identify = obs1.reward
308
-
309
- # Reset for fair comparison
310
- env.reset(task_id="easy")
311
- # Step with identify + fixes
312
- action2 = DataQAAction(
313
  issues=[
314
  "row:4,col:name,issue:missing_value",
315
  "row:7,col:salary,issue:wrong_type",
316
- "row:11,col:employee_id,issue:duplicate_row",
317
  "row:9,col:salary,issue:out_of_range",
 
 
318
  ],
319
  fixes=[
320
  "row:4,col:name,fix:David Kim",
@@ -323,11 +313,9 @@ class TestDataQAEnvironment:
323
  ],
324
  task_id="easy",
325
  )
326
- obs2 = env.step(action2)
327
- score_with_fixes = obs2.metadata["combined_reward"]
328
-
329
- # With correct fixes, combined should be close to 1.0
330
- assert score_with_fixes > 0.8
331
 
332
  def test_step_with_partial_issues(self, env):
333
  env.reset(task_id="easy")
@@ -426,12 +414,7 @@ class TestDataQAEnvironment:
426
  """Verify combined = IDENTIFY_WEIGHT * identify + FIX_WEIGHT * fix."""
427
  env.reset(task_id="easy")
428
  action = DataQAAction(
429
- issues=[
430
- "row:4,col:name,issue:missing_value",
431
- "row:7,col:salary,issue:wrong_type",
432
- "row:11,col:employee_id,issue:duplicate_row",
433
- "row:9,col:salary,issue:out_of_range",
434
- ],
435
  fixes=["row:4,col:name,fix:David Kim"],
436
  task_id="easy",
437
  )
@@ -458,13 +441,15 @@ class TestDataQAEnvironment:
458
  issues=[
459
  "row:4,col:name,issue:missing_value",
460
  "row:7,col:salary,issue:wrong_type",
461
- "row:11,col:employee_id,issue:duplicate_row",
462
  "row:9,col:salary,issue:out_of_range",
 
 
463
  ],
464
  task_id="easy",
465
  )
466
  obs = env.step(action)
467
- # identify_score should be ~1.0 since all issues found
468
  assert obs.reward >= 0.99
469
  # combined_reward equals identify_score when no fixes
470
  assert obs.metadata["combined_reward"] == obs.metadata["identify_score"]
 
228
  assert result["fixes_correct"] >= 1
229
 
230
  def test_all_fixes_correct(self, easy_task):
231
+ # Fix most issues with exact values
232
  fixes = [
233
  (4, "name", "David Kim"),
234
  (7, "salary", "75000"),
235
  (9, "salary", "73000"),
236
+ (15, "email", "oscar.rivera@company.com"),
237
+ (18, "start_date", "2022-01-19"),
238
  ]
239
  result = grade_fixes(fixes, easy_task)
240
+ assert result["fix_score"] > 0.7 # 5 out of 6 issues fixed (duplicate can't be fixed)
241
 
242
  def test_fix_score_bounded(self, easy_task):
243
  fixes = [(4, "name", "David Kim"), (99, "x", "bad")]
 
260
  assert obs.schema_description
261
  assert obs.validation_rules
262
  assert obs.task_description
263
+ assert obs.num_issues_hint == 6
264
  assert obs.max_steps == 3
265
  assert obs.done is False
266
  assert obs.reward == 0.0
 
268
 
269
  def test_reset_medium(self, env):
270
  obs = env.reset(task_id="medium")
271
+ assert obs.num_issues_hint == 8
272
 
273
  def test_reset_hard(self, env):
274
  obs = env.reset(task_id="hard")
 
277
  def test_step_identify_only(self, env):
278
  """Backward compatible: only issues, no fixes."""
279
  env.reset(task_id="easy")
280
+ # Submit all 6 correct issues for easy task
281
  action = DataQAAction(
282
  issues=[
283
  "row:4,col:name,issue:missing_value",
284
  "row:7,col:salary,issue:wrong_type",
285
+ "row:21,col:employee_id,issue:duplicate_row",
286
  "row:9,col:salary,issue:out_of_range",
287
+ "row:15,col:email,issue:inconsistent_value",
288
+ "row:18,col:start_date,issue:out_of_range",
289
  ],
290
  task_id="easy",
291
  )
 
294
  assert obs.reward >= 0.999 # identify-only uses identify_score directly
295
 
296
  def test_step_with_fixes_increases_reward(self, env):
297
+ """Submitting correct fixes should produce high combined reward."""
298
  env.reset(task_id="easy")
299
+ # All 6 issues + 3 fixes
300
+ action = DataQAAction(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  issues=[
302
  "row:4,col:name,issue:missing_value",
303
  "row:7,col:salary,issue:wrong_type",
304
+ "row:21,col:employee_id,issue:duplicate_row",
305
  "row:9,col:salary,issue:out_of_range",
306
+ "row:15,col:email,issue:inconsistent_value",
307
+ "row:18,col:start_date,issue:out_of_range",
308
  ],
309
  fixes=[
310
  "row:4,col:name,fix:David Kim",
 
313
  ],
314
  task_id="easy",
315
  )
316
+ obs = env.step(action)
317
+ # Perfect identify + partial fixes -> high combined reward
318
+ assert obs.metadata["combined_reward"] > 0.7
 
 
319
 
320
  def test_step_with_partial_issues(self, env):
321
  env.reset(task_id="easy")
 
414
  """Verify combined = IDENTIFY_WEIGHT * identify + FIX_WEIGHT * fix."""
415
  env.reset(task_id="easy")
416
  action = DataQAAction(
417
+ issues=["row:4,col:name,issue:missing_value"],
 
 
 
 
 
418
  fixes=["row:4,col:name,fix:David Kim"],
419
  task_id="easy",
420
  )
 
441
  issues=[
442
  "row:4,col:name,issue:missing_value",
443
  "row:7,col:salary,issue:wrong_type",
444
+ "row:21,col:employee_id,issue:duplicate_row",
445
  "row:9,col:salary,issue:out_of_range",
446
+ "row:15,col:email,issue:inconsistent_value",
447
+ "row:18,col:start_date,issue:out_of_range",
448
  ],
449
  task_id="easy",
450
  )
451
  obs = env.step(action)
452
+ # identify_score should be ~1.0 since all 6 issues found
453
  assert obs.reward >= 0.99
454
  # combined_reward equals identify_score when no fixes
455
  assert obs.metadata["combined_reward"] == obs.metadata["identify_score"]
tests/test_tasks.py CHANGED
@@ -49,8 +49,8 @@ class TestTaskEasy:
49
  def test_task_id(self, task):
50
  assert task.task_id == "easy"
51
 
52
- def test_has_4_issues(self, task):
53
- assert len(task.planted_issues) == 4
54
 
55
  def test_issue_types(self, task):
56
  types = {i.issue_type for i in task.planted_issues}
@@ -58,6 +58,7 @@ class TestTaskEasy:
58
  assert "wrong_type" in types
59
  assert "duplicate_row" in types
60
  assert "out_of_range" in types
 
61
 
62
  def test_corrupted_csv_differs_from_clean(self, task):
63
  assert task.corrupted_csv != task.clean_csv
@@ -87,8 +88,8 @@ class TestTaskMedium:
87
  def test_task_id(self, task):
88
  assert task.task_id == "medium"
89
 
90
- def test_has_6_issues(self, task):
91
- assert len(task.planted_issues) == 6
92
 
93
  def test_issue_types(self, task):
94
  types = {i.issue_type for i in task.planted_issues}
 
49
  def test_task_id(self, task):
50
  assert task.task_id == "easy"
51
 
52
+ def test_has_6_issues(self, task):
53
+ assert len(task.planted_issues) == 6
54
 
55
  def test_issue_types(self, task):
56
  types = {i.issue_type for i in task.planted_issues}
 
58
  assert "wrong_type" in types
59
  assert "duplicate_row" in types
60
  assert "out_of_range" in types
61
+ assert "inconsistent_value" in types
62
 
63
  def test_corrupted_csv_differs_from_clean(self, task):
64
  assert task.corrupted_csv != task.clean_csv
 
88
  def test_task_id(self, task):
89
  assert task.task_id == "medium"
90
 
91
+ def test_has_8_issues(self, task):
92
+ assert len(task.planted_issues) == 8
93
 
94
  def test_issue_types(self, task):
95
  types = {i.issue_type for i in task.planted_issues}