dataclean-env / models.py
Anuj424614's picture
Upload folder using huggingface_hub
8345e43 verified
"""Pydantic models for the DataClean-Env environment.
Defines typed Action, Observation, and State models following OpenEnv spec.
Observation uses issue-first design per Codex feedback.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from openenv.core.env_server import Action, Observation, State
# --- Supporting Types ---
class QualityIssue(BaseModel):
"""A detected quality issue in the dataset."""
row_id: int
column: str
issue_type: str = Field(
description="One of: null, format, duplicate, case, type_violation, cross_field, anomaly"
)
description: str
suggestion: Optional[str] = None
class IssueGroup(BaseModel):
"""Issues grouped by type for compact display."""
issue_type: str
count: int
examples: List[QualityIssue] = Field(default_factory=list)
class DataSummary(BaseModel):
"""Compact summary of the dataset state."""
row_count: int = 0
column_count: int = 0
total_cells: int = 0
null_count: int = 0
issue_count: int = 0
columns: List[str] = Field(default_factory=list)
dtypes: Dict[str, str] = Field(default_factory=dict)
class ActionResult(BaseModel):
"""Result of executing an action."""
action: str
status: str = Field(description="One of: success, error, no_effect")
message: str
cells_modified: int = 0
# --- Core Models ---
class DataCleanAction(Action):
"""Agent's action to clean data.
Actions reference rows by stable `row_id` (integer, unique within
episode, survives delete/merge operations). The row_id is visible
in every observation row and does NOT change during the episode.
"""
action_type: str = Field(
...,
description=(
"One of: fix_value, delete_row, fill_missing, standardize_format, "
"merge_duplicates, flag_anomaly, split_column, rename_column, "
"cast_type, escalate_to_human, mark_complete"
),
)
params: Dict[str, Any] = Field(
default_factory=dict,
description=(
"Action-specific parameters. Use 'row_id' (not index) to reference rows. "
"fix_value: {row_id, column, new_value}. "
"delete_row: {row_id}. "
"fill_missing: {row_id, column, value}. "
"standardize_format: {column, format_type}. "
"merge_duplicates: {row_id1, row_id2, strategy}. "
"flag_anomaly: {row_id, column, reason}. "
"split_column: {column, delimiter, new_names}. "
"rename_column: {old_name, new_name}. "
"cast_type: {column, target_type}. "
"escalate_to_human: {row_id, column, confidence, reason}. "
"mark_complete: {}."
),
)
class DataCleanObservation(Observation):
"""What the agent sees after each step.
Issue-first design: quality_issues and data_summary are primary.
Full row data is secondary (truncated for large datasets).
"""
# --- Issue-first fields (PRIMARY) ---
data_summary: DataSummary = Field(default_factory=DataSummary)
quality_issues: List[QualityIssue] = Field(default_factory=list)
issue_groups: List[IssueGroup] = Field(default_factory=list)
issues_remaining: int = 0
# --- Data (SECONDARY, may be truncated) ---
columns: List[str] = Field(default_factory=list)
rows: List[List[Any]] = Field(default_factory=list)
row_count: int = 0
# --- Schema info ---
schema_info: Dict[str, Any] = Field(default_factory=dict)
# --- Step context ---
step_number: int = 0
max_steps: int = 30
steps_remaining: int = 30
# --- Budget info ---
budget_spent: float = 0.0
budget_remaining: float = 100.0
action_costs: Dict[str, float] = Field(default_factory=dict)
# --- History ---
last_action_result: Optional[ActionResult] = None
recent_actions: List[ActionResult] = Field(default_factory=list)
# --- Task info ---
task_id: str = ""
task_name: str = ""
difficulty: str = ""
# --- Inherited from Observation base ---
# done: bool = False
# reward: bool | int | float | None = None
# metadata: Dict[str, Any] = {}
class DataCleanState(State):
"""Internal environment state. Not exposed to agent directly."""
# Inherited: episode_id (str), step_count (int)
task_id: str = ""
difficulty: str = ""
current_data: List[Dict[str, Any]] = Field(default_factory=list)
ground_truth: List[Dict[str, Any]] = Field(default_factory=list)
original_dirty: List[Dict[str, Any]] = Field(default_factory=list)
schema_def: Dict[str, Any] = Field(default_factory=dict)
action_log: List[Dict[str, Any]] = Field(default_factory=list)
flagged_cells: List[Dict[str, str]] = Field(default_factory=list)
escalated_cells: List[Dict[str, Any]] = Field(default_factory=list)
max_steps: int = 30
is_complete: bool = False
previous_score: float = 0.0 # For delta reward computation (mutates each step)
initial_raw_score: float = 0.0 # Raw score of dirty data at reset (immutable)
# Cost-aware intervention budget
action_budget: float = 100.0 # Total budget for the episode
budget_spent: float = 0.0 # Cost spent so far
budget_remaining: float = 100.0 # Budget left