from __future__ import annotations import copy import json from pathlib import Path from statistics import median from typing import Any from env.actions import FEATURE_REGISTRY, is_missing, validate_action from env.models import Action, ColumnInfo, Issue, Observation from env.quality import compute_quality_score from env.rewards import compute_reward DATA_DIR = Path(__file__).resolve().parent.parent / "data" class DataCleaningEnv: def __init__(self, task_name: str = "basic_cleaning"): self.task_name = task_name self.task_config: dict[str, Any] = {} self.dataset: list[dict[str, Any]] = [] self.original_dataset: list[dict[str, Any]] = [] self.issues: list[Issue] = [] self.pending_issues: list[Issue] = [] self.resolved_issues: list[Issue] = [] self.action_history: list[dict[str, Any]] = [] self.steps_remaining = 0 self.max_steps = 0 self.total_issues_at_start = 0 self.quality_score = 0.0 self.expected_dtypes: dict[str, str] = {} self.required_features: list[str] = [] self._issue_id_map: dict[tuple[str, str], str] = {} def reset(self) -> Observation: config_path = DATA_DIR / f"{self.task_name}.json" with config_path.open("r", encoding="utf-8") as handle: self.task_config = json.load(handle) self.dataset = copy.deepcopy(self.task_config["dataset"]) self.original_dataset = copy.deepcopy(self.dataset) self.expected_dtypes = dict(self.task_config["expected_dtypes"]) self.required_features = list(self.task_config.get("required_features", [])) self.action_history = [] self.resolved_issues = [] self.max_steps = int(self.task_config["max_steps"]) self.steps_remaining = self.max_steps self._issue_id_map = {} detected = self._detect_issues(self.dataset) self.pending_issues = detected self.issues = list(detected) self.total_issues_at_start = len(detected) self.quality_score = compute_quality_score( self.dataset, self._build_column_infos(), self.total_issues_at_start, ) return self.state() def step(self, action: Action) -> tuple[Observation, float, bool, dict]: if not self.dataset: self.reset() self.steps_remaining -= 1 old_quality = self.quality_score columns = self._build_column_infos() action_valid, message, matched_issue, dependency_ok = validate_action( self.dataset, self.pending_issues, columns, self.expected_dtypes, action, self.resolved_issues, ) info: dict[str, Any] = {} if not action_valid: reward = compute_reward(old_quality, old_quality, False, False) info = {"error": "invalid_action", "message": message} self.action_history.append( { "action_type": action.action_type, "column": action.column, "params": action.params, "reward": reward, "error": message, } ) observation = self.state() done = self.steps_remaining <= 0 or len(self.pending_issues) == 0 return observation, reward, done, info self._apply_action(action) redetected = self._detect_issues(self.dataset) self.pending_issues = redetected self.issues = list(redetected) if matched_issue and not self._issue_present(redetected, matched_issue.issue_type, matched_issue.column): self.resolved_issues.append(matched_issue) self.quality_score = compute_quality_score( self.dataset, self._build_column_infos(), self.total_issues_at_start, ) reward = compute_reward(old_quality, self.quality_score, True, dependency_ok) self.action_history.append( { "action_type": action.action_type, "column": action.column, "params": action.params, "reward": reward, "error": None, } ) observation = self.state() done = self.steps_remaining <= 0 or len(self.pending_issues) == 0 return observation, reward, done, info def state(self) -> Observation: return Observation( data_preview=copy.deepcopy(self.dataset[:5]), columns=self._build_column_infos(), pending_issues=copy.deepcopy(self.pending_issues), resolved_issues=copy.deepcopy(self.resolved_issues), action_history=copy.deepcopy(self.action_history), quality_score=self.quality_score, steps_remaining=self.steps_remaining, total_rows=len(self.dataset), total_issues_at_start=self.total_issues_at_start, ) def _detect_issues(self, dataset: list[dict[str, Any]]) -> list[Issue]: if not dataset: return [] raw_issues: list[dict[str, Any]] = [] columns = list(self.expected_dtypes.keys()) for column in columns: missing_count = sum(1 for row in dataset if is_missing(row.get(column))) if missing_count: raw_issues.append( { "issue_type": "missing", "column": column, "description": f"Column '{column}' has {missing_count} missing values that should be filled.", } ) if self._has_duplicates(dataset): raw_issues.append( { "issue_type": "duplicate", "column": "__all__", "description": "Dataset contains duplicate rows that should be removed.", } ) for column in columns: expected_dtype = self.expected_dtypes[column] actual_dtype = self._infer_runtime_dtype(dataset, column) if expected_dtype in {"int", "float", "bool"} and actual_dtype != expected_dtype: raw_issues.append( { "issue_type": "wrong_dtype", "column": column, "description": ( f"Column '{column}' should be '{expected_dtype}' but is currently represented as '{actual_dtype}'." ), } ) for column in columns: if self.expected_dtypes[column] != "str": continue if self._has_inconsistent_categories(dataset, column): raw_issues.append( { "issue_type": "inconsistent_category", "column": column, "description": f"Column '{column}' has inconsistent categorical values that differ only by casing.", } ) for feature_name in self.required_features: if not all(feature_name in row for row in dataset): raw_issues.append( { "issue_type": "missing_feature", "column": feature_name, "description": f"Required feature '{feature_name}' has not been created yet.", } ) for raw_issue in raw_issues: signature = (raw_issue["issue_type"], raw_issue["column"]) if signature not in self._issue_id_map: self._issue_id_map[signature] = f"issue_{len(self._issue_id_map) + 1:03d}" issues: list[Issue] = [] signature_to_id = {signature: issue_id for signature, issue_id in self._issue_id_map.items()} for raw_issue in raw_issues: signature = (raw_issue["issue_type"], raw_issue["column"]) depends_on: list[str] = [] if raw_issue["issue_type"] == "wrong_dtype" and raw_issue["column"] in {"salary", "rating"}: missing_signature = ("missing", raw_issue["column"]) if missing_signature in signature_to_id: depends_on.append(signature_to_id[missing_signature]) if raw_issue["issue_type"] == "missing_feature": feature_name = raw_issue["column"] source_column = FEATURE_REGISTRY[feature_name]["source"] for dependency_type in ("missing", "wrong_dtype"): source_signature = (dependency_type, source_column) if source_signature in signature_to_id: depends_on.append(signature_to_id[source_signature]) issues.append( Issue( issue_id=signature_to_id[signature], issue_type=raw_issue["issue_type"], column=raw_issue["column"], description=raw_issue["description"], depends_on=depends_on, ) ) return issues def _build_column_infos(self) -> list[ColumnInfo]: if not self.dataset: return [] infos: list[ColumnInfo] = [] for column in self.dataset[0].keys(): values = [row.get(column) for row in self.dataset] non_missing = [value for value in values if not is_missing(value)] infos.append( ColumnInfo( name=column, dtype=self._infer_runtime_dtype(self.dataset, column), null_count=sum(1 for value in values if is_missing(value)), unique_count=len({str(value) for value in non_missing}), ) ) return infos def _infer_runtime_dtype(self, dataset: list[dict[str, Any]], column: str) -> str: values = [row.get(column) for row in dataset if not is_missing(row.get(column))] if not values: return self.expected_dtypes.get(column, "str") if all(isinstance(value, bool) for value in values): return "bool" if all(isinstance(value, int) and not isinstance(value, bool) for value in values): return "int" if all(isinstance(value, (int, float)) and not isinstance(value, bool) for value in values): return "float" return "str" def _has_duplicates(self, dataset: list[dict[str, Any]]) -> bool: seen: set[tuple[tuple[str, Any], ...]] = set() for row in dataset: key = tuple(sorted(row.items())) if key in seen: return True seen.add(key) return False def _has_inconsistent_categories(self, dataset: list[dict[str, Any]], column: str) -> bool: groups: dict[str, set[str]] = {} for row in dataset: value = row.get(column) if is_missing(value): continue normalized = str(value).lower() groups.setdefault(normalized, set()).add(str(value)) return any(len(forms) > 1 for forms in groups.values()) def _issue_present(self, issues: list[Issue], issue_type: str, column: str) -> bool: return any(issue.issue_type == issue_type and issue.column == column for issue in issues) def _apply_action(self, action: Action) -> None: if action.action_type == "fill_missing": self._apply_fill_missing(action.column, action.params["strategy"]) elif action.action_type == "drop_duplicates": unique_rows: list[dict[str, Any]] = [] seen: set[tuple[tuple[str, Any], ...]] = set() for row in self.dataset: key = tuple(sorted(row.items())) if key in seen: continue seen.add(key) unique_rows.append(row) self.dataset = unique_rows elif action.action_type == "convert_dtype": target_dtype = action.params["target_dtype"] for row in self.dataset: value = row.get(action.column) if is_missing(value): row[action.column] = None else: row[action.column] = self._convert_value(value, target_dtype) elif action.action_type == "normalize_category": self._apply_normalize_category(action.column) elif action.action_type == "create_feature": self._apply_create_feature(action.params["feature_name"]) def _apply_fill_missing(self, column: str, strategy: str) -> None: expected_dtype = self.expected_dtypes.get(column, "str") valid_values = [row.get(column) for row in self.dataset if not is_missing(row.get(column))] if expected_dtype in {"int", "float"}: numeric_values = [self._convert_value(value, expected_dtype) for value in valid_values] if strategy == "mean": fill_value = sum(numeric_values) / len(numeric_values) elif strategy == "median": fill_value = median(numeric_values) else: fill_value = 0 if expected_dtype == "int": fill_value = int(round(fill_value)) else: if strategy == "mode": fill_value = self._pick_mode([str(value) for value in valid_values]) else: fill_value = "unknown" for row in self.dataset: if is_missing(row.get(column)): row[column] = fill_value def _apply_normalize_category(self, column: str) -> None: groups: dict[str, dict[str, int]] = {} for row in self.dataset: value = row.get(column) if is_missing(value): continue surface = str(value) groups.setdefault(surface.lower(), {}) groups[surface.lower()][surface] = groups[surface.lower()].get(surface, 0) + 1 canonical: dict[str, str] = {} for lowered, counts in groups.items(): canonical[lowered] = min( counts.items(), key=lambda item: (-item[1], item[0].lower(), 0 if item[0].islower() else 1, item[0]), )[0] for row in self.dataset: value = row.get(column) if is_missing(value): continue row[column] = canonical[str(value).lower()] def _apply_create_feature(self, feature_name: str) -> None: feature_config = FEATURE_REGISTRY[feature_name] source = feature_config["source"] bins = feature_config["bins"] labels = feature_config["labels"] for row in self.dataset: source_value = row.get(source) if is_missing(source_value): row[feature_name] = None continue numeric_value = float(source_value) assigned = None for index, label in enumerate(labels): lower = bins[index] upper = bins[index + 1] is_last = index == len(labels) - 1 if (lower <= numeric_value < upper) or (is_last and lower <= numeric_value <= upper): assigned = label break row[feature_name] = assigned def _pick_mode(self, values: list[str]) -> str: counts: dict[str, int] = {} for value in values: counts[value] = counts.get(value, 0) + 1 return min( counts.items(), key=lambda item: (-item[1], item[0].lower(), 0 if item[0].islower() else 1, item[0]), )[0] def _convert_value(self, value: Any, target_dtype: str) -> Any: if target_dtype == "int": return int(float(str(value))) if target_dtype == "float": return float(str(value)) if target_dtype == "bool": normalized = str(value).strip().lower() return normalized in {"true", "1", "yes"} return str(value)