yashmarathe's picture
refactor: full openenv protocol compliance
1a55ff4
"""
Core environment logic for the Data Cleaning RL Environment.
Implements the openenv Environment interface so that create_app() can
expose /ws, /reset, /step, /state endpoints automatically.
"""
from __future__ import annotations
import difflib
import uuid
from typing import Any, Optional
import numpy as np
import pandas as pd
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from data_cleaning_env.action_registry import ACTION_COSTS
from data_cleaning_env.datasets import load_clean_dataset
from data_cleaning_env.grader import compute_quality_score
from data_cleaning_env.noise_injector import inject_noise
from data_cleaning_env.models import (
ActionType,
CleaningAction,
ColumnIssues,
ColumnStats,
Observation,
)
MAX_STEPS: dict[str, int] = {
"easy": 20,
"medium": 40,
"hard": 60,
"expert": 80,
}
REWARD_CLIP = 0.1
EPISODE_BUDGET: dict[str, float] = {
"easy": 1.0,
"medium": 1.0,
"hard": 1.0,
"expert": 1.0,
}
class DataCleaningEnvironment(Environment):
"""
OpenEnv-compatible Data Cleaning Environment.
Each instance manages a single episode. The openenv server framework
creates a new instance per WebSocket session via the factory callable.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self) -> None:
super().__init__()
self._state = State(episode_id=str(uuid.uuid4()), step_count=0)
# Episode data — populated by reset()
self._ep: dict[str, Any] = {}
# Multi-episode store for REST grader endpoint
self.episodes: dict[str, dict[str, Any]] = {}
# ------------------------------------------------------------------
# openenv Environment interface
# ------------------------------------------------------------------
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
task: str = "easy",
difficulty: Optional[float] = None,
use_synthetic: bool = False,
**kwargs,
) -> Observation:
"""Start a new episode and return the initial observation."""
if task not in MAX_STEPS:
raise ValueError(f"Unknown task '{task}'. Must be one of: {list(MAX_STEPS)}")
eid = episode_id or str(uuid.uuid4())
self._state = State(episode_id=eid, step_count=0)
if use_synthetic:
from data_cleaning_env.datasets import generate_synthetic_dataset
import time as _time
syn_seed = int(_time.time()) % 100000
clean_df, target_col = generate_synthetic_dataset(seed=syn_seed)
else:
clean_df, target_col = load_clean_dataset(task)
dirty_df = inject_noise(clean_df, task, severity=difficulty)
initial_quality = compute_quality_score(dirty_df, clean_df)["composite"]
import time as _time
self._ep = {
"task": task,
"clean_df": clean_df,
"current_df": dirty_df.copy(),
"dirty_df": dirty_df.copy(),
"initial_quality": initial_quality,
"target_col": target_col,
"created_at": _time.monotonic(),
"step": 0,
"max_steps": MAX_STEPS[task],
"done": False,
"prev_quality": initial_quality,
"action_history": [],
"budget": EPISODE_BUDGET.get(task, 1.0),
"cost_used": 0.0,
"checkpoint_df": None,
}
# Also store in episodes dict for REST grader/baseline
self.episodes[eid] = self._ep
return self._make_observation(reward=0.0)
def step(self, action: CleaningAction, timeout_s=None, **kwargs) -> Observation:
"""Apply a cleaning action and return the new observation."""
ep = self._ep
if not ep:
raise ValueError("No active episode. Call reset() first.")
if ep["done"]:
raise ValueError("Episode already done. Call reset() first.")
reward = 0.0
df = ep["current_df"]
_NON_MUTATING = {ActionType.done, ActionType.profile_column}
if action.action_type == ActionType.done:
ep["done"] = True
reward = 0.0
cost = ACTION_COSTS.get("done", 0.0)
ep["cost_used"] += cost
elif action.action_type in _NON_MUTATING:
try:
df = self._apply_action(df, action, ep)
ep["current_df"] = df
except Exception:
pass
reward = 0.0
cost = ACTION_COSTS.get(action.action_type.value, 0.0)
ep["cost_used"] += cost
ep["action_history"].append({
"action_type": action.action_type.value,
"column": action.column,
"reward": reward,
"step": ep["step"],
})
else:
try:
if action.action_type != ActionType.undo:
ep["checkpoint_df"] = ep["current_df"].copy()
df = self._apply_action(df, action, ep)
ep["current_df"] = df
new_quality = compute_quality_score(df, ep["clean_df"])["composite"]
delta = new_quality - ep["prev_quality"]
reward = float(np.clip(delta, -REWARD_CLIP, REWARD_CLIP))
ep["prev_quality"] = new_quality
cost = ACTION_COSTS.get(action.action_type.value, 0.0)
ep["cost_used"] += cost
cost_efficiency = 1.0 - min(ep["cost_used"] / ep["budget"], 1.0)
reward *= max(cost_efficiency, 0.1)
ep["action_history"].append({
"action_type": action.action_type.value,
"column": action.column,
"reward": reward,
"step": ep["step"],
})
except Exception:
reward = -0.05
cost = ACTION_COSTS.get(action.action_type.value, 0.0)
ep["cost_used"] += cost
ep["action_history"].append({
"action_type": action.action_type.value,
"column": action.column,
"reward": reward,
"step": ep["step"],
})
ep["step"] += 1
self._state.step_count = ep["step"]
if ep["step"] >= ep["max_steps"]:
ep["done"] = True
reward = 0.0
return self._make_observation(reward=reward)
@property
def state(self) -> State:
return self._state
def close(self) -> None:
pass
# ------------------------------------------------------------------
# Action application
# ------------------------------------------------------------------
def _apply_action(self, df: pd.DataFrame, action: CleaningAction, ep: dict) -> pd.DataFrame:
df = df.copy()
col = action.column
if action.action_type == ActionType.fill_missing:
df = self._fill_missing(df, col, action)
elif action.action_type == ActionType.drop_duplicates:
df = df.drop_duplicates().reset_index(drop=True)
elif action.action_type == ActionType.fix_type:
df = self._fix_type(df, col, action)
elif action.action_type == ActionType.normalize:
df = self._normalize(df, col)
elif action.action_type == ActionType.drop_outliers:
df = self._drop_outliers(df, col, action)
elif action.action_type == ActionType.fix_schema_violation:
df = self._fix_schema_violation(df, col, action, ep)
elif action.action_type == ActionType.rename_column:
df = self._rename_column(df, col, action)
elif action.action_type == ActionType.cast_datetime:
df = self._cast_datetime(df, col, action)
elif action.action_type == ActionType.deduplicate_fuzzy:
df = self._deduplicate_fuzzy(df, col, action)
elif action.action_type == ActionType.split_column:
df = self._split_column(df, col, action)
elif action.action_type == ActionType.merge_columns:
df = self._merge_columns(df, col, action)
elif action.action_type == ActionType.fix_format_regex:
df = self._fix_format_regex(df, col, action)
elif action.action_type == ActionType.standardize_categories:
df = self._standardize_categories(df, col, action)
elif action.action_type == ActionType.undo:
if ep.get("checkpoint_df") is None:
raise ValueError("Nothing to undo.")
df = ep["checkpoint_df"].copy()
ep["checkpoint_df"] = None
elif action.action_type == ActionType.profile_column:
df = self._profile_column(df, col, ep)
else:
raise ValueError(f"Unhandled action type: {action.action_type}")
return df
# ------------------------------------------------------------------
# Action implementations
# ------------------------------------------------------------------
def _fill_missing(self, df, col, action):
if col not in df.columns:
raise ValueError(f"Column '{col}' not found.")
strategy = action.strategy.value if action.strategy else "median"
numeric = pd.to_numeric(df[col], errors="coerce")
non_null = df[col].dropna()
is_numeric = (numeric.notna().sum() / max(len(non_null), 1) > 0.9) if len(non_null) > 0 else False
if strategy == "mode":
mode_vals = df[col].mode()
fill_value = mode_vals.iloc[0] if not mode_vals.empty else None
if fill_value is not None:
df[col] = df[col].fillna(fill_value)
elif strategy == "constant":
df[col] = df[col].fillna(action.constant_value)
elif is_numeric:
fill_value = numeric.mean() if strategy == "mean" else numeric.median()
df[col] = numeric.fillna(fill_value)
else:
mode_vals = df[col].mode()
fill_value = mode_vals.iloc[0] if not mode_vals.empty else ""
df[col] = df[col].fillna(fill_value)
return df
def _fix_type(self, df, col, action):
if col not in df.columns:
raise ValueError(f"Column '{col}' not found.")
dtype = action.dtype.value if action.dtype else "float"
if dtype in ("int", "float"):
coerced = pd.to_numeric(df[col], errors="coerce")
df[col] = coerced.astype("Int64") if dtype == "int" else coerced.astype("float64")
else:
df[col] = df[col].astype(str)
return df
def _normalize(self, df, col):
if col not in df.columns:
raise ValueError(f"Column '{col}' not found.")
numeric = pd.to_numeric(df[col], errors="coerce")
mean, std = numeric.mean(), numeric.std()
if pd.isna(mean) or std == 0 or pd.isna(std):
return df
df[col] = (numeric - mean) / std
return df
def _drop_outliers(self, df, col, action):
if col not in df.columns:
raise ValueError(f"Column '{col}' not found.")
numeric = pd.to_numeric(df[col], errors="coerce")
method = action.method.value if action.method else "iqr"
if method == "iqr":
q1, q3 = numeric.quantile(0.25), numeric.quantile(0.75)
iqr = q3 - q1
if iqr == 0:
return df
mask = numeric.between(q1 - 1.5 * iqr, q3 + 1.5 * iqr) | numeric.isna()
else:
mean, std = numeric.mean(), numeric.std()
if std == 0 or pd.isna(std):
return df
z = (numeric - mean) / std
mask = z.abs() < 3
return df[mask].reset_index(drop=True)
def _fix_schema_violation(self, df, col, action, ep):
if col not in df.columns:
raise ValueError(f"Column '{col}' not found.")
numeric = pd.to_numeric(df[col], errors="coerce")
constraint = action.constraint.value if action.constraint else "non_negative"
if constraint == "non_negative":
df[col] = numeric.clip(lower=0)
elif constraint == "clamp_range":
clean_col = pd.to_numeric(ep["clean_df"][col], errors="coerce")
lo, hi = clean_col.quantile(0.05), clean_col.quantile(0.95)
df[col] = numeric.clip(lo, hi)
return df
def _rename_column(self, df, col, action):
if col not in df.columns:
raise ValueError(f"Column '{col}' not found.")
if not action.new_name:
raise ValueError("new_name is required.")
if action.new_name in df.columns:
raise ValueError(f"Column '{action.new_name}' already exists.")
return df.rename(columns={col: action.new_name})
def _cast_datetime(self, df, col, action):
if col not in df.columns:
raise ValueError(f"Column '{col}' not found.")
df[col] = pd.to_datetime(df[col], format=action.datetime_format, errors="coerce")
return df
def _deduplicate_fuzzy(self, df, col, action):
if col not in df.columns:
raise ValueError(f"Column '{col}' not found.")
threshold = action.threshold if action.threshold is not None else 0.8
values = df[col].dropna().astype(str)
unique_vals = values.unique()
if len(unique_vals) > 500:
top_vals = values.value_counts().head(500).index.tolist()
unique_vals = top_vals
mapping: dict[str, str] = {}
processed: set[str] = set()
for val in unique_vals:
if val in processed:
continue
group = [val]
for other in unique_vals:
if other == val or other in processed:
continue
ratio = difflib.SequenceMatcher(None, val.lower(), other.lower()).ratio()
if ratio >= threshold:
group.append(other)
group_counts = {v: int((values == v).sum()) for v in group}
canonical = max(group_counts, key=group_counts.get)
for v in group:
mapping[v] = canonical
processed.add(v)
mask = df[col].notna()
df.loc[mask, col] = df.loc[mask, col].astype(str).map(lambda x: mapping.get(x, x))
return df
def _split_column(self, df, col, action):
if col not in df.columns:
raise ValueError(f"Column '{col}' not found.")
delimiter = action.delimiter or ","
split_result = df[col].astype(str).str.split(delimiter, n=1, expand=True)
df[f"{col}_0"] = split_result[0] if 0 in split_result.columns else None
df[f"{col}_1"] = split_result[1] if 1 in split_result.columns else None
return df.drop(columns=[col])
def _merge_columns(self, df, col, action):
if col not in df.columns:
raise ValueError(f"Column '{col}' not found.")
col2 = action.column2
if not col2 or col2 not in df.columns:
raise ValueError(f"column2 '{col2}' not found.")
strategy = action.merge_strategy.value if action.merge_strategy else "concat"
if strategy == "concat":
df[col] = df[col].astype(str) + " " + df[col2].astype(str)
elif strategy == "first_non_null":
df[col] = df[col].fillna(df[col2])
elif strategy == "sum":
num1 = pd.to_numeric(df[col], errors="coerce")
num2 = pd.to_numeric(df[col2], errors="coerce")
df[col] = num1.fillna(0) + num2.fillna(0)
return df.drop(columns=[col2])
def _fix_format_regex(self, df, col, action):
if col not in df.columns:
raise ValueError(f"Column '{col}' not found.")
if not action.pattern:
raise ValueError("pattern is required.")
replacement = action.replacement or ""
try:
df[col] = df[col].astype(str).str.replace(action.pattern, replacement, regex=True)
except Exception as e:
raise ValueError(f"Invalid regex: {e}")
return df
def _standardize_categories(self, df, col, action):
if col not in df.columns:
raise ValueError(f"Column '{col}' not found.")
mask = df[col].notna()
df.loc[mask, col] = (
df.loc[mask, col].astype(str).str.lower().str.strip().str.replace(r"\s+", " ", regex=True)
)
return df
def _profile_column(self, df, col, ep):
if col not in df.columns:
raise ValueError(f"Column '{col}' not found.")
col_data = df[col]
profile: dict = {"column": col}
numeric = pd.to_numeric(col_data, errors="coerce")
if numeric.notna().sum() > 0:
profile["min"] = float(numeric.min()) if numeric.notna().any() else None
profile["max"] = float(numeric.max()) if numeric.notna().any() else None
profile["median"] = float(numeric.median()) if numeric.notna().any() else None
profile["q25"] = float(numeric.quantile(0.25)) if numeric.notna().any() else None
profile["q75"] = float(numeric.quantile(0.75)) if numeric.notna().any() else None
value_counts = col_data.dropna().astype(str).value_counts().head(10)
profile["top_values"] = {str(k): int(v) for k, v in value_counts.items()}
null_mask = col_data.isna()
profile["null_positions"] = [int(i) for i in null_mask[null_mask].index[:5]]
if col_data.dtype == object or str(col_data.dtype) == "string":
sample = col_data.dropna().astype(str).head(20)
patterns = set()
for v in sample:
if v.replace("-", "").replace("/", "").isdigit():
patterns.add("date-like")
elif v.replace(".", "").replace("-", "").isdigit():
patterns.add("numeric-string")
elif v.isupper():
patterns.add("UPPERCASE")
elif v.islower():
patterns.add("lowercase")
else:
patterns.add("mixed-case")
profile["value_patterns"] = list(patterns)
ep["last_profile"] = profile
return df
# ------------------------------------------------------------------
# Observation assembly
# ------------------------------------------------------------------
def _make_observation(self, reward: float) -> Observation:
ep = self._ep
df = ep["current_df"]
clean_df = ep["clean_df"]
column_issues: dict[str, Any] = {}
column_stats: dict[str, Any] = {}
n_dups = int(df.duplicated().sum())
all_cols = list(dict.fromkeys(list(df.columns) + list(clean_df.columns)))
for col in all_cols:
if col not in df.columns:
column_issues[col] = ColumnIssues(
missing_count=len(clean_df), missing_pct=1.0,
type_errors=0, outlier_count=0, has_duplicates=n_dups > 0,
).model_dump()
column_stats[col] = ColumnStats(null_count=len(clean_df), unique_count=0).model_dump()
continue
col_data = df[col]
numeric = pd.to_numeric(col_data, errors="coerce")
has_clean_ref = col in clean_df.columns
clean_numeric = pd.to_numeric(clean_df[col], errors="coerce") if has_clean_ref else pd.Series(dtype=float)
type_errs = 0
if has_clean_ref and clean_numeric.notna().mean() > 0.9:
type_errs = max(0, int(numeric.isna().sum() - col_data.isna().sum()))
outlier_count = 0
if numeric.notna().sum() > 4:
q1, q3 = numeric.quantile(0.25), numeric.quantile(0.75)
iqr = q3 - q1
if iqr > 0:
outlier_count = int(((numeric < q1 - 1.5 * iqr) | (numeric > q3 + 1.5 * iqr)).sum())
format_violation_count = 0
clean_col = clean_df[col] if col in clean_df.columns else pd.Series(dtype=object)
cn = pd.to_numeric(clean_col, errors="coerce")
if len(clean_col) > 0 and cn.notna().mean() > 0.9:
dirty_numeric = pd.to_numeric(col_data, errors="coerce")
non_null_mask = col_data.notna()
format_violation_count = max(0, int((non_null_mask & dirty_numeric.isna()).sum() - col_data.isna().sum()))
if (cn.dropna() >= 0).all():
format_violation_count += int((dirty_numeric < 0).sum())
else:
if col_data.dtype == object or str(col_data.dtype) == "string":
str_vals = col_data.dropna().astype(str)
if len(str_vals) > 0:
format_violation_count = int((str_vals != str_vals.str.strip()).sum())
encoding_issue_count = 0
if col_data.dtype == object or str(col_data.dtype) == "string":
sample_vals = col_data.dropna().astype(str)
if len(sample_vals) > 200:
sample_vals = sample_vals.sample(n=200, random_state=42)
clean_chars = set()
if col in clean_df.columns:
for v in clean_df[col].dropna().astype(str):
clean_chars.update(c for c in v if ord(c) > 127)
for v in sample_vals:
if any(ord(c) > 127 and c not in clean_chars for c in v):
encoding_issue_count += 1
semantic_duplicate_count = 0
if col_data.dtype == object or str(col_data.dtype) == "string":
unique_vals = col_data.dropna().astype(str).unique()
if len(unique_vals) < 500:
normalized = {}
for v in unique_vals:
key = v.lower().strip()
normalized.setdefault(key, []).append(v)
for forms in normalized.values():
if len(forms) > 1:
semantic_duplicate_count += len(forms) - 1
column_issues[col] = ColumnIssues(
missing_count=int(col_data.isna().sum()),
missing_pct=round(float(col_data.isna().mean()), 4),
type_errors=type_errs,
outlier_count=outlier_count,
has_duplicates=n_dups > 0,
format_violation_count=format_violation_count,
encoding_issue_count=encoding_issue_count,
semantic_duplicate_count=semantic_duplicate_count,
).model_dump()
column_stats[col] = ColumnStats(
mean=round(float(numeric.mean()), 4) if numeric.notna().any() else None,
std=round(float(numeric.std()), 4) if numeric.notna().sum() > 1 else None,
null_count=int(col_data.isna().sum()),
unique_count=int(col_data.nunique()),
).model_dump()
sample_size = min(5, len(df))
sample_rows = []
if sample_size > 0:
sample = df.sample(n=sample_size, random_state=ep["step"])
for _, row in sample.iterrows():
sanitized = {}
for k, v in row.items():
if pd.isna(v):
sanitized[k] = None
elif hasattr(v, "item"):
sanitized[k] = v.item()
elif isinstance(v, str) and len(v) > 100:
sanitized[k] = v[:100]
else:
sanitized[k] = v
sample_rows.append(sanitized)
action_history = ep.get("action_history", [])[-5:]
profile_result = ep.pop("last_profile", None)
return Observation(
done=ep["done"],
reward=reward,
task=ep["task"],
step=ep["step"],
max_steps=ep["max_steps"],
columns=list(df.columns),
column_issues=column_issues,
column_stats=column_stats,
sample_rows=sample_rows,
action_history=action_history,
budget_remaining=max(0.0, 1.0 - ep.get("cost_used", 0.0) / ep.get("budget", 1.0)),
profile_result=profile_result,
)