Spaces:
Sleeping
Sleeping
File size: 3,473 Bytes
71dc210 36f4bdf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Data models for the Data-Centric RL Environment.
Action → plain text command string (like DataWranglerEnv)
Observation → rich structured observation with accuracy, quality, budget info
State → episode metadata
"""
from openenv.core.env_server.types import Action, Observation
from pydantic import Field
class DataCentricAction(Action):
"""Action for the Data-Centric environment — a text command string.
The agent sends natural-language-style commands to inspect the dataset,
query specialist sub-agents, apply their recommendations, and ultimately
submit the cleaned dataset for scoring.
Examples:
- "inspect_dataset"
- "inspect_model"
- "query_cleaner"
- "query_augmenter class_1"
- "query_balancer"
- "query_validator"
- "apply 1"
- "reject 2"
- "validate"
- "submit"
"""
message: str = Field(..., description="Text command to execute in the environment")
class DataCentricObservation(Observation):
"""Observation returned after each action in the Data-Centric environment.
Provides the agent with rich feedback about the current episode state,
including dataset health, model accuracy, budget, and specialist session info.
"""
response: str = Field(
default="",
description="Text result of the executed command",
)
current_accuracy: float = Field(
default=0.0,
description="Last validated model accuracy (or baseline if not yet validated)",
)
baseline_accuracy: float = Field(
default=0.0,
description="Accuracy at episode start — never changes",
)
target_accuracy: float = Field(
default=0.0,
description="Accuracy threshold the agent must exceed to hit target",
)
estimated_quality: float = Field(
default=0.0,
description="Lightweight quality score without sklearn retraining (0.0-1.0)",
)
dataset_shape: str = Field(
default="",
description="Current dataset dimensions, e.g. '200 rows × 5 columns'",
)
rows_preserved_pct: float = Field(
default=1.0,
description="Fraction of original rows still present (1.0 = no data loss)",
)
budget_remaining: int = Field(
default=0,
description="Steps remaining before forced submit",
)
step_number: int = Field(
default=0,
description="Current step number in the episode",
)
max_steps: int = Field(
default=30,
description="Maximum steps allowed for this task",
)
active_session: str = Field(
default="none",
description="Which specialist agent was queried last (cleaner/augmenter/balancer/none)",
)
validate_calls_remaining: int = Field(
default=3,
description="How many more free validates remain before reward turns negative",
)
done: bool = Field(
default=False,
description="Whether the episode has ended",
)
reward: float = Field(
default=0.0,
description="Reward for this step",
)
metadata: dict = Field(
default_factory=dict,
description="Optional extra metadata (e.g. debug info from server)",
)
|