Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import copy | |
| import json | |
| from pathlib import Path | |
| from statistics import median | |
| from typing import Any | |
| from env.actions import FEATURE_REGISTRY, is_missing, validate_action | |
| from env.models import Action, ColumnInfo, Issue, Observation | |
| from env.quality import compute_quality_score | |
| from env.rewards import compute_reward | |
| DATA_DIR = Path(__file__).resolve().parent.parent / "data" | |
| class DataCleaningEnv: | |
| def __init__(self, task_name: str = "basic_cleaning"): | |
| self.task_name = task_name | |
| self.task_config: dict[str, Any] = {} | |
| self.dataset: list[dict[str, Any]] = [] | |
| self.original_dataset: list[dict[str, Any]] = [] | |
| self.issues: list[Issue] = [] | |
| self.pending_issues: list[Issue] = [] | |
| self.resolved_issues: list[Issue] = [] | |
| self.action_history: list[dict[str, Any]] = [] | |
| self.steps_remaining = 0 | |
| self.max_steps = 0 | |
| self.total_issues_at_start = 0 | |
| self.quality_score = 0.0 | |
| self.expected_dtypes: dict[str, str] = {} | |
| self.required_features: list[str] = [] | |
| self._issue_id_map: dict[tuple[str, str], str] = {} | |
| def reset(self) -> Observation: | |
| config_path = DATA_DIR / f"{self.task_name}.json" | |
| with config_path.open("r", encoding="utf-8") as handle: | |
| self.task_config = json.load(handle) | |
| self.dataset = copy.deepcopy(self.task_config["dataset"]) | |
| self.original_dataset = copy.deepcopy(self.dataset) | |
| self.expected_dtypes = dict(self.task_config["expected_dtypes"]) | |
| self.required_features = list(self.task_config.get("required_features", [])) | |
| self.action_history = [] | |
| self.resolved_issues = [] | |
| self.max_steps = int(self.task_config["max_steps"]) | |
| self.steps_remaining = self.max_steps | |
| self._issue_id_map = {} | |
| detected = self._detect_issues(self.dataset) | |
| self.pending_issues = detected | |
| self.issues = list(detected) | |
| self.total_issues_at_start = len(detected) | |
| self.quality_score = compute_quality_score( | |
| self.dataset, | |
| self._build_column_infos(), | |
| self.total_issues_at_start, | |
| ) | |
| return self.state() | |
| def step(self, action: Action) -> tuple[Observation, float, bool, dict]: | |
| if not self.dataset: | |
| self.reset() | |
| self.steps_remaining -= 1 | |
| old_quality = self.quality_score | |
| columns = self._build_column_infos() | |
| action_valid, message, matched_issue, dependency_ok = validate_action( | |
| self.dataset, | |
| self.pending_issues, | |
| columns, | |
| self.expected_dtypes, | |
| action, | |
| self.resolved_issues, | |
| ) | |
| info: dict[str, Any] = {} | |
| if not action_valid: | |
| reward = compute_reward(old_quality, old_quality, False, False) | |
| info = {"error": "invalid_action", "message": message} | |
| self.action_history.append( | |
| { | |
| "action_type": action.action_type, | |
| "column": action.column, | |
| "params": action.params, | |
| "reward": reward, | |
| "error": message, | |
| } | |
| ) | |
| observation = self.state() | |
| done = self.steps_remaining <= 0 or len(self.pending_issues) == 0 | |
| return observation, reward, done, info | |
| self._apply_action(action) | |
| redetected = self._detect_issues(self.dataset) | |
| self.pending_issues = redetected | |
| self.issues = list(redetected) | |
| if matched_issue and not self._issue_present(redetected, matched_issue.issue_type, matched_issue.column): | |
| self.resolved_issues.append(matched_issue) | |
| self.quality_score = compute_quality_score( | |
| self.dataset, | |
| self._build_column_infos(), | |
| self.total_issues_at_start, | |
| ) | |
| reward = compute_reward(old_quality, self.quality_score, True, dependency_ok) | |
| self.action_history.append( | |
| { | |
| "action_type": action.action_type, | |
| "column": action.column, | |
| "params": action.params, | |
| "reward": reward, | |
| "error": None, | |
| } | |
| ) | |
| observation = self.state() | |
| done = self.steps_remaining <= 0 or len(self.pending_issues) == 0 | |
| return observation, reward, done, info | |
| def state(self) -> Observation: | |
| return Observation( | |
| data_preview=copy.deepcopy(self.dataset[:5]), | |
| columns=self._build_column_infos(), | |
| pending_issues=copy.deepcopy(self.pending_issues), | |
| resolved_issues=copy.deepcopy(self.resolved_issues), | |
| action_history=copy.deepcopy(self.action_history), | |
| quality_score=self.quality_score, | |
| steps_remaining=self.steps_remaining, | |
| total_rows=len(self.dataset), | |
| total_issues_at_start=self.total_issues_at_start, | |
| ) | |
| def _detect_issues(self, dataset: list[dict[str, Any]]) -> list[Issue]: | |
| if not dataset: | |
| return [] | |
| raw_issues: list[dict[str, Any]] = [] | |
| columns = list(self.expected_dtypes.keys()) | |
| for column in columns: | |
| missing_count = sum(1 for row in dataset if is_missing(row.get(column))) | |
| if missing_count: | |
| raw_issues.append( | |
| { | |
| "issue_type": "missing", | |
| "column": column, | |
| "description": f"Column '{column}' has {missing_count} missing values that should be filled.", | |
| } | |
| ) | |
| if self._has_duplicates(dataset): | |
| raw_issues.append( | |
| { | |
| "issue_type": "duplicate", | |
| "column": "__all__", | |
| "description": "Dataset contains duplicate rows that should be removed.", | |
| } | |
| ) | |
| for column in columns: | |
| expected_dtype = self.expected_dtypes[column] | |
| actual_dtype = self._infer_runtime_dtype(dataset, column) | |
| if expected_dtype in {"int", "float", "bool"} and actual_dtype != expected_dtype: | |
| raw_issues.append( | |
| { | |
| "issue_type": "wrong_dtype", | |
| "column": column, | |
| "description": ( | |
| f"Column '{column}' should be '{expected_dtype}' but is currently represented as '{actual_dtype}'." | |
| ), | |
| } | |
| ) | |
| for column in columns: | |
| if self.expected_dtypes[column] != "str": | |
| continue | |
| if self._has_inconsistent_categories(dataset, column): | |
| raw_issues.append( | |
| { | |
| "issue_type": "inconsistent_category", | |
| "column": column, | |
| "description": f"Column '{column}' has inconsistent categorical values that differ only by casing.", | |
| } | |
| ) | |
| for feature_name in self.required_features: | |
| if not all(feature_name in row for row in dataset): | |
| raw_issues.append( | |
| { | |
| "issue_type": "missing_feature", | |
| "column": feature_name, | |
| "description": f"Required feature '{feature_name}' has not been created yet.", | |
| } | |
| ) | |
| for raw_issue in raw_issues: | |
| signature = (raw_issue["issue_type"], raw_issue["column"]) | |
| if signature not in self._issue_id_map: | |
| self._issue_id_map[signature] = f"issue_{len(self._issue_id_map) + 1:03d}" | |
| issues: list[Issue] = [] | |
| signature_to_id = {signature: issue_id for signature, issue_id in self._issue_id_map.items()} | |
| for raw_issue in raw_issues: | |
| signature = (raw_issue["issue_type"], raw_issue["column"]) | |
| depends_on: list[str] = [] | |
| if raw_issue["issue_type"] == "wrong_dtype" and raw_issue["column"] in {"salary", "rating"}: | |
| missing_signature = ("missing", raw_issue["column"]) | |
| if missing_signature in signature_to_id: | |
| depends_on.append(signature_to_id[missing_signature]) | |
| if raw_issue["issue_type"] == "missing_feature": | |
| feature_name = raw_issue["column"] | |
| source_column = FEATURE_REGISTRY[feature_name]["source"] | |
| for dependency_type in ("missing", "wrong_dtype"): | |
| source_signature = (dependency_type, source_column) | |
| if source_signature in signature_to_id: | |
| depends_on.append(signature_to_id[source_signature]) | |
| issues.append( | |
| Issue( | |
| issue_id=signature_to_id[signature], | |
| issue_type=raw_issue["issue_type"], | |
| column=raw_issue["column"], | |
| description=raw_issue["description"], | |
| depends_on=depends_on, | |
| ) | |
| ) | |
| return issues | |
| def _build_column_infos(self) -> list[ColumnInfo]: | |
| if not self.dataset: | |
| return [] | |
| infos: list[ColumnInfo] = [] | |
| for column in self.dataset[0].keys(): | |
| values = [row.get(column) for row in self.dataset] | |
| non_missing = [value for value in values if not is_missing(value)] | |
| infos.append( | |
| ColumnInfo( | |
| name=column, | |
| dtype=self._infer_runtime_dtype(self.dataset, column), | |
| null_count=sum(1 for value in values if is_missing(value)), | |
| unique_count=len({str(value) for value in non_missing}), | |
| ) | |
| ) | |
| return infos | |
| def _infer_runtime_dtype(self, dataset: list[dict[str, Any]], column: str) -> str: | |
| values = [row.get(column) for row in dataset if not is_missing(row.get(column))] | |
| if not values: | |
| return self.expected_dtypes.get(column, "str") | |
| if all(isinstance(value, bool) for value in values): | |
| return "bool" | |
| if all(isinstance(value, int) and not isinstance(value, bool) for value in values): | |
| return "int" | |
| if all(isinstance(value, (int, float)) and not isinstance(value, bool) for value in values): | |
| return "float" | |
| return "str" | |
| def _has_duplicates(self, dataset: list[dict[str, Any]]) -> bool: | |
| seen: set[tuple[tuple[str, Any], ...]] = set() | |
| for row in dataset: | |
| key = tuple(sorted(row.items())) | |
| if key in seen: | |
| return True | |
| seen.add(key) | |
| return False | |
| def _has_inconsistent_categories(self, dataset: list[dict[str, Any]], column: str) -> bool: | |
| groups: dict[str, set[str]] = {} | |
| for row in dataset: | |
| value = row.get(column) | |
| if is_missing(value): | |
| continue | |
| normalized = str(value).lower() | |
| groups.setdefault(normalized, set()).add(str(value)) | |
| return any(len(forms) > 1 for forms in groups.values()) | |
| def _issue_present(self, issues: list[Issue], issue_type: str, column: str) -> bool: | |
| return any(issue.issue_type == issue_type and issue.column == column for issue in issues) | |
| def _apply_action(self, action: Action) -> None: | |
| if action.action_type == "fill_missing": | |
| self._apply_fill_missing(action.column, action.params["strategy"]) | |
| elif action.action_type == "drop_duplicates": | |
| unique_rows: list[dict[str, Any]] = [] | |
| seen: set[tuple[tuple[str, Any], ...]] = set() | |
| for row in self.dataset: | |
| key = tuple(sorted(row.items())) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| unique_rows.append(row) | |
| self.dataset = unique_rows | |
| elif action.action_type == "convert_dtype": | |
| target_dtype = action.params["target_dtype"] | |
| for row in self.dataset: | |
| value = row.get(action.column) | |
| if is_missing(value): | |
| row[action.column] = None | |
| else: | |
| row[action.column] = self._convert_value(value, target_dtype) | |
| elif action.action_type == "normalize_category": | |
| self._apply_normalize_category(action.column) | |
| elif action.action_type == "create_feature": | |
| self._apply_create_feature(action.params["feature_name"]) | |
| def _apply_fill_missing(self, column: str, strategy: str) -> None: | |
| expected_dtype = self.expected_dtypes.get(column, "str") | |
| valid_values = [row.get(column) for row in self.dataset if not is_missing(row.get(column))] | |
| if expected_dtype in {"int", "float"}: | |
| numeric_values = [self._convert_value(value, expected_dtype) for value in valid_values] | |
| if strategy == "mean": | |
| fill_value = sum(numeric_values) / len(numeric_values) | |
| elif strategy == "median": | |
| fill_value = median(numeric_values) | |
| else: | |
| fill_value = 0 | |
| if expected_dtype == "int": | |
| fill_value = int(round(fill_value)) | |
| else: | |
| if strategy == "mode": | |
| fill_value = self._pick_mode([str(value) for value in valid_values]) | |
| else: | |
| fill_value = "unknown" | |
| for row in self.dataset: | |
| if is_missing(row.get(column)): | |
| row[column] = fill_value | |
| def _apply_normalize_category(self, column: str) -> None: | |
| groups: dict[str, dict[str, int]] = {} | |
| for row in self.dataset: | |
| value = row.get(column) | |
| if is_missing(value): | |
| continue | |
| surface = str(value) | |
| groups.setdefault(surface.lower(), {}) | |
| groups[surface.lower()][surface] = groups[surface.lower()].get(surface, 0) + 1 | |
| canonical: dict[str, str] = {} | |
| for lowered, counts in groups.items(): | |
| canonical[lowered] = min( | |
| counts.items(), | |
| key=lambda item: (-item[1], item[0].lower(), 0 if item[0].islower() else 1, item[0]), | |
| )[0] | |
| for row in self.dataset: | |
| value = row.get(column) | |
| if is_missing(value): | |
| continue | |
| row[column] = canonical[str(value).lower()] | |
| def _apply_create_feature(self, feature_name: str) -> None: | |
| feature_config = FEATURE_REGISTRY[feature_name] | |
| source = feature_config["source"] | |
| bins = feature_config["bins"] | |
| labels = feature_config["labels"] | |
| for row in self.dataset: | |
| source_value = row.get(source) | |
| if is_missing(source_value): | |
| row[feature_name] = None | |
| continue | |
| numeric_value = float(source_value) | |
| assigned = None | |
| for index, label in enumerate(labels): | |
| lower = bins[index] | |
| upper = bins[index + 1] | |
| is_last = index == len(labels) - 1 | |
| if (lower <= numeric_value < upper) or (is_last and lower <= numeric_value <= upper): | |
| assigned = label | |
| break | |
| row[feature_name] = assigned | |
| def _pick_mode(self, values: list[str]) -> str: | |
| counts: dict[str, int] = {} | |
| for value in values: | |
| counts[value] = counts.get(value, 0) + 1 | |
| return min( | |
| counts.items(), | |
| key=lambda item: (-item[1], item[0].lower(), 0 if item[0].islower() else 1, item[0]), | |
| )[0] | |
| def _convert_value(self, value: Any, target_dtype: str) -> Any: | |
| if target_dtype == "int": | |
| return int(float(str(value))) | |
| if target_dtype == "float": | |
| return float(str(value)) | |
| if target_dtype == "bool": | |
| normalized = str(value).strip().lower() | |
| return normalized in {"true", "1", "yes"} | |
| return str(value) | |