""" Core environment logic for the Data Cleaning RL Environment. Implements the openenv Environment interface so that create_app() can expose /ws, /reset, /step, /state endpoints automatically. """ from __future__ import annotations import difflib import uuid from typing import Any, Optional import numpy as np import pandas as pd from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import State from data_cleaning_env.action_registry import ACTION_COSTS from data_cleaning_env.datasets import load_clean_dataset from data_cleaning_env.grader import compute_quality_score from data_cleaning_env.noise_injector import inject_noise from data_cleaning_env.models import ( ActionType, CleaningAction, ColumnIssues, ColumnStats, Observation, ) MAX_STEPS: dict[str, int] = { "easy": 20, "medium": 40, "hard": 60, "expert": 80, } REWARD_CLIP = 0.1 EPISODE_BUDGET: dict[str, float] = { "easy": 1.0, "medium": 1.0, "hard": 1.0, "expert": 1.0, } class DataCleaningEnvironment(Environment): """ OpenEnv-compatible Data Cleaning Environment. Each instance manages a single episode. The openenv server framework creates a new instance per WebSocket session via the factory callable. """ SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self) -> None: super().__init__() self._state = State(episode_id=str(uuid.uuid4()), step_count=0) # Episode data — populated by reset() self._ep: dict[str, Any] = {} # Multi-episode store for REST grader endpoint self.episodes: dict[str, dict[str, Any]] = {} # ------------------------------------------------------------------ # openenv Environment interface # ------------------------------------------------------------------ def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, task: str = "easy", difficulty: Optional[float] = None, use_synthetic: bool = False, **kwargs, ) -> Observation: """Start a new episode and return the initial observation.""" if task not in MAX_STEPS: raise ValueError(f"Unknown task '{task}'. Must be one of: {list(MAX_STEPS)}") eid = episode_id or str(uuid.uuid4()) self._state = State(episode_id=eid, step_count=0) if use_synthetic: from data_cleaning_env.datasets import generate_synthetic_dataset import time as _time syn_seed = int(_time.time()) % 100000 clean_df, target_col = generate_synthetic_dataset(seed=syn_seed) else: clean_df, target_col = load_clean_dataset(task) dirty_df = inject_noise(clean_df, task, severity=difficulty) initial_quality = compute_quality_score(dirty_df, clean_df)["composite"] import time as _time self._ep = { "task": task, "clean_df": clean_df, "current_df": dirty_df.copy(), "dirty_df": dirty_df.copy(), "initial_quality": initial_quality, "target_col": target_col, "created_at": _time.monotonic(), "step": 0, "max_steps": MAX_STEPS[task], "done": False, "prev_quality": initial_quality, "action_history": [], "budget": EPISODE_BUDGET.get(task, 1.0), "cost_used": 0.0, "checkpoint_df": None, } # Also store in episodes dict for REST grader/baseline self.episodes[eid] = self._ep return self._make_observation(reward=0.0) def step(self, action: CleaningAction, timeout_s=None, **kwargs) -> Observation: """Apply a cleaning action and return the new observation.""" ep = self._ep if not ep: raise ValueError("No active episode. Call reset() first.") if ep["done"]: raise ValueError("Episode already done. Call reset() first.") reward = 0.0 df = ep["current_df"] _NON_MUTATING = {ActionType.done, ActionType.profile_column} if action.action_type == ActionType.done: ep["done"] = True reward = 0.0 cost = ACTION_COSTS.get("done", 0.0) ep["cost_used"] += cost elif action.action_type in _NON_MUTATING: try: df = self._apply_action(df, action, ep) ep["current_df"] = df except Exception: pass reward = 0.0 cost = ACTION_COSTS.get(action.action_type.value, 0.0) ep["cost_used"] += cost ep["action_history"].append({ "action_type": action.action_type.value, "column": action.column, "reward": reward, "step": ep["step"], }) else: try: if action.action_type != ActionType.undo: ep["checkpoint_df"] = ep["current_df"].copy() df = self._apply_action(df, action, ep) ep["current_df"] = df new_quality = compute_quality_score(df, ep["clean_df"])["composite"] delta = new_quality - ep["prev_quality"] reward = float(np.clip(delta, -REWARD_CLIP, REWARD_CLIP)) ep["prev_quality"] = new_quality cost = ACTION_COSTS.get(action.action_type.value, 0.0) ep["cost_used"] += cost cost_efficiency = 1.0 - min(ep["cost_used"] / ep["budget"], 1.0) reward *= max(cost_efficiency, 0.1) ep["action_history"].append({ "action_type": action.action_type.value, "column": action.column, "reward": reward, "step": ep["step"], }) except Exception: reward = -0.05 cost = ACTION_COSTS.get(action.action_type.value, 0.0) ep["cost_used"] += cost ep["action_history"].append({ "action_type": action.action_type.value, "column": action.column, "reward": reward, "step": ep["step"], }) ep["step"] += 1 self._state.step_count = ep["step"] if ep["step"] >= ep["max_steps"]: ep["done"] = True reward = 0.0 return self._make_observation(reward=reward) @property def state(self) -> State: return self._state def close(self) -> None: pass # ------------------------------------------------------------------ # Action application # ------------------------------------------------------------------ def _apply_action(self, df: pd.DataFrame, action: CleaningAction, ep: dict) -> pd.DataFrame: df = df.copy() col = action.column if action.action_type == ActionType.fill_missing: df = self._fill_missing(df, col, action) elif action.action_type == ActionType.drop_duplicates: df = df.drop_duplicates().reset_index(drop=True) elif action.action_type == ActionType.fix_type: df = self._fix_type(df, col, action) elif action.action_type == ActionType.normalize: df = self._normalize(df, col) elif action.action_type == ActionType.drop_outliers: df = self._drop_outliers(df, col, action) elif action.action_type == ActionType.fix_schema_violation: df = self._fix_schema_violation(df, col, action, ep) elif action.action_type == ActionType.rename_column: df = self._rename_column(df, col, action) elif action.action_type == ActionType.cast_datetime: df = self._cast_datetime(df, col, action) elif action.action_type == ActionType.deduplicate_fuzzy: df = self._deduplicate_fuzzy(df, col, action) elif action.action_type == ActionType.split_column: df = self._split_column(df, col, action) elif action.action_type == ActionType.merge_columns: df = self._merge_columns(df, col, action) elif action.action_type == ActionType.fix_format_regex: df = self._fix_format_regex(df, col, action) elif action.action_type == ActionType.standardize_categories: df = self._standardize_categories(df, col, action) elif action.action_type == ActionType.undo: if ep.get("checkpoint_df") is None: raise ValueError("Nothing to undo.") df = ep["checkpoint_df"].copy() ep["checkpoint_df"] = None elif action.action_type == ActionType.profile_column: df = self._profile_column(df, col, ep) else: raise ValueError(f"Unhandled action type: {action.action_type}") return df # ------------------------------------------------------------------ # Action implementations # ------------------------------------------------------------------ def _fill_missing(self, df, col, action): if col not in df.columns: raise ValueError(f"Column '{col}' not found.") strategy = action.strategy.value if action.strategy else "median" numeric = pd.to_numeric(df[col], errors="coerce") non_null = df[col].dropna() is_numeric = (numeric.notna().sum() / max(len(non_null), 1) > 0.9) if len(non_null) > 0 else False if strategy == "mode": mode_vals = df[col].mode() fill_value = mode_vals.iloc[0] if not mode_vals.empty else None if fill_value is not None: df[col] = df[col].fillna(fill_value) elif strategy == "constant": df[col] = df[col].fillna(action.constant_value) elif is_numeric: fill_value = numeric.mean() if strategy == "mean" else numeric.median() df[col] = numeric.fillna(fill_value) else: mode_vals = df[col].mode() fill_value = mode_vals.iloc[0] if not mode_vals.empty else "" df[col] = df[col].fillna(fill_value) return df def _fix_type(self, df, col, action): if col not in df.columns: raise ValueError(f"Column '{col}' not found.") dtype = action.dtype.value if action.dtype else "float" if dtype in ("int", "float"): coerced = pd.to_numeric(df[col], errors="coerce") df[col] = coerced.astype("Int64") if dtype == "int" else coerced.astype("float64") else: df[col] = df[col].astype(str) return df def _normalize(self, df, col): if col not in df.columns: raise ValueError(f"Column '{col}' not found.") numeric = pd.to_numeric(df[col], errors="coerce") mean, std = numeric.mean(), numeric.std() if pd.isna(mean) or std == 0 or pd.isna(std): return df df[col] = (numeric - mean) / std return df def _drop_outliers(self, df, col, action): if col not in df.columns: raise ValueError(f"Column '{col}' not found.") numeric = pd.to_numeric(df[col], errors="coerce") method = action.method.value if action.method else "iqr" if method == "iqr": q1, q3 = numeric.quantile(0.25), numeric.quantile(0.75) iqr = q3 - q1 if iqr == 0: return df mask = numeric.between(q1 - 1.5 * iqr, q3 + 1.5 * iqr) | numeric.isna() else: mean, std = numeric.mean(), numeric.std() if std == 0 or pd.isna(std): return df z = (numeric - mean) / std mask = z.abs() < 3 return df[mask].reset_index(drop=True) def _fix_schema_violation(self, df, col, action, ep): if col not in df.columns: raise ValueError(f"Column '{col}' not found.") numeric = pd.to_numeric(df[col], errors="coerce") constraint = action.constraint.value if action.constraint else "non_negative" if constraint == "non_negative": df[col] = numeric.clip(lower=0) elif constraint == "clamp_range": clean_col = pd.to_numeric(ep["clean_df"][col], errors="coerce") lo, hi = clean_col.quantile(0.05), clean_col.quantile(0.95) df[col] = numeric.clip(lo, hi) return df def _rename_column(self, df, col, action): if col not in df.columns: raise ValueError(f"Column '{col}' not found.") if not action.new_name: raise ValueError("new_name is required.") if action.new_name in df.columns: raise ValueError(f"Column '{action.new_name}' already exists.") return df.rename(columns={col: action.new_name}) def _cast_datetime(self, df, col, action): if col not in df.columns: raise ValueError(f"Column '{col}' not found.") df[col] = pd.to_datetime(df[col], format=action.datetime_format, errors="coerce") return df def _deduplicate_fuzzy(self, df, col, action): if col not in df.columns: raise ValueError(f"Column '{col}' not found.") threshold = action.threshold if action.threshold is not None else 0.8 values = df[col].dropna().astype(str) unique_vals = values.unique() if len(unique_vals) > 500: top_vals = values.value_counts().head(500).index.tolist() unique_vals = top_vals mapping: dict[str, str] = {} processed: set[str] = set() for val in unique_vals: if val in processed: continue group = [val] for other in unique_vals: if other == val or other in processed: continue ratio = difflib.SequenceMatcher(None, val.lower(), other.lower()).ratio() if ratio >= threshold: group.append(other) group_counts = {v: int((values == v).sum()) for v in group} canonical = max(group_counts, key=group_counts.get) for v in group: mapping[v] = canonical processed.add(v) mask = df[col].notna() df.loc[mask, col] = df.loc[mask, col].astype(str).map(lambda x: mapping.get(x, x)) return df def _split_column(self, df, col, action): if col not in df.columns: raise ValueError(f"Column '{col}' not found.") delimiter = action.delimiter or "," split_result = df[col].astype(str).str.split(delimiter, n=1, expand=True) df[f"{col}_0"] = split_result[0] if 0 in split_result.columns else None df[f"{col}_1"] = split_result[1] if 1 in split_result.columns else None return df.drop(columns=[col]) def _merge_columns(self, df, col, action): if col not in df.columns: raise ValueError(f"Column '{col}' not found.") col2 = action.column2 if not col2 or col2 not in df.columns: raise ValueError(f"column2 '{col2}' not found.") strategy = action.merge_strategy.value if action.merge_strategy else "concat" if strategy == "concat": df[col] = df[col].astype(str) + " " + df[col2].astype(str) elif strategy == "first_non_null": df[col] = df[col].fillna(df[col2]) elif strategy == "sum": num1 = pd.to_numeric(df[col], errors="coerce") num2 = pd.to_numeric(df[col2], errors="coerce") df[col] = num1.fillna(0) + num2.fillna(0) return df.drop(columns=[col2]) def _fix_format_regex(self, df, col, action): if col not in df.columns: raise ValueError(f"Column '{col}' not found.") if not action.pattern: raise ValueError("pattern is required.") replacement = action.replacement or "" try: df[col] = df[col].astype(str).str.replace(action.pattern, replacement, regex=True) except Exception as e: raise ValueError(f"Invalid regex: {e}") return df def _standardize_categories(self, df, col, action): if col not in df.columns: raise ValueError(f"Column '{col}' not found.") mask = df[col].notna() df.loc[mask, col] = ( df.loc[mask, col].astype(str).str.lower().str.strip().str.replace(r"\s+", " ", regex=True) ) return df def _profile_column(self, df, col, ep): if col not in df.columns: raise ValueError(f"Column '{col}' not found.") col_data = df[col] profile: dict = {"column": col} numeric = pd.to_numeric(col_data, errors="coerce") if numeric.notna().sum() > 0: profile["min"] = float(numeric.min()) if numeric.notna().any() else None profile["max"] = float(numeric.max()) if numeric.notna().any() else None profile["median"] = float(numeric.median()) if numeric.notna().any() else None profile["q25"] = float(numeric.quantile(0.25)) if numeric.notna().any() else None profile["q75"] = float(numeric.quantile(0.75)) if numeric.notna().any() else None value_counts = col_data.dropna().astype(str).value_counts().head(10) profile["top_values"] = {str(k): int(v) for k, v in value_counts.items()} null_mask = col_data.isna() profile["null_positions"] = [int(i) for i in null_mask[null_mask].index[:5]] if col_data.dtype == object or str(col_data.dtype) == "string": sample = col_data.dropna().astype(str).head(20) patterns = set() for v in sample: if v.replace("-", "").replace("/", "").isdigit(): patterns.add("date-like") elif v.replace(".", "").replace("-", "").isdigit(): patterns.add("numeric-string") elif v.isupper(): patterns.add("UPPERCASE") elif v.islower(): patterns.add("lowercase") else: patterns.add("mixed-case") profile["value_patterns"] = list(patterns) ep["last_profile"] = profile return df # ------------------------------------------------------------------ # Observation assembly # ------------------------------------------------------------------ def _make_observation(self, reward: float) -> Observation: ep = self._ep df = ep["current_df"] clean_df = ep["clean_df"] column_issues: dict[str, Any] = {} column_stats: dict[str, Any] = {} n_dups = int(df.duplicated().sum()) all_cols = list(dict.fromkeys(list(df.columns) + list(clean_df.columns))) for col in all_cols: if col not in df.columns: column_issues[col] = ColumnIssues( missing_count=len(clean_df), missing_pct=1.0, type_errors=0, outlier_count=0, has_duplicates=n_dups > 0, ).model_dump() column_stats[col] = ColumnStats(null_count=len(clean_df), unique_count=0).model_dump() continue col_data = df[col] numeric = pd.to_numeric(col_data, errors="coerce") has_clean_ref = col in clean_df.columns clean_numeric = pd.to_numeric(clean_df[col], errors="coerce") if has_clean_ref else pd.Series(dtype=float) type_errs = 0 if has_clean_ref and clean_numeric.notna().mean() > 0.9: type_errs = max(0, int(numeric.isna().sum() - col_data.isna().sum())) outlier_count = 0 if numeric.notna().sum() > 4: q1, q3 = numeric.quantile(0.25), numeric.quantile(0.75) iqr = q3 - q1 if iqr > 0: outlier_count = int(((numeric < q1 - 1.5 * iqr) | (numeric > q3 + 1.5 * iqr)).sum()) format_violation_count = 0 clean_col = clean_df[col] if col in clean_df.columns else pd.Series(dtype=object) cn = pd.to_numeric(clean_col, errors="coerce") if len(clean_col) > 0 and cn.notna().mean() > 0.9: dirty_numeric = pd.to_numeric(col_data, errors="coerce") non_null_mask = col_data.notna() format_violation_count = max(0, int((non_null_mask & dirty_numeric.isna()).sum() - col_data.isna().sum())) if (cn.dropna() >= 0).all(): format_violation_count += int((dirty_numeric < 0).sum()) else: if col_data.dtype == object or str(col_data.dtype) == "string": str_vals = col_data.dropna().astype(str) if len(str_vals) > 0: format_violation_count = int((str_vals != str_vals.str.strip()).sum()) encoding_issue_count = 0 if col_data.dtype == object or str(col_data.dtype) == "string": sample_vals = col_data.dropna().astype(str) if len(sample_vals) > 200: sample_vals = sample_vals.sample(n=200, random_state=42) clean_chars = set() if col in clean_df.columns: for v in clean_df[col].dropna().astype(str): clean_chars.update(c for c in v if ord(c) > 127) for v in sample_vals: if any(ord(c) > 127 and c not in clean_chars for c in v): encoding_issue_count += 1 semantic_duplicate_count = 0 if col_data.dtype == object or str(col_data.dtype) == "string": unique_vals = col_data.dropna().astype(str).unique() if len(unique_vals) < 500: normalized = {} for v in unique_vals: key = v.lower().strip() normalized.setdefault(key, []).append(v) for forms in normalized.values(): if len(forms) > 1: semantic_duplicate_count += len(forms) - 1 column_issues[col] = ColumnIssues( missing_count=int(col_data.isna().sum()), missing_pct=round(float(col_data.isna().mean()), 4), type_errors=type_errs, outlier_count=outlier_count, has_duplicates=n_dups > 0, format_violation_count=format_violation_count, encoding_issue_count=encoding_issue_count, semantic_duplicate_count=semantic_duplicate_count, ).model_dump() column_stats[col] = ColumnStats( mean=round(float(numeric.mean()), 4) if numeric.notna().any() else None, std=round(float(numeric.std()), 4) if numeric.notna().sum() > 1 else None, null_count=int(col_data.isna().sum()), unique_count=int(col_data.nunique()), ).model_dump() sample_size = min(5, len(df)) sample_rows = [] if sample_size > 0: sample = df.sample(n=sample_size, random_state=ep["step"]) for _, row in sample.iterrows(): sanitized = {} for k, v in row.items(): if pd.isna(v): sanitized[k] = None elif hasattr(v, "item"): sanitized[k] = v.item() elif isinstance(v, str) and len(v) > 100: sanitized[k] = v[:100] else: sanitized[k] = v sample_rows.append(sanitized) action_history = ep.get("action_history", [])[-5:] profile_result = ep.pop("last_profile", None) return Observation( done=ep["done"], reward=reward, task=ep["task"], step=ep["step"], max_steps=ep["max_steps"], columns=list(df.columns), column_issues=column_issues, column_stats=column_stats, sample_rows=sample_rows, action_history=action_history, budget_remaining=max(0.0, 1.0 - ep.get("cost_used", 0.0) / ep.get("budget", 1.0)), profile_result=profile_result, )