File size: 8,390 Bytes
d899b9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from typing import Callable, List, Type
import gymnasium as gym
import numpy as np
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils import common, gym_utils
import argparse
import yaml
import torch
from collections import deque
from PIL import Image
import cv2
import imageio
from functools import partial
from diffusion_policy.workspace.robotworkspace import RobotWorkspace
def parse_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--env-id", type=str, default="PickCube-v1", help=f"Environment to run motion planning solver on. ")
parser.add_argument("-o", "--obs-mode", type=str, default="rgb", help="Observation mode to use. Usually this is kept as 'none' as observations are not necesary to be stored, they can be replayed later via the mani_skill.trajectory.replay_trajectory script.")
parser.add_argument("-n", "--num-traj", type=int, default=25, help="Number of trajectories to generate.")
parser.add_argument("--only-count-success", action="store_true", help="If true, generates trajectories until num_traj of them are successful and only saves the successful trajectories/videos")
parser.add_argument("--reward-mode", type=str)
parser.add_argument("-b", "--sim-backend", type=str, default="auto", help="Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'")
parser.add_argument("--render-mode", type=str, default="rgb_array", help="can be 'sensors' or 'rgb_array' which only affect what is saved to videos")
parser.add_argument("--vis", action="store_true", help="whether or not to open a GUI to visualize the solution live")
parser.add_argument("--save-video", action="store_true", help="whether or not to save videos locally")
parser.add_argument("--traj-name", type=str, help="The name of the trajectory .h5 file that will be created.")
parser.add_argument("--shader", default="default", type=str, help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer")
parser.add_argument("--record-dir", type=str, default="demos", help="where to save the recorded trajectories")
parser.add_argument("--num-procs", type=int, default=1, help="Number of processes to use to help parallelize the trajectory replay process. This uses CPU multiprocessing and only works with the CPU simulation backend at the moment.")
parser.add_argument("--random_seed", type=int, default=0, help="Random seed for the environment.")
parser.add_argument("--pretrained_path", type=str, default=None, help="Random seed for the environment.")
return parser.parse_args()
task2lang = {
"PegInsertionSide-v1": "Pick up a orange-white peg and insert the orange end into the box with a hole in it.",
"PickCube-v1": "Grasp a red cube and move it to a target goal position.",
"StackCube-v1": "Pick up a red cube and stack it on top of a green cube and let go of the cube without it falling.",
"PlugCharger-v1": "Pick up one of the misplaced shapes on the board/kit and insert it into the correct empty slot.",
"PushCube-v1": "Push and move a cube to a goal region in front of it."
}
import random
import os
args = parse_args()
seed = args.random_seed
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
env_id = args.env_id
env = gym.make(
env_id,
obs_mode=args.obs_mode,
control_mode="pd_joint_pos",
render_mode=args.render_mode,
reward_mode="dense" if args.reward_mode is None else args.reward_mode,
sensor_configs=dict(shader_pack=args.shader),
human_render_camera_configs=dict(shader_pack=args.shader),
viewer_camera_configs=dict(shader_pack=args.shader),
sim_backend=args.sim_backend
)
from diffusion_policy.workspace.robotworkspace import RobotWorkspace
import hydra
import dill
checkpoint_path = args.pretrained_path
print(f"Loading policy from {checkpoint_path}. Task is {task2lang[env_id]}")
def get_policy(output_dir, device):
# load checkpoint
payload = torch.load(open(checkpoint_path, 'rb'), pickle_module=dill)
cfg = payload['cfg']
cls = hydra.utils.get_class(cfg._target_)
workspace = cls(cfg, output_dir=output_dir)
workspace: RobotWorkspace
workspace.load_payload(payload, exclude_keys=None, include_keys=None)
# get policy from workspace
policy = workspace.model
if cfg.training.use_ema:
policy = workspace.ema_model
device = torch.device(device)
policy.to(device)
policy.eval()
return policy
policy = get_policy('./', device = 'cuda')
MAX_EPISODE_STEPS = 400
total_episodes = args.num_traj
success_count = 0
base_seed = 20241201
instr = task2lang[env_id]
import tqdm
DATA_STAT = {'state_min': [-0.7463043928146362, -0.0801204964518547, -0.4976441562175751, -2.657780647277832, -0.5742632150650024, 1.8309762477874756, -2.2423808574676514, 0.0, 0.0], 'state_max': [0.7645499110221863, 1.4967026710510254, 0.4650936424732208, -0.3866899907588959, 0.5505855679512024, 3.2900545597076416, 2.5737812519073486, 0.03999999910593033, 0.03999999910593033], 'action_min': [-0.7472005486488342, -0.08631071448326111, -0.4995281398296356, -2.658363103866577, -0.5751323103904724, 1.8290787935256958, -2.245187997817993, -1.0], 'action_max': [0.7654682397842407, 1.4984270334243774, 0.46786263585090637, -0.38181185722351074, 0.5517147779464722, 3.291581630706787, 2.575840711593628, 1.0], 'action_std': [0.2199309915304184, 0.18780815601348877, 0.13044124841690063, 0.30669933557510376, 0.1340624988079071, 0.24968451261520386, 0.9589747190475464, 0.9827960729598999], 'action_mean': [-0.00885344110429287, 0.5523102879524231, -0.007564723491668701, -2.0108158588409424, 0.004714342765510082, 2.615924596786499, 0.08461848646402359, -0.19301606714725494]}
state_min = torch.tensor(DATA_STAT['state_min']).cuda()
state_max = torch.tensor(DATA_STAT['state_max']).cuda()
action_min = torch.tensor(DATA_STAT['action_min']).cuda()
action_max = torch.tensor(DATA_STAT['action_max']).cuda()
for episode in tqdm.trange(total_episodes):
obs_window = deque(maxlen=2)
obs, _ = env.reset(seed = episode + base_seed)
img = env.render().cuda().float()
proprio = obs['agent']['qpos'][:].cuda()
proprio = (proprio - state_min) / (state_max - state_min) * 2 - 1
obs_window.append({
'agent_pos': proprio,
"head_cam": img.permute(0, 3, 1, 2),
})
obs_window.append({
'agent_pos': proprio,
"head_cam": img.permute(0, 3, 1, 2),
})
global_steps = 0
video_frames = []
success_time = 0
done = False
while global_steps < MAX_EPISODE_STEPS and not done:
obs = obs_window[-1]
actions = policy.predict_action(obs)
actions = actions['action_pred'].squeeze(0)
actions = (actions + 1) / 2 * (action_max - action_min) + action_min
actions = actions.detach().cpu().numpy()
actions = actions[:8]
for idx in range(actions.shape[0]):
action = actions[idx]
obs, reward, terminated, truncated, info = env.step(action)
img = env.render().cuda().float()
proprio = obs['agent']['qpos'][:].cuda()
proprio = (proprio - state_min) / (state_max - state_min) * 2 - 1
obs_window.append({
'agent_pos': proprio,
"head_cam": img.permute(0, 3, 1, 2),
})
video_frames.append(env.render().squeeze(0).detach().cpu().numpy())
global_steps += 1
if terminated or truncated:
assert "success" in info, sorted(info.keys())
if info['success']:
done = True
success_count += 1
break
print(f"Trial {episode+1} finished, success: {info['success']}, steps: {global_steps}")
success_rate = success_count / total_episodes * 100
print(f"Tested {total_episodes} episodes, success rate: {success_rate:.2f}%")
log_file = f"results_dp_{checkpoint_path.split('/')[-1].split('.')[0]}.txt"
with open(log_file, 'a') as f:
f.write(f"{args.env_id}:{seed}:{success_count}\n") |