| |
| |
| |
| |
| |
|
|
| """ |
| Data models for the Data Cleaning Environment. |
| Uses openenv-core base classes for Action, Observation, and State. |
| """ |
|
|
| from typing import Optional, Dict, Any, List |
|
|
| from openenv.core import Action, Observation, State |
| from pydantic import BaseModel, Field |
|
|
|
|
| |
| |
| |
|
|
|
|
| class EnvAction(Action): |
| """Action for the Env environment - just a message to echo.""" |
|
|
| message: str = Field(default="", description="Message to echo back") |
|
|
|
|
| class EnvObservation(Observation): |
| """Observation from the Env environment - the echoed message.""" |
|
|
| echoed_message: str = Field(default="", description="The echoed message") |
| message_length: int = Field(default=0, description="Length of the echoed message") |
|
|
|
|
| |
| |
| |
|
|
|
|
| class DataCleaningAction(Action): |
| """ |
| OpenEnv-compliant action model for data cleaning. |
| Represents a single action to be executed in the environment. |
| """ |
| action_type: str = Field( |
| default="", |
| description="Type of action to execute (e.g., 'drop_nulls', 'fill_nulls')" |
| ) |
| params: Dict[str, Any] = Field( |
| default_factory=dict, |
| description="Parameters for the action" |
| ) |
| task_id: Optional[str] = Field( |
| default=None, |
| description="Associated task ID" |
| ) |
|
|
|
|
| class DataCleaningObservation(Observation): |
| """ |
| OpenEnv-compliant observation model for data cleaning. |
| Represents the state observation returned after reset or step. |
| """ |
| dataset_info: Dict[str, Any] = Field( |
| default_factory=dict, |
| description="Current dataset metadata" |
| ) |
| available_actions: List[str] = Field( |
| default_factory=list, |
| description="List of valid actions" |
| ) |
| step_count: int = Field( |
| default=0, |
| description="Number of steps taken" |
| ) |
| task_id: Optional[str] = Field( |
| default=None, |
| description="Current task ID" |
| ) |
| message: str = Field( |
| default="", |
| description="Status message" |
| ) |
|
|
|
|
| class DataCleaningState(State): |
| """ |
| Complete environment state for serialization. |
| """ |
| session_id: str = Field(default="") |
| task_id: Optional[str] = Field(default=None) |
| action_history: List[Dict[str, Any]] = Field(default_factory=list) |
| dataset_hash: Optional[str] = Field(default=None) |
| grade: Optional[Dict[str, Any]] = Field(default=None) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class Reward(BaseModel): |
| """ |
| Structured reward with components for quality, progress, and penalties. |
| """ |
| value: float = Field( |
| default=0.0, |
| description="Total reward value" |
| ) |
| components: Dict[str, float] = Field( |
| default_factory=dict, |
| description="Breakdown of reward components" |
| ) |
|
|
| @classmethod |
| def create( |
| cls, |
| quality: float = 0.0, |
| progress: float = 0.0, |
| penalty: float = 0.0 |
| ) -> "Reward": |
| """Factory method to create a structured reward.""" |
| value = max(0.0, min(1.0, quality + progress - penalty)) |
| return cls( |
| value=round(value, 4), |
| components={ |
| "quality": round(quality, 4), |
| "progress": round(progress, 4), |
| "penalty": round(penalty, 4) |
| } |
| ) |
|
|
|
|
| class TaskConfig(BaseModel): |
| """ |
| Configuration for a data cleaning task. |
| """ |
| name: str = Field( |
| default="", |
| description="Human-readable task name" |
| ) |
| task_id: str = Field( |
| ..., |
| description="Unique task identifier" |
| ) |
| difficulty: str = Field( |
| ..., |
| description="Task difficulty level (easy, medium, hard)" |
| ) |
| description: str = Field( |
| default="", |
| description="Task description" |
| ) |
| dataset_config: Dict[str, Any] = Field( |
| default_factory=dict, |
| description="Dataset generation configuration" |
| ) |
| expected_actions: List[str] = Field( |
| default_factory=list, |
| description="Expected sequence of actions for optimal solution" |
| ) |
| grading_criteria: Dict[str, Any] = Field( |
| default_factory=dict, |
| description="Criteria for grading the task" |
| ) |
| grader: str = Field( |
| default="", |
| description="Import path for the task's grader implementation" |
| ) |
|
|
|
|
| class GradeResult(BaseModel): |
| """ |
| Result from grading a submitted solution. |
| """ |
| final_score: float = Field( |
| default=0.0, |
| description="Final score (0.0 to 1.0)" |
| ) |
| breakdown: Dict[str, float] = Field( |
| default_factory=dict, |
| description="Score breakdown by criterion" |
| ) |
| feedback: str = Field( |
| default="", |
| description="Feedback on the solution" |
| ) |
|
|