File size: 5,374 Bytes
56ddfd4
 
 
 
 
 
 
1a72194
 
56ddfd4
 
 
 
1a72194
56ddfd4
 
 
1a72194
 
 
 
 
 
56ddfd4
 
1a72194
56ddfd4
 
1a72194
56ddfd4
 
 
 
 
 
 
1a72194
56ddfd4
 
 
1a72194
56ddfd4
 
 
 
 
1a72194
56ddfd4
 
 
 
 
 
 
 
 
 
 
 
1a72194
56ddfd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a72194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56ddfd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e4201
56ddfd4
 
 
 
 
 
 
 
 
 
 
 
 
 
40e4201
 
 
 
56ddfd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e4201
 
 
 
56ddfd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e4201
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Data models for the Data Cleaning Environment.
Uses openenv-core base classes for Action, Observation, and State.
"""

from typing import Optional, Dict, Any, List

from openenv.core import Action, Observation, State
from pydantic import BaseModel, Field


# ============================================================
# Original Env Environment Models
# ============================================================


class EnvAction(Action):
    """Action for the Env environment - just a message to echo."""

    message: str = Field(default="", description="Message to echo back")


class EnvObservation(Observation):
    """Observation from the Env environment - the echoed message."""

    echoed_message: str = Field(default="", description="The echoed message")
    message_length: int = Field(default=0, description="Length of the echoed message")


# ============================================================
# Data Cleaning Environment Models
# ============================================================


class DataCleaningAction(Action):
    """
    OpenEnv-compliant action model for data cleaning.
    Represents a single action to be executed in the environment.
    """
    action_type: str = Field(
        default="",
        description="Type of action to execute (e.g., 'drop_nulls', 'fill_nulls')"
    )
    params: Dict[str, Any] = Field(
        default_factory=dict,
        description="Parameters for the action"
    )
    task_id: Optional[str] = Field(
        default=None,
        description="Associated task ID"
    )


class DataCleaningObservation(Observation):
    """
    OpenEnv-compliant observation model for data cleaning.
    Represents the state observation returned after reset or step.
    """
    dataset_info: Dict[str, Any] = Field(
        default_factory=dict,
        description="Current dataset metadata"
    )
    available_actions: List[str] = Field(
        default_factory=list,
        description="List of valid actions"
    )
    step_count: int = Field(
        default=0,
        description="Number of steps taken"
    )
    task_id: Optional[str] = Field(
        default=None,
        description="Current task ID"
    )
    message: str = Field(
        default="",
        description="Status message"
    )


class DataCleaningState(State):
    """
    Complete environment state for serialization.
    """
    session_id: str = Field(default="")
    task_id: Optional[str] = Field(default=None)
    action_history: List[Dict[str, Any]] = Field(default_factory=list)
    dataset_hash: Optional[str] = Field(default=None)
    grade: Optional[Dict[str, Any]] = Field(default=None)


# ============================================================
# Supporting Data Models (not inheriting from openenv-core)
# ============================================================


class Reward(BaseModel):
    """
    Structured reward with components for quality, progress, and penalties.
    """
    value: float = Field(
        default=0.0,
        description="Total reward value"
    )
    components: Dict[str, float] = Field(
        default_factory=dict,
        description="Breakdown of reward components"
    )

    @classmethod
    def create(
        cls,
        quality: float = 0.0,
        progress: float = 0.0,
        penalty: float = 0.0
    ) -> "Reward":
        """Factory method to create a structured reward."""
        value = max(0.0, min(1.0, quality + progress - penalty))
        return cls(
            value=round(value, 4),
            components={
                "quality": round(quality, 4),
                "progress": round(progress, 4),
                "penalty": round(penalty, 4)
            }
        )


class TaskConfig(BaseModel):
    """
    Configuration for a data cleaning task.
    """
    name: str = Field(
        default="",
        description="Human-readable task name"
    )
    task_id: str = Field(
        ...,
        description="Unique task identifier"
    )
    difficulty: str = Field(
        ...,
        description="Task difficulty level (easy, medium, hard)"
    )
    description: str = Field(
        default="",
        description="Task description"
    )
    dataset_config: Dict[str, Any] = Field(
        default_factory=dict,
        description="Dataset generation configuration"
    )
    expected_actions: List[str] = Field(
        default_factory=list,
        description="Expected sequence of actions for optimal solution"
    )
    grading_criteria: Dict[str, Any] = Field(
        default_factory=dict,
        description="Criteria for grading the task"
    )
    grader: str = Field(
        default="",
        description="Import path for the task's grader implementation"
    )


class GradeResult(BaseModel):
    """
    Result from grading a submitted solution.
    """
    final_score: float = Field(
        default=0.0,
        description="Final score (0.0 to 1.0)"
    )
    breakdown: Dict[str, float] = Field(
        default_factory=dict,
        description="Score breakdown by criterion"
    )
    feedback: str = Field(
        default="",
        description="Feedback on the solution"
    )