GlitchGhost's picture
Fix Phase 2: add [START]/[STEP]/[END] structured output to inference.py
48e9b06
"""
DataClean Environment – core simulation logic.
===============================================
Implements reset(), step(), state for the data-cleaning agent.
"""
from __future__ import annotations
import copy
import uuid
from typing import Any, Dict, List, Optional, Tuple
from ..models import DataCleanAction, DataCleanObservation, DataCleanState
from .tasks import get_task, Row
class DataCleanEnvironment:
"""Simulates a data-quality review session."""
SUPPORTS_CONCURRENT_SESSIONS = True
# ── lifecycle ──────────────────────────────────────────────────────────
def __init__(self) -> None:
self._task: dict = {}
self._data: List[Row] = []
self._clean: List[Row] = []
self._issues: list = []
self._columns: List[str] = []
self._max_steps: int = 0
self._step_count: int = 0
self._done: bool = True
self._episode_id: str = ""
self._action_log: List[str] = []
self._deleted_rows: set = set()
self._fixed_issues: set = set()
self._wrong_fixes: int = 0
# ── reset ──────────────────────────────────────────────────────────────
def reset(self, task_name: str = "easy") -> dict:
"""Start a fresh episode for the given task."""
self._task = get_task(task_name)
self._data = copy.deepcopy(self._task["dirty_data"])
self._clean = self._task["clean_data"]
self._issues = self._task["issues"]
self._columns = self._task["columns"]
self._max_steps = self._task["max_steps"]
self._step_count = 0
self._done = False
self._episode_id = uuid.uuid4().hex[:12]
self._action_log = []
self._deleted_rows = set()
self._fixed_issues = set()
self._wrong_fixes = 0
obs = self._build_observation()
return {
"observation": obs.model_dump(),
"reward": 0.0,
"done": False,
}
# ── step ───────────────────────────────────────────────────────────────
def step(self, action: DataCleanAction) -> dict:
if self._done:
obs = self._build_observation()
return {
"observation": obs.model_dump(),
"reward": self._compute_score(),
"done": True,
}
self._step_count += 1
msg = self._apply_action(action)
self._action_log.append(f"Step {self._step_count}: {action.action_type} -> {msg}")
# episode ends on submit, max steps, or all issues fixed
if (
action.action_type == "submit"
or self._step_count >= self._max_steps
or len(self._fixed_issues) == len(self._issues)
):
self._done = True
score = self._compute_score()
obs = self._build_observation()
return {
"observation": obs.model_dump(),
"reward": round(score, 4),
"done": self._done,
}
# ── state ──────────────────────────────────────────────────────────────
@property
def state(self) -> dict:
return DataCleanState(
episode_id=self._episode_id,
task_name=self._task.get("name", ""),
difficulty=self._task.get("difficulty", ""),
step_count=self._step_count,
max_steps=self._max_steps,
total_issues=len(self._issues),
issues_fixed=len(self._fixed_issues),
current_score=round(self._compute_score(), 4),
done=self._done,
).model_dump()
# ── internal: apply actions ────────────────────────────────────────────
def _apply_action(self, action: DataCleanAction) -> str:
at = action.action_type
if at == "noop":
return "No action taken."
if at == "submit":
return "Submitted for grading."
if at in ("fix_value", "fill_missing"):
return self._do_fix(action)
if at == "delete_row":
return self._do_delete(action)
if at == "flag_anomaly":
return self._do_flag(action)
return f"Unknown action_type '{at}'. No effect."
def _do_fix(self, action: DataCleanAction) -> str:
ri = action.row_index
col = action.column_name
val = action.new_value
if ri is None or col is None or val is None:
return "fix_value requires row_index, column_name, and new_value."
if ri < 0 or ri >= len(self._data):
return f"row_index {ri} out of range (0-{len(self._data)-1})."
if ri in self._deleted_rows:
return f"Row {ri} was already deleted."
if col not in self._columns:
return f"Unknown column '{col}'. Valid: {self._columns}"
# apply the edit
old_val = str(self._data[ri].get(col, ""))
self._data[ri][col] = self._coerce(val, self._data[ri][col])
# check whether this fixes a known issue
matched = self._match_fix(ri, col, val)
if matched is not None:
self._fixed_issues.add(matched)
return f"Fixed row {ri} [{col}]: '{old_val}' -> '{val}' (issue resolved)"
else:
# check if the edit made things worse
if old_val == str(self._ground_truth_value(ri, col)):
self._wrong_fixes += 1
return f"Changed row {ri} [{col}]: '{old_val}' -> '{val}' (WARNING: was already correct!)"
return f"Changed row {ri} [{col}]: '{old_val}' -> '{val}'"
def _do_delete(self, action: DataCleanAction) -> str:
ri = action.row_index
if ri is None:
return "delete_row requires row_index."
if ri < 0 or ri >= len(self._data):
return f"row_index {ri} out of range."
if ri in self._deleted_rows:
return f"Row {ri} already deleted."
self._deleted_rows.add(ri)
matched = self._match_delete(ri)
if matched is not None:
self._fixed_issues.add(matched)
return f"Deleted row {ri} (duplicate removed)"
else:
self._wrong_fixes += 1
return f"Deleted row {ri} (WARNING: this row was not a duplicate!)"
def _do_flag(self, action: DataCleanAction) -> str:
ri = action.row_index
col = action.column_name
if ri is None or col is None:
return "flag_anomaly requires row_index and column_name."
# partial credit: flagging the right cell earns 0.5 of the fix
for idx, issue in enumerate(self._issues):
if issue["row"] == ri and issue.get("col") == col and idx not in self._fixed_issues:
self._fixed_issues.add(idx)
return f"Flagged row {ri} [{col}] as anomalous (partial credit)"
return f"Flagged row {ri} [{col}] β€” no matching issue found."
# ── grading helpers ────────────────────────────────────────────────────
def _match_fix(self, row: int, col: str, val: str) -> Optional[int]:
"""Return issue index if this fix resolves a known issue, else None."""
for idx, issue in enumerate(self._issues):
if idx in self._fixed_issues:
continue
if issue["row"] == row and issue.get("col") == col:
expected = str(issue["fix"])
if self._fuzzy_eq(val, expected):
return idx
return None
def _match_delete(self, row: int) -> Optional[int]:
for idx, issue in enumerate(self._issues):
if idx in self._fixed_issues:
continue
if issue["row"] == row and issue["fix"] == "__DELETE__":
return idx
return None
def _compute_score(self) -> float:
if not self._issues:
return 1.0
total = len(self._issues)
fixed = len(self._fixed_issues)
# base score from fixed issues
base = fixed / total
# penalty for wrong fixes (capped so score stays >= 0)
penalty = min(self._wrong_fixes * 0.05, base)
# small efficiency bonus if done early
if self._done and self._max_steps > 0:
remaining_ratio = max(0, (self._max_steps - self._step_count)) / self._max_steps
efficiency = remaining_ratio * 0.05
else:
efficiency = 0.0
score = base - penalty + efficiency
return max(0.0, min(1.0, score))
def _ground_truth_value(self, dirty_row_idx: int, col: str) -> Any:
"""Look up the expected clean value for a dirty-data row."""
# map dirty index to clean index (accounting for deleted rows in ground truth)
clean_idx = self._dirty_to_clean_idx(dirty_row_idx)
if clean_idx is not None and clean_idx < len(self._clean):
return self._clean[clean_idx].get(col)
return None
def _dirty_to_clean_idx(self, dirty_idx: int) -> Optional[int]:
"""Map a dirty-data row index to the clean-data row index."""
# find rows that should be deleted
delete_rows = {
issue["row"]
for issue in self._issues
if issue["fix"] == "__DELETE__"
}
# count non-deleted rows before dirty_idx
if dirty_idx in delete_rows:
return None
clean_i = 0
for i in range(dirty_idx):
if i not in delete_rows:
clean_i += 1
return clean_i
@staticmethod
def _fuzzy_eq(a: str, b: str) -> bool:
"""Lenient comparison for grading (strip, lower, remove leading zeros)."""
a = str(a).strip().lower()
b = str(b).strip().lower()
if a == b:
return True
# numeric comparison
try:
return abs(float(a) - float(b)) < 0.01
except (ValueError, TypeError):
pass
return False
@staticmethod
def _coerce(val_str: str, existing: Any) -> Any:
"""Try to coerce the string value to the same type as the existing cell."""
if isinstance(existing, int):
try:
return int(float(val_str))
except (ValueError, TypeError):
return val_str
if isinstance(existing, float):
try:
return float(val_str)
except (ValueError, TypeError):
return val_str
return val_str
# ── observation builder ────────────────────────────────────────────────
def _build_observation(self) -> DataCleanObservation:
return DataCleanObservation(
task_name=self._task.get("name", ""),
task_description=self._task.get("description", ""),
difficulty=self._task.get("difficulty", ""),
data_preview=self._render_table(),
quality_report=self._render_quality_report(),
columns_info=self._render_columns_info(),
action_history=list(self._action_log[-10:]),
step_number=self._step_count,
max_steps=self._max_steps,
current_score=round(self._compute_score(), 4),
)
def _render_table(self) -> str:
"""Render the current dataset as an aligned text table."""
if not self._data:
return "(empty dataset)"
cols = self._columns
# compute column widths
widths = {c: len(c) for c in cols}
widths["row"] = 3
active_rows: List[Tuple[int, Row]] = [
(i, row) for i, row in enumerate(self._data) if i not in self._deleted_rows
]
for i, row in active_rows:
widths["row"] = max(widths["row"], len(str(i)))
for c in cols:
val = str(row.get(c, ""))
if val == "":
val = "[EMPTY]"
widths[c] = max(widths[c], min(len(val), 30))
# header
hdr = "| " + " | ".join(
["row".ljust(widths["row"])] + [c.ljust(widths[c]) for c in cols]
) + " |"
sep = "|-" + "-|-".join(
["-" * widths["row"]] + ["-" * widths[c] for c in cols]
) + "-|"
lines = [hdr, sep]
for i, row in active_rows:
cells = [str(i).ljust(widths["row"])]
for c in cols:
val = str(row.get(c, ""))
if val == "":
val = "[EMPTY]"
cells.append(val[:30].ljust(widths[c]))
lines.append("| " + " | ".join(cells) + " |")
return "\n".join(lines)
def _render_quality_report(self) -> str:
"""Generate a quality-report hinting at (but not solving) issues."""
if not self._data:
return "No data loaded."
lines = ["DATA QUALITY REPORT", "=" * 40]
cols = self._columns
active_rows = [
(i, row) for i, row in enumerate(self._data) if i not in self._deleted_rows
]
num_rows = len(active_rows)
lines.append(f"Total rows: {num_rows} (original: {len(self._data)}, deleted: {len(self._deleted_rows)})")
# per-column stats
for c in cols:
vals = [str(row.get(c, "")) for _, row in active_rows]
empties = sum(1 for v in vals if v.strip() == "" or v.strip().upper() == "NULL")
unique = len(set(vals))
if empties:
lines.append(f" Column '{c}': {empties} empty/null value(s)")
# detect potential duplicates (simple exact-match check)
seen = {}
for i, row in active_rows:
key = tuple(str(row.get(c, "")) for c in cols)
if key in seen:
lines.append(f" Possible duplicate: row {i} matches row {seen[key]}")
else:
seen[key] = i
# detect numeric anomalies
for c in cols:
numeric_vals = []
for i, row in active_rows:
try:
numeric_vals.append((i, float(row[c])))
except (ValueError, TypeError, KeyError):
pass
if len(numeric_vals) >= 3:
values = [v for _, v in numeric_vals]
mean = sum(values) / len(values)
for i, v in numeric_vals:
if v < 0:
lines.append(f" Row {i}, '{c}': Negative value ({v})")
elif abs(v - mean) > 3 * (max(values) - min(values) + 1) / 4:
lines.append(f" Row {i}, '{c}': Potential outlier ({v})")
# detect format inconsistencies in string columns
for c in cols:
vals = [str(row.get(c, "")) for _, row in active_rows]
non_empty = [v for v in vals if v.strip() and v.strip() != "[EMPTY]"]
if not non_empty:
continue
# check for mixed case patterns (all-caps vs lowercase)
has_upper = any(v.isupper() for v in non_empty)
has_lower = any(v.islower() or (not v.isupper() and not v.istitle()) for v in non_empty)
if has_upper and has_lower and c in ("email",):
lines.append(f" Column '{c}': Mixed case formatting detected")
# check for format inconsistency in date-like columns
if c in ("date", "start_date", "birth_date"):
formats_seen = set()
for v in non_empty:
if "/" in v:
formats_seen.add("slash")
elif "." in v and v.count(".") == 2:
formats_seen.add("dot")
elif "-" in v:
formats_seen.add("dash")
if len(formats_seen) > 1:
lines.append(f" Column '{c}': Inconsistent date formats ({', '.join(formats_seen)})")
lines.append(f"\nProgress: {len(self._fixed_issues)}/{len(self._issues)} issues resolved")
lines.append(f"Steps used: {self._step_count}/{self._max_steps}")
return "\n".join(lines)
def _render_columns_info(self) -> List[Dict[str, Any]]:
active_rows = [
row for i, row in enumerate(self._data) if i not in self._deleted_rows
]
info = []
for c in self._columns:
vals = [row.get(c, "") for row in active_rows]
non_empty = [v for v in vals if str(v).strip() not in ("", "NULL")]
info.append({
"name": c,
"total": len(vals),
"non_empty": len(non_empty),
"empty": len(vals) - len(non_empty),
"unique": len(set(str(v) for v in vals)),
})
return info