File size: 5,208 Bytes
9c195fe
 
 
842577f
 
 
9c195fe
 
 
 
842577f
 
 
42757ca
9c195fe
842577f
9c195fe
 
 
 
 
42757ca
842577f
 
9c195fe
 
42757ca
9c195fe
42757ca
9c195fe
 
 
 
42757ca
9c195fe
842577f
9c195fe
 
 
 
 
1bac517
9c195fe
 
 
 
 
 
 
42757ca
9c195fe
 
 
 
 
842577f
 
 
9c195fe
 
1bac517
 
9c195fe
 
 
842577f
9c195fe
 
42757ca
842577f
9c195fe
1bac517
42757ca
9c195fe
42757ca
9c195fe
 
 
42757ca
9c195fe
1bac517
9c195fe
 
 
 
 
 
 
 
 
 
 
 
 
42757ca
9c195fe
 
42757ca
9c195fe
42757ca
9c195fe
 
 
 
 
 
42757ca
9c195fe
42757ca
842577f
9c195fe
 
42757ca
9c195fe
 
 
 
42757ca
9c195fe
42757ca
1bac517
 
 
9c195fe
 
 
 
 
842577f
 
 
9c195fe
 
1bac517
 
9c195fe
 
 
842577f
9c195fe
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import uuid
from typing import Any, Dict, List, Optional

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

from env.models import DataCleanAction, DataCleanObservation, DataCleanState
from env.tasks import generate_task, get_task_names, grade_action


class DataValidationEnvironment(Environment):

    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self):
        super().__init__()
        self._state = DataCleanState()
        self._ground_truth: List[Dict[str, Any]] = []
        self._errors: List[Dict[str, Any]] = []
        self._task_info: Dict[str, Any] = {}
        self._field_names: List[str] = []

    def reset(self, task_name: Optional[str] = None, seed: int = 42,
              episode_id: Optional[str] = None, **kwargs) -> DataCleanObservation:
        if task_name is None:
            task_name = "easy_missing_values"

        task = generate_task(task_name, seed)

        self._ground_truth = task["ground_truth"]
        self._errors = task["errors"]
        self._task_info = task
        self._field_names = task["field_names"]

        self._state = DataCleanState(
            episode_id=episode_id or str(uuid.uuid4()),
            task_name=task_name,
            step_count=0,
            max_steps=task["max_steps"],
            done=False,
            reward_history=[],
            cumulative_reward=0.01,
            dataset=task["dataset"],
            ground_truth=self._ground_truth,
            errors=self._errors,
            errors_fixed=0,
            total_errors=len(self._errors),
            last_actions=[],
        )

        return DataCleanObservation(
            task_name=task_name,
            task_description=task["description"],
            dataset=task["dataset"],
            errors_found=self._errors,
            errors_remaining=len(self._errors),
            errors_total=len(self._errors),
            errors_fixed=0,
            step_count=0,
            max_steps=task["max_steps"],
            reward=0.01,
            cumulative_reward=0.01,
            done=False,
            last_action_result="Environment reset. Examine errors and fix them.",
            task_hint=task["hint"],
            progress_pct=0.0,
            field_names=self._field_names,
        )

    def step(self, action: DataCleanAction, **kwargs) -> DataCleanObservation:
        if self._state.done:
            return self._make_observation(0.01, "Episode already done. Call reset().")

        self._state.step_count += 1

        action_key = f"{action.action_type}:{action.target_field}:{action.target_row}:{action.new_value}"
        is_repeat = action_key in self._state.last_actions
        self._state.last_actions.append(action_key)

        if is_repeat:
            reward = 0.01
            message = "Penalty: repeated identical action"
        else:
            reward, message, fixed = grade_action(
                action.action_type,
                action.target_field,
                action.target_row,
                action.new_value,
                self._state.dataset,
                self._ground_truth,
                self._errors,
            )
            if fixed:
                self._state.errors_fixed += 1

        self._state.cumulative_reward += reward
        self._state.reward_history.append(reward)

        errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))

        if errors_remaining == 0:
            self._state.done = True
            message += " | All errors fixed! Episode complete."
        elif self._state.step_count >= self._state.max_steps:
            self._state.done = True
            message += f" | Max steps reached. {errors_remaining} errors remaining."

        return self._make_observation(reward, message)

    @property
    def state(self) -> DataCleanState:
        return self._state

    def _make_observation(self, reward: float, message: str) -> DataCleanObservation:
        errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))
        total = self._state.total_errors if self._state.total_errors > 0 else 1
        progress = (self._state.errors_fixed / total) * 100

        unfixed_errors = [e for e in self._errors if not e.get("fixed", False)]

        clamped_reward = max(0.01, min(0.99, reward))
        clamped_cumulative = max(0.01, min(0.99, self._state.cumulative_reward))

        return DataCleanObservation(
            task_name=self._state.task_name,
            task_description=self._task_info.get("description", ""),
            dataset=self._state.dataset,
            errors_found=unfixed_errors,
            errors_remaining=errors_remaining,
            errors_total=self._state.total_errors,
            errors_fixed=self._state.errors_fixed,
            step_count=self._state.step_count,
            max_steps=self._state.max_steps,
            reward=clamped_reward,
            cumulative_reward=clamped_cumulative,
            done=self._state.done,
            last_action_result=message,
            task_hint=self._task_info.get("hint", ""),
            progress_pct=progress,
            field_names=self._field_names,
        )