GlitchGhost's picture
Fix Phase 2: add [START]/[STEP]/[END] structured output to inference.py
48e9b06
"""
Data Clean Environment - Typed Models
======================================
Pydantic models for actions, observations, and state.
"""
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Base classes – use openenv-core when available, plain Pydantic otherwise
# ---------------------------------------------------------------------------
try:
from openenv.core.env_server.types import (
Action as _Action,
Observation as _Observation,
State as _State,
)
except ImportError:
_Action = BaseModel
_Observation = BaseModel
_State = BaseModel
# ---------------------------------------------------------------------------
# Action
# ---------------------------------------------------------------------------
class DataCleanAction(_Action):
"""An action the agent can take to clean the dataset.
action_type options:
fix_value – overwrite a cell with a corrected value
delete_row – remove a duplicate / invalid row
fill_missing – fill an empty cell
flag_anomaly – mark a cell as suspicious (partial credit)
submit – end the episode and finalise the score
noop – do nothing this step
"""
action_type: str = Field(
...,
description="One of: fix_value, delete_row, fill_missing, flag_anomaly, submit, noop",
)
row_index: Optional[int] = Field(
None, description="0-based row index to act on"
)
column_name: Optional[str] = Field(
None, description="Column name to act on"
)
new_value: Optional[str] = Field(
None, description="Replacement value (for fix_value / fill_missing)"
)
# ---------------------------------------------------------------------------
# Observation
# ---------------------------------------------------------------------------
class DataCleanObservation(_Observation):
"""What the agent sees after each step."""
task_name: str = Field(..., description="Current task identifier")
task_description: str = Field(..., description="Human-readable task goal")
difficulty: str = Field(..., description="easy / medium / hard")
data_preview: str = Field(
..., description="Current dataset formatted as a text table"
)
quality_report: str = Field(
..., description="Summary of detected data-quality issues"
)
columns_info: List[Dict[str, Any]] = Field(
default_factory=list,
description="Per-column metadata: name, dtype, nulls, unique count",
)
action_history: List[str] = Field(
default_factory=list, description="Log of previous actions and outcomes"
)
step_number: int = Field(0, description="Current step (1-based)")
max_steps: int = Field(0, description="Budget of remaining steps")
current_score: float = Field(
0.0, description="Running score 0.0-1.0"
)
available_actions: List[str] = Field(
default_factory=lambda: [
"fix_value",
"delete_row",
"fill_missing",
"flag_anomaly",
"submit",
"noop",
]
)
# ---------------------------------------------------------------------------
# State (episode metadata)
# ---------------------------------------------------------------------------
class DataCleanState(_State):
"""Episode-level metadata returned by state()."""
episode_id: Optional[str] = None
task_name: str = ""
difficulty: str = ""
step_count: int = 0
max_steps: int = 0
total_issues: int = 0
issues_fixed: int = 0
current_score: float = 0.0
done: bool = False