e-zorzi's picture
download
raw
27.7 kB
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
from legged_gym import LEGGED_GYM_ROOT_DIR
import os
import math
import isaacgym
from isaacgym import gymapi
from isaacgym import gymutil
from legged_gym.envs import *
from legged_gym.utils import get_args, export_policy_as_jit, task_registry, Logger, set_seed
from legged_gym.envs.base.legged_robot_config import LeggedRobotCfg, LeggedRobotCfgPPO
import numpy as np
import torch
import cv2
from matplotlib import pyplot as plt
from tqdm import tqdm
H, W = 480, 640
ROBOT_LIST = ['h1', 'g1', 'h1_2']
WALKING = [f"{robot}_walking" for robot in ROBOT_LIST]
REACHING = [f"{robot}_reaching" for robot in ROBOT_LIST]
STEPPING = [f"{robot}_stepping" for robot in ROBOT_LIST]
SQUATTING = [f"{robot}_squatting" for robot in ROBOT_LIST]
TASK_REACH = [f"{robot}_task_reach" for robot in ROBOT_LIST]
TASK_BOX = [f"{robot}_task_box" for robot in ROBOT_LIST]
TASK_BUTTON = [f"{robot}_task_button" for robot in ROBOT_LIST]
TASK_LIFT = [f"{robot}_task_lift" for robot in ROBOT_LIST]
TASK_BALL = [f"{robot}_task_ball" for robot in ROBOT_LIST]
TASK_CARRY = [f"{robot}_task_carry" for robot in ROBOT_LIST]
TASK_TRANSFER = [f"{robot}_task_transfer" for robot in ROBOT_LIST]
TASK_CABINET = [f"{robot}_task_cabinet" for robot in ROBOT_LIST]
def visualize_task(task, env):
"""only be used when with display"""
### Low-level skills
if task in WALKING:
env.gym.clear_lines(env.viewer)
commands = env.commands.clone()
commands[:, 2] = 0
root_states = env.ori_root_states[:, :3]
root_end = root_states + 1e3 * commands[:, :3]
for i in range(env.num_envs):
gymutil.draw_line(gymapi.Vec3(root_states[i, 0], root_states[i, 1], root_states[i, 2]), gymapi.Vec3(root_end[i, 0], root_end[i, 1], root_end[i, 2]), gymapi.Vec3(1, 0, 0), env.gym, env.viewer, env.envs[i])
elif task in REACHING+TASK_REACH:
env.gym.clear_lines(env.viewer)
# Create helper geometry used for visualization
# Create an wireframe axis
axes_geom = gymutil.AxesGeometry(0.15)
# Create a wireframe sphere
sphere_rot = gymapi.Quat.from_euler_zyx(0.5 * math.pi, 0, 0)
sphere_pose = gymapi.Transform(r=sphere_rot)
yellow_geom = gymutil.WireframeSphereGeometry(0.05, 12, 12, sphere_pose, color=(1, 1, 0))
purple_geom = gymutil.WireframeSphereGeometry(0.02, 12, 12, sphere_pose, color=(1, 0, 1))
wrist_pos = env.rigid_state[:, env.wrist_indices, :7] # [num_envs, 2, 7]
ref_wrist_pos = env.ref_wrist_pos # [num_envs, 2, 7]
for i in range(env.num_envs):
wrist_pos_i = wrist_pos[i] # [2, 7]
ref_wrist_pos_i = ref_wrist_pos[i] # [2, 7]
for j in range(2):
wrist_pos_ij = wrist_pos_i[j] # [7]
ref_wrist_pos_ij = ref_wrist_pos_i[j] # [7]
ori_wrist_pos_ij = env.ori_wrist_pos[i, j] # [7]
# convert to gymapi.Transform
wrist_pos_ij = gymapi.Transform(gymapi.Vec3(wrist_pos_ij[0], wrist_pos_ij[1], wrist_pos_ij[2]), gymapi.Quat())
ref_wrist_pos_ij = gymapi.Transform(gymapi.Vec3(ref_wrist_pos_ij[0], ref_wrist_pos_ij[1], ref_wrist_pos_ij[2]), gymapi.Quat())
ori_wrist_pos_ij = gymapi.Transform(gymapi.Vec3(ori_wrist_pos_ij[0], ori_wrist_pos_ij[1], ori_wrist_pos_ij[2]), gymapi.Quat())
# current wrist pos and ref wrist pos
gymutil.draw_lines(axes_geom, env.gym, env.viewer, env.envs[i], wrist_pos_ij)
# gymutil.draw_lines(sphere_geom, env.gym, env.viewer, env.envs[i], wrist_pos_ij)
# gymutil.draw_lines(axes_geom, env.gym, env.viewer, env.envs[i], ref_wrist_pos_ij)
gymutil.draw_lines(yellow_geom, env.gym, env.viewer, env.envs[i], ref_wrist_pos_ij)
# original wrist pos
# gymutil.draw_lines(axes_geom, env.gym, env.viewer, env.envs[i], ori_wrist_pos_ij)
gymutil.draw_lines(purple_geom, env.gym, env.viewer, env.envs[i], ori_wrist_pos_ij)
elif task in STEPPING:
env.gym.clear_lines(env.viewer)
# Create helper geometry used for visualization
# Create an wireframe axis
axes_geom = gymutil.AxesGeometry(0.15)
# Create a wireframe sphere
sphere_rot = gymapi.Quat.from_euler_zyx(0.5 * math.pi, 0, 0)
sphere_pose = gymapi.Transform(r=sphere_rot)
yellow_geom = gymutil.WireframeSphereGeometry(0.05, 12, 12, sphere_pose, color=(1, 1, 0))
purple_geom = gymutil.WireframeSphereGeometry(0.02, 12, 12, sphere_pose, color=(1, 0, 1))
feet_pos = env.rigid_state[:, env.feet_indices, :2] # [num_envs, 2, 2]
ref_feet_pos = env.ref_feet_pos # [num_envs, 2, 2]
for i in range(env.num_envs):
feet_pos_i = feet_pos[i] # [2, 2]
ref_feet_pos_i = ref_feet_pos[i] # [2, 2]
for j in range(2):
feet_pos_ij = feet_pos_i[j] # [2]
ref_feet_pos_ij = ref_feet_pos_i[j] # [2]
ori_feet_pos_ij = env.ori_feet_pos[i, j] # [2]
# convert to gymapi.Transform
feet_pos_ij = gymapi.Transform(gymapi.Vec3(feet_pos_ij[0], feet_pos_ij[1], 0), gymapi.Quat())
ref_feet_pos_ij = gymapi.Transform(gymapi.Vec3(ref_feet_pos_ij[0], ref_feet_pos_ij[1], 0), gymapi.Quat())
ori_feet_pos_ij = gymapi.Transform(gymapi.Vec3(ori_feet_pos_ij[0], ori_feet_pos_ij[1], 0), gymapi.Quat())
# current feet pos and ref feet pos
gymutil.draw_lines(axes_geom, env.gym, env.viewer, env.envs[i], feet_pos_ij)
gymutil.draw_lines(yellow_geom, env.gym, env.viewer, env.envs[i], ref_feet_pos_ij)
# original feet pos
gymutil.draw_lines(purple_geom, env.gym, env.viewer, env.envs[i], ori_feet_pos_ij)
elif task in SQUATTING:
env.gym.clear_lines(env.viewer)
# Create helper geometry used for visualization
# Create an wireframe axis
axes_geom_cur = gymutil.AxesGeometry(0.25)
axes_geom_tgt = gymutil.AxesGeometry(0.5)
root_pos = env.root_states[:, :3].clone() # [num_envs, 3]
ref_root_pos = env.root_states[:, :3].clone() # [num_envs, 3]
ref_root_pos[:, 2] = env.ref_root_height
for i in range(env.num_envs):
root_pos_i = root_pos[i] # [3]
ref_root_pos_i = ref_root_pos[i] # [3]
# convert to gymapi.Transform
root_pos_i = gymapi.Transform(gymapi.Vec3(root_pos_i[0], root_pos_i[1], root_pos_i[2]), gymapi.Quat())
ref_root_pos_i = gymapi.Transform(gymapi.Vec3(ref_root_pos_i[0], ref_root_pos_i[1], ref_root_pos_i[2]), gymapi.Quat())
# current root pos and ref root pos
gymutil.draw_lines(axes_geom_cur, env.gym, env.viewer, env.envs[i], root_pos_i)
gymutil.draw_lines(axes_geom_tgt, env.gym, env.viewer, env.envs[i], ref_root_pos_i)
### High-level tasks
elif task in TASK_BOX+TASK_TRANSFER:
env.gym.clear_lines(env.viewer)
# Create helper geometry used for visualization
# Create a wireframe sphere
sphere_rot = gymapi.Quat.from_euler_zyx(0.5 * math.pi, 0, 0)
sphere_pose = gymapi.Transform(r=sphere_rot)
red_geom = gymutil.WireframeSphereGeometry(0.05, 12, 12, sphere_pose, color=(1, 0, 0))
for i in range(env.num_envs):
box_goal_pos = env.box_goal_pos[i, :3]
box_goal_i = gymapi.Transform(gymapi.Vec3(box_goal_pos[0], box_goal_pos[1], box_goal_pos[2]), gymapi.Quat())
gymutil.draw_lines(red_geom, env.gym, env.viewer, env.envs[i], box_goal_i)
elif task in TASK_BUTTON:
env.gym.clear_lines(env.viewer)
# Create helper geometry used for visualization
# Create a wireframe sphere
sphere_rot = gymapi.Quat.from_euler_zyx(0.5 * math.pi, 0, 0)
sphere_pose = gymapi.Transform(r=sphere_rot)
red_geom = gymutil.WireframeSphereGeometry(0.05, 12, 12, sphere_pose, color=(1, 0, 0))
for i in range(env.num_envs):
button_goal_pos = env.button_goal_pos[i, :3]
button_goal_i = gymapi.Transform(gymapi.Vec3(button_goal_pos[0], button_goal_pos[1], button_goal_pos[2]), gymapi.Quat())
gymutil.draw_lines(red_geom, env.gym, env.viewer, env.envs[i], button_goal_i)
elif task in TASK_LIFT+TASK_CARRY:
env.gym.clear_lines(env.viewer)
# Create helper geometry used for visualization
# Create a wireframe sphere
sphere_rot = gymapi.Quat.from_euler_zyx(0.5 * math.pi, 0, 0)
sphere_pose = gymapi.Transform(r=sphere_rot)
yellow_geom = gymutil.WireframeSphereGeometry(0.1, 12, 12, sphere_pose, color=(1, 1, 0))
red_geom = gymutil.WireframeSphereGeometry(0.1, 12, 12, sphere_pose, color=(1, 0, 0))
for i in range(env.num_envs):
box_pos = env.box_root_states[i, :3]
box_goal_pos = env.box_goal_pos[i]
box_handle_left = box_pos.clone()
box_handle_right = box_pos.clone()
gymutil.draw_lines(yellow_geom, env.gym, env.viewer, env.envs[i], gymapi.Transform(gymapi.Vec3(box_handle_left[0], box_handle_left[1], box_handle_left[2]), gymapi.Quat()))
gymutil.draw_lines(yellow_geom, env.gym, env.viewer, env.envs[i], gymapi.Transform(gymapi.Vec3(box_handle_right[0], box_handle_right[1], box_handle_right[2]), gymapi.Quat()))
gymutil.draw_lines(red_geom, env.gym, env.viewer, env.envs[i], gymapi.Transform(gymapi.Vec3(box_goal_pos[0], box_goal_pos[1], box_goal_pos[2]), gymapi.Quat()))
elif task in TASK_BALL:
env.gym.clear_lines(env.viewer)
# Create helper geometry used for visualization
# Create a wireframe sphere
sphere_rot = gymapi.Quat.from_euler_zyx(0.5 * math.pi, 0, 0)
sphere_pose = gymapi.Transform(r=sphere_rot)
red_geom = gymutil.WireframeSphereGeometry(0.2, 12, 12, sphere_pose, color=(1, 0, 0))
for i in range(env.num_envs):
ball_pos = env.ball_root_states[i, :3]
goal_pos = env.goal_pos[i, :3]
gymutil.draw_lines(red_geom, env.gym, env.viewer, env.envs[i], gymapi.Transform(gymapi.Vec3(goal_pos[0], goal_pos[1], goal_pos[2]), gymapi.Quat()))
elif task in TASK_CABINET:
# no need to visualize
pass
elif task in ['h1_imitation', 'h1_exbody']:
pass
else:
raise NotImplementedError(f"Task {task} is not supported for visualization")
def override_env_cfg(env_cfg: LeggedRobotCfg, args):
print('====> URDF file:', env_cfg.asset.file)
# override some parameters for testing
default_num_envs = 50
if args.task in WALKING:
default_num_envs = 1
env_cfg.env.episode_length_s = 24
elif args.task in TASK_REACH:
env_cfg.env.episode_length_s = 20
env_cfg.human.freq = 2
elif args.task in TASK_BOX+TASK_TRANSFER+TASK_BUTTON:
env_cfg.env.episode_length_s = 2.5
elif args.task in TASK_LIFT+TASK_BALL+TASK_CARRY+TASK_CABINET:
env_cfg.env.episode_length_s = 2.5
else:
env_cfg.env.episode_length_s = 8
if args.visualize:
default_num_envs = 1
env_cfg.env.num_envs = min(env_cfg.env.num_envs, default_num_envs)
# env_cfg.env.mesh_type = "plane"
env_cfg.terrain.num_rows = 5
env_cfg.terrain.num_cols = 5
env_cfg.terrain.curriculum = False
# env_cfg.terrain.max_init_terrain_level = 5
env_cfg.noise.add_noise = False
# env_cfg.noise.noise_level = 0.5
env_cfg.domain_rand.randomize_friction = False
env_cfg.domain_rand.push_robots = False
# env_cfg.commands.ranges.l_wrist_pos_x = [-0.5, 0.5]
# env_cfg.commands.ranges.l_wrist_pos_y = [-0.5, 0.5]
# env_cfg.commands.ranges.l_wrist_pos_z = [-0.5, 0.5]
# env_cfg.commands.ranges.r_wrist_pos_x = [-0.5, 0.5]
# env_cfg.commands.ranges.r_wrist_pos_y = [-0.5, 0.5]
# env_cfg.commands.ranges.r_wrist_pos_z = [-0.5, 0.5]
# env_cfg.commands.ranges.wrist_max_radius = 0.5
return env_cfg
def get_camera_pose(task):
if not EGO_CENTRIC:
if task in TASK_BUTTON+TASK_BALL+TASK_CABINET:
camera_offset = gymapi.Vec3(-1, -2, 1)
camera_rotation = gymapi.Quat.from_axis_angle(gymapi.Vec3(-0.3, 0.2, 1),
np.deg2rad(45))
else:
camera_offset = gymapi.Vec3(1, -1, 1)
camera_rotation = gymapi.Quat.from_axis_angle(gymapi.Vec3(-0.3, 0.2, 1),
np.deg2rad(135))
else:
camera_offset = gymapi.Vec3(0.1, 0, 0.9)
camera_rotation = gymapi.Quat.from_axis_angle(gymapi.Vec3(0, 1, 0),
np.deg2rad(45))
return gymapi.Transform(camera_offset, camera_rotation)
def play(args):
env_cfg, train_cfg = task_registry.get_cfgs(name=args.task, load_run=args.load_run, experiment_name=args.experiment_name)
HRL = "_task_" in args.task and "Hierarchical" in train_cfg.runner.policy_class_name
env_cfg = override_env_cfg(env_cfg, args)
# prepare environment
env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg)
obs = env.get_observations()
# load policy
train_cfg.runner.resume = True
train_cfg.runner.run_name = 'play'
ppo_runner, train_cfg = task_registry.make_alg_runner(env=env, name=args.task, args=args, train_cfg=train_cfg)
policy = ppo_runner.get_inference_policy(device=env.device, hrl=HRL)
model_name = f'{args.load_run}_{train_cfg.runner.resume_path.split("_")[-1].split(".")[0]}'
# export policy as a jit module (used to run it from C++)
if EXPORT_POLICY:
path = os.path.join(LEGGED_GYM_ROOT_DIR, 'logs', train_cfg.runner.experiment_name, 'exported', 'policies')
filename_pt = f'{model_name}.pt'
export_policy_as_jit(ppo_runner.alg.actor_critic, path, filename_pt)
print('Exported policy as jit script to: ', path)
robot_index = 0 # which robot (env) is used for logging
joint_index = 1 # which joint pos is used for logging
state_log_interval = 1000 # number of steps before plotting states
rew_log_interval = env.max_episode_length - 1 # number of steps before print average episode rewards
N_rollouts = 10
if RECORD_FRAMES:
frame_path = os.path.join(LEGGED_GYM_ROOT_DIR, 'logs', train_cfg.runner.experiment_name, 'exported', 'frames')
os.makedirs(frame_path, exist_ok=True)
camera_properties = gymapi.CameraProperties()
camera_properties.width = W
camera_properties.height = H
cam = env.gym.create_camera_sensor(env.envs[robot_index], camera_properties)
camera_pose = get_camera_pose(args.task)
actor_handle = env.gym.get_actor_handle(env.envs[robot_index], 0)
body_handle = env.gym.get_actor_rigid_body_handle(env.envs[robot_index], actor_handle, 0)
env.gym.attach_camera_to_body(
cam,
env.envs[robot_index],
body_handle,
camera_pose,
gymapi.FOLLOW_POSITION if not EGO_CENTRIC else gymapi.FOLLOW_TRANSFORM
)
max_steps = int(env.max_episode_length)
if args.task in REACHING+TASK_REACH:
last_ref_wrist_pos = env.ref_wrist_pos[robot_index][:,:3].cpu().numpy()
if args.task in STEPPING:
last_ref_feet_pos = env.ref_feet_pos[robot_index][:,:2].cpu().numpy()
if args.task in SQUATTING:
last_ref_root_height = env.ref_root_height[robot_index].cpu().numpy()
if args.task in TASK_BOX+TASK_BUTTON+TASK_LIFT+TASK_BALL+TASK_CARRY+TASK_CABINET+TASK_TRANSFER:
max_steps = int(env.max_episode_length) * 10
for i_rollout in range(N_rollouts):
print(f"====> Rollout {i_rollout+1}/{N_rollouts}")
robot_index = i_rollout % env_cfg.env.num_envs
logger = Logger(env.dt)
env.ori_root_states = env.root_states.clone()
if RECORD_FRAMES:
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
filename_mp4 = f'{args.task}_{model_name}_{i_rollout}.mp4'
video = cv2.VideoWriter(os.path.join(frame_path, filename_mp4), fourcc, 25.0, (W, H))
for i in tqdm(range(max_steps)):
# import pdb; pdb.set_trace()
# visualize task-related information
visualize_task(args.task, env)
actions = policy(obs.detach())
actions = actions["actions_mean"] if HRL else actions
# import pdb; pdb.set_trace()
if FIX_COMMAND:
env.commands[:, 0] = 1.0
env.commands[:, 1] = 0.0
env.commands[:, 2] = 0.0
env.commands[:, 3] = 0.0
# import pdb;pdb.set_trace()
# actions[:] = 0.0
obs, _, rews, dones, infos = env.step(actions.detach())
if (args.task == 'g1_task_cabinet' or args.task == 'h1_2_task_cabinet') and i <= 1000:
continue # still need to fix accuracy, but all methods are 0.000 so it's okay
if RECORD_FRAMES:
if i % 4 == 0:
env.gym.fetch_results(env.sim, True)
env.gym.step_graphics(env.sim)
env.gym.render_all_camera_sensors(env.sim)
img = env.gym.get_camera_image(env.sim, env.envs[robot_index], cam, gymapi.IMAGE_COLOR)
img = np.reshape(img, (H, W, 4))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
video.write(img[..., :3])
logger.log_states(
{
'dof_pos_target': actions[robot_index, :].detach().cpu().numpy() * env.cfg.control.action_scale,
'dof_pos': env.dof_pos[robot_index, :].detach().cpu().numpy(),
'dof_vel': env.dof_vel[robot_index, :].detach().cpu().numpy(),
'dof_torque': env.torques[robot_index, :].detach().cpu().numpy(),
'command_x': env.commands[robot_index, 0].detach().cpu().numpy(),
'command_y': env.commands[robot_index, 1].detach().cpu().numpy(),
'command_yaw': env.commands[robot_index, 2].detach().cpu().numpy(),
'base_vel_x': env.base_lin_vel[robot_index, 0].detach().cpu().numpy(),
'base_vel_y': env.base_lin_vel[robot_index, 1].detach().cpu().numpy(),
'base_vel_z': env.base_lin_vel[robot_index, 2].detach().cpu().numpy(),
'base_vel_yaw': env.base_ang_vel[robot_index, 2].detach().cpu().numpy(),
'contact_forces_z': env.contact_forces[robot_index, env.feet_indices, 2].detach().cpu().numpy(),
'base_roll': env.base_euler_xyz[robot_index, 0].detach().cpu().numpy(),
'base_pitch': env.base_euler_xyz[robot_index, 1].detach().cpu().numpy(),
'base_height': env.root_states[robot_index, 2].detach().cpu().numpy(),
}
)
if args.task in REACHING+TASK_REACH:
ref_wrist_pos = env.ref_wrist_pos[robot_index][:,:3].cpu().numpy()
wrist_pos = env.rigid_state[robot_index, env.wrist_indices, :3].cpu().numpy()
# if wrist_pos updates, which means that ref_wrist_pos is not last_wrist_pos
if not np.allclose(ref_wrist_pos, last_ref_wrist_pos):
print("=> update reach goal")
wrist_pos_diff = np.mean(np.abs(last_ref_wrist_pos - wrist_pos))
logger.log_metrics(
{
'wrist_pos_diff': wrist_pos_diff
}
)
last_ref_wrist_pos = ref_wrist_pos
print('wrist_pos_diff:', wrist_pos_diff)
if args.task in STEPPING:
ref_feet_pos = env.ref_feet_pos[robot_index][:,:2].cpu().numpy()
feet_pos = env.rigid_state[robot_index, env.feet_indices, :2].cpu().numpy()
# if feet_pos updates, which means that ref_feet_pos is not last_feet_pos
if not np.allclose(ref_feet_pos, last_ref_feet_pos):
print("=> update kicking goal")
feet_pos_diff = np.mean(np.abs(last_ref_feet_pos - feet_pos))
logger.log_metrics(
{
'feet_pos_diff': feet_pos_diff
}
)
last_ref_feet_pos = ref_feet_pos
print('feet_pos_diff:', feet_pos_diff)
if args.task in SQUATTING:
ref_root_height = env.ref_root_height[robot_index].cpu().numpy()
root_height = env.root_states[robot_index, 2].cpu().numpy()
# if root_height updates, which means that ref_root_height is not last_root_height
if not np.allclose(ref_root_height, last_ref_root_height):
print("=> update squating goal")
root_height_diff = np.mean(np.abs(last_ref_root_height - root_height))
logger.log_metrics(
{
'root_height_diff': root_height_diff
}
)
last_ref_root_height = ref_root_height
print('root_height_diff:', root_height_diff)
if args.task in TASK_BOX+TASK_TRANSFER:
if i > 0 and i % (rew_log_interval-1) == 0:
box_pos = env.box_root_states[robot_index, :3].cpu().numpy()
box_goal_pos = env.box_goal_pos[robot_index, :3].cpu().numpy()
box_pos_diff = np.mean(np.abs(box_pos - box_goal_pos))
logger.log_metrics(
{
'box_pos_diff': box_pos_diff
}
)
print('box_pos_diff:', box_pos_diff)
if args.task in TASK_BUTTON:
if i > 0 and i % (rew_log_interval-1) == 0:
button_goal_pos = env.button_goal_pos[robot_index, 1:3].cpu().numpy()
wrist_pos = env.rigid_state[robot_index, env.wrist_indices, 1:3].cpu().numpy()
wrist_pos = wrist_pos[0] # left hand
wrist_pos_diff = np.mean(np.abs(button_goal_pos - wrist_pos))
logger.log_metrics(
{
'button_pos_diff': wrist_pos_diff
}
)
print('button_pos_diff:', wrist_pos_diff)
if args.task in TASK_LIFT:
if i > 0 and i % (rew_log_interval-1) == 0:
box_goal_pos = env.box_goal_pos[robot_index].cpu().numpy()
box_pos = env.box_root_states[robot_index, :3].cpu().numpy()
box_pos_diff = np.mean(np.abs(box_pos - box_goal_pos)[2:3]) # z axis only
logger.log_metrics(
{
'box_pos_diff': box_pos_diff
}
)
print('box_pos_diff:', box_pos_diff)
if args.task in TASK_BALL:
if i > 0 and i % (rew_log_interval-1) == 0:
ball_goal_pos = env.goal_pos[robot_index].cpu().numpy()
ball_pos = env.ball_root_states[robot_index, :3].cpu().numpy()
ball_pos_diff = np.mean(np.abs(ball_pos - ball_goal_pos))
logger.log_metrics(
{
'ball_pos_diff': ball_pos_diff
}
)
print('ball_pos_diff:', ball_pos_diff)
if args.task in TASK_CARRY:
if i > 0 and i % (rew_log_interval-1) == 0:
box_goal_pos = env.box_goal_pos[robot_index].cpu().numpy()
box_pos = env.box_root_states[robot_index, :3].cpu().numpy()
box_pos_diff = np.mean(np.abs(box_pos - box_goal_pos))
logger.log_metrics(
{
'box_pos_diff': box_pos_diff
}
)
print('box_pos_diff:', box_pos_diff)
if args.task in TASK_CABINET:
if i > 0 and i % (rew_log_interval-1) == 0:
arti_obj_dof_diff = (env.arti_obj_dof_state[:, :, 0] - env.arti_obj_dof_goal).cpu().numpy() # [num_envs, 2]
arti_obj_dof_diff = np.mean(np.abs(arti_obj_dof_diff))
logger.log_metrics(
{
'arti_obj_dof_diff': arti_obj_dof_diff
}
)
print('arti_obj_dof_diff:', arti_obj_dof_diff)
### logging end ###
if i > 0 and i % state_log_interval == 0:
logger.plot_states()
if infos["episode"]:
num_episodes = torch.sum(env.reset_buf).item()
if num_episodes>0:
logger.log_rewards(infos["episode"], num_episodes)
if i > 0 and i % rew_log_interval == 0:
logger.print_rewards()
fig = logger._plot()
filename_png = f'{model_name}_{i_rollout}.png'
fig.savefig(os.path.join(frame_path, filename_png))
video.release()
import pdb; pdb.set_trace()
del logger
if __name__ == '__main__':
EXPORT_POLICY = True
RECORD_FRAMES = True
EGO_CENTRIC = False
FIX_COMMAND = True
args = get_args(test=True)
play(args)

Xet Storage Details

Size:
27.7 kB
·
Xet hash:
689e1f3a7c4cfaf4cff703138f139498a9071fe1d91a5de163d9eda8e71312fe

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.