data-validation-env / env /environment.py
kush5699's picture
Upload folder using huggingface_hub
1bac517 verified
import uuid
from typing import Any, Dict, List, Optional
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from env.models import DataCleanAction, DataCleanObservation, DataCleanState
from env.tasks import generate_task, get_task_names, grade_action
class DataValidationEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
super().__init__()
self._state = DataCleanState()
self._ground_truth: List[Dict[str, Any]] = []
self._errors: List[Dict[str, Any]] = []
self._task_info: Dict[str, Any] = {}
self._field_names: List[str] = []
def reset(self, task_name: Optional[str] = None, seed: int = 42,
episode_id: Optional[str] = None, **kwargs) -> DataCleanObservation:
if task_name is None:
task_name = "easy_missing_values"
task = generate_task(task_name, seed)
self._ground_truth = task["ground_truth"]
self._errors = task["errors"]
self._task_info = task
self._field_names = task["field_names"]
self._state = DataCleanState(
episode_id=episode_id or str(uuid.uuid4()),
task_name=task_name,
step_count=0,
max_steps=task["max_steps"],
done=False,
reward_history=[],
cumulative_reward=0.01,
dataset=task["dataset"],
ground_truth=self._ground_truth,
errors=self._errors,
errors_fixed=0,
total_errors=len(self._errors),
last_actions=[],
)
return DataCleanObservation(
task_name=task_name,
task_description=task["description"],
dataset=task["dataset"],
errors_found=self._errors,
errors_remaining=len(self._errors),
errors_total=len(self._errors),
errors_fixed=0,
step_count=0,
max_steps=task["max_steps"],
reward=0.01,
cumulative_reward=0.01,
done=False,
last_action_result="Environment reset. Examine errors and fix them.",
task_hint=task["hint"],
progress_pct=0.0,
field_names=self._field_names,
)
def step(self, action: DataCleanAction, **kwargs) -> DataCleanObservation:
if self._state.done:
return self._make_observation(0.01, "Episode already done. Call reset().")
self._state.step_count += 1
action_key = f"{action.action_type}:{action.target_field}:{action.target_row}:{action.new_value}"
is_repeat = action_key in self._state.last_actions
self._state.last_actions.append(action_key)
if is_repeat:
reward = 0.01
message = "Penalty: repeated identical action"
else:
reward, message, fixed = grade_action(
action.action_type,
action.target_field,
action.target_row,
action.new_value,
self._state.dataset,
self._ground_truth,
self._errors,
)
if fixed:
self._state.errors_fixed += 1
self._state.cumulative_reward += reward
self._state.reward_history.append(reward)
errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))
if errors_remaining == 0:
self._state.done = True
message += " | All errors fixed! Episode complete."
elif self._state.step_count >= self._state.max_steps:
self._state.done = True
message += f" | Max steps reached. {errors_remaining} errors remaining."
return self._make_observation(reward, message)
@property
def state(self) -> DataCleanState:
return self._state
def _make_observation(self, reward: float, message: str) -> DataCleanObservation:
errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))
total = self._state.total_errors if self._state.total_errors > 0 else 1
progress = (self._state.errors_fixed / total) * 100
unfixed_errors = [e for e in self._errors if not e.get("fixed", False)]
clamped_reward = max(0.01, min(0.99, reward))
clamped_cumulative = max(0.01, min(0.99, self._state.cumulative_reward))
return DataCleanObservation(
task_name=self._state.task_name,
task_description=self._task_info.get("description", ""),
dataset=self._state.dataset,
errors_found=unfixed_errors,
errors_remaining=errors_remaining,
errors_total=self._state.total_errors,
errors_fixed=self._state.errors_fixed,
step_count=self._state.step_count,
max_steps=self._state.max_steps,
reward=clamped_reward,
cumulative_reward=clamped_cumulative,
done=self._state.done,
last_action_result=message,
task_hint=self._task_info.get("hint", ""),
progress_pct=progress,
field_names=self._field_names,
)