Spaces:
Runtime error
Runtime error
| # Copyright 2019 The dm_control Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================ | |
| """A task where the goal is to move the hand close to a target prop or site.""" | |
| import collections | |
| from dm_control import composer | |
| from dm_control.composer import initializers | |
| from dm_control.composer.observation import observable | |
| from dm_control.composer.variation import distributions | |
| from dm_control.entities import props | |
| from dm_control.manipulation.shared import arenas | |
| from dm_control.manipulation.shared import cameras | |
| from dm_control.manipulation.shared import constants | |
| from dm_control.manipulation.shared import observations | |
| from dm_control.manipulation.shared import registry | |
| from dm_control.manipulation.shared import robots | |
| from dm_control.manipulation.shared import tags | |
| from dm_control.manipulation.shared import workspaces | |
| from dm_control.utils import rewards | |
| import numpy as np | |
| _ReachWorkspace = collections.namedtuple( | |
| '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset']) | |
| # Ensures that the props are not touching the table before settling. | |
| _PROP_Z_OFFSET = 0.001 | |
| _DUPLO_WORKSPACE = _ReachWorkspace( | |
| target_bbox=workspaces.BoundingBox( | |
| lower=(-0.1, -0.1, _PROP_Z_OFFSET), | |
| upper=(0.1, 0.1, _PROP_Z_OFFSET)), | |
| tcp_bbox=workspaces.BoundingBox( | |
| lower=(-0.1, -0.1, 0.2), | |
| upper=(0.1, 0.1, 0.4)), | |
| arm_offset=robots.ARM_OFFSET) | |
| _SITE_WORKSPACE = _ReachWorkspace( | |
| target_bbox=workspaces.BoundingBox( | |
| lower=(-0.2, -0.2, 0.02), | |
| upper=(0.2, 0.2, 0.4)), | |
| tcp_bbox=workspaces.BoundingBox( | |
| lower=(-0.2, -0.2, 0.02), | |
| upper=(0.2, 0.2, 0.4)), | |
| arm_offset=robots.ARM_OFFSET) | |
| _TARGET_RADIUS = 0.05 | |
| _TIME_LIMIT = 10 | |
| TASKS = { | |
| 'reach_top_left': workspaces.BoundingBox( | |
| lower=(-0.09, 0.09, _PROP_Z_OFFSET), | |
| upper=(-0.09, 0.09, _PROP_Z_OFFSET)), | |
| 'reach_top_right': workspaces.BoundingBox( | |
| lower=(0.09, 0.09, _PROP_Z_OFFSET), | |
| upper=(0.09, 0.09, _PROP_Z_OFFSET)), | |
| 'reach_bottom_left': workspaces.BoundingBox( | |
| lower=(-0.09, -0.09, _PROP_Z_OFFSET), | |
| upper=(-0.09, -0.09, _PROP_Z_OFFSET)), | |
| 'reach_bottom_right': workspaces.BoundingBox( | |
| lower=(0.09, -0.09, _PROP_Z_OFFSET), | |
| upper=(0.09, -0.09, _PROP_Z_OFFSET)), | |
| } | |
| def make(task_id, obs_type, seed, img_size=64,): | |
| obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES | |
| obs_settings = obs_settings._replace(camera=obs_settings[-1]._replace(width=img_size)) | |
| obs_settings = obs_settings._replace(camera=obs_settings[-1]._replace(height=img_size)) | |
| if obs_type == 'states': | |
| global _TIME_LIMIT | |
| _TIME_LIMIT = 10.04 | |
| # Note: Adding this fixes the problem of having 249 steps with action repeat = 1 | |
| task = _reach(task_id, obs_settings=obs_settings, use_site=False) | |
| return composer.Environment(task, time_limit=_TIME_LIMIT, random_state=seed) | |
| class MTReach(composer.Task): | |
| """Bring the hand close to a target prop or site.""" | |
| def __init__( | |
| self, task_id, arena, arm, hand, prop, obs_settings, workspace, control_timestep): | |
| """Initializes a new `Reach` task. | |
| Args: | |
| arena: `composer.Entity` instance. | |
| arm: `robot_base.RobotArm` instance. | |
| hand: `robot_base.RobotHand` instance. | |
| prop: `composer.Entity` instance specifying the prop to reach to, or None | |
| in which case the target is a fixed site whose position is specified by | |
| the workspace. | |
| obs_settings: `observations.ObservationSettings` instance. | |
| workspace: `_ReachWorkspace` specifying the placement of the prop and TCP. | |
| control_timestep: Float specifying the control timestep in seconds. | |
| """ | |
| self._task_id = task_id | |
| self._arena = arena | |
| self._arm = arm | |
| self._hand = hand | |
| self._arm.attach(self._hand) | |
| self._arena.attach_offset(self._arm, offset=workspace.arm_offset) | |
| self.control_timestep = control_timestep | |
| self._tcp_initializer = initializers.ToolCenterPointInitializer( | |
| self._hand, self._arm, | |
| position=distributions.Uniform(*workspace.tcp_bbox), | |
| quaternion=workspaces.DOWN_QUATERNION) | |
| # Add custom camera observable. | |
| self._task_observables = cameras.add_camera_observables( | |
| arena, obs_settings, cameras.FRONT_CLOSE) | |
| target_pos_distribution = distributions.Uniform(*TASKS[task_id]) | |
| self._prop = prop | |
| if prop: | |
| # The prop itself is used to visualize the target location. | |
| self._make_target_site(parent_entity=prop, visible=False) | |
| self._target = self._arena.add_free_entity(prop) | |
| self._prop_placer = initializers.PropPlacer( | |
| props=[prop], | |
| position=target_pos_distribution, | |
| quaternion=workspaces.uniform_z_rotation, | |
| settle_physics=True) | |
| else: | |
| self._target = self._make_target_site(parent_entity=arena, visible=True) | |
| self._target_placer = target_pos_distribution | |
| # Commented to match EXORL | |
| # obs = observable.MJCFFeature('pos', self._target) | |
| # obs.configure(**obs_settings.prop_pose._asdict()) | |
| # self._task_observables['target_position'] = obs | |
| # Add sites for visualizing the prop and target bounding boxes. | |
| workspaces.add_bbox_site( | |
| body=self.root_entity.mjcf_model.worldbody, | |
| lower=workspace.tcp_bbox.lower, upper=workspace.tcp_bbox.upper, | |
| rgba=constants.GREEN, name='tcp_spawn_area') | |
| workspaces.add_bbox_site( | |
| body=self.root_entity.mjcf_model.worldbody, | |
| lower=workspace.target_bbox.lower, upper=workspace.target_bbox.upper, | |
| rgba=constants.BLUE, name='target_spawn_area') | |
| def _make_target_site(self, parent_entity, visible): | |
| return workspaces.add_target_site( | |
| body=parent_entity.mjcf_model.worldbody, | |
| radius=_TARGET_RADIUS, visible=visible, | |
| rgba=constants.RED, name='target_site') | |
| def root_entity(self): | |
| return self._arena | |
| def arm(self): | |
| return self._arm | |
| def hand(self): | |
| return self._hand | |
| def task_observables(self): | |
| return self._task_observables | |
| def get_reward(self, physics): | |
| hand_pos = physics.bind(self._hand.tool_center_point).xpos | |
| target_pos = physics.bind(self._target).xpos | |
| # This was used exceptionally for the PT reward predictor experiments | |
| # target_pos = distributions.Uniform(*TASKS[self._task_id])() | |
| distance = np.linalg.norm(hand_pos - target_pos) | |
| return rewards.tolerance( | |
| distance, bounds=(0, _TARGET_RADIUS), margin=_TARGET_RADIUS) | |
| def initialize_episode(self, physics, random_state): | |
| self._hand.set_grasp(physics, close_factors=random_state.uniform()) | |
| self._tcp_initializer(physics, random_state) | |
| if self._prop: | |
| self._prop_placer(physics, random_state) | |
| else: | |
| physics.bind(self._target).pos = ( | |
| self._target_placer(random_state=random_state)) | |
| def _reach(task_id, obs_settings, use_site): | |
| """Configure and instantiate a `Reach` task. | |
| Args: | |
| obs_settings: An `observations.ObservationSettings` instance. | |
| use_site: Boolean, if True then the target will be a fixed site, otherwise | |
| it will be a moveable Duplo brick. | |
| Returns: | |
| An instance of `reach.Reach`. | |
| """ | |
| arena = arenas.Standard() | |
| arm = robots.make_arm(obs_settings=obs_settings) | |
| hand = robots.make_hand(obs_settings=obs_settings) | |
| if use_site: | |
| workspace = _SITE_WORKSPACE | |
| prop = None | |
| else: | |
| workspace = _DUPLO_WORKSPACE | |
| prop = props.Duplo(observable_options=observations.make_options( | |
| obs_settings, observations.FREEPROP_OBSERVABLES)) | |
| task = MTReach(task_id, arena=arena, arm=arm, hand=hand, prop=prop, | |
| obs_settings=obs_settings, | |
| workspace=workspace, | |
| control_timestep=constants.CONTROL_TIMESTEP) | |
| return task |