env / models.py
sairaj2's picture
Upload folder using huggingface_hub
40e4201 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Data models for the Data Cleaning Environment.
Uses openenv-core base classes for Action, Observation, and State.
"""
from typing import Optional, Dict, Any, List
from openenv.core import Action, Observation, State
from pydantic import BaseModel, Field
# ============================================================
# Original Env Environment Models
# ============================================================
class EnvAction(Action):
"""Action for the Env environment - just a message to echo."""
message: str = Field(default="", description="Message to echo back")
class EnvObservation(Observation):
"""Observation from the Env environment - the echoed message."""
echoed_message: str = Field(default="", description="The echoed message")
message_length: int = Field(default=0, description="Length of the echoed message")
# ============================================================
# Data Cleaning Environment Models
# ============================================================
class DataCleaningAction(Action):
"""
OpenEnv-compliant action model for data cleaning.
Represents a single action to be executed in the environment.
"""
action_type: str = Field(
default="",
description="Type of action to execute (e.g., 'drop_nulls', 'fill_nulls')"
)
params: Dict[str, Any] = Field(
default_factory=dict,
description="Parameters for the action"
)
task_id: Optional[str] = Field(
default=None,
description="Associated task ID"
)
class DataCleaningObservation(Observation):
"""
OpenEnv-compliant observation model for data cleaning.
Represents the state observation returned after reset or step.
"""
dataset_info: Dict[str, Any] = Field(
default_factory=dict,
description="Current dataset metadata"
)
available_actions: List[str] = Field(
default_factory=list,
description="List of valid actions"
)
step_count: int = Field(
default=0,
description="Number of steps taken"
)
task_id: Optional[str] = Field(
default=None,
description="Current task ID"
)
message: str = Field(
default="",
description="Status message"
)
class DataCleaningState(State):
"""
Complete environment state for serialization.
"""
session_id: str = Field(default="")
task_id: Optional[str] = Field(default=None)
action_history: List[Dict[str, Any]] = Field(default_factory=list)
dataset_hash: Optional[str] = Field(default=None)
grade: Optional[Dict[str, Any]] = Field(default=None)
# ============================================================
# Supporting Data Models (not inheriting from openenv-core)
# ============================================================
class Reward(BaseModel):
"""
Structured reward with components for quality, progress, and penalties.
"""
value: float = Field(
default=0.0,
description="Total reward value"
)
components: Dict[str, float] = Field(
default_factory=dict,
description="Breakdown of reward components"
)
@classmethod
def create(
cls,
quality: float = 0.0,
progress: float = 0.0,
penalty: float = 0.0
) -> "Reward":
"""Factory method to create a structured reward."""
value = max(0.0, min(1.0, quality + progress - penalty))
return cls(
value=round(value, 4),
components={
"quality": round(quality, 4),
"progress": round(progress, 4),
"penalty": round(penalty, 4)
}
)
class TaskConfig(BaseModel):
"""
Configuration for a data cleaning task.
"""
name: str = Field(
default="",
description="Human-readable task name"
)
task_id: str = Field(
...,
description="Unique task identifier"
)
difficulty: str = Field(
...,
description="Task difficulty level (easy, medium, hard)"
)
description: str = Field(
default="",
description="Task description"
)
dataset_config: Dict[str, Any] = Field(
default_factory=dict,
description="Dataset generation configuration"
)
expected_actions: List[str] = Field(
default_factory=list,
description="Expected sequence of actions for optimal solution"
)
grading_criteria: Dict[str, Any] = Field(
default_factory=dict,
description="Criteria for grading the task"
)
grader: str = Field(
default="",
description="Import path for the task's grader implementation"
)
class GradeResult(BaseModel):
"""
Result from grading a submitted solution.
"""
final_score: float = Field(
default=0.0,
description="Final score (0.0 to 1.0)"
)
breakdown: Dict[str, float] = Field(
default_factory=dict,
description="Score breakdown by criterion"
)
feedback: str = Field(
default="",
description="Feedback on the solution"
)