data-centric-env / models.py
Aswini-Kumar's picture
Audit fixes: remove duplicate torch import, add metadata field, fix stale strings, fix test assertions, update reward docs
36f4bdf
Raw
History Blame Contribute Delete
3.47 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Data models for the Data-Centric RL Environment.
Action β†’ plain text command string (like DataWranglerEnv)
Observation β†’ rich structured observation with accuracy, quality, budget info
State β†’ episode metadata
"""
from openenv.core.env_server.types import Action, Observation
from pydantic import Field
class DataCentricAction(Action):
"""Action for the Data-Centric environment β€” a text command string.
The agent sends natural-language-style commands to inspect the dataset,
query specialist sub-agents, apply their recommendations, and ultimately
submit the cleaned dataset for scoring.
Examples:
- "inspect_dataset"
- "inspect_model"
- "query_cleaner"
- "query_augmenter class_1"
- "query_balancer"
- "query_validator"
- "apply 1"
- "reject 2"
- "validate"
- "submit"
"""
message: str = Field(..., description="Text command to execute in the environment")
class DataCentricObservation(Observation):
"""Observation returned after each action in the Data-Centric environment.
Provides the agent with rich feedback about the current episode state,
including dataset health, model accuracy, budget, and specialist session info.
"""
response: str = Field(
default="",
description="Text result of the executed command",
)
current_accuracy: float = Field(
default=0.0,
description="Last validated model accuracy (or baseline if not yet validated)",
)
baseline_accuracy: float = Field(
default=0.0,
description="Accuracy at episode start β€” never changes",
)
target_accuracy: float = Field(
default=0.0,
description="Accuracy threshold the agent must exceed to hit target",
)
estimated_quality: float = Field(
default=0.0,
description="Lightweight quality score without sklearn retraining (0.0-1.0)",
)
dataset_shape: str = Field(
default="",
description="Current dataset dimensions, e.g. '200 rows Γ— 5 columns'",
)
rows_preserved_pct: float = Field(
default=1.0,
description="Fraction of original rows still present (1.0 = no data loss)",
)
budget_remaining: int = Field(
default=0,
description="Steps remaining before forced submit",
)
step_number: int = Field(
default=0,
description="Current step number in the episode",
)
max_steps: int = Field(
default=30,
description="Maximum steps allowed for this task",
)
active_session: str = Field(
default="none",
description="Which specialist agent was queried last (cleaner/augmenter/balancer/none)",
)
validate_calls_remaining: int = Field(
default=3,
description="How many more free validates remain before reward turns negative",
)
done: bool = Field(
default=False,
description="Whether the episode has ended",
)
reward: float = Field(
default=0.0,
description="Reward for this step",
)
metadata: dict = Field(
default_factory=dict,
description="Optional extra metadata (e.g. debug info from server)",
)