""" Pydantic models for the Data Cleaning RL Environment. Defines the typed action, observation, and state structures. Inherits from openenv base types so the openenv SDK (WebSocket protocol, from_docker_image, GenericEnvClient) works out of the box. """ from __future__ import annotations from enum import Enum from typing import Any, Dict, List, Optional from pydantic import Field from openenv.core.env_server.types import ( Action as OEAction, Observation as OEObservation, ) # ---------------------- # Action enums # ----------------------- class ActionType(str, Enum): fill_missing = "fill_missing" drop_duplicates = "drop_duplicates" fix_type = "fix_type" normalize = "normalize" drop_outliers = "drop_outliers" fix_schema_violation = "fix_schema_violation" done = "done" rename_column = "rename_column" cast_datetime = "cast_datetime" deduplicate_fuzzy = "deduplicate_fuzzy" split_column = "split_column" merge_columns = "merge_columns" fix_format_regex = "fix_format_regex" standardize_categories = "standardize_categories" undo = "undo" profile_column = "profile_column" class FillStrategy(str, Enum): mean = "mean" median = "median" mode = "mode" constant = "constant" class DType(str, Enum): int = "int" float = "float" str = "str" class OutlierMethod(str, Enum): iqr = "iqr" zscore = "zscore" class SchemaConstraint(str, Enum): non_negative = "non_negative" clamp_range = "clamp_range" class MergeStrategy(str, Enum): concat = "concat" first_non_null = "first_non_null" sum = "sum" # ---------------------- # Action model — inherits from openenv.Action # ----------------------- class CleaningAction(OEAction): """A single cleaning action issued by the agent.""" model_config = {"extra": "allow"} action_type: ActionType = Field( ..., description="Type of cleaning action to perform.", ) column: Optional[str] = Field( None, description="Target column name.", ) strategy: Optional[FillStrategy] = Field( None, description="Fill strategy for fill_missing action.", ) dtype: Optional[DType] = Field( None, description="Target dtype for fix_type action.", ) method: Optional[OutlierMethod] = Field( None, description="Outlier detection method for drop_outliers action.", ) constraint: Optional[SchemaConstraint] = Field( None, description="Constraint type for fix_schema_violation action.", ) constant_value: Optional[Any] = Field( None, description="Constant fill value for fill_missing with strategy=constant.", ) new_name: Optional[str] = Field( None, description="New column name for rename_column action.", ) datetime_format: Optional[str] = Field( None, description="DateTime format string for cast_datetime action.", ) threshold: Optional[float] = Field( None, description="Similarity threshold for deduplicate_fuzzy (0.0-1.0).", ) delimiter: Optional[str] = Field( None, description="Delimiter string for split_column action.", ) column2: Optional[str] = Field( None, description="Second column name for merge_columns action.", ) merge_strategy: Optional[MergeStrategy] = Field( None, description="Merge strategy for merge_columns.", ) pattern: Optional[str] = Field( None, description="Regex pattern for fix_format_regex action.", ) replacement: Optional[str] = Field( None, description="Replacement string for fix_format_regex action.", ) # --------------- # Observation — inherits from openenv.Observation # ------------------- class ColumnIssues(OEObservation): """Per-column data quality issues detected in the current state.""" model_config = {"extra": "allow"} missing_count: int = Field(0, description="Number of missing (NaN) values.") missing_pct: float = Field(0.0, description="Fraction of missing values [0, 1].") type_errors: int = Field(0, description="Cells that cannot be parsed as expected dtype.") outlier_count: int = Field(0, description="Outliers detected via IQR rule.") has_duplicates: bool = Field(False, description="True if dataset has duplicate rows.") format_violation_count: int = Field(0, description="Values not matching expected format.") encoding_issue_count: int = Field(0, description="Values with unexpected non-ASCII chars.") semantic_duplicate_count: int = Field(0, description="Near-duplicate distinct string values.") class ColumnStats(OEObservation): """Compact statistical summary for a column.""" model_config = {"extra": "allow"} mean: Optional[float] = None std: Optional[float] = None null_count: int = 0 unique_count: int = 0 class Observation(OEObservation): """Observation returned by reset() and step().""" model_config = {"extra": "allow"} task: str = Field("", description="Task tier: 'easy', 'medium', 'hard', or 'expert'.") step: int = Field(0, description="Current step number (0-indexed).") max_steps: int = Field(0, description="Maximum steps allowed in this episode.") columns: List[str] = Field(default_factory=list, description="Column names in the dataset.") column_issues: Dict[str, Any] = Field( default_factory=dict, description="Data quality issues per column.", ) column_stats: Dict[str, Any] = Field( default_factory=dict, description="Compact statistics per column.", ) sample_rows: List[Dict[str, Any]] = Field( default_factory=list, description="Up to 5 sampled rows.", ) action_history: List[Dict[str, Any]] = Field( default_factory=list, description="Last 5 actions with rewards.", ) budget_remaining: Optional[float] = Field( None, description="Remaining action budget (0.0-1.0).", ) profile_result: Optional[Dict[str, Any]] = Field( None, description="Extended column statistics from last profile_column.", ) # ---------------- # Episode state # -------------------- class EpisodeState(OEObservation): """Metadata about the current episode.""" model_config = {"extra": "allow"} episode_id: str = Field("", description="Unique episode identifier (UUID).") task: str = Field("", description="Task tier.") step: int = Field(0, description="Current step number.") max_steps: int = Field(0, description="Maximum steps allowed.") score: Optional[float] = Field( None, description="Final grader score (0.0-1.0).", )