"""DataClean Environment: core logic for the data cleaning RL environment. Implements reset(), step(), state property following OpenEnv spec. All 10 action handlers fully implemented. Delta reward system. """ from __future__ import annotations import copy import re from collections import Counter from datetime import datetime from typing import Any, Dict, List, Optional from uuid import uuid4 import logging from openenv.core.env_server import Environment logger = logging.getLogger(__name__) from dataclean_env.models import ( ActionResult, DataCleanAction, DataCleanObservation, DataCleanState, DataSummary, IssueGroup, QualityIssue, ) from dataclean_env.server.grader import DataCleanGrader from dataclean_env.server.tasks import get_task, list_tasks # US state name -> abbreviation mapping US_STATES: Dict[str, str] = { "alabama": "AL", "alaska": "AK", "arizona": "AZ", "arkansas": "AR", "california": "CA", "colorado": "CO", "connecticut": "CT", "delaware": "DE", "florida": "FL", "georgia": "GA", "hawaii": "HI", "idaho": "ID", "illinois": "IL", "indiana": "IN", "iowa": "IA", "kansas": "KS", "kentucky": "KY", "louisiana": "LA", "maine": "ME", "maryland": "MD", "massachusetts": "MA", "michigan": "MI", "minnesota": "MN", "mississippi": "MS", "missouri": "MO", "montana": "MT", "nebraska": "NE", "nevada": "NV", "new hampshire": "NH", "new jersey": "NJ", "new mexico": "NM", "new york": "NY", "north carolina": "NC", "north dakota": "ND", "ohio": "OH", "oklahoma": "OK", "oregon": "OR", "pennsylvania": "PA", "rhode island": "RI", "south carolina": "SC", "south dakota": "SD", "tennessee": "TN", "texas": "TX", "utah": "UT", "vermont": "VT", "virginia": "VA", "washington": "WA", "west virginia": "WV", "wisconsin": "WI", "wyoming": "WY", } # Date parsing formats (most specific first) DATE_PARSE_FORMATS = [ "%Y-%m-%d", # 2023-01-15 "%m/%d/%Y", # 01/15/2023 "%d-%m-%Y", # 15-01-2023 "%B %d, %Y", # January 15, 2023 "%b %d, %Y", # Jan 15, 2023 "%d %B %Y", # 15 January 2023 "%d-%b-%Y", # 15-Jan-2023 "%m-%d-%Y", # 01-15-2023 "%B %d %Y", # January 15 2023 "%d/%m/%Y", # 15/01/2023 "%Y/%m/%d", # 2023/01/15 ] # Per-action costs for the intervention budget system ACTION_COSTS: Dict[str, float] = { "fix_value": 1.0, "delete_row": 6.0, "fill_missing": 1.0, "standardize_format": 2.0, "merge_duplicates": 4.0, "flag_anomaly": 0.5, "split_column": 3.0, "rename_column": 0.5, "cast_type": 2.0, "escalate_to_human": 0.5, "mark_complete": 0.0, } # Budget allocation per difficulty level DIFFICULTY_BUDGETS: Dict[str, float] = { "easy": 50.0, "medium": 100.0, "hard": 150.0, } # Per-step penalty in delta reward computation STEP_COST: float = 0.005 # Default seed when none provided (deterministic fallback) DEFAULT_SEED: int = 42 class DataCleanEnvironment( Environment[DataCleanAction, DataCleanObservation, DataCleanState] ): """Data Cleaning environment for training AI agents.""" SUPPORTS_CONCURRENT_SESSIONS = False def __init__(self) -> None: super().__init__() self._state = DataCleanState() self._grader = DataCleanGrader() self._utility_probes: list = [] self._ambiguous_cells: list = [] self._task_name: str = "" self._last_grade_result = None self._next_row_id: int = 0 def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, ) -> DataCleanObservation: """Initialize a new data cleaning episode.""" task_id = kwargs.get("task_id", "easy_contacts") task = get_task(task_id) # raises KeyError/ValueError on unknown task_id actual_seed = seed if seed is not None else DEFAULT_SEED from dataclean_env.server.data_generator import generate_dirty_data dirty_data = generate_dirty_data( clean_data=task.ground_truth, corruptions=task.corruptions, seed=actual_seed, ) # Assign stable row_ids (persist through delete/merge within episode) self._next_row_id = 0 for row in dirty_data: row["_row_id"] = self._next_row_id self._next_row_id += 1 # Compute initial score (dirty data vs ground truth) initial_score = self._grader.grade( final_data=dirty_data, ground_truth=task.ground_truth, original_data=dirty_data, action_history=[], schema=task.schema, flagged_cells=[], escalated_cells=[], ambiguous_cells=list(getattr(task, "ambiguous_cells", [])), utility_probes=list(getattr(task, "utility_probes", [])), ).score budget = DIFFICULTY_BUDGETS.get(task.difficulty, 100.0) self._state = DataCleanState( episode_id=episode_id or str(uuid4()), step_count=0, task_id=task_id, difficulty=task.difficulty, current_data=copy.deepcopy(dirty_data), ground_truth=copy.deepcopy(task.ground_truth), original_dirty=copy.deepcopy(dirty_data), schema_def=task.schema, action_log=[], flagged_cells=[], escalated_cells=[], max_steps=task.max_steps, is_complete=False, previous_score=initial_score, initial_raw_score=initial_score, action_budget=budget, budget_spent=0.0, budget_remaining=budget, ) self._task_name = task.name self._ambiguous_cells: List[tuple[str, str]] = list( getattr(task, "ambiguous_cells", []) ) self._utility_probes = list(getattr(task, "utility_probes", [])) return self._build_observation(reward=None, done=False) def step( self, action: DataCleanAction, timeout_s: Optional[float] = None, **kwargs: Any, ) -> DataCleanObservation: """Process one cleaning action. Returns observation with delta reward.""" # Guard: episode already ended if self._state.is_complete: return self._build_observation( reward=0.0, done=True, ) self._state.step_count += 1 # Enforce budget: reject actions (except mark_complete) when exhausted cost = ACTION_COSTS.get(action.action_type, 1.0) if cost > 0 and self._state.budget_remaining < cost: result = { "action": action.action_type, "status": "error", "message": f"Budget exhausted ({self._state.budget_remaining:.1f} remaining, " f"action costs {cost:.1f})", "cells_modified": 0, } self._state.action_log.append(result) return self._build_observation( reward=-0.01, done=False, ) # Execute the action result = self._execute_action(action) self._state.action_log.append(result) # Deduct action cost from budget self._state.budget_spent += cost self._state.budget_remaining -= cost # Check termination is_done = ( action.action_type == "mark_complete" or self._state.step_count >= self._state.max_steps ) self._state.is_complete = is_done # Compute reward if is_done: # Terminal: return absolute final score grade_result = self._grader.grade( final_data=self._state.current_data, ground_truth=self._state.ground_truth, original_data=self._state.original_dirty, action_history=self._state.action_log, schema=self._state.schema_def, flagged_cells=self._state.flagged_cells, budget_spent=self._state.budget_spent, action_budget=self._state.action_budget, escalated_cells=self._state.escalated_cells, ambiguous_cells=self._ambiguous_cells, utility_probes=self._utility_probes, ) reward = grade_result.score self._last_grade_result = grade_result else: # Non-terminal: delta reward reward = self._compute_delta_reward(result) return self._build_observation(reward=reward, done=is_done) @property def state(self) -> DataCleanState: return self._state # ------------------------------------------------------------------ # Delta Reward System # ------------------------------------------------------------------ def _compute_delta_reward(self, action_result: Dict[str, Any]) -> float: """Compute reward = current_score - previous_score - step_cost. Penalizes no-ops and errors explicitly. """ # Explicit penalties for bad actions if action_result.get("status") == "error": return -0.02 if action_result.get("status") == "no_effect": return -0.01 if action_result.get("cells_modified", 0) == 0 and action_result.get("action") not in ("flag_anomaly", "escalate_to_human"): return -0.01 # Compute current score (raw, without normalization — delta is relative # so the baseline cancels out; normalization only at terminal grading) current_score = self._grader.grade( final_data=self._state.current_data, ground_truth=self._state.ground_truth, original_data=self._state.original_dirty, action_history=self._state.action_log, schema=self._state.schema_def, flagged_cells=self._state.flagged_cells, budget_spent=self._state.budget_spent, action_budget=self._state.action_budget, escalated_cells=self._state.escalated_cells, ambiguous_cells=self._ambiguous_cells, utility_probes=self._utility_probes, ).score delta = current_score - self._state.previous_score - STEP_COST self._state.previous_score = current_score return round(delta, 4) # ------------------------------------------------------------------ # Action Dispatch # ------------------------------------------------------------------ @staticmethod def _normalize_action_params(action_type: str, params: Dict[str, Any]) -> Dict[str, Any]: """Normalize common LLM param aliases to canonical names.""" p = dict(params) # Universal aliases if "row" in p and "row_id" not in p: p["row_id"] = p.pop("row") if "col" in p and "column" not in p: p["column"] = p.pop("col") # Action-specific aliases if action_type == "fix_value" and "value" in p and "new_value" not in p: p["new_value"] = p.pop("value") if action_type == "merge_duplicates": if "row_id_1" in p and "row_id1" not in p: p["row_id1"] = p.pop("row_id_1") if "row_id_2" in p and "row_id2" not in p: p["row_id2"] = p.pop("row_id_2") if "row1" in p and "row_id1" not in p: p["row_id1"] = p.pop("row1") if "row2" in p and "row_id2" not in p: p["row_id2"] = p.pop("row2") return p def _execute_action(self, action: DataCleanAction) -> Dict[str, Any]: """Dispatch action to the appropriate handler.""" handler = getattr(self, f"_action_{action.action_type}", None) if handler is None: return { "action": action.action_type, "status": "error", "message": f"Unknown action type: {action.action_type}", "cells_modified": 0, } # Normalize param aliases before dispatching normalized_params = self._normalize_action_params(action.action_type, action.params) try: return handler(normalized_params) except (KeyError, TypeError, IndexError) as exc: return { "action": action.action_type, "status": "error", "message": f"Invalid params: {exc}", "cells_modified": 0, } except Exception as exc: logger.exception("Unexpected error in action handler %s", action.action_type) return { "action": action.action_type, "status": "error", "message": str(exc), "cells_modified": 0, } # ------------------------------------------------------------------ # Row Lookup by Stable row_id # ------------------------------------------------------------------ def _find_row_by_id(self, row_id: int) -> tuple[int, Dict[str, Any] | None]: """Find the list index and row dict for a given stable row_id. Returns (index, row_dict) or (-1, None) if not found. """ for i, row in enumerate(self._state.current_data): if row.get("_row_id") == row_id: return i, row return -1, None # ------------------------------------------------------------------ # Action Handlers (10 total) — all use stable row_id # ------------------------------------------------------------------ def _action_fix_value(self, params: Dict[str, Any]) -> Dict[str, Any]: row_id = int(params["row_id"]) column = str(params["column"]) new_value = params["new_value"] idx, row = self._find_row_by_id(row_id) if row is None: return {"action": "fix_value", "status": "error", "message": f"row_id {row_id} not found", "cells_modified": 0} if column not in row or column.startswith("_"): return {"action": "fix_value", "status": "error", "message": f"Column '{column}' not found", "cells_modified": 0} old_value = row[column] if str(old_value) == str(new_value): return {"action": "fix_value", "status": "no_effect", "message": f"Value unchanged at (row_id={row_id}, '{column}')", "cells_modified": 0} row[column] = new_value return {"action": "fix_value", "status": "success", "message": f"(row_id={row_id}, '{column}'): '{old_value}' -> '{new_value}'", "cells_modified": 1, "old_value": old_value, "new_value": new_value, "row_id": row_id, "column": column} def _action_delete_row(self, params: Dict[str, Any]) -> Dict[str, Any]: row_id = int(params["row_id"]) idx, row = self._find_row_by_id(row_id) if row is None: return {"action": "delete_row", "status": "error", "message": f"row_id {row_id} not found", "cells_modified": 0} deleted = self._state.current_data.pop(idx) return {"action": "delete_row", "status": "success", "message": f"row_id={row_id} deleted", "cells_modified": len(deleted), "deleted_data": deleted, "row_id": row_id, "deleted_entity_id": deleted.get("_entity_id")} def _action_fill_missing(self, params: Dict[str, Any]) -> Dict[str, Any]: row_id = int(params["row_id"]) column = str(params["column"]) value = params["value"] idx, row = self._find_row_by_id(row_id) if row is None: return {"action": "fill_missing", "status": "error", "message": f"row_id {row_id} not found", "cells_modified": 0} if column not in row or column.startswith("_"): return {"action": "fill_missing", "status": "error", "message": f"Column '{column}' not found", "cells_modified": 0} current = row.get(column) if current is not None and str(current).strip() != "": return {"action": "fill_missing", "status": "error", "message": f"Cell (row_id={row_id}, '{column}') is not empty: '{current}'", "cells_modified": 0} row[column] = value return {"action": "fill_missing", "status": "success", "message": f"(row_id={row_id}, '{column}'): NULL -> '{value}'", "cells_modified": 1, "row_id": row_id, "column": column, "new_value": value} def _action_standardize_format(self, params: Dict[str, Any]) -> Dict[str, Any]: column = str(params["column"]) format_type = str(params["format_type"]) data = self._state.current_data modified = 0 errors: List[str] = [] for row in data: if column not in row or row[column] is None: continue try: new_val = self._apply_format(row[column], format_type) if str(new_val) != str(row[column]): row[column] = new_val modified += 1 except (ValueError, TypeError) as exc: errors.append(f"row_id={row.get('_row_id', '?')}: {exc}") if modified == 0 and not errors: return {"action": "standardize_format", "status": "no_effect", "message": f"No changes needed in '{column}' for {format_type}", "cells_modified": 0} msg = f"Formatted {modified} cell(s) in '{column}' to {format_type}" if errors: msg += f". {len(errors)} parse failure(s)." return {"action": "standardize_format", "status": "success", "message": msg, "cells_modified": modified} def _action_merge_duplicates(self, params: Dict[str, Any]) -> Dict[str, Any]: row_id1 = int(params["row_id1"]) row_id2 = int(params["row_id2"]) strategy = str(params.get("strategy", "merge_prefer_nonnull")) if row_id1 == row_id2: return {"action": "merge_duplicates", "status": "error", "message": "Cannot merge a row with itself", "cells_modified": 0} idx1, r1 = self._find_row_by_id(row_id1) idx2, r2 = self._find_row_by_id(row_id2) if r1 is None or r2 is None: missing = row_id1 if r1 is None else row_id2 return {"action": "merge_duplicates", "status": "error", "message": f"row_id {missing} not found", "cells_modified": 0} # Track entity IDs for penalty checking eid1 = r1.get("_entity_id", "") eid2 = r2.get("_entity_id", "") merged = self._merge_rows(r1, r2, strategy) # Merged row keeps the first row's entity_id and row_id merged["_entity_id"] = eid1 merged["_row_id"] = r1["_row_id"] # Remove both, insert merged at first position data = self._state.current_data lo_idx = min(idx1, idx2) hi_idx = max(idx1, idx2) data.pop(hi_idx) data.pop(lo_idx) data.insert(lo_idx, merged) return {"action": "merge_duplicates", "status": "success", "message": f"Merged row_id={row_id1} and row_id={row_id2} using '{strategy}'", "cells_modified": len(merged), "row_id1": row_id1, "row_id2": row_id2, "entity_id1": eid1, "entity_id2": eid2, "deleted_entity_id": eid2, "strategy": strategy} def _action_flag_anomaly(self, params: Dict[str, Any]) -> Dict[str, Any]: row_id = int(params["row_id"]) column = str(params["column"]) reason = str(params.get("reason", "")) idx, row = self._find_row_by_id(row_id) if row is None: return {"action": "flag_anomaly", "status": "error", "message": f"row_id {row_id} not found", "cells_modified": 0} self._state.flagged_cells.append( {"row_id": row_id, "column": column, "reason": reason} ) return {"action": "flag_anomaly", "status": "success", "message": f"Flagged (row_id={row_id}, '{column}'): {reason}", "cells_modified": 0} def _action_escalate_to_human(self, params: Dict[str, Any]) -> Dict[str, Any]: """Escalate a cell to human review -- agent signals it is uncertain.""" row_id = int(params["row_id"]) column = str(params["column"]) confidence = float(params.get("confidence", 0.5)) reason = str(params.get("reason", "")) idx, row = self._find_row_by_id(row_id) if row is None: return {"action": "escalate_to_human", "status": "error", "message": f"row_id {row_id} not found", "cells_modified": 0} self._state.escalated_cells.append({ "row_id": row_id, "column": column, "confidence": confidence, "reason": reason, }) return {"action": "escalate_to_human", "status": "success", "message": f"Escalated (row_id={row_id}, '{column}'): {reason} (confidence={confidence})", "cells_modified": 0} def _action_split_column(self, params: Dict[str, Any]) -> Dict[str, Any]: column = str(params["column"]) delimiter = str(params["delimiter"]) new_names = list(params["new_names"]) data = self._state.current_data modified = 0 for row in data: if column not in row or row[column] is None: continue parts = str(row[column]).split(delimiter, maxsplit=len(new_names) - 1) for i, name in enumerate(new_names): row[name] = parts[i].strip() if i < len(parts) else None del row[column] modified += 1 if modified == 0: return {"action": "split_column", "status": "no_effect", "message": f"Column '{column}' not found or all null", "cells_modified": 0} return {"action": "split_column", "status": "success", "message": f"Split '{column}' into {new_names} ({modified} rows)", "cells_modified": modified} def _action_rename_column(self, params: Dict[str, Any]) -> Dict[str, Any]: old_name = str(params["old_name"]) new_name = str(params["new_name"]) data = self._state.current_data if not data or old_name not in data[0]: return {"action": "rename_column", "status": "error", "message": f"Column '{old_name}' not found", "cells_modified": 0} if data and new_name in data[0]: return {"action": "rename_column", "status": "error", "message": f"Column '{new_name}' already exists", "cells_modified": 0} if new_name.startswith("_"): return {"action": "rename_column", "status": "error", "message": f"Column names starting with '_' are reserved", "cells_modified": 0} for row in data: if old_name in row: row[new_name] = row.pop(old_name) return {"action": "rename_column", "status": "success", "message": f"Renamed '{old_name}' -> '{new_name}'", "cells_modified": len(data)} def _action_cast_type(self, params: Dict[str, Any]) -> Dict[str, Any]: column = str(params["column"]) target_type = str(params["target_type"]) valid_types = {"int", "float", "str", "bool", "date"} if target_type not in valid_types: return {"action": "cast_type", "status": "error", "message": f"Unknown type '{target_type}'. Valid: {sorted(valid_types)}", "cells_modified": 0} data = self._state.current_data modified = 0 nullified = 0 for row in data: if column not in row or row[column] is None: continue try: row[column] = self._cast_value(row[column], target_type) modified += 1 except (ValueError, TypeError): row[column] = None nullified += 1 msg = f"Cast {modified} cell(s) in '{column}' to {target_type}" if nullified: msg += f" ({nullified} failed -> null)" status = "success" if modified > 0 else ("error" if nullified > 0 else "no_effect") return {"action": "cast_type", "status": status, "message": msg, "cells_modified": modified + nullified} def _action_mark_complete(self, params: Dict[str, Any]) -> Dict[str, Any]: return {"action": "mark_complete", "status": "success", "message": "Agent signaled completion", "cells_modified": 0} # ------------------------------------------------------------------ # Format Standardization (8 types, fully implemented) # ------------------------------------------------------------------ def _apply_format(self, value: Any, format_type: str) -> Any: """Apply format transformation to a single value.""" val_str = str(value).strip() if not val_str: return value if format_type == "date:YYYY-MM-DD": return self._format_date_iso(val_str) elif format_type == "phone:US": return self._format_phone_us(val_str) elif format_type == "phone:E164": return self._format_phone_e164(val_str) elif format_type == "name:title_case": return val_str.title() elif format_type == "email:lowercase": return val_str.lower() elif format_type == "zip:5digit": return self._format_zip_5digit(val_str) elif format_type == "currency:float": return self._format_currency_float(val_str) elif format_type == "state:abbreviation": return self._format_state_abbrev(val_str) else: raise ValueError(f"Unknown format type: {format_type}") def _format_date_iso(self, val: str) -> str: """Parse various date formats and return YYYY-MM-DD.""" for fmt in DATE_PARSE_FORMATS: try: dt = datetime.strptime(val.strip(), fmt) return dt.strftime("%Y-%m-%d") except ValueError: continue raise ValueError(f"Cannot parse date: '{val}'") def _format_phone_us(self, val: str) -> str: """Normalize phone to (XXX) XXX-XXXX format.""" digits = re.sub(r"\D", "", val) if digits.startswith("1") and len(digits) == 11: digits = digits[1:] if len(digits) != 10: raise ValueError(f"Phone must have 10 digits, got {len(digits)}: '{val}'") return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}" def _format_phone_e164(self, val: str) -> str: """Normalize phone to +1XXXXXXXXXX format.""" digits = re.sub(r"\D", "", val) if digits.startswith("1") and len(digits) == 11: digits = digits[1:] if len(digits) != 10: raise ValueError(f"Phone must have 10 digits, got {len(digits)}: '{val}'") return f"+1{digits}" def _format_zip_5digit(self, val: str) -> str: """Normalize ZIP to 5 digits (pad or truncate).""" digits = re.sub(r"\D", "", val.split("-")[0]) if not digits: raise ValueError(f"No digits in ZIP: '{val}'") return digits[:5].zfill(5) def _format_currency_float(self, val: str) -> float: """Parse currency string to float. '$1,234.56' -> 1234.56""" cleaned = val.replace("$", "").replace(",", "").strip() if cleaned.lower().endswith("k"): return float(cleaned[:-1]) * 1000 return float(cleaned) def _format_state_abbrev(self, val: str) -> str: """Convert full state name to 2-letter abbreviation.""" if len(val) == 2 and val.upper() in US_STATES.values(): return val.upper() lower = val.strip().lower() if lower in US_STATES: return US_STATES[lower] raise ValueError(f"Unknown state: '{val}'") # ------------------------------------------------------------------ # Row Merging (all strategies) # ------------------------------------------------------------------ def _merge_rows(self, r1: Dict, r2: Dict, strategy: str) -> Dict: """Merge two rows according to the given strategy.""" if strategy == "keep_first": return copy.deepcopy(r1) elif strategy == "keep_second": return copy.deepcopy(r2) elif strategy == "merge_prefer_nonnull": merged: Dict[str, Any] = {} for key in dict.fromkeys(list(r1.keys()) + list(r2.keys())): v1 = r1.get(key) v2 = r2.get(key) if v1 is not None and str(v1).strip(): merged[key] = v1 elif v2 is not None and str(v2).strip(): merged[key] = v2 else: merged[key] = v1 return merged elif strategy == "merge_prefer_row1": merged = copy.deepcopy(r2) for key, val in r1.items(): if val is not None and str(val).strip(): merged[key] = val return merged elif strategy == "merge_prefer_row2": merged = copy.deepcopy(r1) for key, val in r2.items(): if val is not None and str(val).strip(): merged[key] = val return merged else: raise ValueError(f"Unknown merge strategy: '{strategy}'") # ------------------------------------------------------------------ # Type Casting # ------------------------------------------------------------------ def _cast_value(self, value: Any, target_type: str) -> Any: """Cast a value to the target type.""" val_str = str(value).strip() if target_type == "int": return int(float(val_str.replace(",", "").replace("$", ""))) elif target_type == "float": return float(val_str.replace(",", "").replace("$", "")) elif target_type == "str": return val_str elif target_type == "bool": return val_str.lower() in ("true", "1", "yes", "y") elif target_type == "date": return self._format_date_iso(val_str) else: raise ValueError(f"Unknown target type: '{target_type}'") # ------------------------------------------------------------------ # Observation Builder # ------------------------------------------------------------------ def _build_observation( self, reward: float | None, done: bool ) -> DataCleanObservation: """Build issue-first observation from current state.""" data = self._state.current_data columns = list(data[0].keys()) if data else [] # Filter out internal fields EXCEPT _row_id (renamed to row_id for agent) hidden = {"_entity_id"} visible_columns = ["row_id"] + [c for c in columns if c not in hidden and c != "_row_id"] rows = [ [row.get("_row_id")] + [row.get(col) for col in columns if col not in hidden and col != "_row_id"] for row in data ] # Quality analysis quality_issues = self._analyze_quality() issue_groups = self._group_issues(quality_issues) # Data summary — count nulls using INTERNAL column names (not aliases) internal_cols = [c for c in columns if c not in hidden and c != "_row_id"] null_count = sum( 1 for row in data for col in internal_cols if row.get(col) is None ) data_summary = DataSummary( row_count=len(data), column_count=len(visible_columns), total_cells=len(visible_columns) * len(data), null_count=null_count, issue_count=len(quality_issues), columns=visible_columns, dtypes={ col: self._state.schema_def.get("expected_types", {}).get(col, "str") for col in visible_columns }, ) # Recent actions (last 5) recent = [ ActionResult( action=a.get("action", ""), status=a.get("status", ""), message=a.get("message", ""), cells_modified=a.get("cells_modified", 0), ) for a in self._state.action_log[-5:] ] # Build metadata with grade breakdown when episode ends metadata: Dict[str, Any] = {} grade = getattr(self, "_last_grade_result", None) if done and grade is not None: metadata = { "grade_breakdown": { "accuracy": grade.accuracy, "completeness": grade.completeness, "format_consistency": grade.format_consistency, "row_correctness": grade.row_correctness, "efficiency": grade.efficiency, "utility_score": grade.utility_score, "penalties": grade.penalties, "bonuses": grade.bonuses, }, "utility_details": grade.utility_details, } return DataCleanObservation( done=done, reward=reward, metadata=metadata, data_summary=data_summary, quality_issues=quality_issues[:20], # Cap at 20 for readability issue_groups=issue_groups, issues_remaining=len(quality_issues), columns=visible_columns, rows=rows, row_count=len(data), schema_info=self._state.schema_def, step_number=self._state.step_count, max_steps=self._state.max_steps, steps_remaining=self._state.max_steps - self._state.step_count, budget_spent=self._state.budget_spent, budget_remaining=self._state.budget_remaining, action_costs=ACTION_COSTS, last_action_result=recent[-1] if recent else None, recent_actions=recent, task_id=self._state.task_id, task_name=getattr(self, "_task_name", self._state.task_id), difficulty=self._state.difficulty, ) # ------------------------------------------------------------------ # Quality Analysis # ------------------------------------------------------------------ def _analyze_quality(self) -> List[QualityIssue]: """Analyze current data and return detected quality issues.""" issues: List[QualityIssue] = [] schema = self._state.schema_def data = self._state.current_data constraints = schema.get("constraints", {}) for row in data: rid = row.get("_row_id", 0) for col in [c for c in row if not c.startswith("_")]: val = row.get(col) col_constraints = constraints.get(col, {}) # Null check if val is None and col_constraints.get("not_null"): issues.append(QualityIssue( row_id=rid, column=col, issue_type="null", description="Required field is null", )) if val is None: continue # Format check fmt = col_constraints.get("format") if fmt and not self._matches_format(val, fmt): issues.append(QualityIssue( row_id=rid, column=col, issue_type="format", description=f"Does not match format: {fmt}", suggestion=f"Use standardize_format('{col}', '{self._suggest_format_type(fmt)}')", )) # Allowed values allowed = col_constraints.get("allowed_values") if allowed and str(val) not in allowed: issues.append(QualityIssue( row_id=rid, column=col, issue_type="type_violation", description=f"Value '{val}' not in allowed values", )) # Duplicate detection issues.extend(self._detect_potential_duplicates()) # Cross-field validation (for hard mode) issues.extend(self._detect_cross_field_issues()) return issues def _detect_cross_field_issues(self) -> List[QualityIssue]: """Detect cross-field inconsistencies: zip/city, date relationships, insurance ID prefixes.""" issues: List[QualityIssue] = [] data = self._state.current_data schema = self._state.schema_def cross_field_rules = schema.get("cross_field_rules", {}) # Rule: zip_city_match — zip code should correspond to the city zip_city_map = cross_field_rules.get("zip_city_map", {}) if zip_city_map: for row in data: rid = row.get("_row_id", 0) zip_val = str(row.get("zip", row.get("office_zip", ""))).strip() city_val = str(row.get("city", row.get("office_city", ""))).strip().lower() if zip_val in zip_city_map: expected_city = zip_city_map[zip_val].lower() if city_val and city_val != expected_city: issues.append(QualityIssue( row_id=rid, column="zip", issue_type="cross_field", description=f"ZIP '{zip_val}' should map to '{zip_city_map[zip_val]}', got '{row.get('city', row.get('office_city', ''))}'", suggestion=f"fix_value(row_id={rid}, column='city', new_value='{zip_city_map[zip_val]}')", )) # Rule: date_order — dob must be before last_visit_date if "dob" in schema.get("expected_types", {}) and "last_visit_date" in schema.get("expected_types", {}): for row in data: rid = row.get("_row_id", 0) dob = row.get("dob") visit = row.get("last_visit_date") if dob and visit: try: dob_dt = datetime.strptime(str(dob), "%Y-%m-%d") visit_dt = datetime.strptime(str(visit), "%Y-%m-%d") if dob_dt > visit_dt: issues.append(QualityIssue( row_id=rid, column="dob", issue_type="cross_field", description=f"DOB '{dob}' is after last_visit_date '{visit}'", )) if dob_dt > datetime.now(): issues.append(QualityIssue( row_id=rid, column="dob", issue_type="cross_field", description=f"DOB '{dob}' is in the future", )) except ValueError: pass # Rule: insurance_prefix — insurance_id prefix must match provider prefix_map = cross_field_rules.get("insurance_prefix_map", {}) if prefix_map: for row in data: rid = row.get("_row_id", 0) provider = str(row.get("insurance_provider", "")).strip() ins_id = str(row.get("insurance_id", "")).strip() if provider and ins_id and provider in prefix_map: expected_prefix = prefix_map[provider] if not ins_id.startswith(expected_prefix): issues.append(QualityIssue( row_id=rid, column="insurance_id", issue_type="cross_field", description=f"Insurance ID '{ins_id}' should start with '{expected_prefix}' for provider '{provider}'", )) return issues def _detect_potential_duplicates(self) -> List[QualityIssue]: """Detect potential duplicate rows by email, phone, or name similarity.""" issues: List[QualityIssue] = [] data = self._state.current_data # Check by email email_index: Dict[str, List[int]] = {} for row in data: rid = row.get("_row_id", 0) email = row.get("email") if email and str(email).strip(): key = str(email).strip().lower() email_index.setdefault(key, []).append(rid) for email, row_ids in email_index.items(): if len(row_ids) > 1: issues.append(QualityIssue( row_id=row_ids[0], column="email", issue_type="duplicate", description=f"Rows {row_ids} share email '{email}'", suggestion=f"Consider merge_duplicates(row_id1={row_ids[0]}, row_id2={row_ids[1]}, strategy='merge_prefer_nonnull')", )) # Check by phone (digit-only comparison) phone_index: Dict[str, List[int]] = {} for row in data: rid = row.get("_row_id", 0) phone = row.get("phone") if phone and str(phone).strip(): digits = re.sub(r"\D", "", str(phone)) if digits.startswith("1") and len(digits) == 11: digits = digits[1:] if len(digits) == 10: phone_index.setdefault(digits, []).append(rid) for digits, row_ids in phone_index.items(): if len(row_ids) > 1: # Avoid duplicate issues if already flagged by email issues.append(QualityIssue( row_id=row_ids[0], column="phone", issue_type="duplicate", description=f"Rows {row_ids} share phone digits '{digits}'", )) return issues def _group_issues(self, issues: List[QualityIssue]) -> List[IssueGroup]: """Group issues by type for compact display.""" type_counter: Dict[str, List[QualityIssue]] = {} for issue in issues: type_counter.setdefault(issue.issue_type, []).append(issue) return [ IssueGroup( issue_type=itype, count=len(items), examples=items[:3], # Show max 3 examples per type ) for itype, items in sorted(type_counter.items()) ] def _matches_format(self, value: Any, format_spec: str) -> bool: """Check if a value matches the expected format.""" val_str = str(value) format_patterns: Dict[str, str] = { "YYYY-MM-DD": r"^\d{4}-\d{2}-\d{2}$", "(XXX) XXX-XXXX": r"^\(\d{3}\) \d{3}-\d{4}$", "email": r"^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$", "5_digit": r"^\d{5}$", "+1XXXXXXXXXX": r"^\+1\d{10}$", } pattern = format_patterns.get(format_spec) if pattern: return bool(re.match(pattern, val_str)) return True def _suggest_format_type(self, format_spec: str) -> str: """Suggest the standardize_format type for a given format spec.""" mapping = { "YYYY-MM-DD": "date:YYYY-MM-DD", "(XXX) XXX-XXXX": "phone:US", "email": "email:lowercase", "5_digit": "zip:5digit", "+1XXXXXXXXXX": "phone:E164", } return mapping.get(format_spec, format_spec) # ------------------------------------------------------------------ # Metadata # ------------------------------------------------------------------ def get_metadata(self): # type: ignore[override] from openenv.core.env_server.types import EnvironmentMetadata return EnvironmentMetadata( name="dataclean_env", description=( "Data Cleaning environment for training AI agents to clean " "messy tabular data. Supports 3 difficulty levels (easy, medium, hard) " "with deterministic grading via cell-by-cell comparison against ground truth." ), version="0.1.0", )