Spaces:
Sleeping
Sleeping
| """ | |
| DataClean Environment β core simulation logic. | |
| =============================================== | |
| Implements reset(), step(), state for the data-cleaning agent. | |
| """ | |
| from __future__ import annotations | |
| import copy | |
| import uuid | |
| from typing import Any, Dict, List, Optional, Tuple | |
| from ..models import DataCleanAction, DataCleanObservation, DataCleanState | |
| from .tasks import get_task, Row | |
| class DataCleanEnvironment: | |
| """Simulates a data-quality review session.""" | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| # ββ lifecycle ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def __init__(self) -> None: | |
| self._task: dict = {} | |
| self._data: List[Row] = [] | |
| self._clean: List[Row] = [] | |
| self._issues: list = [] | |
| self._columns: List[str] = [] | |
| self._max_steps: int = 0 | |
| self._step_count: int = 0 | |
| self._done: bool = True | |
| self._episode_id: str = "" | |
| self._action_log: List[str] = [] | |
| self._deleted_rows: set = set() | |
| self._fixed_issues: set = set() | |
| self._wrong_fixes: int = 0 | |
| # ββ reset ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, task_name: str = "easy") -> dict: | |
| """Start a fresh episode for the given task.""" | |
| self._task = get_task(task_name) | |
| self._data = copy.deepcopy(self._task["dirty_data"]) | |
| self._clean = self._task["clean_data"] | |
| self._issues = self._task["issues"] | |
| self._columns = self._task["columns"] | |
| self._max_steps = self._task["max_steps"] | |
| self._step_count = 0 | |
| self._done = False | |
| self._episode_id = uuid.uuid4().hex[:12] | |
| self._action_log = [] | |
| self._deleted_rows = set() | |
| self._fixed_issues = set() | |
| self._wrong_fixes = 0 | |
| obs = self._build_observation() | |
| return { | |
| "observation": obs.model_dump(), | |
| "reward": 0.0, | |
| "done": False, | |
| } | |
| # ββ step βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def step(self, action: DataCleanAction) -> dict: | |
| if self._done: | |
| obs = self._build_observation() | |
| return { | |
| "observation": obs.model_dump(), | |
| "reward": self._compute_score(), | |
| "done": True, | |
| } | |
| self._step_count += 1 | |
| msg = self._apply_action(action) | |
| self._action_log.append(f"Step {self._step_count}: {action.action_type} -> {msg}") | |
| # episode ends on submit, max steps, or all issues fixed | |
| if ( | |
| action.action_type == "submit" | |
| or self._step_count >= self._max_steps | |
| or len(self._fixed_issues) == len(self._issues) | |
| ): | |
| self._done = True | |
| score = self._compute_score() | |
| obs = self._build_observation() | |
| return { | |
| "observation": obs.model_dump(), | |
| "reward": round(score, 4), | |
| "done": self._done, | |
| } | |
| # ββ state ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def state(self) -> dict: | |
| return DataCleanState( | |
| episode_id=self._episode_id, | |
| task_name=self._task.get("name", ""), | |
| difficulty=self._task.get("difficulty", ""), | |
| step_count=self._step_count, | |
| max_steps=self._max_steps, | |
| total_issues=len(self._issues), | |
| issues_fixed=len(self._fixed_issues), | |
| current_score=round(self._compute_score(), 4), | |
| done=self._done, | |
| ).model_dump() | |
| # ββ internal: apply actions ββββββββββββββββββββββββββββββββββββββββββββ | |
| def _apply_action(self, action: DataCleanAction) -> str: | |
| at = action.action_type | |
| if at == "noop": | |
| return "No action taken." | |
| if at == "submit": | |
| return "Submitted for grading." | |
| if at in ("fix_value", "fill_missing"): | |
| return self._do_fix(action) | |
| if at == "delete_row": | |
| return self._do_delete(action) | |
| if at == "flag_anomaly": | |
| return self._do_flag(action) | |
| return f"Unknown action_type '{at}'. No effect." | |
| def _do_fix(self, action: DataCleanAction) -> str: | |
| ri = action.row_index | |
| col = action.column_name | |
| val = action.new_value | |
| if ri is None or col is None or val is None: | |
| return "fix_value requires row_index, column_name, and new_value." | |
| if ri < 0 or ri >= len(self._data): | |
| return f"row_index {ri} out of range (0-{len(self._data)-1})." | |
| if ri in self._deleted_rows: | |
| return f"Row {ri} was already deleted." | |
| if col not in self._columns: | |
| return f"Unknown column '{col}'. Valid: {self._columns}" | |
| # apply the edit | |
| old_val = str(self._data[ri].get(col, "")) | |
| self._data[ri][col] = self._coerce(val, self._data[ri][col]) | |
| # check whether this fixes a known issue | |
| matched = self._match_fix(ri, col, val) | |
| if matched is not None: | |
| self._fixed_issues.add(matched) | |
| return f"Fixed row {ri} [{col}]: '{old_val}' -> '{val}' (issue resolved)" | |
| else: | |
| # check if the edit made things worse | |
| if old_val == str(self._ground_truth_value(ri, col)): | |
| self._wrong_fixes += 1 | |
| return f"Changed row {ri} [{col}]: '{old_val}' -> '{val}' (WARNING: was already correct!)" | |
| return f"Changed row {ri} [{col}]: '{old_val}' -> '{val}'" | |
| def _do_delete(self, action: DataCleanAction) -> str: | |
| ri = action.row_index | |
| if ri is None: | |
| return "delete_row requires row_index." | |
| if ri < 0 or ri >= len(self._data): | |
| return f"row_index {ri} out of range." | |
| if ri in self._deleted_rows: | |
| return f"Row {ri} already deleted." | |
| self._deleted_rows.add(ri) | |
| matched = self._match_delete(ri) | |
| if matched is not None: | |
| self._fixed_issues.add(matched) | |
| return f"Deleted row {ri} (duplicate removed)" | |
| else: | |
| self._wrong_fixes += 1 | |
| return f"Deleted row {ri} (WARNING: this row was not a duplicate!)" | |
| def _do_flag(self, action: DataCleanAction) -> str: | |
| ri = action.row_index | |
| col = action.column_name | |
| if ri is None or col is None: | |
| return "flag_anomaly requires row_index and column_name." | |
| # partial credit: flagging the right cell earns 0.5 of the fix | |
| for idx, issue in enumerate(self._issues): | |
| if issue["row"] == ri and issue.get("col") == col and idx not in self._fixed_issues: | |
| self._fixed_issues.add(idx) | |
| return f"Flagged row {ri} [{col}] as anomalous (partial credit)" | |
| return f"Flagged row {ri} [{col}] β no matching issue found." | |
| # ββ grading helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _match_fix(self, row: int, col: str, val: str) -> Optional[int]: | |
| """Return issue index if this fix resolves a known issue, else None.""" | |
| for idx, issue in enumerate(self._issues): | |
| if idx in self._fixed_issues: | |
| continue | |
| if issue["row"] == row and issue.get("col") == col: | |
| expected = str(issue["fix"]) | |
| if self._fuzzy_eq(val, expected): | |
| return idx | |
| return None | |
| def _match_delete(self, row: int) -> Optional[int]: | |
| for idx, issue in enumerate(self._issues): | |
| if idx in self._fixed_issues: | |
| continue | |
| if issue["row"] == row and issue["fix"] == "__DELETE__": | |
| return idx | |
| return None | |
| def _compute_score(self) -> float: | |
| if not self._issues: | |
| return 1.0 | |
| total = len(self._issues) | |
| fixed = len(self._fixed_issues) | |
| # base score from fixed issues | |
| base = fixed / total | |
| # penalty for wrong fixes (capped so score stays >= 0) | |
| penalty = min(self._wrong_fixes * 0.05, base) | |
| # small efficiency bonus if done early | |
| if self._done and self._max_steps > 0: | |
| remaining_ratio = max(0, (self._max_steps - self._step_count)) / self._max_steps | |
| efficiency = remaining_ratio * 0.05 | |
| else: | |
| efficiency = 0.0 | |
| score = base - penalty + efficiency | |
| return max(0.0, min(1.0, score)) | |
| def _ground_truth_value(self, dirty_row_idx: int, col: str) -> Any: | |
| """Look up the expected clean value for a dirty-data row.""" | |
| # map dirty index to clean index (accounting for deleted rows in ground truth) | |
| clean_idx = self._dirty_to_clean_idx(dirty_row_idx) | |
| if clean_idx is not None and clean_idx < len(self._clean): | |
| return self._clean[clean_idx].get(col) | |
| return None | |
| def _dirty_to_clean_idx(self, dirty_idx: int) -> Optional[int]: | |
| """Map a dirty-data row index to the clean-data row index.""" | |
| # find rows that should be deleted | |
| delete_rows = { | |
| issue["row"] | |
| for issue in self._issues | |
| if issue["fix"] == "__DELETE__" | |
| } | |
| # count non-deleted rows before dirty_idx | |
| if dirty_idx in delete_rows: | |
| return None | |
| clean_i = 0 | |
| for i in range(dirty_idx): | |
| if i not in delete_rows: | |
| clean_i += 1 | |
| return clean_i | |
| def _fuzzy_eq(a: str, b: str) -> bool: | |
| """Lenient comparison for grading (strip, lower, remove leading zeros).""" | |
| a = str(a).strip().lower() | |
| b = str(b).strip().lower() | |
| if a == b: | |
| return True | |
| # numeric comparison | |
| try: | |
| return abs(float(a) - float(b)) < 0.01 | |
| except (ValueError, TypeError): | |
| pass | |
| return False | |
| def _coerce(val_str: str, existing: Any) -> Any: | |
| """Try to coerce the string value to the same type as the existing cell.""" | |
| if isinstance(existing, int): | |
| try: | |
| return int(float(val_str)) | |
| except (ValueError, TypeError): | |
| return val_str | |
| if isinstance(existing, float): | |
| try: | |
| return float(val_str) | |
| except (ValueError, TypeError): | |
| return val_str | |
| return val_str | |
| # ββ observation builder ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_observation(self) -> DataCleanObservation: | |
| return DataCleanObservation( | |
| task_name=self._task.get("name", ""), | |
| task_description=self._task.get("description", ""), | |
| difficulty=self._task.get("difficulty", ""), | |
| data_preview=self._render_table(), | |
| quality_report=self._render_quality_report(), | |
| columns_info=self._render_columns_info(), | |
| action_history=list(self._action_log[-10:]), | |
| step_number=self._step_count, | |
| max_steps=self._max_steps, | |
| current_score=round(self._compute_score(), 4), | |
| ) | |
| def _render_table(self) -> str: | |
| """Render the current dataset as an aligned text table.""" | |
| if not self._data: | |
| return "(empty dataset)" | |
| cols = self._columns | |
| # compute column widths | |
| widths = {c: len(c) for c in cols} | |
| widths["row"] = 3 | |
| active_rows: List[Tuple[int, Row]] = [ | |
| (i, row) for i, row in enumerate(self._data) if i not in self._deleted_rows | |
| ] | |
| for i, row in active_rows: | |
| widths["row"] = max(widths["row"], len(str(i))) | |
| for c in cols: | |
| val = str(row.get(c, "")) | |
| if val == "": | |
| val = "[EMPTY]" | |
| widths[c] = max(widths[c], min(len(val), 30)) | |
| # header | |
| hdr = "| " + " | ".join( | |
| ["row".ljust(widths["row"])] + [c.ljust(widths[c]) for c in cols] | |
| ) + " |" | |
| sep = "|-" + "-|-".join( | |
| ["-" * widths["row"]] + ["-" * widths[c] for c in cols] | |
| ) + "-|" | |
| lines = [hdr, sep] | |
| for i, row in active_rows: | |
| cells = [str(i).ljust(widths["row"])] | |
| for c in cols: | |
| val = str(row.get(c, "")) | |
| if val == "": | |
| val = "[EMPTY]" | |
| cells.append(val[:30].ljust(widths[c])) | |
| lines.append("| " + " | ".join(cells) + " |") | |
| return "\n".join(lines) | |
| def _render_quality_report(self) -> str: | |
| """Generate a quality-report hinting at (but not solving) issues.""" | |
| if not self._data: | |
| return "No data loaded." | |
| lines = ["DATA QUALITY REPORT", "=" * 40] | |
| cols = self._columns | |
| active_rows = [ | |
| (i, row) for i, row in enumerate(self._data) if i not in self._deleted_rows | |
| ] | |
| num_rows = len(active_rows) | |
| lines.append(f"Total rows: {num_rows} (original: {len(self._data)}, deleted: {len(self._deleted_rows)})") | |
| # per-column stats | |
| for c in cols: | |
| vals = [str(row.get(c, "")) for _, row in active_rows] | |
| empties = sum(1 for v in vals if v.strip() == "" or v.strip().upper() == "NULL") | |
| unique = len(set(vals)) | |
| if empties: | |
| lines.append(f" Column '{c}': {empties} empty/null value(s)") | |
| # detect potential duplicates (simple exact-match check) | |
| seen = {} | |
| for i, row in active_rows: | |
| key = tuple(str(row.get(c, "")) for c in cols) | |
| if key in seen: | |
| lines.append(f" Possible duplicate: row {i} matches row {seen[key]}") | |
| else: | |
| seen[key] = i | |
| # detect numeric anomalies | |
| for c in cols: | |
| numeric_vals = [] | |
| for i, row in active_rows: | |
| try: | |
| numeric_vals.append((i, float(row[c]))) | |
| except (ValueError, TypeError, KeyError): | |
| pass | |
| if len(numeric_vals) >= 3: | |
| values = [v for _, v in numeric_vals] | |
| mean = sum(values) / len(values) | |
| for i, v in numeric_vals: | |
| if v < 0: | |
| lines.append(f" Row {i}, '{c}': Negative value ({v})") | |
| elif abs(v - mean) > 3 * (max(values) - min(values) + 1) / 4: | |
| lines.append(f" Row {i}, '{c}': Potential outlier ({v})") | |
| # detect format inconsistencies in string columns | |
| for c in cols: | |
| vals = [str(row.get(c, "")) for _, row in active_rows] | |
| non_empty = [v for v in vals if v.strip() and v.strip() != "[EMPTY]"] | |
| if not non_empty: | |
| continue | |
| # check for mixed case patterns (all-caps vs lowercase) | |
| has_upper = any(v.isupper() for v in non_empty) | |
| has_lower = any(v.islower() or (not v.isupper() and not v.istitle()) for v in non_empty) | |
| if has_upper and has_lower and c in ("email",): | |
| lines.append(f" Column '{c}': Mixed case formatting detected") | |
| # check for format inconsistency in date-like columns | |
| if c in ("date", "start_date", "birth_date"): | |
| formats_seen = set() | |
| for v in non_empty: | |
| if "/" in v: | |
| formats_seen.add("slash") | |
| elif "." in v and v.count(".") == 2: | |
| formats_seen.add("dot") | |
| elif "-" in v: | |
| formats_seen.add("dash") | |
| if len(formats_seen) > 1: | |
| lines.append(f" Column '{c}': Inconsistent date formats ({', '.join(formats_seen)})") | |
| lines.append(f"\nProgress: {len(self._fixed_issues)}/{len(self._issues)} issues resolved") | |
| lines.append(f"Steps used: {self._step_count}/{self._max_steps}") | |
| return "\n".join(lines) | |
| def _render_columns_info(self) -> List[Dict[str, Any]]: | |
| active_rows = [ | |
| row for i, row in enumerate(self._data) if i not in self._deleted_rows | |
| ] | |
| info = [] | |
| for c in self._columns: | |
| vals = [row.get(c, "") for row in active_rows] | |
| non_empty = [v for v in vals if str(v).strip() not in ("", "NULL")] | |
| info.append({ | |
| "name": c, | |
| "total": len(vals), | |
| "non_empty": len(non_empty), | |
| "empty": len(vals) - len(non_empty), | |
| "unique": len(set(str(v) for v in vals)), | |
| }) | |
| return info | |