Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Data models for the Data Cleaning Environment. | |
| Uses openenv-core base classes for Action, Observation, and State. | |
| """ | |
| from typing import Optional, Dict, Any, List | |
| from openenv.core import Action, Observation, State | |
| from pydantic import BaseModel, Field | |
| # ============================================================ | |
| # Original Env Environment Models | |
| # ============================================================ | |
| class EnvAction(Action): | |
| """Action for the Env environment - just a message to echo.""" | |
| message: str = Field(default="", description="Message to echo back") | |
| class EnvObservation(Observation): | |
| """Observation from the Env environment - the echoed message.""" | |
| echoed_message: str = Field(default="", description="The echoed message") | |
| message_length: int = Field(default=0, description="Length of the echoed message") | |
| # ============================================================ | |
| # Data Cleaning Environment Models | |
| # ============================================================ | |
| class DataCleaningAction(Action): | |
| """ | |
| OpenEnv-compliant action model for data cleaning. | |
| Represents a single action to be executed in the environment. | |
| """ | |
| action_type: str = Field( | |
| default="", | |
| description="Type of action to execute (e.g., 'drop_nulls', 'fill_nulls')" | |
| ) | |
| params: Dict[str, Any] = Field( | |
| default_factory=dict, | |
| description="Parameters for the action" | |
| ) | |
| task_id: Optional[str] = Field( | |
| default=None, | |
| description="Associated task ID" | |
| ) | |
| class DataCleaningObservation(Observation): | |
| """ | |
| OpenEnv-compliant observation model for data cleaning. | |
| Represents the state observation returned after reset or step. | |
| """ | |
| dataset_info: Dict[str, Any] = Field( | |
| default_factory=dict, | |
| description="Current dataset metadata" | |
| ) | |
| available_actions: List[str] = Field( | |
| default_factory=list, | |
| description="List of valid actions" | |
| ) | |
| step_count: int = Field( | |
| default=0, | |
| description="Number of steps taken" | |
| ) | |
| task_id: Optional[str] = Field( | |
| default=None, | |
| description="Current task ID" | |
| ) | |
| message: str = Field( | |
| default="", | |
| description="Status message" | |
| ) | |
| class DataCleaningState(State): | |
| """ | |
| Complete environment state for serialization. | |
| """ | |
| session_id: str = Field(default="") | |
| task_id: Optional[str] = Field(default=None) | |
| action_history: List[Dict[str, Any]] = Field(default_factory=list) | |
| dataset_hash: Optional[str] = Field(default=None) | |
| grade: Optional[Dict[str, Any]] = Field(default=None) | |
| # ============================================================ | |
| # Supporting Data Models (not inheriting from openenv-core) | |
| # ============================================================ | |
| class Reward(BaseModel): | |
| """ | |
| Structured reward with components for quality, progress, and penalties. | |
| """ | |
| value: float = Field( | |
| default=0.0, | |
| description="Total reward value" | |
| ) | |
| components: Dict[str, float] = Field( | |
| default_factory=dict, | |
| description="Breakdown of reward components" | |
| ) | |
| def create( | |
| cls, | |
| quality: float = 0.0, | |
| progress: float = 0.0, | |
| penalty: float = 0.0 | |
| ) -> "Reward": | |
| """Factory method to create a structured reward.""" | |
| value = max(0.0, min(1.0, quality + progress - penalty)) | |
| return cls( | |
| value=round(value, 4), | |
| components={ | |
| "quality": round(quality, 4), | |
| "progress": round(progress, 4), | |
| "penalty": round(penalty, 4) | |
| } | |
| ) | |
| class TaskConfig(BaseModel): | |
| """ | |
| Configuration for a data cleaning task. | |
| """ | |
| name: str = Field( | |
| default="", | |
| description="Human-readable task name" | |
| ) | |
| task_id: str = Field( | |
| ..., | |
| description="Unique task identifier" | |
| ) | |
| difficulty: str = Field( | |
| ..., | |
| description="Task difficulty level (easy, medium, hard)" | |
| ) | |
| description: str = Field( | |
| default="", | |
| description="Task description" | |
| ) | |
| dataset_config: Dict[str, Any] = Field( | |
| default_factory=dict, | |
| description="Dataset generation configuration" | |
| ) | |
| expected_actions: List[str] = Field( | |
| default_factory=list, | |
| description="Expected sequence of actions for optimal solution" | |
| ) | |
| grading_criteria: Dict[str, Any] = Field( | |
| default_factory=dict, | |
| description="Criteria for grading the task" | |
| ) | |
| grader: str = Field( | |
| default="", | |
| description="Import path for the task's grader implementation" | |
| ) | |
| class GradeResult(BaseModel): | |
| """ | |
| Result from grading a submitted solution. | |
| """ | |
| final_score: float = Field( | |
| default=0.0, | |
| description="Final score (0.0 to 1.0)" | |
| ) | |
| breakdown: Dict[str, float] = Field( | |
| default_factory=dict, | |
| description="Score breakdown by criterion" | |
| ) | |
| feedback: str = Field( | |
| default="", | |
| description="Feedback on the solution" | |
| ) | |