Spaces:
Sleeping
Sleeping
| """DataClean Environment: core logic for the data cleaning RL environment. | |
| Implements reset(), step(), state property following OpenEnv spec. | |
| All 10 action handlers fully implemented. Delta reward system. | |
| """ | |
| from __future__ import annotations | |
| import copy | |
| import re | |
| from collections import Counter | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Optional | |
| from uuid import uuid4 | |
| import logging | |
| from openenv.core.env_server import Environment | |
| logger = logging.getLogger(__name__) | |
| from dataclean_env.models import ( | |
| ActionResult, | |
| DataCleanAction, | |
| DataCleanObservation, | |
| DataCleanState, | |
| DataSummary, | |
| IssueGroup, | |
| QualityIssue, | |
| ) | |
| from dataclean_env.server.grader import DataCleanGrader | |
| from dataclean_env.server.tasks import get_task, list_tasks | |
| # US state name -> abbreviation mapping | |
| US_STATES: Dict[str, str] = { | |
| "alabama": "AL", "alaska": "AK", "arizona": "AZ", "arkansas": "AR", | |
| "california": "CA", "colorado": "CO", "connecticut": "CT", "delaware": "DE", | |
| "florida": "FL", "georgia": "GA", "hawaii": "HI", "idaho": "ID", | |
| "illinois": "IL", "indiana": "IN", "iowa": "IA", "kansas": "KS", | |
| "kentucky": "KY", "louisiana": "LA", "maine": "ME", "maryland": "MD", | |
| "massachusetts": "MA", "michigan": "MI", "minnesota": "MN", | |
| "mississippi": "MS", "missouri": "MO", "montana": "MT", "nebraska": "NE", | |
| "nevada": "NV", "new hampshire": "NH", "new jersey": "NJ", | |
| "new mexico": "NM", "new york": "NY", "north carolina": "NC", | |
| "north dakota": "ND", "ohio": "OH", "oklahoma": "OK", "oregon": "OR", | |
| "pennsylvania": "PA", "rhode island": "RI", "south carolina": "SC", | |
| "south dakota": "SD", "tennessee": "TN", "texas": "TX", "utah": "UT", | |
| "vermont": "VT", "virginia": "VA", "washington": "WA", | |
| "west virginia": "WV", "wisconsin": "WI", "wyoming": "WY", | |
| } | |
| # Date parsing formats (most specific first) | |
| DATE_PARSE_FORMATS = [ | |
| "%Y-%m-%d", # 2023-01-15 | |
| "%m/%d/%Y", # 01/15/2023 | |
| "%d-%m-%Y", # 15-01-2023 | |
| "%B %d, %Y", # January 15, 2023 | |
| "%b %d, %Y", # Jan 15, 2023 | |
| "%d %B %Y", # 15 January 2023 | |
| "%d-%b-%Y", # 15-Jan-2023 | |
| "%m-%d-%Y", # 01-15-2023 | |
| "%B %d %Y", # January 15 2023 | |
| "%d/%m/%Y", # 15/01/2023 | |
| "%Y/%m/%d", # 2023/01/15 | |
| ] | |
| # Per-action costs for the intervention budget system | |
| ACTION_COSTS: Dict[str, float] = { | |
| "fix_value": 1.0, | |
| "delete_row": 6.0, | |
| "fill_missing": 1.0, | |
| "standardize_format": 2.0, | |
| "merge_duplicates": 4.0, | |
| "flag_anomaly": 0.5, | |
| "split_column": 3.0, | |
| "rename_column": 0.5, | |
| "cast_type": 2.0, | |
| "escalate_to_human": 0.5, | |
| "mark_complete": 0.0, | |
| } | |
| # Budget allocation per difficulty level | |
| DIFFICULTY_BUDGETS: Dict[str, float] = { | |
| "easy": 50.0, | |
| "medium": 100.0, | |
| "hard": 150.0, | |
| } | |
| # Per-step penalty in delta reward computation | |
| STEP_COST: float = 0.005 | |
| # Default seed when none provided (deterministic fallback) | |
| DEFAULT_SEED: int = 42 | |
| class DataCleanEnvironment( | |
| Environment[DataCleanAction, DataCleanObservation, DataCleanState] | |
| ): | |
| """Data Cleaning environment for training AI agents.""" | |
| SUPPORTS_CONCURRENT_SESSIONS = False | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._state = DataCleanState() | |
| self._grader = DataCleanGrader() | |
| self._utility_probes: list = [] | |
| self._ambiguous_cells: list = [] | |
| self._task_name: str = "" | |
| self._last_grade_result = None | |
| self._next_row_id: int = 0 | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> DataCleanObservation: | |
| """Initialize a new data cleaning episode.""" | |
| task_id = kwargs.get("task_id", "easy_contacts") | |
| task = get_task(task_id) # raises KeyError/ValueError on unknown task_id | |
| actual_seed = seed if seed is not None else DEFAULT_SEED | |
| from dataclean_env.server.data_generator import generate_dirty_data | |
| dirty_data = generate_dirty_data( | |
| clean_data=task.ground_truth, | |
| corruptions=task.corruptions, | |
| seed=actual_seed, | |
| ) | |
| # Assign stable row_ids (persist through delete/merge within episode) | |
| self._next_row_id = 0 | |
| for row in dirty_data: | |
| row["_row_id"] = self._next_row_id | |
| self._next_row_id += 1 | |
| # Compute initial score (dirty data vs ground truth) | |
| initial_score = self._grader.grade( | |
| final_data=dirty_data, | |
| ground_truth=task.ground_truth, | |
| original_data=dirty_data, | |
| action_history=[], | |
| schema=task.schema, | |
| flagged_cells=[], | |
| escalated_cells=[], | |
| ambiguous_cells=list(getattr(task, "ambiguous_cells", [])), | |
| utility_probes=list(getattr(task, "utility_probes", [])), | |
| ).score | |
| budget = DIFFICULTY_BUDGETS.get(task.difficulty, 100.0) | |
| self._state = DataCleanState( | |
| episode_id=episode_id or str(uuid4()), | |
| step_count=0, | |
| task_id=task_id, | |
| difficulty=task.difficulty, | |
| current_data=copy.deepcopy(dirty_data), | |
| ground_truth=copy.deepcopy(task.ground_truth), | |
| original_dirty=copy.deepcopy(dirty_data), | |
| schema_def=task.schema, | |
| action_log=[], | |
| flagged_cells=[], | |
| escalated_cells=[], | |
| max_steps=task.max_steps, | |
| is_complete=False, | |
| previous_score=initial_score, | |
| initial_raw_score=initial_score, | |
| action_budget=budget, | |
| budget_spent=0.0, | |
| budget_remaining=budget, | |
| ) | |
| self._task_name = task.name | |
| self._ambiguous_cells: List[tuple[str, str]] = list( | |
| getattr(task, "ambiguous_cells", []) | |
| ) | |
| self._utility_probes = list(getattr(task, "utility_probes", [])) | |
| return self._build_observation(reward=None, done=False) | |
| def step( | |
| self, | |
| action: DataCleanAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> DataCleanObservation: | |
| """Process one cleaning action. Returns observation with delta reward.""" | |
| # Guard: episode already ended | |
| if self._state.is_complete: | |
| return self._build_observation( | |
| reward=0.0, done=True, | |
| ) | |
| self._state.step_count += 1 | |
| # Enforce budget: reject actions (except mark_complete) when exhausted | |
| cost = ACTION_COSTS.get(action.action_type, 1.0) | |
| if cost > 0 and self._state.budget_remaining < cost: | |
| result = { | |
| "action": action.action_type, | |
| "status": "error", | |
| "message": f"Budget exhausted ({self._state.budget_remaining:.1f} remaining, " | |
| f"action costs {cost:.1f})", | |
| "cells_modified": 0, | |
| } | |
| self._state.action_log.append(result) | |
| return self._build_observation( | |
| reward=-0.01, done=False, | |
| ) | |
| # Execute the action | |
| result = self._execute_action(action) | |
| self._state.action_log.append(result) | |
| # Deduct action cost from budget | |
| self._state.budget_spent += cost | |
| self._state.budget_remaining -= cost | |
| # Check termination | |
| is_done = ( | |
| action.action_type == "mark_complete" | |
| or self._state.step_count >= self._state.max_steps | |
| ) | |
| self._state.is_complete = is_done | |
| # Compute reward | |
| if is_done: | |
| # Terminal: return absolute final score | |
| grade_result = self._grader.grade( | |
| final_data=self._state.current_data, | |
| ground_truth=self._state.ground_truth, | |
| original_data=self._state.original_dirty, | |
| action_history=self._state.action_log, | |
| schema=self._state.schema_def, | |
| flagged_cells=self._state.flagged_cells, | |
| budget_spent=self._state.budget_spent, | |
| action_budget=self._state.action_budget, | |
| escalated_cells=self._state.escalated_cells, | |
| ambiguous_cells=self._ambiguous_cells, | |
| utility_probes=self._utility_probes, | |
| ) | |
| reward = grade_result.score | |
| self._last_grade_result = grade_result | |
| else: | |
| # Non-terminal: delta reward | |
| reward = self._compute_delta_reward(result) | |
| return self._build_observation(reward=reward, done=is_done) | |
| def state(self) -> DataCleanState: | |
| return self._state | |
| # ------------------------------------------------------------------ | |
| # Delta Reward System | |
| # ------------------------------------------------------------------ | |
| def _compute_delta_reward(self, action_result: Dict[str, Any]) -> float: | |
| """Compute reward = current_score - previous_score - step_cost. | |
| Penalizes no-ops and errors explicitly. | |
| """ | |
| # Explicit penalties for bad actions | |
| if action_result.get("status") == "error": | |
| return -0.02 | |
| if action_result.get("status") == "no_effect": | |
| return -0.01 | |
| if action_result.get("cells_modified", 0) == 0 and action_result.get("action") not in ("flag_anomaly", "escalate_to_human"): | |
| return -0.01 | |
| # Compute current score (raw, without normalization — delta is relative | |
| # so the baseline cancels out; normalization only at terminal grading) | |
| current_score = self._grader.grade( | |
| final_data=self._state.current_data, | |
| ground_truth=self._state.ground_truth, | |
| original_data=self._state.original_dirty, | |
| action_history=self._state.action_log, | |
| schema=self._state.schema_def, | |
| flagged_cells=self._state.flagged_cells, | |
| budget_spent=self._state.budget_spent, | |
| action_budget=self._state.action_budget, | |
| escalated_cells=self._state.escalated_cells, | |
| ambiguous_cells=self._ambiguous_cells, | |
| utility_probes=self._utility_probes, | |
| ).score | |
| delta = current_score - self._state.previous_score - STEP_COST | |
| self._state.previous_score = current_score | |
| return round(delta, 4) | |
| # ------------------------------------------------------------------ | |
| # Action Dispatch | |
| # ------------------------------------------------------------------ | |
| def _normalize_action_params(action_type: str, params: Dict[str, Any]) -> Dict[str, Any]: | |
| """Normalize common LLM param aliases to canonical names.""" | |
| p = dict(params) | |
| # Universal aliases | |
| if "row" in p and "row_id" not in p: | |
| p["row_id"] = p.pop("row") | |
| if "col" in p and "column" not in p: | |
| p["column"] = p.pop("col") | |
| # Action-specific aliases | |
| if action_type == "fix_value" and "value" in p and "new_value" not in p: | |
| p["new_value"] = p.pop("value") | |
| if action_type == "merge_duplicates": | |
| if "row_id_1" in p and "row_id1" not in p: | |
| p["row_id1"] = p.pop("row_id_1") | |
| if "row_id_2" in p and "row_id2" not in p: | |
| p["row_id2"] = p.pop("row_id_2") | |
| if "row1" in p and "row_id1" not in p: | |
| p["row_id1"] = p.pop("row1") | |
| if "row2" in p and "row_id2" not in p: | |
| p["row_id2"] = p.pop("row2") | |
| return p | |
| def _execute_action(self, action: DataCleanAction) -> Dict[str, Any]: | |
| """Dispatch action to the appropriate handler.""" | |
| handler = getattr(self, f"_action_{action.action_type}", None) | |
| if handler is None: | |
| return { | |
| "action": action.action_type, | |
| "status": "error", | |
| "message": f"Unknown action type: {action.action_type}", | |
| "cells_modified": 0, | |
| } | |
| # Normalize param aliases before dispatching | |
| normalized_params = self._normalize_action_params(action.action_type, action.params) | |
| try: | |
| return handler(normalized_params) | |
| except (KeyError, TypeError, IndexError) as exc: | |
| return { | |
| "action": action.action_type, | |
| "status": "error", | |
| "message": f"Invalid params: {exc}", | |
| "cells_modified": 0, | |
| } | |
| except Exception as exc: | |
| logger.exception("Unexpected error in action handler %s", action.action_type) | |
| return { | |
| "action": action.action_type, | |
| "status": "error", | |
| "message": str(exc), | |
| "cells_modified": 0, | |
| } | |
| # ------------------------------------------------------------------ | |
| # Row Lookup by Stable row_id | |
| # ------------------------------------------------------------------ | |
| def _find_row_by_id(self, row_id: int) -> tuple[int, Dict[str, Any] | None]: | |
| """Find the list index and row dict for a given stable row_id. | |
| Returns (index, row_dict) or (-1, None) if not found. | |
| """ | |
| for i, row in enumerate(self._state.current_data): | |
| if row.get("_row_id") == row_id: | |
| return i, row | |
| return -1, None | |
| # ------------------------------------------------------------------ | |
| # Action Handlers (10 total) — all use stable row_id | |
| # ------------------------------------------------------------------ | |
| def _action_fix_value(self, params: Dict[str, Any]) -> Dict[str, Any]: | |
| row_id = int(params["row_id"]) | |
| column = str(params["column"]) | |
| new_value = params["new_value"] | |
| idx, row = self._find_row_by_id(row_id) | |
| if row is None: | |
| return {"action": "fix_value", "status": "error", | |
| "message": f"row_id {row_id} not found", "cells_modified": 0} | |
| if column not in row or column.startswith("_"): | |
| return {"action": "fix_value", "status": "error", | |
| "message": f"Column '{column}' not found", "cells_modified": 0} | |
| old_value = row[column] | |
| if str(old_value) == str(new_value): | |
| return {"action": "fix_value", "status": "no_effect", | |
| "message": f"Value unchanged at (row_id={row_id}, '{column}')", "cells_modified": 0} | |
| row[column] = new_value | |
| return {"action": "fix_value", "status": "success", | |
| "message": f"(row_id={row_id}, '{column}'): '{old_value}' -> '{new_value}'", | |
| "cells_modified": 1, "old_value": old_value, "new_value": new_value, | |
| "row_id": row_id, "column": column} | |
| def _action_delete_row(self, params: Dict[str, Any]) -> Dict[str, Any]: | |
| row_id = int(params["row_id"]) | |
| idx, row = self._find_row_by_id(row_id) | |
| if row is None: | |
| return {"action": "delete_row", "status": "error", | |
| "message": f"row_id {row_id} not found", "cells_modified": 0} | |
| deleted = self._state.current_data.pop(idx) | |
| return {"action": "delete_row", "status": "success", | |
| "message": f"row_id={row_id} deleted", | |
| "cells_modified": len(deleted), "deleted_data": deleted, | |
| "row_id": row_id, "deleted_entity_id": deleted.get("_entity_id")} | |
| def _action_fill_missing(self, params: Dict[str, Any]) -> Dict[str, Any]: | |
| row_id = int(params["row_id"]) | |
| column = str(params["column"]) | |
| value = params["value"] | |
| idx, row = self._find_row_by_id(row_id) | |
| if row is None: | |
| return {"action": "fill_missing", "status": "error", | |
| "message": f"row_id {row_id} not found", "cells_modified": 0} | |
| if column not in row or column.startswith("_"): | |
| return {"action": "fill_missing", "status": "error", | |
| "message": f"Column '{column}' not found", "cells_modified": 0} | |
| current = row.get(column) | |
| if current is not None and str(current).strip() != "": | |
| return {"action": "fill_missing", "status": "error", | |
| "message": f"Cell (row_id={row_id}, '{column}') is not empty: '{current}'", | |
| "cells_modified": 0} | |
| row[column] = value | |
| return {"action": "fill_missing", "status": "success", | |
| "message": f"(row_id={row_id}, '{column}'): NULL -> '{value}'", | |
| "cells_modified": 1, "row_id": row_id, "column": column, "new_value": value} | |
| def _action_standardize_format(self, params: Dict[str, Any]) -> Dict[str, Any]: | |
| column = str(params["column"]) | |
| format_type = str(params["format_type"]) | |
| data = self._state.current_data | |
| modified = 0 | |
| errors: List[str] = [] | |
| for row in data: | |
| if column not in row or row[column] is None: | |
| continue | |
| try: | |
| new_val = self._apply_format(row[column], format_type) | |
| if str(new_val) != str(row[column]): | |
| row[column] = new_val | |
| modified += 1 | |
| except (ValueError, TypeError) as exc: | |
| errors.append(f"row_id={row.get('_row_id', '?')}: {exc}") | |
| if modified == 0 and not errors: | |
| return {"action": "standardize_format", "status": "no_effect", | |
| "message": f"No changes needed in '{column}' for {format_type}", | |
| "cells_modified": 0} | |
| msg = f"Formatted {modified} cell(s) in '{column}' to {format_type}" | |
| if errors: | |
| msg += f". {len(errors)} parse failure(s)." | |
| return {"action": "standardize_format", "status": "success", | |
| "message": msg, "cells_modified": modified} | |
| def _action_merge_duplicates(self, params: Dict[str, Any]) -> Dict[str, Any]: | |
| row_id1 = int(params["row_id1"]) | |
| row_id2 = int(params["row_id2"]) | |
| strategy = str(params.get("strategy", "merge_prefer_nonnull")) | |
| if row_id1 == row_id2: | |
| return {"action": "merge_duplicates", "status": "error", | |
| "message": "Cannot merge a row with itself", "cells_modified": 0} | |
| idx1, r1 = self._find_row_by_id(row_id1) | |
| idx2, r2 = self._find_row_by_id(row_id2) | |
| if r1 is None or r2 is None: | |
| missing = row_id1 if r1 is None else row_id2 | |
| return {"action": "merge_duplicates", "status": "error", | |
| "message": f"row_id {missing} not found", "cells_modified": 0} | |
| # Track entity IDs for penalty checking | |
| eid1 = r1.get("_entity_id", "") | |
| eid2 = r2.get("_entity_id", "") | |
| merged = self._merge_rows(r1, r2, strategy) | |
| # Merged row keeps the first row's entity_id and row_id | |
| merged["_entity_id"] = eid1 | |
| merged["_row_id"] = r1["_row_id"] | |
| # Remove both, insert merged at first position | |
| data = self._state.current_data | |
| lo_idx = min(idx1, idx2) | |
| hi_idx = max(idx1, idx2) | |
| data.pop(hi_idx) | |
| data.pop(lo_idx) | |
| data.insert(lo_idx, merged) | |
| return {"action": "merge_duplicates", "status": "success", | |
| "message": f"Merged row_id={row_id1} and row_id={row_id2} using '{strategy}'", | |
| "cells_modified": len(merged), | |
| "row_id1": row_id1, "row_id2": row_id2, | |
| "entity_id1": eid1, "entity_id2": eid2, | |
| "deleted_entity_id": eid2, | |
| "strategy": strategy} | |
| def _action_flag_anomaly(self, params: Dict[str, Any]) -> Dict[str, Any]: | |
| row_id = int(params["row_id"]) | |
| column = str(params["column"]) | |
| reason = str(params.get("reason", "")) | |
| idx, row = self._find_row_by_id(row_id) | |
| if row is None: | |
| return {"action": "flag_anomaly", "status": "error", | |
| "message": f"row_id {row_id} not found", "cells_modified": 0} | |
| self._state.flagged_cells.append( | |
| {"row_id": row_id, "column": column, "reason": reason} | |
| ) | |
| return {"action": "flag_anomaly", "status": "success", | |
| "message": f"Flagged (row_id={row_id}, '{column}'): {reason}", | |
| "cells_modified": 0} | |
| def _action_escalate_to_human(self, params: Dict[str, Any]) -> Dict[str, Any]: | |
| """Escalate a cell to human review -- agent signals it is uncertain.""" | |
| row_id = int(params["row_id"]) | |
| column = str(params["column"]) | |
| confidence = float(params.get("confidence", 0.5)) | |
| reason = str(params.get("reason", "")) | |
| idx, row = self._find_row_by_id(row_id) | |
| if row is None: | |
| return {"action": "escalate_to_human", "status": "error", | |
| "message": f"row_id {row_id} not found", "cells_modified": 0} | |
| self._state.escalated_cells.append({ | |
| "row_id": row_id, "column": column, | |
| "confidence": confidence, "reason": reason, | |
| }) | |
| return {"action": "escalate_to_human", "status": "success", | |
| "message": f"Escalated (row_id={row_id}, '{column}'): {reason} (confidence={confidence})", | |
| "cells_modified": 0} | |
| def _action_split_column(self, params: Dict[str, Any]) -> Dict[str, Any]: | |
| column = str(params["column"]) | |
| delimiter = str(params["delimiter"]) | |
| new_names = list(params["new_names"]) | |
| data = self._state.current_data | |
| modified = 0 | |
| for row in data: | |
| if column not in row or row[column] is None: | |
| continue | |
| parts = str(row[column]).split(delimiter, maxsplit=len(new_names) - 1) | |
| for i, name in enumerate(new_names): | |
| row[name] = parts[i].strip() if i < len(parts) else None | |
| del row[column] | |
| modified += 1 | |
| if modified == 0: | |
| return {"action": "split_column", "status": "no_effect", | |
| "message": f"Column '{column}' not found or all null", "cells_modified": 0} | |
| return {"action": "split_column", "status": "success", | |
| "message": f"Split '{column}' into {new_names} ({modified} rows)", | |
| "cells_modified": modified} | |
| def _action_rename_column(self, params: Dict[str, Any]) -> Dict[str, Any]: | |
| old_name = str(params["old_name"]) | |
| new_name = str(params["new_name"]) | |
| data = self._state.current_data | |
| if not data or old_name not in data[0]: | |
| return {"action": "rename_column", "status": "error", | |
| "message": f"Column '{old_name}' not found", "cells_modified": 0} | |
| if data and new_name in data[0]: | |
| return {"action": "rename_column", "status": "error", | |
| "message": f"Column '{new_name}' already exists", "cells_modified": 0} | |
| if new_name.startswith("_"): | |
| return {"action": "rename_column", "status": "error", | |
| "message": f"Column names starting with '_' are reserved", "cells_modified": 0} | |
| for row in data: | |
| if old_name in row: | |
| row[new_name] = row.pop(old_name) | |
| return {"action": "rename_column", "status": "success", | |
| "message": f"Renamed '{old_name}' -> '{new_name}'", | |
| "cells_modified": len(data)} | |
| def _action_cast_type(self, params: Dict[str, Any]) -> Dict[str, Any]: | |
| column = str(params["column"]) | |
| target_type = str(params["target_type"]) | |
| valid_types = {"int", "float", "str", "bool", "date"} | |
| if target_type not in valid_types: | |
| return {"action": "cast_type", "status": "error", | |
| "message": f"Unknown type '{target_type}'. Valid: {sorted(valid_types)}", | |
| "cells_modified": 0} | |
| data = self._state.current_data | |
| modified = 0 | |
| nullified = 0 | |
| for row in data: | |
| if column not in row or row[column] is None: | |
| continue | |
| try: | |
| row[column] = self._cast_value(row[column], target_type) | |
| modified += 1 | |
| except (ValueError, TypeError): | |
| row[column] = None | |
| nullified += 1 | |
| msg = f"Cast {modified} cell(s) in '{column}' to {target_type}" | |
| if nullified: | |
| msg += f" ({nullified} failed -> null)" | |
| status = "success" if modified > 0 else ("error" if nullified > 0 else "no_effect") | |
| return {"action": "cast_type", "status": status, | |
| "message": msg, "cells_modified": modified + nullified} | |
| def _action_mark_complete(self, params: Dict[str, Any]) -> Dict[str, Any]: | |
| return {"action": "mark_complete", "status": "success", | |
| "message": "Agent signaled completion", "cells_modified": 0} | |
| # ------------------------------------------------------------------ | |
| # Format Standardization (8 types, fully implemented) | |
| # ------------------------------------------------------------------ | |
| def _apply_format(self, value: Any, format_type: str) -> Any: | |
| """Apply format transformation to a single value.""" | |
| val_str = str(value).strip() | |
| if not val_str: | |
| return value | |
| if format_type == "date:YYYY-MM-DD": | |
| return self._format_date_iso(val_str) | |
| elif format_type == "phone:US": | |
| return self._format_phone_us(val_str) | |
| elif format_type == "phone:E164": | |
| return self._format_phone_e164(val_str) | |
| elif format_type == "name:title_case": | |
| return val_str.title() | |
| elif format_type == "email:lowercase": | |
| return val_str.lower() | |
| elif format_type == "zip:5digit": | |
| return self._format_zip_5digit(val_str) | |
| elif format_type == "currency:float": | |
| return self._format_currency_float(val_str) | |
| elif format_type == "state:abbreviation": | |
| return self._format_state_abbrev(val_str) | |
| else: | |
| raise ValueError(f"Unknown format type: {format_type}") | |
| def _format_date_iso(self, val: str) -> str: | |
| """Parse various date formats and return YYYY-MM-DD.""" | |
| for fmt in DATE_PARSE_FORMATS: | |
| try: | |
| dt = datetime.strptime(val.strip(), fmt) | |
| return dt.strftime("%Y-%m-%d") | |
| except ValueError: | |
| continue | |
| raise ValueError(f"Cannot parse date: '{val}'") | |
| def _format_phone_us(self, val: str) -> str: | |
| """Normalize phone to (XXX) XXX-XXXX format.""" | |
| digits = re.sub(r"\D", "", val) | |
| if digits.startswith("1") and len(digits) == 11: | |
| digits = digits[1:] | |
| if len(digits) != 10: | |
| raise ValueError(f"Phone must have 10 digits, got {len(digits)}: '{val}'") | |
| return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}" | |
| def _format_phone_e164(self, val: str) -> str: | |
| """Normalize phone to +1XXXXXXXXXX format.""" | |
| digits = re.sub(r"\D", "", val) | |
| if digits.startswith("1") and len(digits) == 11: | |
| digits = digits[1:] | |
| if len(digits) != 10: | |
| raise ValueError(f"Phone must have 10 digits, got {len(digits)}: '{val}'") | |
| return f"+1{digits}" | |
| def _format_zip_5digit(self, val: str) -> str: | |
| """Normalize ZIP to 5 digits (pad or truncate).""" | |
| digits = re.sub(r"\D", "", val.split("-")[0]) | |
| if not digits: | |
| raise ValueError(f"No digits in ZIP: '{val}'") | |
| return digits[:5].zfill(5) | |
| def _format_currency_float(self, val: str) -> float: | |
| """Parse currency string to float. '$1,234.56' -> 1234.56""" | |
| cleaned = val.replace("$", "").replace(",", "").strip() | |
| if cleaned.lower().endswith("k"): | |
| return float(cleaned[:-1]) * 1000 | |
| return float(cleaned) | |
| def _format_state_abbrev(self, val: str) -> str: | |
| """Convert full state name to 2-letter abbreviation.""" | |
| if len(val) == 2 and val.upper() in US_STATES.values(): | |
| return val.upper() | |
| lower = val.strip().lower() | |
| if lower in US_STATES: | |
| return US_STATES[lower] | |
| raise ValueError(f"Unknown state: '{val}'") | |
| # ------------------------------------------------------------------ | |
| # Row Merging (all strategies) | |
| # ------------------------------------------------------------------ | |
| def _merge_rows(self, r1: Dict, r2: Dict, strategy: str) -> Dict: | |
| """Merge two rows according to the given strategy.""" | |
| if strategy == "keep_first": | |
| return copy.deepcopy(r1) | |
| elif strategy == "keep_second": | |
| return copy.deepcopy(r2) | |
| elif strategy == "merge_prefer_nonnull": | |
| merged: Dict[str, Any] = {} | |
| for key in dict.fromkeys(list(r1.keys()) + list(r2.keys())): | |
| v1 = r1.get(key) | |
| v2 = r2.get(key) | |
| if v1 is not None and str(v1).strip(): | |
| merged[key] = v1 | |
| elif v2 is not None and str(v2).strip(): | |
| merged[key] = v2 | |
| else: | |
| merged[key] = v1 | |
| return merged | |
| elif strategy == "merge_prefer_row1": | |
| merged = copy.deepcopy(r2) | |
| for key, val in r1.items(): | |
| if val is not None and str(val).strip(): | |
| merged[key] = val | |
| return merged | |
| elif strategy == "merge_prefer_row2": | |
| merged = copy.deepcopy(r1) | |
| for key, val in r2.items(): | |
| if val is not None and str(val).strip(): | |
| merged[key] = val | |
| return merged | |
| else: | |
| raise ValueError(f"Unknown merge strategy: '{strategy}'") | |
| # ------------------------------------------------------------------ | |
| # Type Casting | |
| # ------------------------------------------------------------------ | |
| def _cast_value(self, value: Any, target_type: str) -> Any: | |
| """Cast a value to the target type.""" | |
| val_str = str(value).strip() | |
| if target_type == "int": | |
| return int(float(val_str.replace(",", "").replace("$", ""))) | |
| elif target_type == "float": | |
| return float(val_str.replace(",", "").replace("$", "")) | |
| elif target_type == "str": | |
| return val_str | |
| elif target_type == "bool": | |
| return val_str.lower() in ("true", "1", "yes", "y") | |
| elif target_type == "date": | |
| return self._format_date_iso(val_str) | |
| else: | |
| raise ValueError(f"Unknown target type: '{target_type}'") | |
| # ------------------------------------------------------------------ | |
| # Observation Builder | |
| # ------------------------------------------------------------------ | |
| def _build_observation( | |
| self, reward: float | None, done: bool | |
| ) -> DataCleanObservation: | |
| """Build issue-first observation from current state.""" | |
| data = self._state.current_data | |
| columns = list(data[0].keys()) if data else [] | |
| # Filter out internal fields EXCEPT _row_id (renamed to row_id for agent) | |
| hidden = {"_entity_id"} | |
| visible_columns = ["row_id"] + [c for c in columns if c not in hidden and c != "_row_id"] | |
| rows = [ | |
| [row.get("_row_id")] + [row.get(col) for col in columns if col not in hidden and col != "_row_id"] | |
| for row in data | |
| ] | |
| # Quality analysis | |
| quality_issues = self._analyze_quality() | |
| issue_groups = self._group_issues(quality_issues) | |
| # Data summary — count nulls using INTERNAL column names (not aliases) | |
| internal_cols = [c for c in columns if c not in hidden and c != "_row_id"] | |
| null_count = sum( | |
| 1 for row in data for col in internal_cols | |
| if row.get(col) is None | |
| ) | |
| data_summary = DataSummary( | |
| row_count=len(data), | |
| column_count=len(visible_columns), | |
| total_cells=len(visible_columns) * len(data), | |
| null_count=null_count, | |
| issue_count=len(quality_issues), | |
| columns=visible_columns, | |
| dtypes={ | |
| col: self._state.schema_def.get("expected_types", {}).get(col, "str") | |
| for col in visible_columns | |
| }, | |
| ) | |
| # Recent actions (last 5) | |
| recent = [ | |
| ActionResult( | |
| action=a.get("action", ""), | |
| status=a.get("status", ""), | |
| message=a.get("message", ""), | |
| cells_modified=a.get("cells_modified", 0), | |
| ) | |
| for a in self._state.action_log[-5:] | |
| ] | |
| # Build metadata with grade breakdown when episode ends | |
| metadata: Dict[str, Any] = {} | |
| grade = getattr(self, "_last_grade_result", None) | |
| if done and grade is not None: | |
| metadata = { | |
| "grade_breakdown": { | |
| "accuracy": grade.accuracy, | |
| "completeness": grade.completeness, | |
| "format_consistency": grade.format_consistency, | |
| "row_correctness": grade.row_correctness, | |
| "efficiency": grade.efficiency, | |
| "utility_score": grade.utility_score, | |
| "penalties": grade.penalties, | |
| "bonuses": grade.bonuses, | |
| }, | |
| "utility_details": grade.utility_details, | |
| } | |
| return DataCleanObservation( | |
| done=done, | |
| reward=reward, | |
| metadata=metadata, | |
| data_summary=data_summary, | |
| quality_issues=quality_issues[:20], # Cap at 20 for readability | |
| issue_groups=issue_groups, | |
| issues_remaining=len(quality_issues), | |
| columns=visible_columns, | |
| rows=rows, | |
| row_count=len(data), | |
| schema_info=self._state.schema_def, | |
| step_number=self._state.step_count, | |
| max_steps=self._state.max_steps, | |
| steps_remaining=self._state.max_steps - self._state.step_count, | |
| budget_spent=self._state.budget_spent, | |
| budget_remaining=self._state.budget_remaining, | |
| action_costs=ACTION_COSTS, | |
| last_action_result=recent[-1] if recent else None, | |
| recent_actions=recent, | |
| task_id=self._state.task_id, | |
| task_name=getattr(self, "_task_name", self._state.task_id), | |
| difficulty=self._state.difficulty, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Quality Analysis | |
| # ------------------------------------------------------------------ | |
| def _analyze_quality(self) -> List[QualityIssue]: | |
| """Analyze current data and return detected quality issues.""" | |
| issues: List[QualityIssue] = [] | |
| schema = self._state.schema_def | |
| data = self._state.current_data | |
| constraints = schema.get("constraints", {}) | |
| for row in data: | |
| rid = row.get("_row_id", 0) | |
| for col in [c for c in row if not c.startswith("_")]: | |
| val = row.get(col) | |
| col_constraints = constraints.get(col, {}) | |
| # Null check | |
| if val is None and col_constraints.get("not_null"): | |
| issues.append(QualityIssue( | |
| row_id=rid, column=col, issue_type="null", | |
| description="Required field is null", | |
| )) | |
| if val is None: | |
| continue | |
| # Format check | |
| fmt = col_constraints.get("format") | |
| if fmt and not self._matches_format(val, fmt): | |
| issues.append(QualityIssue( | |
| row_id=rid, column=col, issue_type="format", | |
| description=f"Does not match format: {fmt}", | |
| suggestion=f"Use standardize_format('{col}', '{self._suggest_format_type(fmt)}')", | |
| )) | |
| # Allowed values | |
| allowed = col_constraints.get("allowed_values") | |
| if allowed and str(val) not in allowed: | |
| issues.append(QualityIssue( | |
| row_id=rid, column=col, issue_type="type_violation", | |
| description=f"Value '{val}' not in allowed values", | |
| )) | |
| # Duplicate detection | |
| issues.extend(self._detect_potential_duplicates()) | |
| # Cross-field validation (for hard mode) | |
| issues.extend(self._detect_cross_field_issues()) | |
| return issues | |
| def _detect_cross_field_issues(self) -> List[QualityIssue]: | |
| """Detect cross-field inconsistencies: zip/city, date relationships, insurance ID prefixes.""" | |
| issues: List[QualityIssue] = [] | |
| data = self._state.current_data | |
| schema = self._state.schema_def | |
| cross_field_rules = schema.get("cross_field_rules", {}) | |
| # Rule: zip_city_match — zip code should correspond to the city | |
| zip_city_map = cross_field_rules.get("zip_city_map", {}) | |
| if zip_city_map: | |
| for row in data: | |
| rid = row.get("_row_id", 0) | |
| zip_val = str(row.get("zip", row.get("office_zip", ""))).strip() | |
| city_val = str(row.get("city", row.get("office_city", ""))).strip().lower() | |
| if zip_val in zip_city_map: | |
| expected_city = zip_city_map[zip_val].lower() | |
| if city_val and city_val != expected_city: | |
| issues.append(QualityIssue( | |
| row_id=rid, column="zip", | |
| issue_type="cross_field", | |
| description=f"ZIP '{zip_val}' should map to '{zip_city_map[zip_val]}', got '{row.get('city', row.get('office_city', ''))}'", | |
| suggestion=f"fix_value(row_id={rid}, column='city', new_value='{zip_city_map[zip_val]}')", | |
| )) | |
| # Rule: date_order — dob must be before last_visit_date | |
| if "dob" in schema.get("expected_types", {}) and "last_visit_date" in schema.get("expected_types", {}): | |
| for row in data: | |
| rid = row.get("_row_id", 0) | |
| dob = row.get("dob") | |
| visit = row.get("last_visit_date") | |
| if dob and visit: | |
| try: | |
| dob_dt = datetime.strptime(str(dob), "%Y-%m-%d") | |
| visit_dt = datetime.strptime(str(visit), "%Y-%m-%d") | |
| if dob_dt > visit_dt: | |
| issues.append(QualityIssue( | |
| row_id=rid, column="dob", | |
| issue_type="cross_field", | |
| description=f"DOB '{dob}' is after last_visit_date '{visit}'", | |
| )) | |
| if dob_dt > datetime.now(): | |
| issues.append(QualityIssue( | |
| row_id=rid, column="dob", | |
| issue_type="cross_field", | |
| description=f"DOB '{dob}' is in the future", | |
| )) | |
| except ValueError: | |
| pass | |
| # Rule: insurance_prefix — insurance_id prefix must match provider | |
| prefix_map = cross_field_rules.get("insurance_prefix_map", {}) | |
| if prefix_map: | |
| for row in data: | |
| rid = row.get("_row_id", 0) | |
| provider = str(row.get("insurance_provider", "")).strip() | |
| ins_id = str(row.get("insurance_id", "")).strip() | |
| if provider and ins_id and provider in prefix_map: | |
| expected_prefix = prefix_map[provider] | |
| if not ins_id.startswith(expected_prefix): | |
| issues.append(QualityIssue( | |
| row_id=rid, column="insurance_id", | |
| issue_type="cross_field", | |
| description=f"Insurance ID '{ins_id}' should start with '{expected_prefix}' for provider '{provider}'", | |
| )) | |
| return issues | |
| def _detect_potential_duplicates(self) -> List[QualityIssue]: | |
| """Detect potential duplicate rows by email, phone, or name similarity.""" | |
| issues: List[QualityIssue] = [] | |
| data = self._state.current_data | |
| # Check by email | |
| email_index: Dict[str, List[int]] = {} | |
| for row in data: | |
| rid = row.get("_row_id", 0) | |
| email = row.get("email") | |
| if email and str(email).strip(): | |
| key = str(email).strip().lower() | |
| email_index.setdefault(key, []).append(rid) | |
| for email, row_ids in email_index.items(): | |
| if len(row_ids) > 1: | |
| issues.append(QualityIssue( | |
| row_id=row_ids[0], column="email", issue_type="duplicate", | |
| description=f"Rows {row_ids} share email '{email}'", | |
| suggestion=f"Consider merge_duplicates(row_id1={row_ids[0]}, row_id2={row_ids[1]}, strategy='merge_prefer_nonnull')", | |
| )) | |
| # Check by phone (digit-only comparison) | |
| phone_index: Dict[str, List[int]] = {} | |
| for row in data: | |
| rid = row.get("_row_id", 0) | |
| phone = row.get("phone") | |
| if phone and str(phone).strip(): | |
| digits = re.sub(r"\D", "", str(phone)) | |
| if digits.startswith("1") and len(digits) == 11: | |
| digits = digits[1:] | |
| if len(digits) == 10: | |
| phone_index.setdefault(digits, []).append(rid) | |
| for digits, row_ids in phone_index.items(): | |
| if len(row_ids) > 1: | |
| # Avoid duplicate issues if already flagged by email | |
| issues.append(QualityIssue( | |
| row_id=row_ids[0], column="phone", issue_type="duplicate", | |
| description=f"Rows {row_ids} share phone digits '{digits}'", | |
| )) | |
| return issues | |
| def _group_issues(self, issues: List[QualityIssue]) -> List[IssueGroup]: | |
| """Group issues by type for compact display.""" | |
| type_counter: Dict[str, List[QualityIssue]] = {} | |
| for issue in issues: | |
| type_counter.setdefault(issue.issue_type, []).append(issue) | |
| return [ | |
| IssueGroup( | |
| issue_type=itype, | |
| count=len(items), | |
| examples=items[:3], # Show max 3 examples per type | |
| ) | |
| for itype, items in sorted(type_counter.items()) | |
| ] | |
| def _matches_format(self, value: Any, format_spec: str) -> bool: | |
| """Check if a value matches the expected format.""" | |
| val_str = str(value) | |
| format_patterns: Dict[str, str] = { | |
| "YYYY-MM-DD": r"^\d{4}-\d{2}-\d{2}$", | |
| "(XXX) XXX-XXXX": r"^\(\d{3}\) \d{3}-\d{4}$", | |
| "email": r"^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$", | |
| "5_digit": r"^\d{5}$", | |
| "+1XXXXXXXXXX": r"^\+1\d{10}$", | |
| } | |
| pattern = format_patterns.get(format_spec) | |
| if pattern: | |
| return bool(re.match(pattern, val_str)) | |
| return True | |
| def _suggest_format_type(self, format_spec: str) -> str: | |
| """Suggest the standardize_format type for a given format spec.""" | |
| mapping = { | |
| "YYYY-MM-DD": "date:YYYY-MM-DD", | |
| "(XXX) XXX-XXXX": "phone:US", | |
| "email": "email:lowercase", | |
| "5_digit": "zip:5digit", | |
| "+1XXXXXXXXXX": "phone:E164", | |
| } | |
| return mapping.get(format_spec, format_spec) | |
| # ------------------------------------------------------------------ | |
| # Metadata | |
| # ------------------------------------------------------------------ | |
| def get_metadata(self): # type: ignore[override] | |
| from openenv.core.env_server.types import EnvironmentMetadata | |
| return EnvironmentMetadata( | |
| name="dataclean_env", | |
| description=( | |
| "Data Cleaning environment for training AI agents to clean " | |
| "messy tabular data. Supports 3 difficulty levels (easy, medium, hard) " | |
| "with deterministic grading via cell-by-cell comparison against ground truth." | |
| ), | |
| version="0.1.0", | |
| ) | |