# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Data models for the dm_control OpenEnv Environment. This environment wraps dm_control.suite, providing access to all MuJoCo-based continuous control tasks (cartpole, walker, humanoid, cheetah, etc.). """ from typing import Any, Dict, List, Optional from pydantic import Field try: from openenv.core.env_server.types import Action, Observation, State except ImportError: from openenv.core.env_server.types import Action, Observation, State class DMControlAction(Action): """ Action for dm_control environments. All dm_control.suite environments use continuous actions represented as a list of float values. The size and bounds depend on the specific domain/task combination. Example (cartpole - 1D action): >>> action = DMControlAction(values=[0.5]) # Push cart right Example (walker - 6D action): >>> action = DMControlAction(values=[0.1, -0.2, 0.3, 0.0, -0.1, 0.2]) Attributes: values: List of continuous action values. Shape and bounds depend on the loaded environment's action_spec. """ values: List[float] = Field( default_factory=list, description="Continuous action values matching the environment's action_spec", ) class DMControlObservation(Observation): """ Observation from dm_control environments. dm_control environments return observations as a dictionary of named arrays. Common observation keys include 'position', 'velocity', 'orientations', etc. The exact keys depend on the domain/task combination. Example observation keys by domain: - cartpole: 'position' (cos/sin of angle), 'velocity' - walker: 'orientations', 'height', 'velocity' - humanoid: 'joint_angles', 'head_height', 'extremities', 'torso_vertical', 'com_velocity' Attributes: observations: Dictionary mapping observation names to their values. Each value is a flattened list of floats. pixels: Optional base64-encoded PNG image of the rendered scene. Only included when render=True is passed to reset/step. """ observations: Dict[str, List[float]] = Field( default_factory=dict, description="Named observation arrays from the environment", ) pixels: Optional[str] = Field( default=None, description="Base64-encoded PNG image (when render=True)", ) class DMControlState(State): """ Extended state for dm_control environments. Provides metadata about the currently loaded environment including the domain/task names and action/observation specifications. Attributes: episode_id: Unique identifier for the current episode. step_count: Number of steps taken in the current episode. domain_name: The dm_control domain (e.g., 'cartpole', 'walker'). task_name: The specific task (e.g., 'balance', 'walk'). action_spec: Specification of the action space including shape and bounds. observation_spec: Specification of the observation space. physics_timestep: The physics simulation timestep in seconds. control_timestep: The control timestep (time between actions) in seconds. """ domain_name: str = Field( default="cartpole", description="The dm_control domain name", ) task_name: str = Field( default="balance", description="The task name within the domain", ) action_spec: Dict[str, Any] = Field( default_factory=dict, description="Specification of the action space (shape, dtype, bounds)", ) observation_spec: Dict[str, Any] = Field( default_factory=dict, description="Specification of the observation space", ) physics_timestep: float = Field( default=0.002, description="Physics simulation timestep in seconds", ) control_timestep: float = Field( default=0.02, description="Control timestep (time between actions) in seconds", ) # Available dm_control.suite environments # Format: (domain_name, task_name) AVAILABLE_ENVIRONMENTS = [ # Cartpole ("cartpole", "balance"), ("cartpole", "balance_sparse"), ("cartpole", "swingup"), ("cartpole", "swingup_sparse"), # Pendulum ("pendulum", "swingup"), # Point mass ("point_mass", "easy"), ("point_mass", "hard"), # Reacher ("reacher", "easy"), ("reacher", "hard"), # Ball in cup ("ball_in_cup", "catch"), # Finger ("finger", "spin"), ("finger", "turn_easy"), ("finger", "turn_hard"), # Fish ("fish", "upright"), ("fish", "swim"), # Cheetah ("cheetah", "run"), # Walker ("walker", "stand"), ("walker", "walk"), ("walker", "run"), # Hopper ("hopper", "stand"), ("hopper", "hop"), # Swimmer ("swimmer", "swimmer6"), ("swimmer", "swimmer15"), # Humanoid ("humanoid", "stand"), ("humanoid", "walk"), ("humanoid", "run"), # Manipulator ("manipulator", "bring_ball"), ("manipulator", "bring_peg"), ("manipulator", "insert_ball"), ("manipulator", "insert_peg"), # Acrobot ("acrobot", "swingup"), ("acrobot", "swingup_sparse"), # Stacker ("stacker", "stack_2"), ("stacker", "stack_4"), # Dog ("dog", "stand"), ("dog", "walk"), ("dog", "trot"), ("dog", "run"), ("dog", "fetch"), # Quadruped ("quadruped", "walk"), ("quadruped", "run"), ("quadruped", "escape"), ("quadruped", "fetch"), ]