lsnu's picture
Add files using upload-large-folder tool
0d89eb9 verified
# Adapted from ARM
# Source: https://github.com/stepjam/ARM
# License: https://github.com/stepjam/ARM/LICENSE
import logging
from typing import List
import numpy as np
from omegaconf import DictConfig
from rlbench.backend.observation import Observation
from rlbench.observation_config import ObservationConfig
import rlbench.utils as rlbench_utils
from rlbench.demo import Demo
from yarr.replay_buffer.prioritized_replay_buffer import (
PrioritizedReplayBuffer,
ObservationElement,
)
from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer
from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer
from helpers import demo_loading_utils, utils
from helpers import observation_utils
from agents.baselines.bc_lang.bc_lang_agent import BCLangAgent
from helpers.custom_rlbench_env import CustomRLBenchEnv
from helpers.network_utils import SiameseNet, CNNLangAndFcsNet
from helpers.preprocess_agent import PreprocessAgent
import torch
from torch.multiprocessing import Process, Value, Manager
from helpers.clip.core.clip import build_model, load_clip, tokenize
LOW_DIM_SIZE = 4
def create_replay(
batch_size: int,
timesteps: int,
prioritisation: bool,
task_uniform: bool,
save_dir: str,
cameras: list,
image_size=[128, 128],
replay_size=3e5,
):
lang_feat_dim = 1024
# low_dim_state
observation_elements = []
observation_elements.append(
ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32)
)
# rgb, depth, point cloud, intrinsics, extrinsics
for cname in cameras:
observation_elements.append(
ObservationElement(
"%s_rgb" % cname,
(
3,
*image_size,
),
np.float32,
)
)
observation_elements.append(
ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32)
) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
observation_elements.append(
ObservationElement(
"%s_camera_extrinsics" % cname,
(
4,
4,
),
np.float32,
)
)
observation_elements.append(
ObservationElement(
"%s_camera_intrinsics" % cname,
(
3,
3,
),
np.float32,
)
)
observation_elements.extend(
[
ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
ReplayElement("task", (), str),
ReplayElement(
"lang_goal", (1,), object
), # language goal string for debugging and visualization
]
)
extra_replay_elements = [
ReplayElement("demo", (), np.bool),
]
replay_buffer = TaskUniformReplayBuffer(
save_dir=save_dir,
batch_size=batch_size,
timesteps=timesteps,
replay_capacity=int(replay_size),
action_shape=(8,),
action_dtype=np.float32,
reward_shape=(),
reward_dtype=np.float32,
update_horizon=1,
observation_elements=observation_elements,
extra_replay_elements=extra_replay_elements,
)
return replay_buffer
def _get_action(obs_tp1: Observation):
quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
if quat[-1] < 0:
quat = -quat
return np.concatenate(
[obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]]
)
def _add_keypoints_to_replay(
cfg: DictConfig,
task: str,
replay: ReplayBuffer,
inital_obs: Observation,
demo: Demo,
episode_keypoints: List[int],
cameras: List[str],
description: str = "",
clip_model=None,
device="cpu",
):
prev_action = None
obs = inital_obs
all_actions = []
for k, keypoint in enumerate(episode_keypoints):
obs_tp1 = demo[keypoint]
action = _get_action(obs_tp1)
all_actions.append(action)
terminal = k == len(episode_keypoints) - 1
reward = float(terminal) if terminal else 0
obs_dict = observation_utils.extract_obs(
obs,
t=k,
prev_action=prev_action,
cameras=cameras,
episode_length=cfg.rlbench.episode_length,
robot_name=cfg.method.robot_name,
)
del obs_dict["ignore_collisions"]
tokens = tokenize([description]).numpy()
token_tensor = torch.from_numpy(tokens).to(device)
lang_feats, lang_embs = clip_model.encode_text_with_embeddings(token_tensor)
obs_dict["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy()
final_obs = {
"task": task,
"lang_goal": np.array([description], dtype=object),
}
prev_action = np.copy(action)
others = {"demo": True}
others.update(final_obs)
others.update(obs_dict)
timeout = False
replay.add(action, reward, terminal, timeout, **others)
obs = obs_tp1 # Set the next obs
# Final step
obs_dict_tp1 = observation_utils.extract_obs(
obs_tp1,
t=k + 1,
prev_action=prev_action,
cameras=cameras,
episode_length=cfg.rlbench.episode_length,
robot_name=cfg.method.robot_name,
)
obs_dict_tp1["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy()
# del obs_dict_tp1['lang_goal_tokens']
del obs_dict_tp1["ignore_collisions"]
# obs_dict_tp1['task'] = task
obs_dict_tp1.update(final_obs)
replay.add_final(**obs_dict_tp1)
return all_actions
def fill_replay(
cfg: DictConfig,
obs_config: ObservationConfig,
rank: int,
replay: ReplayBuffer,
task: str,
num_demos: int,
demo_augmentation: bool,
demo_augmentation_every_n: int,
cameras: List[str],
clip_model=None,
device="cpu",
):
if clip_model is None:
model, _ = load_clip("RN50", jit=False, device=device)
clip_model = build_model(model.state_dict())
clip_model.to(device)
del model
logging.debug("Filling %s replay ..." % task)
all_actions = []
for d_idx in range(num_demos):
# load demo from disk
demo = rlbench_utils.get_stored_demos(
amount=1,
image_paths=False,
dataset_root=cfg.rlbench.demo_path,
variation_number=-1,
task_name=task,
obs_config=obs_config,
random_selection=False,
from_episode_number=d_idx,
)[0]
descs = demo._observations[0].misc["descriptions"]
# extract keypoints (a.k.a keyframes)
episode_keypoints = demo_loading_utils.keypoint_discovery(demo)
if rank == 0:
logging.info(
f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}"
)
for i in range(len(demo) - 1):
if not demo_augmentation and i > 0:
break
if i % demo_augmentation_every_n != 0:
continue
obs = demo[i]
desc = descs[0]
# if our starting point is past one of the keypoints, then remove it
while len(episode_keypoints) > 0 and i >= episode_keypoints[0]:
episode_keypoints = episode_keypoints[1:]
if len(episode_keypoints) == 0:
break
all_actions.extend(
_add_keypoints_to_replay(
cfg,
task,
replay,
obs,
demo,
episode_keypoints,
cameras,
description=desc,
clip_model=clip_model,
device=device,
)
)
logging.debug("Replay filled with demos.")
return all_actions
def fill_multi_task_replay(
cfg: DictConfig,
obs_config: ObservationConfig,
rank: int,
replay: ReplayBuffer,
tasks: List[str],
num_demos: int,
demo_augmentation: bool,
demo_augmentation_every_n: int,
cameras: List[str],
clip_model=None,
):
manager = Manager()
store = manager.dict()
# create a MP dict for storing indicies
# TODO(mohit): this shouldn't be initialized here
del replay._task_idxs
task_idxs = manager.dict()
replay._task_idxs = task_idxs
replay._create_storage(store)
replay.add_count = Value("i", 0)
# fill replay buffer in parallel across tasks
max_parallel_processes = cfg.replay.max_parallel_processes
processes = []
n = np.arange(len(tasks))
split_n = utils.split_list(n, max_parallel_processes)
for split in split_n:
for e_idx, task_idx in enumerate(split):
task = tasks[int(task_idx)]
model_device = torch.device(
"cuda:%s" % (e_idx % torch.cuda.device_count())
if torch.cuda.is_available()
else "cpu"
)
p = Process(
target=fill_replay,
args=(
cfg,
obs_config,
rank,
replay,
task,
num_demos,
demo_augmentation,
demo_augmentation_every_n,
cameras,
clip_model,
model_device,
),
)
p.start()
processes.append(p)
for p in processes:
p.join()
logging.debug("Replay filled with multi demos.")
def create_agent(cfg: DictConfig):
camera_name = cfg.rlbench.cameras
activation = cfg.method.activation
lr = cfg.method.lr
weight_decay = cfg.method.weight_decay
image_resolution = cfg.rlbench.camera_resolution
grad_clip = cfg.method.grad_clip
siamese_net = SiameseNet(
input_channels=[3, 3],
filters=[16],
kernel_sizes=[5],
strides=[1],
activation=activation,
norm=None,
)
actor_net = CNNLangAndFcsNet(
siamese_net=siamese_net,
input_resolution=image_resolution,
filters=[32, 64, 64],
kernel_sizes=[3, 3, 3],
strides=[2, 2, 2],
norm=None,
activation=activation,
fc_layers=[128, 64, 3 + 4 + 1],
low_dim_state_len=LOW_DIM_SIZE,
)
bc_agent = BCLangAgent(
actor_network=actor_net,
camera_name=camera_name,
lr=lr,
weight_decay=weight_decay,
grad_clip=grad_clip,
)
return PreprocessAgent(pose_agent=bc_agent)