Spaces:
Sleeping
Sleeping
| import uuid | |
| from typing import Any, Dict, List, Optional | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| from env.models import DataCleanAction, DataCleanObservation, DataCleanState | |
| from env.tasks import generate_task, get_task_names, grade_action | |
| class DataValidationEnvironment(Environment): | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| super().__init__() | |
| self._state = DataCleanState() | |
| self._ground_truth: List[Dict[str, Any]] = [] | |
| self._errors: List[Dict[str, Any]] = [] | |
| self._task_info: Dict[str, Any] = {} | |
| self._field_names: List[str] = [] | |
| def reset(self, task_name: Optional[str] = None, seed: int = 42, | |
| episode_id: Optional[str] = None, **kwargs) -> DataCleanObservation: | |
| if task_name is None: | |
| task_name = "easy_missing_values" | |
| task = generate_task(task_name, seed) | |
| self._ground_truth = task["ground_truth"] | |
| self._errors = task["errors"] | |
| self._task_info = task | |
| self._field_names = task["field_names"] | |
| self._state = DataCleanState( | |
| episode_id=episode_id or str(uuid.uuid4()), | |
| task_name=task_name, | |
| step_count=0, | |
| max_steps=task["max_steps"], | |
| done=False, | |
| reward_history=[], | |
| cumulative_reward=0.01, | |
| dataset=task["dataset"], | |
| ground_truth=self._ground_truth, | |
| errors=self._errors, | |
| errors_fixed=0, | |
| total_errors=len(self._errors), | |
| last_actions=[], | |
| ) | |
| return DataCleanObservation( | |
| task_name=task_name, | |
| task_description=task["description"], | |
| dataset=task["dataset"], | |
| errors_found=self._errors, | |
| errors_remaining=len(self._errors), | |
| errors_total=len(self._errors), | |
| errors_fixed=0, | |
| step_count=0, | |
| max_steps=task["max_steps"], | |
| reward=0.01, | |
| cumulative_reward=0.01, | |
| done=False, | |
| last_action_result="Environment reset. Examine errors and fix them.", | |
| task_hint=task["hint"], | |
| progress_pct=0.0, | |
| field_names=self._field_names, | |
| ) | |
| def step(self, action: DataCleanAction, **kwargs) -> DataCleanObservation: | |
| if self._state.done: | |
| return self._make_observation(0.01, "Episode already done. Call reset().") | |
| self._state.step_count += 1 | |
| action_key = f"{action.action_type}:{action.target_field}:{action.target_row}:{action.new_value}" | |
| is_repeat = action_key in self._state.last_actions | |
| self._state.last_actions.append(action_key) | |
| if is_repeat: | |
| reward = 0.01 | |
| message = "Penalty: repeated identical action" | |
| else: | |
| reward, message, fixed = grade_action( | |
| action.action_type, | |
| action.target_field, | |
| action.target_row, | |
| action.new_value, | |
| self._state.dataset, | |
| self._ground_truth, | |
| self._errors, | |
| ) | |
| if fixed: | |
| self._state.errors_fixed += 1 | |
| self._state.cumulative_reward += reward | |
| self._state.reward_history.append(reward) | |
| errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False)) | |
| if errors_remaining == 0: | |
| self._state.done = True | |
| message += " | All errors fixed! Episode complete." | |
| elif self._state.step_count >= self._state.max_steps: | |
| self._state.done = True | |
| message += f" | Max steps reached. {errors_remaining} errors remaining." | |
| return self._make_observation(reward, message) | |
| def state(self) -> DataCleanState: | |
| return self._state | |
| def _make_observation(self, reward: float, message: str) -> DataCleanObservation: | |
| errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False)) | |
| total = self._state.total_errors if self._state.total_errors > 0 else 1 | |
| progress = (self._state.errors_fixed / total) * 100 | |
| unfixed_errors = [e for e in self._errors if not e.get("fixed", False)] | |
| clamped_reward = max(0.01, min(0.99, reward)) | |
| clamped_cumulative = max(0.01, min(0.99, self._state.cumulative_reward)) | |
| return DataCleanObservation( | |
| task_name=self._state.task_name, | |
| task_description=self._task_info.get("description", ""), | |
| dataset=self._state.dataset, | |
| errors_found=unfixed_errors, | |
| errors_remaining=errors_remaining, | |
| errors_total=self._state.total_errors, | |
| errors_fixed=self._state.errors_fixed, | |
| step_count=self._state.step_count, | |
| max_steps=self._state.max_steps, | |
| reward=clamped_reward, | |
| cumulative_reward=clamped_cumulative, | |
| done=self._state.done, | |
| last_action_result=message, | |
| task_hint=self._task_info.get("hint", ""), | |
| progress_pct=progress, | |
| field_names=self._field_names, | |
| ) | |