Spaces:
Sleeping
Sleeping
File size: 5,208 Bytes
9c195fe 842577f 9c195fe 842577f 42757ca 9c195fe 842577f 9c195fe 42757ca 842577f 9c195fe 42757ca 9c195fe 42757ca 9c195fe 42757ca 9c195fe 842577f 9c195fe 1bac517 9c195fe 42757ca 9c195fe 842577f 9c195fe 1bac517 9c195fe 842577f 9c195fe 42757ca 842577f 9c195fe 1bac517 42757ca 9c195fe 42757ca 9c195fe 42757ca 9c195fe 1bac517 9c195fe 42757ca 9c195fe 42757ca 9c195fe 42757ca 9c195fe 42757ca 9c195fe 42757ca 842577f 9c195fe 42757ca 9c195fe 42757ca 9c195fe 42757ca 1bac517 9c195fe 842577f 9c195fe 1bac517 9c195fe 842577f 9c195fe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | import uuid
from typing import Any, Dict, List, Optional
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from env.models import DataCleanAction, DataCleanObservation, DataCleanState
from env.tasks import generate_task, get_task_names, grade_action
class DataValidationEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
super().__init__()
self._state = DataCleanState()
self._ground_truth: List[Dict[str, Any]] = []
self._errors: List[Dict[str, Any]] = []
self._task_info: Dict[str, Any] = {}
self._field_names: List[str] = []
def reset(self, task_name: Optional[str] = None, seed: int = 42,
episode_id: Optional[str] = None, **kwargs) -> DataCleanObservation:
if task_name is None:
task_name = "easy_missing_values"
task = generate_task(task_name, seed)
self._ground_truth = task["ground_truth"]
self._errors = task["errors"]
self._task_info = task
self._field_names = task["field_names"]
self._state = DataCleanState(
episode_id=episode_id or str(uuid.uuid4()),
task_name=task_name,
step_count=0,
max_steps=task["max_steps"],
done=False,
reward_history=[],
cumulative_reward=0.01,
dataset=task["dataset"],
ground_truth=self._ground_truth,
errors=self._errors,
errors_fixed=0,
total_errors=len(self._errors),
last_actions=[],
)
return DataCleanObservation(
task_name=task_name,
task_description=task["description"],
dataset=task["dataset"],
errors_found=self._errors,
errors_remaining=len(self._errors),
errors_total=len(self._errors),
errors_fixed=0,
step_count=0,
max_steps=task["max_steps"],
reward=0.01,
cumulative_reward=0.01,
done=False,
last_action_result="Environment reset. Examine errors and fix them.",
task_hint=task["hint"],
progress_pct=0.0,
field_names=self._field_names,
)
def step(self, action: DataCleanAction, **kwargs) -> DataCleanObservation:
if self._state.done:
return self._make_observation(0.01, "Episode already done. Call reset().")
self._state.step_count += 1
action_key = f"{action.action_type}:{action.target_field}:{action.target_row}:{action.new_value}"
is_repeat = action_key in self._state.last_actions
self._state.last_actions.append(action_key)
if is_repeat:
reward = 0.01
message = "Penalty: repeated identical action"
else:
reward, message, fixed = grade_action(
action.action_type,
action.target_field,
action.target_row,
action.new_value,
self._state.dataset,
self._ground_truth,
self._errors,
)
if fixed:
self._state.errors_fixed += 1
self._state.cumulative_reward += reward
self._state.reward_history.append(reward)
errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))
if errors_remaining == 0:
self._state.done = True
message += " | All errors fixed! Episode complete."
elif self._state.step_count >= self._state.max_steps:
self._state.done = True
message += f" | Max steps reached. {errors_remaining} errors remaining."
return self._make_observation(reward, message)
@property
def state(self) -> DataCleanState:
return self._state
def _make_observation(self, reward: float, message: str) -> DataCleanObservation:
errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))
total = self._state.total_errors if self._state.total_errors > 0 else 1
progress = (self._state.errors_fixed / total) * 100
unfixed_errors = [e for e in self._errors if not e.get("fixed", False)]
clamped_reward = max(0.01, min(0.99, reward))
clamped_cumulative = max(0.01, min(0.99, self._state.cumulative_reward))
return DataCleanObservation(
task_name=self._state.task_name,
task_description=self._task_info.get("description", ""),
dataset=self._state.dataset,
errors_found=unfixed_errors,
errors_remaining=errors_remaining,
errors_total=self._state.total_errors,
errors_fixed=self._state.errors_fixed,
step_count=self._state.step_count,
max_steps=self._state.max_steps,
reward=clamped_reward,
cumulative_reward=clamped_cumulative,
done=self._state.done,
last_action_result=message,
task_hint=self._task_info.get("hint", ""),
progress_pct=progress,
field_names=self._field_names,
)
|