yashmarathe's picture
refactor: full openenv protocol compliance
1a55ff4
"""
Pydantic models for the Data Cleaning RL Environment.
Defines the typed action, observation, and state structures.
Inherits from openenv base types so the openenv SDK (WebSocket protocol,
from_docker_image, GenericEnvClient) works out of the box.
"""
from __future__ import annotations
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import Field
from openenv.core.env_server.types import (
Action as OEAction,
Observation as OEObservation,
)
# ----------------------
# Action enums
# -----------------------
class ActionType(str, Enum):
fill_missing = "fill_missing"
drop_duplicates = "drop_duplicates"
fix_type = "fix_type"
normalize = "normalize"
drop_outliers = "drop_outliers"
fix_schema_violation = "fix_schema_violation"
done = "done"
rename_column = "rename_column"
cast_datetime = "cast_datetime"
deduplicate_fuzzy = "deduplicate_fuzzy"
split_column = "split_column"
merge_columns = "merge_columns"
fix_format_regex = "fix_format_regex"
standardize_categories = "standardize_categories"
undo = "undo"
profile_column = "profile_column"
class FillStrategy(str, Enum):
mean = "mean"
median = "median"
mode = "mode"
constant = "constant"
class DType(str, Enum):
int = "int"
float = "float"
str = "str"
class OutlierMethod(str, Enum):
iqr = "iqr"
zscore = "zscore"
class SchemaConstraint(str, Enum):
non_negative = "non_negative"
clamp_range = "clamp_range"
class MergeStrategy(str, Enum):
concat = "concat"
first_non_null = "first_non_null"
sum = "sum"
# ----------------------
# Action model — inherits from openenv.Action
# -----------------------
class CleaningAction(OEAction):
"""A single cleaning action issued by the agent."""
model_config = {"extra": "allow"}
action_type: ActionType = Field(
..., description="Type of cleaning action to perform.",
)
column: Optional[str] = Field(
None, description="Target column name.",
)
strategy: Optional[FillStrategy] = Field(
None, description="Fill strategy for fill_missing action.",
)
dtype: Optional[DType] = Field(
None, description="Target dtype for fix_type action.",
)
method: Optional[OutlierMethod] = Field(
None, description="Outlier detection method for drop_outliers action.",
)
constraint: Optional[SchemaConstraint] = Field(
None, description="Constraint type for fix_schema_violation action.",
)
constant_value: Optional[Any] = Field(
None, description="Constant fill value for fill_missing with strategy=constant.",
)
new_name: Optional[str] = Field(
None, description="New column name for rename_column action.",
)
datetime_format: Optional[str] = Field(
None, description="DateTime format string for cast_datetime action.",
)
threshold: Optional[float] = Field(
None, description="Similarity threshold for deduplicate_fuzzy (0.0-1.0).",
)
delimiter: Optional[str] = Field(
None, description="Delimiter string for split_column action.",
)
column2: Optional[str] = Field(
None, description="Second column name for merge_columns action.",
)
merge_strategy: Optional[MergeStrategy] = Field(
None, description="Merge strategy for merge_columns.",
)
pattern: Optional[str] = Field(
None, description="Regex pattern for fix_format_regex action.",
)
replacement: Optional[str] = Field(
None, description="Replacement string for fix_format_regex action.",
)
# ---------------
# Observation — inherits from openenv.Observation
# -------------------
class ColumnIssues(OEObservation):
"""Per-column data quality issues detected in the current state."""
model_config = {"extra": "allow"}
missing_count: int = Field(0, description="Number of missing (NaN) values.")
missing_pct: float = Field(0.0, description="Fraction of missing values [0, 1].")
type_errors: int = Field(0, description="Cells that cannot be parsed as expected dtype.")
outlier_count: int = Field(0, description="Outliers detected via IQR rule.")
has_duplicates: bool = Field(False, description="True if dataset has duplicate rows.")
format_violation_count: int = Field(0, description="Values not matching expected format.")
encoding_issue_count: int = Field(0, description="Values with unexpected non-ASCII chars.")
semantic_duplicate_count: int = Field(0, description="Near-duplicate distinct string values.")
class ColumnStats(OEObservation):
"""Compact statistical summary for a column."""
model_config = {"extra": "allow"}
mean: Optional[float] = None
std: Optional[float] = None
null_count: int = 0
unique_count: int = 0
class Observation(OEObservation):
"""Observation returned by reset() and step()."""
model_config = {"extra": "allow"}
task: str = Field("", description="Task tier: 'easy', 'medium', 'hard', or 'expert'.")
step: int = Field(0, description="Current step number (0-indexed).")
max_steps: int = Field(0, description="Maximum steps allowed in this episode.")
columns: List[str] = Field(default_factory=list, description="Column names in the dataset.")
column_issues: Dict[str, Any] = Field(
default_factory=dict, description="Data quality issues per column.",
)
column_stats: Dict[str, Any] = Field(
default_factory=dict, description="Compact statistics per column.",
)
sample_rows: List[Dict[str, Any]] = Field(
default_factory=list, description="Up to 5 sampled rows.",
)
action_history: List[Dict[str, Any]] = Field(
default_factory=list, description="Last 5 actions with rewards.",
)
budget_remaining: Optional[float] = Field(
None, description="Remaining action budget (0.0-1.0).",
)
profile_result: Optional[Dict[str, Any]] = Field(
None, description="Extended column statistics from last profile_column.",
)
# ----------------
# Episode state
# --------------------
class EpisodeState(OEObservation):
"""Metadata about the current episode."""
model_config = {"extra": "allow"}
episode_id: str = Field("", description="Unique episode identifier (UUID).")
task: str = Field("", description="Task tier.")
step: int = Field(0, description="Current step number.")
max_steps: int = Field(0, description="Maximum steps allowed.")
score: Optional[float] = Field(
None, description="Final grader score (0.0-1.0).",
)