burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
6dd47af verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Data models for the dm_control OpenEnv Environment.
This environment wraps dm_control.suite, providing access to all MuJoCo-based
continuous control tasks (cartpole, walker, humanoid, cheetah, etc.).
"""
from typing import Any, Dict, List, Optional
from pydantic import Field
try:
from openenv.core.env_server.types import Action, Observation, State
except ImportError:
from openenv.core.env_server.types import Action, Observation, State
class DMControlAction(Action):
"""
Action for dm_control environments.
All dm_control.suite environments use continuous actions represented as
a list of float values. The size and bounds depend on the specific
domain/task combination.
Example (cartpole - 1D action):
>>> action = DMControlAction(values=[0.5]) # Push cart right
Example (walker - 6D action):
>>> action = DMControlAction(values=[0.1, -0.2, 0.3, 0.0, -0.1, 0.2])
Attributes:
values: List of continuous action values. Shape and bounds depend on
the loaded environment's action_spec.
"""
values: List[float] = Field(
default_factory=list,
description="Continuous action values matching the environment's action_spec",
)
class DMControlObservation(Observation):
"""
Observation from dm_control environments.
dm_control environments return observations as a dictionary of named arrays.
Common observation keys include 'position', 'velocity', 'orientations', etc.
The exact keys depend on the domain/task combination.
Example observation keys by domain:
- cartpole: 'position' (cos/sin of angle), 'velocity'
- walker: 'orientations', 'height', 'velocity'
- humanoid: 'joint_angles', 'head_height', 'extremities', 'torso_vertical', 'com_velocity'
Attributes:
observations: Dictionary mapping observation names to their values.
Each value is a flattened list of floats.
pixels: Optional base64-encoded PNG image of the rendered scene.
Only included when render=True is passed to reset/step.
"""
observations: Dict[str, List[float]] = Field(
default_factory=dict,
description="Named observation arrays from the environment",
)
pixels: Optional[str] = Field(
default=None,
description="Base64-encoded PNG image (when render=True)",
)
class DMControlState(State):
"""
Extended state for dm_control environments.
Provides metadata about the currently loaded environment including
the domain/task names and action/observation specifications.
Attributes:
episode_id: Unique identifier for the current episode.
step_count: Number of steps taken in the current episode.
domain_name: The dm_control domain (e.g., 'cartpole', 'walker').
task_name: The specific task (e.g., 'balance', 'walk').
action_spec: Specification of the action space including shape and bounds.
observation_spec: Specification of the observation space.
physics_timestep: The physics simulation timestep in seconds.
control_timestep: The control timestep (time between actions) in seconds.
"""
domain_name: str = Field(
default="cartpole",
description="The dm_control domain name",
)
task_name: str = Field(
default="balance",
description="The task name within the domain",
)
action_spec: Dict[str, Any] = Field(
default_factory=dict,
description="Specification of the action space (shape, dtype, bounds)",
)
observation_spec: Dict[str, Any] = Field(
default_factory=dict,
description="Specification of the observation space",
)
physics_timestep: float = Field(
default=0.002,
description="Physics simulation timestep in seconds",
)
control_timestep: float = Field(
default=0.02,
description="Control timestep (time between actions) in seconds",
)
# Available dm_control.suite environments
# Format: (domain_name, task_name)
AVAILABLE_ENVIRONMENTS = [
# Cartpole
("cartpole", "balance"),
("cartpole", "balance_sparse"),
("cartpole", "swingup"),
("cartpole", "swingup_sparse"),
# Pendulum
("pendulum", "swingup"),
# Point mass
("point_mass", "easy"),
("point_mass", "hard"),
# Reacher
("reacher", "easy"),
("reacher", "hard"),
# Ball in cup
("ball_in_cup", "catch"),
# Finger
("finger", "spin"),
("finger", "turn_easy"),
("finger", "turn_hard"),
# Fish
("fish", "upright"),
("fish", "swim"),
# Cheetah
("cheetah", "run"),
# Walker
("walker", "stand"),
("walker", "walk"),
("walker", "run"),
# Hopper
("hopper", "stand"),
("hopper", "hop"),
# Swimmer
("swimmer", "swimmer6"),
("swimmer", "swimmer15"),
# Humanoid
("humanoid", "stand"),
("humanoid", "walk"),
("humanoid", "run"),
# Manipulator
("manipulator", "bring_ball"),
("manipulator", "bring_peg"),
("manipulator", "insert_ball"),
("manipulator", "insert_peg"),
# Acrobot
("acrobot", "swingup"),
("acrobot", "swingup_sparse"),
# Stacker
("stacker", "stack_2"),
("stacker", "stack_4"),
# Dog
("dog", "stand"),
("dog", "walk"),
("dog", "trot"),
("dog", "run"),
("dog", "fetch"),
# Quadruped
("quadruped", "walk"),
("quadruped", "run"),
("quadruped", "escape"),
("quadruped", "fetch"),
]