File size: 13,253 Bytes
9c195fe
 
 
 
 
 
 
42757ca
9c195fe
 
 
 
 
 
 
42757ca
9c195fe
 
42757ca
9c195fe
 
 
 
 
42757ca
9c195fe
 
 
 
 
 
 
 
 
 
42757ca
9c195fe
 
 
 
42757ca
9c195fe
 
 
 
 
 
 
 
 
42757ca
9c195fe
 
42757ca
9c195fe
 
 
 
 
 
 
 
 
42757ca
9c195fe
 
 
 
 
 
 
 
 
42757ca
9c195fe
 
 
 
 
 
 
 
 
42757ca
9c195fe
 
 
 
 
 
 
 
42757ca
9c195fe
 
 
 
 
 
 
 
42757ca
9c195fe
 
 
 
 
 
 
 
 
42757ca
9c195fe
 
 
 
42757ca
9c195fe
 
 
 
 
 
 
 
 
 
 
 
42757ca
9c195fe
 
42757ca
9c195fe
 
42757ca
9c195fe
 
42757ca
9c195fe
 
42757ca
9c195fe
 
42757ca
9c195fe
 
42757ca
9c195fe
 
42757ca
9c195fe
 
 
42757ca
9c195fe
 
42757ca
9c195fe
 
42757ca
9c195fe
 
 
42757ca
9c195fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42757ca
9c195fe
 
42757ca
9c195fe
 
 
 
 
 
 
 
 
 
 
 
 
42757ca
 
9c195fe
 
42757ca
9c195fe
 
593f876
42757ca
9c195fe
593f876
42757ca
9c195fe
 
 
 
 
 
 
42757ca
9c195fe
593f876
42757ca
9c195fe
 
 
 
 
 
 
42757ca
9c195fe
 
caa9970
42757ca
9c195fe
42757ca
9c195fe
 
 
 
 
 
 
 
 
 
42757ca
9c195fe
 
 
 
 
 
 
 
42757ca
d714fdb
42757ca
9c195fe
593f876
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import copy
import random
from typing import Any, Dict, List, Tuple


