Spaces:
Sleeping
Sleeping
| """ | |
| Pydantic models for the Data Cleaning RL Environment. | |
| Defines the typed action, observation, and state structures. | |
| Inherits from openenv base types so the openenv SDK (WebSocket protocol, | |
| from_docker_image, GenericEnvClient) works out of the box. | |
| """ | |
| from __future__ import annotations | |
| from enum import Enum | |
| from typing import Any, Dict, List, Optional | |
| from pydantic import Field | |
| from openenv.core.env_server.types import ( | |
| Action as OEAction, | |
| Observation as OEObservation, | |
| ) | |
| # ---------------------- | |
| # Action enums | |
| # ----------------------- | |
| class ActionType(str, Enum): | |
| fill_missing = "fill_missing" | |
| drop_duplicates = "drop_duplicates" | |
| fix_type = "fix_type" | |
| normalize = "normalize" | |
| drop_outliers = "drop_outliers" | |
| fix_schema_violation = "fix_schema_violation" | |
| done = "done" | |
| rename_column = "rename_column" | |
| cast_datetime = "cast_datetime" | |
| deduplicate_fuzzy = "deduplicate_fuzzy" | |
| split_column = "split_column" | |
| merge_columns = "merge_columns" | |
| fix_format_regex = "fix_format_regex" | |
| standardize_categories = "standardize_categories" | |
| undo = "undo" | |
| profile_column = "profile_column" | |
| class FillStrategy(str, Enum): | |
| mean = "mean" | |
| median = "median" | |
| mode = "mode" | |
| constant = "constant" | |
| class DType(str, Enum): | |
| int = "int" | |
| float = "float" | |
| str = "str" | |
| class OutlierMethod(str, Enum): | |
| iqr = "iqr" | |
| zscore = "zscore" | |
| class SchemaConstraint(str, Enum): | |
| non_negative = "non_negative" | |
| clamp_range = "clamp_range" | |
| class MergeStrategy(str, Enum): | |
| concat = "concat" | |
| first_non_null = "first_non_null" | |
| sum = "sum" | |
| # ---------------------- | |
| # Action model — inherits from openenv.Action | |
| # ----------------------- | |
| class CleaningAction(OEAction): | |
| """A single cleaning action issued by the agent.""" | |
| model_config = {"extra": "allow"} | |
| action_type: ActionType = Field( | |
| ..., description="Type of cleaning action to perform.", | |
| ) | |
| column: Optional[str] = Field( | |
| None, description="Target column name.", | |
| ) | |
| strategy: Optional[FillStrategy] = Field( | |
| None, description="Fill strategy for fill_missing action.", | |
| ) | |
| dtype: Optional[DType] = Field( | |
| None, description="Target dtype for fix_type action.", | |
| ) | |
| method: Optional[OutlierMethod] = Field( | |
| None, description="Outlier detection method for drop_outliers action.", | |
| ) | |
| constraint: Optional[SchemaConstraint] = Field( | |
| None, description="Constraint type for fix_schema_violation action.", | |
| ) | |
| constant_value: Optional[Any] = Field( | |
| None, description="Constant fill value for fill_missing with strategy=constant.", | |
| ) | |
| new_name: Optional[str] = Field( | |
| None, description="New column name for rename_column action.", | |
| ) | |
| datetime_format: Optional[str] = Field( | |
| None, description="DateTime format string for cast_datetime action.", | |
| ) | |
| threshold: Optional[float] = Field( | |
| None, description="Similarity threshold for deduplicate_fuzzy (0.0-1.0).", | |
| ) | |
| delimiter: Optional[str] = Field( | |
| None, description="Delimiter string for split_column action.", | |
| ) | |
| column2: Optional[str] = Field( | |
| None, description="Second column name for merge_columns action.", | |
| ) | |
| merge_strategy: Optional[MergeStrategy] = Field( | |
| None, description="Merge strategy for merge_columns.", | |
| ) | |
| pattern: Optional[str] = Field( | |
| None, description="Regex pattern for fix_format_regex action.", | |
| ) | |
| replacement: Optional[str] = Field( | |
| None, description="Replacement string for fix_format_regex action.", | |
| ) | |
| # --------------- | |
| # Observation — inherits from openenv.Observation | |
| # ------------------- | |
| class ColumnIssues(OEObservation): | |
| """Per-column data quality issues detected in the current state.""" | |
| model_config = {"extra": "allow"} | |
| missing_count: int = Field(0, description="Number of missing (NaN) values.") | |
| missing_pct: float = Field(0.0, description="Fraction of missing values [0, 1].") | |
| type_errors: int = Field(0, description="Cells that cannot be parsed as expected dtype.") | |
| outlier_count: int = Field(0, description="Outliers detected via IQR rule.") | |
| has_duplicates: bool = Field(False, description="True if dataset has duplicate rows.") | |
| format_violation_count: int = Field(0, description="Values not matching expected format.") | |
| encoding_issue_count: int = Field(0, description="Values with unexpected non-ASCII chars.") | |
| semantic_duplicate_count: int = Field(0, description="Near-duplicate distinct string values.") | |
| class ColumnStats(OEObservation): | |
| """Compact statistical summary for a column.""" | |
| model_config = {"extra": "allow"} | |
| mean: Optional[float] = None | |
| std: Optional[float] = None | |
| null_count: int = 0 | |
| unique_count: int = 0 | |
| class Observation(OEObservation): | |
| """Observation returned by reset() and step().""" | |
| model_config = {"extra": "allow"} | |
| task: str = Field("", description="Task tier: 'easy', 'medium', 'hard', or 'expert'.") | |
| step: int = Field(0, description="Current step number (0-indexed).") | |
| max_steps: int = Field(0, description="Maximum steps allowed in this episode.") | |
| columns: List[str] = Field(default_factory=list, description="Column names in the dataset.") | |
| column_issues: Dict[str, Any] = Field( | |
| default_factory=dict, description="Data quality issues per column.", | |
| ) | |
| column_stats: Dict[str, Any] = Field( | |
| default_factory=dict, description="Compact statistics per column.", | |
| ) | |
| sample_rows: List[Dict[str, Any]] = Field( | |
| default_factory=list, description="Up to 5 sampled rows.", | |
| ) | |
| action_history: List[Dict[str, Any]] = Field( | |
| default_factory=list, description="Last 5 actions with rewards.", | |
| ) | |
| budget_remaining: Optional[float] = Field( | |
| None, description="Remaining action budget (0.0-1.0).", | |
| ) | |
| profile_result: Optional[Dict[str, Any]] = Field( | |
| None, description="Extended column statistics from last profile_column.", | |
| ) | |
| # ---------------- | |
| # Episode state | |
| # -------------------- | |
| class EpisodeState(OEObservation): | |
| """Metadata about the current episode.""" | |
| model_config = {"extra": "allow"} | |
| episode_id: str = Field("", description="Unique episode identifier (UUID).") | |
| task: str = Field("", description="Task tier.") | |
| step: int = Field(0, description="Current step number.") | |
| max_steps: int = Field(0, description="Maximum steps allowed.") | |
| score: Optional[float] = Field( | |
| None, description="Final grader score (0.0-1.0).", | |
| ) | |