Spaces:
Sleeping
Sleeping
| """ | |
| Core environment logic for the Data Cleaning RL Environment. | |
| Implements the openenv Environment interface so that create_app() can | |
| expose /ws, /reset, /step, /state endpoints automatically. | |
| """ | |
| from __future__ import annotations | |
| import difflib | |
| import uuid | |
| from typing import Any, Optional | |
| import numpy as np | |
| import pandas as pd | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| from data_cleaning_env.action_registry import ACTION_COSTS | |
| from data_cleaning_env.datasets import load_clean_dataset | |
| from data_cleaning_env.grader import compute_quality_score | |
| from data_cleaning_env.noise_injector import inject_noise | |
| from data_cleaning_env.models import ( | |
| ActionType, | |
| CleaningAction, | |
| ColumnIssues, | |
| ColumnStats, | |
| Observation, | |
| ) | |
| MAX_STEPS: dict[str, int] = { | |
| "easy": 20, | |
| "medium": 40, | |
| "hard": 60, | |
| "expert": 80, | |
| } | |
| REWARD_CLIP = 0.1 | |
| EPISODE_BUDGET: dict[str, float] = { | |
| "easy": 1.0, | |
| "medium": 1.0, | |
| "hard": 1.0, | |
| "expert": 1.0, | |
| } | |
| class DataCleaningEnvironment(Environment): | |
| """ | |
| OpenEnv-compatible Data Cleaning Environment. | |
| Each instance manages a single episode. The openenv server framework | |
| creates a new instance per WebSocket session via the factory callable. | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._state = State(episode_id=str(uuid.uuid4()), step_count=0) | |
| # Episode data — populated by reset() | |
| self._ep: dict[str, Any] = {} | |
| # Multi-episode store for REST grader endpoint | |
| self.episodes: dict[str, dict[str, Any]] = {} | |
| # ------------------------------------------------------------------ | |
| # openenv Environment interface | |
| # ------------------------------------------------------------------ | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| task: str = "easy", | |
| difficulty: Optional[float] = None, | |
| use_synthetic: bool = False, | |
| **kwargs, | |
| ) -> Observation: | |
| """Start a new episode and return the initial observation.""" | |
| if task not in MAX_STEPS: | |
| raise ValueError(f"Unknown task '{task}'. Must be one of: {list(MAX_STEPS)}") | |
| eid = episode_id or str(uuid.uuid4()) | |
| self._state = State(episode_id=eid, step_count=0) | |
| if use_synthetic: | |
| from data_cleaning_env.datasets import generate_synthetic_dataset | |
| import time as _time | |
| syn_seed = int(_time.time()) % 100000 | |
| clean_df, target_col = generate_synthetic_dataset(seed=syn_seed) | |
| else: | |
| clean_df, target_col = load_clean_dataset(task) | |
| dirty_df = inject_noise(clean_df, task, severity=difficulty) | |
| initial_quality = compute_quality_score(dirty_df, clean_df)["composite"] | |
| import time as _time | |
| self._ep = { | |
| "task": task, | |
| "clean_df": clean_df, | |
| "current_df": dirty_df.copy(), | |
| "dirty_df": dirty_df.copy(), | |
| "initial_quality": initial_quality, | |
| "target_col": target_col, | |
| "created_at": _time.monotonic(), | |
| "step": 0, | |
| "max_steps": MAX_STEPS[task], | |
| "done": False, | |
| "prev_quality": initial_quality, | |
| "action_history": [], | |
| "budget": EPISODE_BUDGET.get(task, 1.0), | |
| "cost_used": 0.0, | |
| "checkpoint_df": None, | |
| } | |
| # Also store in episodes dict for REST grader/baseline | |
| self.episodes[eid] = self._ep | |
| return self._make_observation(reward=0.0) | |
| def step(self, action: CleaningAction, timeout_s=None, **kwargs) -> Observation: | |
| """Apply a cleaning action and return the new observation.""" | |
| ep = self._ep | |
| if not ep: | |
| raise ValueError("No active episode. Call reset() first.") | |
| if ep["done"]: | |
| raise ValueError("Episode already done. Call reset() first.") | |
| reward = 0.0 | |
| df = ep["current_df"] | |
| _NON_MUTATING = {ActionType.done, ActionType.profile_column} | |
| if action.action_type == ActionType.done: | |
| ep["done"] = True | |
| reward = 0.0 | |
| cost = ACTION_COSTS.get("done", 0.0) | |
| ep["cost_used"] += cost | |
| elif action.action_type in _NON_MUTATING: | |
| try: | |
| df = self._apply_action(df, action, ep) | |
| ep["current_df"] = df | |
| except Exception: | |
| pass | |
| reward = 0.0 | |
| cost = ACTION_COSTS.get(action.action_type.value, 0.0) | |
| ep["cost_used"] += cost | |
| ep["action_history"].append({ | |
| "action_type": action.action_type.value, | |
| "column": action.column, | |
| "reward": reward, | |
| "step": ep["step"], | |
| }) | |
| else: | |
| try: | |
| if action.action_type != ActionType.undo: | |
| ep["checkpoint_df"] = ep["current_df"].copy() | |
| df = self._apply_action(df, action, ep) | |
| ep["current_df"] = df | |
| new_quality = compute_quality_score(df, ep["clean_df"])["composite"] | |
| delta = new_quality - ep["prev_quality"] | |
| reward = float(np.clip(delta, -REWARD_CLIP, REWARD_CLIP)) | |
| ep["prev_quality"] = new_quality | |
| cost = ACTION_COSTS.get(action.action_type.value, 0.0) | |
| ep["cost_used"] += cost | |
| cost_efficiency = 1.0 - min(ep["cost_used"] / ep["budget"], 1.0) | |
| reward *= max(cost_efficiency, 0.1) | |
| ep["action_history"].append({ | |
| "action_type": action.action_type.value, | |
| "column": action.column, | |
| "reward": reward, | |
| "step": ep["step"], | |
| }) | |
| except Exception: | |
| reward = -0.05 | |
| cost = ACTION_COSTS.get(action.action_type.value, 0.0) | |
| ep["cost_used"] += cost | |
| ep["action_history"].append({ | |
| "action_type": action.action_type.value, | |
| "column": action.column, | |
| "reward": reward, | |
| "step": ep["step"], | |
| }) | |
| ep["step"] += 1 | |
| self._state.step_count = ep["step"] | |
| if ep["step"] >= ep["max_steps"]: | |
| ep["done"] = True | |
| reward = 0.0 | |
| return self._make_observation(reward=reward) | |
| def state(self) -> State: | |
| return self._state | |
| def close(self) -> None: | |
| pass | |
| # ------------------------------------------------------------------ | |
| # Action application | |
| # ------------------------------------------------------------------ | |
| def _apply_action(self, df: pd.DataFrame, action: CleaningAction, ep: dict) -> pd.DataFrame: | |
| df = df.copy() | |
| col = action.column | |
| if action.action_type == ActionType.fill_missing: | |
| df = self._fill_missing(df, col, action) | |
| elif action.action_type == ActionType.drop_duplicates: | |
| df = df.drop_duplicates().reset_index(drop=True) | |
| elif action.action_type == ActionType.fix_type: | |
| df = self._fix_type(df, col, action) | |
| elif action.action_type == ActionType.normalize: | |
| df = self._normalize(df, col) | |
| elif action.action_type == ActionType.drop_outliers: | |
| df = self._drop_outliers(df, col, action) | |
| elif action.action_type == ActionType.fix_schema_violation: | |
| df = self._fix_schema_violation(df, col, action, ep) | |
| elif action.action_type == ActionType.rename_column: | |
| df = self._rename_column(df, col, action) | |
| elif action.action_type == ActionType.cast_datetime: | |
| df = self._cast_datetime(df, col, action) | |
| elif action.action_type == ActionType.deduplicate_fuzzy: | |
| df = self._deduplicate_fuzzy(df, col, action) | |
| elif action.action_type == ActionType.split_column: | |
| df = self._split_column(df, col, action) | |
| elif action.action_type == ActionType.merge_columns: | |
| df = self._merge_columns(df, col, action) | |
| elif action.action_type == ActionType.fix_format_regex: | |
| df = self._fix_format_regex(df, col, action) | |
| elif action.action_type == ActionType.standardize_categories: | |
| df = self._standardize_categories(df, col, action) | |
| elif action.action_type == ActionType.undo: | |
| if ep.get("checkpoint_df") is None: | |
| raise ValueError("Nothing to undo.") | |
| df = ep["checkpoint_df"].copy() | |
| ep["checkpoint_df"] = None | |
| elif action.action_type == ActionType.profile_column: | |
| df = self._profile_column(df, col, ep) | |
| else: | |
| raise ValueError(f"Unhandled action type: {action.action_type}") | |
| return df | |
| # ------------------------------------------------------------------ | |
| # Action implementations | |
| # ------------------------------------------------------------------ | |
| def _fill_missing(self, df, col, action): | |
| if col not in df.columns: | |
| raise ValueError(f"Column '{col}' not found.") | |
| strategy = action.strategy.value if action.strategy else "median" | |
| numeric = pd.to_numeric(df[col], errors="coerce") | |
| non_null = df[col].dropna() | |
| is_numeric = (numeric.notna().sum() / max(len(non_null), 1) > 0.9) if len(non_null) > 0 else False | |
| if strategy == "mode": | |
| mode_vals = df[col].mode() | |
| fill_value = mode_vals.iloc[0] if not mode_vals.empty else None | |
| if fill_value is not None: | |
| df[col] = df[col].fillna(fill_value) | |
| elif strategy == "constant": | |
| df[col] = df[col].fillna(action.constant_value) | |
| elif is_numeric: | |
| fill_value = numeric.mean() if strategy == "mean" else numeric.median() | |
| df[col] = numeric.fillna(fill_value) | |
| else: | |
| mode_vals = df[col].mode() | |
| fill_value = mode_vals.iloc[0] if not mode_vals.empty else "" | |
| df[col] = df[col].fillna(fill_value) | |
| return df | |
| def _fix_type(self, df, col, action): | |
| if col not in df.columns: | |
| raise ValueError(f"Column '{col}' not found.") | |
| dtype = action.dtype.value if action.dtype else "float" | |
| if dtype in ("int", "float"): | |
| coerced = pd.to_numeric(df[col], errors="coerce") | |
| df[col] = coerced.astype("Int64") if dtype == "int" else coerced.astype("float64") | |
| else: | |
| df[col] = df[col].astype(str) | |
| return df | |
| def _normalize(self, df, col): | |
| if col not in df.columns: | |
| raise ValueError(f"Column '{col}' not found.") | |
| numeric = pd.to_numeric(df[col], errors="coerce") | |
| mean, std = numeric.mean(), numeric.std() | |
| if pd.isna(mean) or std == 0 or pd.isna(std): | |
| return df | |
| df[col] = (numeric - mean) / std | |
| return df | |
| def _drop_outliers(self, df, col, action): | |
| if col not in df.columns: | |
| raise ValueError(f"Column '{col}' not found.") | |
| numeric = pd.to_numeric(df[col], errors="coerce") | |
| method = action.method.value if action.method else "iqr" | |
| if method == "iqr": | |
| q1, q3 = numeric.quantile(0.25), numeric.quantile(0.75) | |
| iqr = q3 - q1 | |
| if iqr == 0: | |
| return df | |
| mask = numeric.between(q1 - 1.5 * iqr, q3 + 1.5 * iqr) | numeric.isna() | |
| else: | |
| mean, std = numeric.mean(), numeric.std() | |
| if std == 0 or pd.isna(std): | |
| return df | |
| z = (numeric - mean) / std | |
| mask = z.abs() < 3 | |
| return df[mask].reset_index(drop=True) | |
| def _fix_schema_violation(self, df, col, action, ep): | |
| if col not in df.columns: | |
| raise ValueError(f"Column '{col}' not found.") | |
| numeric = pd.to_numeric(df[col], errors="coerce") | |
| constraint = action.constraint.value if action.constraint else "non_negative" | |
| if constraint == "non_negative": | |
| df[col] = numeric.clip(lower=0) | |
| elif constraint == "clamp_range": | |
| clean_col = pd.to_numeric(ep["clean_df"][col], errors="coerce") | |
| lo, hi = clean_col.quantile(0.05), clean_col.quantile(0.95) | |
| df[col] = numeric.clip(lo, hi) | |
| return df | |
| def _rename_column(self, df, col, action): | |
| if col not in df.columns: | |
| raise ValueError(f"Column '{col}' not found.") | |
| if not action.new_name: | |
| raise ValueError("new_name is required.") | |
| if action.new_name in df.columns: | |
| raise ValueError(f"Column '{action.new_name}' already exists.") | |
| return df.rename(columns={col: action.new_name}) | |
| def _cast_datetime(self, df, col, action): | |
| if col not in df.columns: | |
| raise ValueError(f"Column '{col}' not found.") | |
| df[col] = pd.to_datetime(df[col], format=action.datetime_format, errors="coerce") | |
| return df | |
| def _deduplicate_fuzzy(self, df, col, action): | |
| if col not in df.columns: | |
| raise ValueError(f"Column '{col}' not found.") | |
| threshold = action.threshold if action.threshold is not None else 0.8 | |
| values = df[col].dropna().astype(str) | |
| unique_vals = values.unique() | |
| if len(unique_vals) > 500: | |
| top_vals = values.value_counts().head(500).index.tolist() | |
| unique_vals = top_vals | |
| mapping: dict[str, str] = {} | |
| processed: set[str] = set() | |
| for val in unique_vals: | |
| if val in processed: | |
| continue | |
| group = [val] | |
| for other in unique_vals: | |
| if other == val or other in processed: | |
| continue | |
| ratio = difflib.SequenceMatcher(None, val.lower(), other.lower()).ratio() | |
| if ratio >= threshold: | |
| group.append(other) | |
| group_counts = {v: int((values == v).sum()) for v in group} | |
| canonical = max(group_counts, key=group_counts.get) | |
| for v in group: | |
| mapping[v] = canonical | |
| processed.add(v) | |
| mask = df[col].notna() | |
| df.loc[mask, col] = df.loc[mask, col].astype(str).map(lambda x: mapping.get(x, x)) | |
| return df | |
| def _split_column(self, df, col, action): | |
| if col not in df.columns: | |
| raise ValueError(f"Column '{col}' not found.") | |
| delimiter = action.delimiter or "," | |
| split_result = df[col].astype(str).str.split(delimiter, n=1, expand=True) | |
| df[f"{col}_0"] = split_result[0] if 0 in split_result.columns else None | |
| df[f"{col}_1"] = split_result[1] if 1 in split_result.columns else None | |
| return df.drop(columns=[col]) | |
| def _merge_columns(self, df, col, action): | |
| if col not in df.columns: | |
| raise ValueError(f"Column '{col}' not found.") | |
| col2 = action.column2 | |
| if not col2 or col2 not in df.columns: | |
| raise ValueError(f"column2 '{col2}' not found.") | |
| strategy = action.merge_strategy.value if action.merge_strategy else "concat" | |
| if strategy == "concat": | |
| df[col] = df[col].astype(str) + " " + df[col2].astype(str) | |
| elif strategy == "first_non_null": | |
| df[col] = df[col].fillna(df[col2]) | |
| elif strategy == "sum": | |
| num1 = pd.to_numeric(df[col], errors="coerce") | |
| num2 = pd.to_numeric(df[col2], errors="coerce") | |
| df[col] = num1.fillna(0) + num2.fillna(0) | |
| return df.drop(columns=[col2]) | |
| def _fix_format_regex(self, df, col, action): | |
| if col not in df.columns: | |
| raise ValueError(f"Column '{col}' not found.") | |
| if not action.pattern: | |
| raise ValueError("pattern is required.") | |
| replacement = action.replacement or "" | |
| try: | |
| df[col] = df[col].astype(str).str.replace(action.pattern, replacement, regex=True) | |
| except Exception as e: | |
| raise ValueError(f"Invalid regex: {e}") | |
| return df | |
| def _standardize_categories(self, df, col, action): | |
| if col not in df.columns: | |
| raise ValueError(f"Column '{col}' not found.") | |
| mask = df[col].notna() | |
| df.loc[mask, col] = ( | |
| df.loc[mask, col].astype(str).str.lower().str.strip().str.replace(r"\s+", " ", regex=True) | |
| ) | |
| return df | |
| def _profile_column(self, df, col, ep): | |
| if col not in df.columns: | |
| raise ValueError(f"Column '{col}' not found.") | |
| col_data = df[col] | |
| profile: dict = {"column": col} | |
| numeric = pd.to_numeric(col_data, errors="coerce") | |
| if numeric.notna().sum() > 0: | |
| profile["min"] = float(numeric.min()) if numeric.notna().any() else None | |
| profile["max"] = float(numeric.max()) if numeric.notna().any() else None | |
| profile["median"] = float(numeric.median()) if numeric.notna().any() else None | |
| profile["q25"] = float(numeric.quantile(0.25)) if numeric.notna().any() else None | |
| profile["q75"] = float(numeric.quantile(0.75)) if numeric.notna().any() else None | |
| value_counts = col_data.dropna().astype(str).value_counts().head(10) | |
| profile["top_values"] = {str(k): int(v) for k, v in value_counts.items()} | |
| null_mask = col_data.isna() | |
| profile["null_positions"] = [int(i) for i in null_mask[null_mask].index[:5]] | |
| if col_data.dtype == object or str(col_data.dtype) == "string": | |
| sample = col_data.dropna().astype(str).head(20) | |
| patterns = set() | |
| for v in sample: | |
| if v.replace("-", "").replace("/", "").isdigit(): | |
| patterns.add("date-like") | |
| elif v.replace(".", "").replace("-", "").isdigit(): | |
| patterns.add("numeric-string") | |
| elif v.isupper(): | |
| patterns.add("UPPERCASE") | |
| elif v.islower(): | |
| patterns.add("lowercase") | |
| else: | |
| patterns.add("mixed-case") | |
| profile["value_patterns"] = list(patterns) | |
| ep["last_profile"] = profile | |
| return df | |
| # ------------------------------------------------------------------ | |
| # Observation assembly | |
| # ------------------------------------------------------------------ | |
| def _make_observation(self, reward: float) -> Observation: | |
| ep = self._ep | |
| df = ep["current_df"] | |
| clean_df = ep["clean_df"] | |
| column_issues: dict[str, Any] = {} | |
| column_stats: dict[str, Any] = {} | |
| n_dups = int(df.duplicated().sum()) | |
| all_cols = list(dict.fromkeys(list(df.columns) + list(clean_df.columns))) | |
| for col in all_cols: | |
| if col not in df.columns: | |
| column_issues[col] = ColumnIssues( | |
| missing_count=len(clean_df), missing_pct=1.0, | |
| type_errors=0, outlier_count=0, has_duplicates=n_dups > 0, | |
| ).model_dump() | |
| column_stats[col] = ColumnStats(null_count=len(clean_df), unique_count=0).model_dump() | |
| continue | |
| col_data = df[col] | |
| numeric = pd.to_numeric(col_data, errors="coerce") | |
| has_clean_ref = col in clean_df.columns | |
| clean_numeric = pd.to_numeric(clean_df[col], errors="coerce") if has_clean_ref else pd.Series(dtype=float) | |
| type_errs = 0 | |
| if has_clean_ref and clean_numeric.notna().mean() > 0.9: | |
| type_errs = max(0, int(numeric.isna().sum() - col_data.isna().sum())) | |
| outlier_count = 0 | |
| if numeric.notna().sum() > 4: | |
| q1, q3 = numeric.quantile(0.25), numeric.quantile(0.75) | |
| iqr = q3 - q1 | |
| if iqr > 0: | |
| outlier_count = int(((numeric < q1 - 1.5 * iqr) | (numeric > q3 + 1.5 * iqr)).sum()) | |
| format_violation_count = 0 | |
| clean_col = clean_df[col] if col in clean_df.columns else pd.Series(dtype=object) | |
| cn = pd.to_numeric(clean_col, errors="coerce") | |
| if len(clean_col) > 0 and cn.notna().mean() > 0.9: | |
| dirty_numeric = pd.to_numeric(col_data, errors="coerce") | |
| non_null_mask = col_data.notna() | |
| format_violation_count = max(0, int((non_null_mask & dirty_numeric.isna()).sum() - col_data.isna().sum())) | |
| if (cn.dropna() >= 0).all(): | |
| format_violation_count += int((dirty_numeric < 0).sum()) | |
| else: | |
| if col_data.dtype == object or str(col_data.dtype) == "string": | |
| str_vals = col_data.dropna().astype(str) | |
| if len(str_vals) > 0: | |
| format_violation_count = int((str_vals != str_vals.str.strip()).sum()) | |
| encoding_issue_count = 0 | |
| if col_data.dtype == object or str(col_data.dtype) == "string": | |
| sample_vals = col_data.dropna().astype(str) | |
| if len(sample_vals) > 200: | |
| sample_vals = sample_vals.sample(n=200, random_state=42) | |
| clean_chars = set() | |
| if col in clean_df.columns: | |
| for v in clean_df[col].dropna().astype(str): | |
| clean_chars.update(c for c in v if ord(c) > 127) | |
| for v in sample_vals: | |
| if any(ord(c) > 127 and c not in clean_chars for c in v): | |
| encoding_issue_count += 1 | |
| semantic_duplicate_count = 0 | |
| if col_data.dtype == object or str(col_data.dtype) == "string": | |
| unique_vals = col_data.dropna().astype(str).unique() | |
| if len(unique_vals) < 500: | |
| normalized = {} | |
| for v in unique_vals: | |
| key = v.lower().strip() | |
| normalized.setdefault(key, []).append(v) | |
| for forms in normalized.values(): | |
| if len(forms) > 1: | |
| semantic_duplicate_count += len(forms) - 1 | |
| column_issues[col] = ColumnIssues( | |
| missing_count=int(col_data.isna().sum()), | |
| missing_pct=round(float(col_data.isna().mean()), 4), | |
| type_errors=type_errs, | |
| outlier_count=outlier_count, | |
| has_duplicates=n_dups > 0, | |
| format_violation_count=format_violation_count, | |
| encoding_issue_count=encoding_issue_count, | |
| semantic_duplicate_count=semantic_duplicate_count, | |
| ).model_dump() | |
| column_stats[col] = ColumnStats( | |
| mean=round(float(numeric.mean()), 4) if numeric.notna().any() else None, | |
| std=round(float(numeric.std()), 4) if numeric.notna().sum() > 1 else None, | |
| null_count=int(col_data.isna().sum()), | |
| unique_count=int(col_data.nunique()), | |
| ).model_dump() | |
| sample_size = min(5, len(df)) | |
| sample_rows = [] | |
| if sample_size > 0: | |
| sample = df.sample(n=sample_size, random_state=ep["step"]) | |
| for _, row in sample.iterrows(): | |
| sanitized = {} | |
| for k, v in row.items(): | |
| if pd.isna(v): | |
| sanitized[k] = None | |
| elif hasattr(v, "item"): | |
| sanitized[k] = v.item() | |
| elif isinstance(v, str) and len(v) > 100: | |
| sanitized[k] = v[:100] | |
| else: | |
| sanitized[k] = v | |
| sample_rows.append(sanitized) | |
| action_history = ep.get("action_history", [])[-5:] | |
| profile_result = ep.pop("last_profile", None) | |
| return Observation( | |
| done=ep["done"], | |
| reward=reward, | |
| task=ep["task"], | |
| step=ep["step"], | |
| max_steps=ep["max_steps"], | |
| columns=list(df.columns), | |
| column_issues=column_issues, | |
| column_stats=column_stats, | |
| sample_rows=sample_rows, | |
| action_history=action_history, | |
| budget_remaining=max(0.0, 1.0 - ep.get("cost_used", 0.0) / ep.get("budget", 1.0)), | |
| profile_result=profile_result, | |
| ) | |