def _generate_task_easy(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
    rng = random.Random(seed)

    ground_truth = [
        {"id": 1, "name": "Alice Johnson", "email": "alice@example.com", "age": 30, "department": "Engineering"},
        {"id": 2, "name": "Bob Smith", "email": "bob@example.com", "age": 25, "department": "Marketing"},
        {"id": 3, "name": "Carol Williams", "email": "carol@example.com", "age": 35, "department": "Engineering"},
        {"id": 4, "name": "David Brown", "email": "david@example.com", "age": 28, "department": "Sales"},
        {"id": 5, "name": "Eve Davis", "email": "eve@example.com", "age": 32, "department": "Marketing"},
    ]

    dirty = copy.deepcopy(ground_truth)
    errors = []

    missing_configs = [
        (1, "email", ""),
        (2, "department", ""),
        (4, "name", ""),
    ]

    for row_idx, field, replacement in missing_configs:
        dirty[row_idx][field] = replacement
        errors.append({
            "error_type": "missing",
            "row": row_idx,
            "field": field,
            "current_value": replacement,
            "description": f"Row {row_idx}: '{field}' is missing/empty"
        })

    return dirty, ground_truth, errors


def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
    rng = random.Random(seed)

    ground_truth = [
        {"id": 1, "product": "Laptop Pro", "price": 999.99, "quantity": 50, "sku": "LP-001", "category": "Electronics"},
        {"id": 2, "product": "Wireless Mouse", "price": 29.99, "quantity": 200, "sku": "WM-002", "category": "Accessories"},
        {"id": 3, "product": "USB Cable", "price": 9.99, "quantity": 500, "sku": "UC-003", "category": "Accessories"},
        {"id": 4, "product": "Monitor 27in", "price": 449.99, "quantity": 30, "sku": "MN-004", "category": "Electronics"},
        {"id": 5, "product": "Keyboard", "price": 79.99, "quantity": 150, "sku": "KB-005", "category": "Accessories"},
        {"id": 6, "product": "Headphones", "price": 149.99, "quantity": 80, "sku": "HP-006", "category": "Audio"},
        {"id": 7, "product": "Webcam HD", "price": 59.99, "quantity": 120, "sku": "WC-007", "category": "Electronics"},
    ]

    dirty = copy.deepcopy(ground_truth)
    errors = []

    dirty[0]["price"] = "999.99"
    errors.append({
        "error_type": "type",
        "row": 0,
        "field": "price",
        "current_value": "999.99",
        "expected_type": "float",
        "description": "Row 0: 'price' should be float, got string '999.99'"
    })

    dirty[2]["quantity"] = "five hundred"
    errors.append({
        "error_type": "type",
        "row": 2,
        "field": "quantity",
        "current_value": "five hundred",
        "expected_type": "int",
        "description": "Row 2: 'quantity' should be int, got string 'five hundred'"
    })

    dirty[3]["sku"] = "mn004"
    errors.append({
        "error_type": "format",
        "row": 3,
        "field": "sku",
        "current_value": "mn004",
        "expected_format": "XX-NNN",
        "description": "Row 3: 'sku' should match format 'XX-NNN', got 'mn004'"
    })

    dirty[5]["category"] = ""
    errors.append({
        "error_type": "missing",
        "row": 5,
        "field": "category",
        "current_value": "",
        "description": "Row 5: 'category' is missing/empty"
    })

    dirty[4]["price"] = -79.99
    errors.append({
        "error_type": "range",
        "row": 4,
        "field": "price",
        "current_value": -79.99,
        "description": "Row 4: 'price' is negative (-79.99), should be positive"
    })

    dirty[6]["sku"] = "WM-002"
    errors.append({
        "error_type": "duplicate",
        "row": 6,
        "field": "sku",
        "current_value": "WM-002",
        "description": "Row 6: 'sku' value 'WM-002' duplicates row 1"
    })

    return dirty, ground_truth, errors


def _generate_task_hard(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
    rng = random.Random(seed)

    ground_truth = [
        {"id": 1, "customer": "Acme Corp", "email": "orders@acme.com", "amount": 1500.00, "currency": "USD", "status": "completed", "date": "2024-03-15", "region": "North America", "priority": "high"},
        {"id": 2, "customer": "GlobalTech", "email": "sales@globaltech.io", "amount": 2300.50, "currency": "EUR", "status": "pending", "date": "2024-03-16", "region": "Europe", "priority": "medium"},
        {"id": 3, "customer": "DataFlow Inc", "email": "info@dataflow.com", "amount": 750.00, "currency": "USD", "status": "completed", "date": "2024-03-17", "region": "North America", "priority": "low"},
        {"id": 4, "customer": "NexGen Labs", "email": "contact@nexgen.co", "amount": 4200.00, "currency": "GBP", "status": "shipped", "date": "2024-03-18", "region": "Europe", "priority": "high"},
        {"id": 5, "customer": "SwiftServe", "email": "hello@swiftserve.com", "amount": 890.25, "currency": "USD", "status": "pending", "date": "2024-03-19", "region": "Asia Pacific", "priority": "medium"},
        {"id": 6, "customer": "CloudNine", "email": "team@cloudnine.dev", "amount": 3100.00, "currency": "EUR", "status": "completed", "date": "2024-03-20", "region": "Europe", "priority": "high"},
        {"id": 7, "customer": "PeakTech", "email": "support@peaktech.com", "amount": 560.00, "currency": "USD", "status": "cancelled", "date": "2024-03-21", "region": "North America", "priority": "low"},
        {"id": 8, "customer": "Quantum AI", "email": "admin@quantumai.org", "amount": 6750.00, "currency": "USD", "status": "completed", "date": "2024-03-22", "region": "North America", "priority": "high"},
        {"id": 9, "customer": "EcoSmart", "email": "green@ecosmart.co", "amount": 1200.00, "currency": "AUD", "status": "shipped", "date": "2024-03-23", "region": "Asia Pacific", "priority": "medium"},
        {"id": 10, "customer": "BlueOcean", "email": "info@blueocean.net", "amount": 980.75, "currency": "USD", "status": "pending", "date": "2024-03-24", "region": "North America", "priority": "low"},
    ]

    dirty = copy.deepcopy(ground_truth)
    errors = []

    dirty[0]["email"] = ""
    errors.append({"error_type": "missing", "row": 0, "field": "email", "current_value": "", "description": "Row 0: 'email' is missing"})

    dirty[1]["amount"] = -2300.50
    errors.append({"error_type": "range", "row": 1, "field": "amount", "current_value": -2300.50, "description": "Row 1: 'amount' is negative"})

    dirty[2]["date"] = "03/17/2024"
    errors.append({"error_type": "format", "row": 2, "field": "date", "current_value": "03/17/2024", "expected_format": "YYYY-MM-DD", "description": "Row 2: 'date' wrong format, expected YYYY-MM-DD"})

    dirty[3]["amount"] = "4200"
    errors.append({"error_type": "type", "row": 3, "field": "amount", "current_value": "4200", "expected_type": "float", "description": "Row 3: 'amount' should be float, got string"})

    dirty[4]["status"] = "in-progress"
    errors.append({"error_type": "format", "row": 4, "field": "status", "current_value": "in-progress", "expected_format": "one of: completed, pending, shipped, cancelled", "description": "Row 4: 'status' invalid value 'in-progress'"})

    dirty[5]["region"] = ""
    errors.append({"error_type": "missing", "row": 5, "field": "region", "current_value": "", "description": "Row 5: 'region' is missing"})

    dirty[6]["customer"] = "Acme Corp"
    dirty[6]["email"] = "orders@acme.com"
    errors.append({"error_type": "duplicate", "row": 6, "field": "customer", "current_value": "Acme Corp", "description": "Row 6: 'customer' duplicates row 0"})

    dirty[7]["amount"] = 99999.99
    errors.append({"error_type": "range", "row": 7, "field": "amount", "current_value": 99999.99, "description": "Row 7: 'amount' exceeds maximum threshold (should be 6750.00)"})

    dirty[8]["currency"] = "AUSD"
    errors.append({"error_type": "format", "row": 8, "field": "currency", "current_value": "AUSD", "expected_format": "3-letter ISO code", "description": "Row 8: 'currency' invalid code 'AUSD'"})

    dirty[9]["priority"] = ""
    errors.append({"error_type": "missing", "row": 9, "field": "priority", "current_value": "", "description": "Row 9: 'priority' is missing"})

    return dirty, ground_truth, errors


TASK_REGISTRY = {
    "easy_missing_values": {
        "generator": _generate_task_easy,
        "name": "easy_missing_values",
        "description": "Fix missing values in a small employee dataset. Each empty field needs to be filled with the correct value.",
        "difficulty": "easy",
        "max_steps": 10,
        "hint": "Look for empty string values in the dataset. Use 'fix_missing' action with target_field and new_value to fill them in. Check the errors_found list for all issues."
    },
    "medium_mixed_errors": {
        "generator": _generate_task_medium,
        "name": "medium_mixed_errors",
        "description": "Fix type errors, format issues, missing values, range violations, and duplicates in a product inventory dataset.",
        "difficulty": "medium",
        "max_steps": 15,
        "hint": "Multiple error types: use fix_type for wrong data types, fix_format for pattern mismatches, fix_missing for empty fields, fix_range for out-of-bound values, and fix_duplicate for duplicate entries."
    },
    "hard_multi_constraint": {
        "generator": _generate_task_hard,
        "name": "hard_multi_constraint",
        "description": "Clean a complex customer orders dataset with 10 interrelated errors spanning missing data, type mismatches, format violations, range errors, and duplicates. Requires planning to resolve dependencies.",
        "difficulty": "hard",
        "max_steps": 20,
        "hint": "This dataset has 10 errors across 10 rows. Examine each error carefully. Some errors may need specific domain knowledge (e.g., valid statuses are: completed, pending, shipped, cancelled)."
    },
}


def get_task_names() -> List[str]:
    return list(TASK_REGISTRY.keys())


def generate_task(task_name: str, seed: int = 42) -> Dict[str, Any]:
    if task_name not in TASK_REGISTRY:
        raise ValueError(f"Unknown task: {task_name}. Available: {get_task_names()}")

    task_info = TASK_REGISTRY[task_name]
    dirty, ground_truth, errors = task_info["generator"](seed)

    return {
        "name": task_info["name"],
        "description": task_info["description"],
        "difficulty": task_info["difficulty"],
        "max_steps": task_info["max_steps"],
        "hint": task_info["hint"],
        "dataset": dirty,
        "ground_truth": ground_truth,
        "errors": errors,
        "field_names": list(ground_truth[0].keys()) if ground_truth else [],
    }


def grade_action(action_type: str, target_field: str, target_row: int,
                 new_value: str, dirty_dataset: List[Dict],
                 ground_truth: List[Dict], errors: List[Dict]) -> Tuple[float, str, bool]:
    total_errors = len(errors) if errors else 1

    if action_type == "validate":
        fixed = sum(1 for e in errors if e.get("fixed", False))
        return 0.01, f"Validation: {fixed}/{total_errors} errors fixed ({fixed/total_errors*100:.0f}%)", False

    if action_type == "skip":
        return 0.01, "Skipped current action", False

    matching_error = None
    for e in errors:
        if e.get("fixed", False):
            continue
        if e["row"] == target_row and e["field"] == target_field:
            matching_error = e
            break

    if matching_error is None:
        return 0.01, f"No unfixed error at row {target_row}, field '{target_field}'", False

    action_to_error_map = {
        "fix_missing": "missing",
        "fix_type": "type",
        "fix_range": "range",
        "fix_format": "format",
        "fix_duplicate": "duplicate",
    }

    expected_error_type = action_to_error_map.get(action_type, "")
    if expected_error_type != matching_error["error_type"]:
        return 0.01, f"Wrong action type '{action_type}' for error type '{matching_error['error_type']}'", False

    gt_value = ground_truth[target_row][target_field]

    is_correct = False
    try:
        if isinstance(gt_value, float):
            is_correct = abs(float(new_value) - gt_value) < 0.01
        elif isinstance(gt_value, int):
            is_correct = int(float(new_value)) == gt_value
        else:
            is_correct = str(new_value).strip() == str(gt_value).strip()
    except (ValueError, TypeError):
        is_correct = str(new_value).strip() == str(gt_value).strip()

    if is_correct:
        matching_error["fixed"] = True
        if isinstance(gt_value, float):
            dirty_dataset[target_row][target_field] = float(new_value)
        elif isinstance(gt_value, int):
            dirty_dataset[target_row][target_field] = int(float(new_value))
        else:
            dirty_dataset[target_row][target_field] = new_value

        reward = 0.8 / total_errors
        return reward, f"Fixed: row {target_row}, field '{target_field}' -> '{new_value}'", True
    else:
        return 0.01, f"Wrong value for row {target_row}, field '{target_field}'. Got '{new_value}', expected something else.", False