# Adapted from ARM # Source: https://github.com/stepjam/ARM # License: https://github.com/stepjam/ARM/LICENSE import logging from typing import List import numpy as np from omegaconf import DictConfig from rlbench.backend.observation import Observation from rlbench.observation_config import ObservationConfig import rlbench.utils as rlbench_utils from rlbench.demo import Demo from yarr.replay_buffer.prioritized_replay_buffer import ( PrioritizedReplayBuffer, ObservationElement, ) from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer from helpers import utils from helpers import observation_utils from agents.act_bc_lang.act_bc_lang_agent import ActBCLangAgent from helpers.custom_rlbench_env import CustomRLBenchEnv from helpers.preprocess_agent import PreprocessAgent from agents.act_bc_lang.act_policy import ACTPolicy, CNNMLPPolicy import torch from torch.multiprocessing import Process, Value, Manager from helpers.clip.core.clip import build_model, load_clip, tokenize LOW_DIM_SIZE = 8 def create_replay( batch_size: int, timesteps: int, prioritisation: bool, task_uniform: bool, save_dir: str, cameras: list, image_size=[128, 128], replay_size=3e5, prev_action_horizon: int = 1, next_action_horizon: int = 1, ): lang_feat_dim = 1024 # low_dim_state observation_elements = [] observation_elements.append( ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32) ) # action sequences action_seq_sizes = { "right_prev_joint_positions": 7, "right_prev_gripper_joint_positions": 2, "right_prev_gripper_poses": 7, "right_next_joint_positions": 7, "right_next_gripper_joint_positions": 2, "right_next_gripper_poses": 7, "left_prev_joint_positions": 7, "left_prev_gripper_joint_positions": 2, "left_prev_gripper_poses": 7, "left_next_joint_positions": 7, "left_next_gripper_joint_positions": 2, "left_next_gripper_poses": 7, } for seq_name, seq_size in action_seq_sizes.items(): horizon = prev_action_horizon if "prev" in seq_name else next_action_horizon observation_elements.append( ObservationElement( seq_name, ( horizon, seq_size, ), np.float32, ) ) # action is_pad observation_elements.append( ObservationElement("is_pad", (next_action_horizon,), np.int32) ) # rgb, depth, point cloud, intrinsics, extrinsics for cname in cameras: observation_elements.append( ObservationElement( "%s_rgb" % cname, ( 3, *image_size, ), np.float32, ) ) observation_elements.append( ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32) ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames observation_elements.append( ObservationElement( "%s_camera_extrinsics" % cname, ( 4, 4, ), np.float32, ) ) observation_elements.append( ObservationElement( "%s_camera_intrinsics" % cname, ( 3, 3, ), np.float32, ) ) observation_elements.extend( [ ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32), ReplayElement("task", (), str), ReplayElement( "lang_goal", (1,), object ), # language goal string for debugging and visualization ] ) extra_replay_elements = [ ReplayElement("demo", (), bool), ] replay_buffer = TaskUniformReplayBuffer( save_dir=save_dir, batch_size=batch_size, timesteps=timesteps, replay_capacity=int(replay_size), action_shape=(8 * 2,), action_dtype=np.float32, reward_shape=(), reward_dtype=np.float32, update_horizon=1, observation_elements=observation_elements, extra_replay_elements=extra_replay_elements, ) return replay_buffer def _get_action(obs_tp1: Observation): quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:]) if quat[-1] < 0: quat = -quat return np.concatenate( [obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]] ) def _get_action_seq( demo: Demo, timestep: int, prev_action_horizon: int, next_action_horizon: int, robot_name: str, ): action_seq = { "right_prev_joint_positions": [], "right_prev_gripper_joint_positions": [], "right_prev_gripper_poses": [], "left_prev_joint_positions": [], "left_prev_gripper_joint_positions": [], "left_prev_gripper_poses": [], "right_next_joint_positions": [], "right_next_gripper_joint_positions": [], "right_next_gripper_poses": [], "left_next_joint_positions": [], "left_next_gripper_joint_positions": [], "left_next_gripper_poses": [], "is_pad": [], } for prev_t in list(reversed(range(prev_action_horizon))): t = timestep - prev_t t = max(0, t) obs = demo[t] action_seq["right_prev_joint_positions"].append(obs.right.joint_positions) action_seq["right_prev_gripper_joint_positions"].append( obs.right.gripper_joint_positions ) action_seq["right_prev_gripper_poses"].append(obs.right.gripper_pose) action_seq["left_prev_joint_positions"].append(obs.left.joint_positions) action_seq["left_prev_gripper_joint_positions"].append( obs.left.gripper_joint_positions ) action_seq["left_prev_gripper_poses"].append(obs.left.gripper_pose) action_seq["is_pad"] = np.zeros(next_action_horizon) for idx, next_t in enumerate(range(0, next_action_horizon)): t = timestep + next_t t = min(t, len(demo) - 1) obs = demo[t] if timestep + next_t > len(demo) - 1: action_seq["is_pad"][idx] = 1 action_seq["right_next_joint_positions"].append(obs.right.joint_positions) action_seq["right_next_gripper_joint_positions"].append( obs.right.gripper_joint_positions ) action_seq["right_next_gripper_poses"].append(obs.right.gripper_pose) action_seq["left_next_joint_positions"].append(obs.left.joint_positions) action_seq["left_next_gripper_joint_positions"].append( obs.left.gripper_joint_positions ) action_seq["left_next_gripper_poses"].append(obs.left.gripper_pose) # convert to numpy arrays return {k: np.array(v) for k, v in action_seq.items()} def _add_keypoints_to_replay( step: int, cfg: DictConfig, task: str, replay: ReplayBuffer, inital_obs: Observation, demo: Demo, description: str = "", clip_model=None, device="cpu", ): cameras = cfg.rlbench.cameras robot_name = cfg.method.robot_name prev_action = None obs = inital_obs all_actions = [] k = step k_tp1 = min(k + 1, len(demo) - 1) obs_tp1 = demo[k_tp1] if obs_tp1.is_bimanual and robot_name == "bimanual": right_action = _get_action(obs_tp1.right) left_action = _get_action(obs_tp1.left) action = np.append(right_action, left_action) elif robot_name == "unimanual": action = _get_action(obs_tp1) elif obs_tp1.is_bimanual and robot_name == "right": action = _get_action(obs_tp1.right) elif obs_tp1.is_bimanual and robot_name == "left": action = _get_action(obs_tp1.left) else: logging.error("Invalid robot name %s", cfg.method.robot_name) raise Exception("Invalid robot name.") all_actions.append(action) terminal = k == len(demo) - 1 reward = float(terminal) if terminal else 0 obs_dict = observation_utils.extract_obs( obs, t=k, prev_action=prev_action, cameras=cameras, episode_length=cfg.rlbench.episode_length, robot_name=robot_name, ) if obs_tp1.is_bimanual and robot_name == "bimanual": obs_dict["low_dim_state"] = np.concatenate( [obs_dict["right_low_dim_state"], obs_dict["left_low_dim_state"]] ) del obs_dict["right_low_dim_state"] del obs_dict["left_low_dim_state"] del obs_dict["right_ignore_collisions"] del obs_dict["left_ignore_collisions"] else: del obs_dict["ignore_collisions"] tokens = tokenize([description]).numpy() token_tensor = torch.from_numpy(tokens).to(device) lang_feats, lang_embs = clip_model.encode_text_with_embeddings(token_tensor) obs_dict["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy() final_obs = { "task": task, "lang_goal": np.array([description], dtype=object), } action_seq = _get_action_seq( demo, step, cfg.method.prev_action_horizon, cfg.method.next_action_horizon, robot_name, ) obs_dict.update(action_seq) prev_action = np.copy(action) others = {"demo": True} others.update(final_obs) others.update(obs_dict) timeout = False replay.add(action, reward, terminal, timeout, **others) return all_actions def fill_replay( cfg: DictConfig, obs_config: ObservationConfig, rank: int, replay: ReplayBuffer, task: str, num_demos: int, demo_augmentation: bool, demo_augmentation_every_n: int, cameras: List[str], clip_model=None, device="cpu", ): if clip_model is None: model, _ = load_clip("RN50", jit=False, device=device) clip_model = build_model(model.state_dict()) clip_model.to(device) del model logging.debug("Filling %s replay ..." % task) all_actions = [] for d_idx in range(num_demos): # load demo from disk demo = rlbench_utils.get_stored_demos( amount=1, image_paths=False, dataset_root=cfg.rlbench.demo_path, variation_number=-1, task_name=task, obs_config=obs_config, random_selection=False, from_episode_number=d_idx, )[0] descs = demo._observations[0].misc["descriptions"] if rank == 0: logging.info(f"Loading Demo({d_idx})") for i in range(len(demo) - 1): obs = demo[i] desc = descs[0] # stopped = np.allclose(obs.joint_velocities, 0, atol=0.1) # if stopped: # continue all_actions.extend( _add_keypoints_to_replay( i, cfg, task, replay, obs, demo, description=desc, clip_model=clip_model, device=device, ) ) logging.debug("Replay filled with demos.") return all_actions def fill_multi_task_replay( cfg: DictConfig, obs_config: ObservationConfig, rank: int, replay: ReplayBuffer, tasks: List[str], num_demos: int, demo_augmentation: bool, demo_augmentation_every_n: int, cameras: List[str], clip_model=None, ): manager = Manager() store = manager.dict() # create a MP dict for storing indicies # TODO(mohit): this shouldn't be initialized here del replay._task_idxs task_idxs = manager.dict() replay._task_idxs = task_idxs replay._create_storage(store) replay.add_count = Value("i", 0) # fill replay buffer in parallel across tasks max_parallel_processes = cfg.replay.max_parallel_processes processes = [] n = np.arange(len(tasks)) split_n = utils.split_list(n, max_parallel_processes) for split in split_n: for e_idx, task_idx in enumerate(split): task = tasks[int(task_idx)] model_device = torch.device( "cuda:%s" % (e_idx % torch.cuda.device_count()) if torch.cuda.is_available() else "cpu" ) p = Process( target=fill_replay, args=( cfg, obs_config, rank, replay, task, num_demos, demo_augmentation, demo_augmentation_every_n, cameras, clip_model, model_device, ), ) p.start() processes.append(p) for p in processes: p.join() logging.debug("Replay filled with multi demos.") def create_agent(cfg: DictConfig): actor_net = ACTPolicy(cfg.method) bc_agent = ActBCLangAgent( actor_network=actor_net, camera_names=cfg.rlbench.cameras, lr=cfg.method.lr, weight_decay=cfg.method.weight_decay, grad_clip=cfg.method.grad_clip, episode_length=cfg.rlbench.episode_length, train_demo_path=cfg.method.train_demo_path, task_name=cfg.rlbench.tasks[0], ) return PreprocessAgent(pose_agent=bc_agent, norm_type="imagenet")