Spaces:
Sleeping
Sleeping
File size: 5,361 Bytes
8345e43 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """Pydantic models for the DataClean-Env environment.
Defines typed Action, Observation, and State models following OpenEnv spec.
Observation uses issue-first design per Codex feedback.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from openenv.core.env_server import Action, Observation, State
# --- Supporting Types ---
class QualityIssue(BaseModel):
"""A detected quality issue in the dataset."""
row_id: int
column: str
issue_type: str = Field(
description="One of: null, format, duplicate, case, type_violation, cross_field, anomaly"
)
description: str
suggestion: Optional[str] = None
class IssueGroup(BaseModel):
"""Issues grouped by type for compact display."""
issue_type: str
count: int
examples: List[QualityIssue] = Field(default_factory=list)
class DataSummary(BaseModel):
"""Compact summary of the dataset state."""
row_count: int = 0
column_count: int = 0
total_cells: int = 0
null_count: int = 0
issue_count: int = 0
columns: List[str] = Field(default_factory=list)
dtypes: Dict[str, str] = Field(default_factory=dict)
class ActionResult(BaseModel):
"""Result of executing an action."""
action: str
status: str = Field(description="One of: success, error, no_effect")
message: str
cells_modified: int = 0
# --- Core Models ---
class DataCleanAction(Action):
"""Agent's action to clean data.
Actions reference rows by stable `row_id` (integer, unique within
episode, survives delete/merge operations). The row_id is visible
in every observation row and does NOT change during the episode.
"""
action_type: str = Field(
...,
description=(
"One of: fix_value, delete_row, fill_missing, standardize_format, "
"merge_duplicates, flag_anomaly, split_column, rename_column, "
"cast_type, escalate_to_human, mark_complete"
),
)
params: Dict[str, Any] = Field(
default_factory=dict,
description=(
"Action-specific parameters. Use 'row_id' (not index) to reference rows. "
"fix_value: {row_id, column, new_value}. "
"delete_row: {row_id}. "
"fill_missing: {row_id, column, value}. "
"standardize_format: {column, format_type}. "
"merge_duplicates: {row_id1, row_id2, strategy}. "
"flag_anomaly: {row_id, column, reason}. "
"split_column: {column, delimiter, new_names}. "
"rename_column: {old_name, new_name}. "
"cast_type: {column, target_type}. "
"escalate_to_human: {row_id, column, confidence, reason}. "
"mark_complete: {}."
),
)
class DataCleanObservation(Observation):
"""What the agent sees after each step.
Issue-first design: quality_issues and data_summary are primary.
Full row data is secondary (truncated for large datasets).
"""
# --- Issue-first fields (PRIMARY) ---
data_summary: DataSummary = Field(default_factory=DataSummary)
quality_issues: List[QualityIssue] = Field(default_factory=list)
issue_groups: List[IssueGroup] = Field(default_factory=list)
issues_remaining: int = 0
# --- Data (SECONDARY, may be truncated) ---
columns: List[str] = Field(default_factory=list)
rows: List[List[Any]] = Field(default_factory=list)
row_count: int = 0
# --- Schema info ---
schema_info: Dict[str, Any] = Field(default_factory=dict)
# --- Step context ---
step_number: int = 0
max_steps: int = 30
steps_remaining: int = 30
# --- Budget info ---
budget_spent: float = 0.0
budget_remaining: float = 100.0
action_costs: Dict[str, float] = Field(default_factory=dict)
# --- History ---
last_action_result: Optional[ActionResult] = None
recent_actions: List[ActionResult] = Field(default_factory=list)
# --- Task info ---
task_id: str = ""
task_name: str = ""
difficulty: str = ""
# --- Inherited from Observation base ---
# done: bool = False
# reward: bool | int | float | None = None
# metadata: Dict[str, Any] = {}
class DataCleanState(State):
"""Internal environment state. Not exposed to agent directly."""
# Inherited: episode_id (str), step_count (int)
task_id: str = ""
difficulty: str = ""
current_data: List[Dict[str, Any]] = Field(default_factory=list)
ground_truth: List[Dict[str, Any]] = Field(default_factory=list)
original_dirty: List[Dict[str, Any]] = Field(default_factory=list)
schema_def: Dict[str, Any] = Field(default_factory=dict)
action_log: List[Dict[str, Any]] = Field(default_factory=list)
flagged_cells: List[Dict[str, str]] = Field(default_factory=list)
escalated_cells: List[Dict[str, Any]] = Field(default_factory=list)
max_steps: int = 30
is_complete: bool = False
previous_score: float = 0.0 # For delta reward computation (mutates each step)
initial_raw_score: float = 0.0 # Raw score of dirty data at reset (immutable)
# Cost-aware intervention budget
action_budget: float = 100.0 # Total budget for the episode
budget_spent: float = 0.0 # Cost spent so far
budget_remaining: float = 100.0 # Budget left
|