| |
| |
| |
|
|
| import logging |
| from typing import List |
|
|
| import numpy as np |
| from omegaconf import DictConfig |
| from rlbench.backend.observation import Observation |
| from rlbench.observation_config import ObservationConfig |
| import rlbench.utils as rlbench_utils |
| from rlbench.demo import Demo |
| from yarr.replay_buffer.prioritized_replay_buffer import ( |
| PrioritizedReplayBuffer, |
| ObservationElement, |
| ) |
| from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer |
| from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer |
| from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer |
|
|
| from helpers import demo_loading_utils, utils |
| from helpers import observation_utils |
| from agents.baselines.bc_lang.bc_lang_agent import BCLangAgent |
| from helpers.custom_rlbench_env import CustomRLBenchEnv |
| from helpers.network_utils import SiameseNet, CNNLangAndFcsNet |
| from helpers.preprocess_agent import PreprocessAgent |
|
|
| import torch |
| from torch.multiprocessing import Process, Value, Manager |
| from helpers.clip.core.clip import build_model, load_clip, tokenize |
|
|
| LOW_DIM_SIZE = 4 |
|
|
|
|
| def create_replay( |
| batch_size: int, |
| timesteps: int, |
| prioritisation: bool, |
| task_uniform: bool, |
| save_dir: str, |
| cameras: list, |
| image_size=[128, 128], |
| replay_size=3e5, |
| ): |
| lang_feat_dim = 1024 |
|
|
| |
| observation_elements = [] |
| observation_elements.append( |
| ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32) |
| ) |
|
|
| |
| for cname in cameras: |
| observation_elements.append( |
| ObservationElement( |
| "%s_rgb" % cname, |
| ( |
| 3, |
| *image_size, |
| ), |
| np.float32, |
| ) |
| ) |
| observation_elements.append( |
| ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32) |
| ) |
| observation_elements.append( |
| ObservationElement( |
| "%s_camera_extrinsics" % cname, |
| ( |
| 4, |
| 4, |
| ), |
| np.float32, |
| ) |
| ) |
| observation_elements.append( |
| ObservationElement( |
| "%s_camera_intrinsics" % cname, |
| ( |
| 3, |
| 3, |
| ), |
| np.float32, |
| ) |
| ) |
|
|
| observation_elements.extend( |
| [ |
| ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32), |
| ReplayElement("task", (), str), |
| ReplayElement( |
| "lang_goal", (1,), object |
| ), |
| ] |
| ) |
|
|
| extra_replay_elements = [ |
| ReplayElement("demo", (), np.bool), |
| ] |
|
|
| replay_buffer = TaskUniformReplayBuffer( |
| save_dir=save_dir, |
| batch_size=batch_size, |
| timesteps=timesteps, |
| replay_capacity=int(replay_size), |
| action_shape=(8,), |
| action_dtype=np.float32, |
| reward_shape=(), |
| reward_dtype=np.float32, |
| update_horizon=1, |
| observation_elements=observation_elements, |
| extra_replay_elements=extra_replay_elements, |
| ) |
| return replay_buffer |
|
|
|
|
| def _get_action(obs_tp1: Observation): |
| quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:]) |
| if quat[-1] < 0: |
| quat = -quat |
| return np.concatenate( |
| [obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]] |
| ) |
|
|
|
|
| def _add_keypoints_to_replay( |
| cfg: DictConfig, |
| task: str, |
| replay: ReplayBuffer, |
| inital_obs: Observation, |
| demo: Demo, |
| episode_keypoints: List[int], |
| cameras: List[str], |
| description: str = "", |
| clip_model=None, |
| device="cpu", |
| ): |
| prev_action = None |
| obs = inital_obs |
| all_actions = [] |
| for k, keypoint in enumerate(episode_keypoints): |
| obs_tp1 = demo[keypoint] |
| action = _get_action(obs_tp1) |
| all_actions.append(action) |
| terminal = k == len(episode_keypoints) - 1 |
| reward = float(terminal) if terminal else 0 |
|
|
| obs_dict = observation_utils.extract_obs( |
| obs, |
| t=k, |
| prev_action=prev_action, |
| cameras=cameras, |
| episode_length=cfg.rlbench.episode_length, |
| robot_name=cfg.method.robot_name, |
| ) |
| del obs_dict["ignore_collisions"] |
| tokens = tokenize([description]).numpy() |
| token_tensor = torch.from_numpy(tokens).to(device) |
| lang_feats, lang_embs = clip_model.encode_text_with_embeddings(token_tensor) |
| obs_dict["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy() |
|
|
| final_obs = { |
| "task": task, |
| "lang_goal": np.array([description], dtype=object), |
| } |
|
|
| prev_action = np.copy(action) |
| others = {"demo": True} |
| others.update(final_obs) |
| others.update(obs_dict) |
| timeout = False |
| replay.add(action, reward, terminal, timeout, **others) |
| obs = obs_tp1 |
| |
| obs_dict_tp1 = observation_utils.extract_obs( |
| obs_tp1, |
| t=k + 1, |
| prev_action=prev_action, |
| cameras=cameras, |
| episode_length=cfg.rlbench.episode_length, |
| robot_name=cfg.method.robot_name, |
| ) |
| obs_dict_tp1["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy() |
| |
| del obs_dict_tp1["ignore_collisions"] |
| |
| obs_dict_tp1.update(final_obs) |
| replay.add_final(**obs_dict_tp1) |
| return all_actions |
|
|
|
|
| def fill_replay( |
| cfg: DictConfig, |
| obs_config: ObservationConfig, |
| rank: int, |
| replay: ReplayBuffer, |
| task: str, |
| num_demos: int, |
| demo_augmentation: bool, |
| demo_augmentation_every_n: int, |
| cameras: List[str], |
| clip_model=None, |
| device="cpu", |
| ): |
| if clip_model is None: |
| model, _ = load_clip("RN50", jit=False, device=device) |
| clip_model = build_model(model.state_dict()) |
| clip_model.to(device) |
| del model |
|
|
| logging.debug("Filling %s replay ..." % task) |
| all_actions = [] |
| for d_idx in range(num_demos): |
| |
| demo = rlbench_utils.get_stored_demos( |
| amount=1, |
| image_paths=False, |
| dataset_root=cfg.rlbench.demo_path, |
| variation_number=-1, |
| task_name=task, |
| obs_config=obs_config, |
| random_selection=False, |
| from_episode_number=d_idx, |
| )[0] |
|
|
| descs = demo._observations[0].misc["descriptions"] |
|
|
| |
| episode_keypoints = demo_loading_utils.keypoint_discovery(demo) |
|
|
| if rank == 0: |
| logging.info( |
| f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}" |
| ) |
|
|
| for i in range(len(demo) - 1): |
| if not demo_augmentation and i > 0: |
| break |
| if i % demo_augmentation_every_n != 0: |
| continue |
|
|
| obs = demo[i] |
| desc = descs[0] |
| |
| while len(episode_keypoints) > 0 and i >= episode_keypoints[0]: |
| episode_keypoints = episode_keypoints[1:] |
| if len(episode_keypoints) == 0: |
| break |
| all_actions.extend( |
| _add_keypoints_to_replay( |
| cfg, |
| task, |
| replay, |
| obs, |
| demo, |
| episode_keypoints, |
| cameras, |
| description=desc, |
| clip_model=clip_model, |
| device=device, |
| ) |
| ) |
| logging.debug("Replay filled with demos.") |
| return all_actions |
|
|
|
|
| def fill_multi_task_replay( |
| cfg: DictConfig, |
| obs_config: ObservationConfig, |
| rank: int, |
| replay: ReplayBuffer, |
| tasks: List[str], |
| num_demos: int, |
| demo_augmentation: bool, |
| demo_augmentation_every_n: int, |
| cameras: List[str], |
| clip_model=None, |
| ): |
| manager = Manager() |
| store = manager.dict() |
|
|
| |
| |
| del replay._task_idxs |
| task_idxs = manager.dict() |
| replay._task_idxs = task_idxs |
| replay._create_storage(store) |
| replay.add_count = Value("i", 0) |
|
|
| |
| max_parallel_processes = cfg.replay.max_parallel_processes |
| processes = [] |
| n = np.arange(len(tasks)) |
| split_n = utils.split_list(n, max_parallel_processes) |
| for split in split_n: |
| for e_idx, task_idx in enumerate(split): |
| task = tasks[int(task_idx)] |
| model_device = torch.device( |
| "cuda:%s" % (e_idx % torch.cuda.device_count()) |
| if torch.cuda.is_available() |
| else "cpu" |
| ) |
| p = Process( |
| target=fill_replay, |
| args=( |
| cfg, |
| obs_config, |
| rank, |
| replay, |
| task, |
| num_demos, |
| demo_augmentation, |
| demo_augmentation_every_n, |
| cameras, |
| clip_model, |
| model_device, |
| ), |
| ) |
| p.start() |
| processes.append(p) |
|
|
| for p in processes: |
| p.join() |
|
|
| logging.debug("Replay filled with multi demos.") |
|
|
|
|
| def create_agent(cfg: DictConfig): |
| camera_name = cfg.rlbench.cameras |
| activation = cfg.method.activation |
| lr = cfg.method.lr |
| weight_decay = cfg.method.weight_decay |
| image_resolution = cfg.rlbench.camera_resolution |
| grad_clip = cfg.method.grad_clip |
|
|
| siamese_net = SiameseNet( |
| input_channels=[3, 3], |
| filters=[16], |
| kernel_sizes=[5], |
| strides=[1], |
| activation=activation, |
| norm=None, |
| ) |
|
|
| actor_net = CNNLangAndFcsNet( |
| siamese_net=siamese_net, |
| input_resolution=image_resolution, |
| filters=[32, 64, 64], |
| kernel_sizes=[3, 3, 3], |
| strides=[2, 2, 2], |
| norm=None, |
| activation=activation, |
| fc_layers=[128, 64, 3 + 4 + 1], |
| low_dim_state_len=LOW_DIM_SIZE, |
| ) |
|
|
| bc_agent = BCLangAgent( |
| actor_network=actor_net, |
| camera_name=camera_name, |
| lr=lr, |
| weight_decay=weight_decay, |
| grad_clip=grad_clip, |
| ) |
|
|
| return PreprocessAgent(pose_agent=bc_agent) |
|
|