robot_twin / data_utils /dataset.py
ljm2023's picture
Upload RoboTwin-Challenge-RealWorld-Deployment
ce425f4 verified
import numpy as np
import torch
import os
import h5py
import fnmatch
import cv2
import torchvision.transforms as transforms
import copy
from aloha_scripts.utils import *
def flatten_list(l):
return [item for sublist in l for item in sublist]
import gc
class EpisodicDataset(torch.utils.data.Dataset):
def __init__(self, dataset_path_list, camera_names, norm_stats, episode_ids, episode_len, chunk_size, policy_class, robot=None, rank0_print=print, llava_pythia_process=None, data_args=None):
super(EpisodicDataset).__init__()
self.episode_ids = episode_ids
self.dataset_path_list = dataset_path_list
self.camera_names = camera_names
self.norm_stats = norm_stats
self.episode_len = episode_len
self.chunk_size = chunk_size
self.cumulative_len = np.cumsum(self.episode_len)
self.max_episode_len = max(episode_len)
self.policy_class = policy_class
self.llava_pythia_process = llava_pythia_process
self.data_args = data_args
self.robot = robot
self.rank0_print = rank0_print
if 'diffusion' in self.policy_class:
self.augment_images = True
else:
self.augment_images = False
self.transformations = None
self.rank0_print(f"########################Current Image Size is [{self.data_args.image_size_stable}]###################################")
self.rank0_print(f"{RED}policy class: {self.policy_class}; augument: {self.augment_images}{RESET}")
a=self.__getitem__(0) # initialize self.is_sim and self.transformations
if len(self.camera_names) > 2:
# self.rank0_print("%"*40)
self.rank0_print(f"The robot is {RED} {self.robot} {RESET} | The camera views: {RED} {self.camera_names} {RESET} | The history length: {RED} {self.data_args.history_images_length} {RESET}")
self.is_sim = False
def __len__(self):
return sum(self.episode_len)
def _locate_transition(self, index):
assert index < self.cumulative_len[-1]
episode_index = np.argmax(self.cumulative_len > index) # argmax returns first True index
start_ts = index - (self.cumulative_len[episode_index] - self.episode_len[episode_index])
episode_id = self.episode_ids[episode_index]
return episode_id, start_ts
def load_from_h5(self, dataset_path, start_ts):
with h5py.File(dataset_path, 'r') as root:
try: # some legacy data does not have this attribute
is_sim = root.attrs['sim']
except:
is_sim = False
compressed = root.attrs.get('compress', False)
try:
raw_lang = root['language_raw'][0].decode('utf-8')
except Exception as e:
# self.rank0_print(e)
self.rank0_print(f"Read {dataset_path} happens {YELLOW}{e}{RESET}")
exit(0)
reasoning = " "
if self.data_args.use_reasoning:
if 'substep_reasonings' in root.keys():
reasoning = root['substep_reasonings'][start_ts].decode('utf-8')
else:
# print("no substep reasonings")
try:
reasoning = root['reasoning'][0].decode('utf-8')
except Exception as e:
# self.rank0_print(e)
self.rank0_print(f"Read reasoning from {dataset_path} happens {YELLOW}{e}{RESET}")
exit(0)
# print(reasoning)
action = root['/action'][()]
original_action_shape = action.shape
episode_len = original_action_shape[0]
# get observation at start_ts only
qpos = root['/observations/qpos'][start_ts]
qvel = root['/observations/qvel'][start_ts]
image_dict = dict()
for cam_name in self.camera_names:
if self.data_args.history_images_length == 1:
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts]
# if self.data_args.pretrain_image_size != image_dict[cam_name].shape[1]:
image_dict[cam_name] = cv2.resize(image_dict[cam_name], eval(self.data_args.image_size_stable))
elif self.data_args.history_images_length >= 2:
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][max(0, start_ts-self.data_args.history_images_length + 1):start_ts+1:12]
temp = []
for each in image_dict[cam_name]:
temp.append(cv2.resize(each, eval(self.data_args.image_size_stable)))
if len(temp) < self.data_args.history_images_length / 2:
new_temp = copy.deepcopy([temp[0]] * (int(self.data_args.history_images_length / 12) - len(temp)))
new_temp = new_temp + temp
temp = new_temp
image_dict[cam_name] = np.stack(temp, axis=0)
if compressed:
for cam_name in image_dict.keys():
decompressed_image = cv2.imdecode(image_dict[cam_name], 1)
image_dict[cam_name] = np.array(decompressed_image)
# get all actions after and including start_ts
if is_sim:
action = action[start_ts:]
action_len = episode_len - start_ts
else:
action = action[max(0, start_ts - 1):] # hack, to make timesteps more aligned
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned
return original_action_shape, action, action_len, image_dict, qpos, qvel, raw_lang, reasoning
def __getitem__(self, index):
episode_id, start_ts = self._locate_transition(index)
dataset_path = self.dataset_path_list[episode_id]
# print(dataset_path)
try:
original_action_shape, action, action_len, image_dict, qpos, qvel, raw_lang, reasoning = self.load_from_h5(dataset_path, start_ts)
except Exception as e:
print(f"Read {dataset_path} happens {YELLOW}{e}{RESET}")
try:
dataset_path = self.dataset_path_list[episode_id + 1]
except Exception as e:
dataset_path = self.dataset_path_list[episode_id - 1]
original_action_shape, action, action_len, image_dict, qpos, qvel, raw_lang, reasoning = self.load_from_h5(dataset_path, start_ts)
# self.is_sim = is_sim
padded_action = np.zeros((self.max_episode_len, original_action_shape[1]), dtype=np.float32)
if self.data_args.delta_control:
padded_action[:action_len - 1] = action[1:] - action[:-1]
else:
padded_action[:action_len] = action
is_pad = np.zeros(self.max_episode_len)
is_pad[action_len:] = 1
padded_action = padded_action[:self.chunk_size]
is_pad = is_pad[:self.chunk_size]
# new axis for different cameras
all_cam_images = []
for cam_name in self.camera_names:
all_cam_images.append(image_dict[cam_name])
all_cam_images = np.stack(all_cam_images, axis=0)
# construct observations
image_data = torch.from_numpy(all_cam_images)
qpos_data = torch.from_numpy(qpos).float()
action_data = torch.from_numpy(padded_action).float()
is_pad = torch.from_numpy(is_pad).bool()
# if 'top' in self.camera_names or 'cam_high' in self.camera_names: # denote for data collect via bimanual UR5
if self.robot == 'franka':
assert image_data.ndim==4, f"image_data's shape is {image_data.shape}, maybe the reason of adding historical images"
image_data = torch.stack([torch.from_numpy(cv2.cvtColor(img.numpy(), cv2.COLOR_BGR2RGB)) for img in image_data], dim=0)
# image_data = torch.stack([torch.from_numpy(cv2.cvtColor(img.numpy(), cv2.COLOR_BGR2RGB)) for img in image_data], dim=0)
# cv2.imshow("aa", image_data[0].numpy())
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# channel last
if image_data.ndim == 4:
image_data = torch.einsum('k h w c -> k c h w', image_data)
else:
image_data = torch.einsum('k t h w c -> k t c h w', image_data)
# augmentation
if self.transformations is None:
self.rank0_print('Initializing transformations')
original_size = image_data.shape[-2:]
ratio = 0.95
self.transformations = [
transforms.RandomCrop(size=[int(original_size[0] * ratio), int(original_size[1] * ratio)]),
transforms.Resize(original_size, antialias=True),
transforms.RandomRotation(degrees=[-5.0, 5.0], expand=False),
transforms.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5) #, hue=0.08)
]
if self.augment_images:
#print("yes"*100)
#exit(0)
orig_shape=None
if image_data.ndim != 4:
orig_shape = image_data.shape
image_data = image_data.view(-1, *image_data.shape[-3:])
for transform in self.transformations:
image_data = transform(image_data)
if orig_shape is not None:
image_data = image_data.view(*orig_shape)
# normalize image and change dtype to float
# todo whether to use?
# image_data = image_data / 255.0
if 'fold_shirt' in self.norm_stats.keys():
if 'fold' in dataset_path or 'shirt' in dataset_path:
key = 'fold_shirt'
elif 'clean_table' in dataset_path and 'pick' not in dataset_path:
key = 'clean_table'
else:
key = 'others'
norm_stats = self.norm_stats[key]
else:
norm_stats = self.norm_stats
if 'diffusion' in self.policy_class:
# normalize to [-1, 1]
action_data = ((action_data - norm_stats["action_min"]) / (norm_stats["action_max"] - norm_stats["action_min"])) * 2 - 1
else:
# normalize to mean 0 std 1
action_data = (action_data - norm_stats["action_mean"]) / norm_stats["action_std"]
qpos_data = (qpos_data - norm_stats["qpos_mean"]) / norm_stats["qpos_std"]
if self.policy_class == 'ACT':
return image_data, qpos_data, action_data, is_pad
sample = {
'image': image_data,
'state': qpos_data,
'action': action_data,
'is_pad': is_pad,
'raw_lang': raw_lang,
'reasoning': reasoning
}
assert raw_lang is not None, ""
if index == 0:
self.rank0_print(reasoning)
del image_data
del qpos_data
del action_data
del is_pad
del raw_lang
del reasoning
gc.collect()
torch.cuda.empty_cache()
return self.llava_pythia_process.forward_process(sample, use_reasoning=self.data_args.use_reasoning)
# print(image_data.dtype, qpos_data.dtype, action_data.dtype, is_pad.dtype)
def get_norm_stats(dataset_path_list, rank0_print=print):
all_qpos_data = []
all_action_data = []
all_episode_len = []
for dataset_path in dataset_path_list:
try:
with h5py.File(dataset_path, 'r') as root:
qpos = root['/observations/qpos'][()]
qvel = root['/observations/qvel'][()]
action = root['/action'][()]
except Exception as e:
rank0_print(f'Error loading {dataset_path} in get_norm_stats')
rank0_print(e)
quit()
all_qpos_data.append(torch.from_numpy(qpos))
all_action_data.append(torch.from_numpy(action))
all_episode_len.append(len(qpos))
all_qpos_data = torch.cat(all_qpos_data, dim=0)
all_action_data = torch.cat(all_action_data, dim=0)
# normalize action data
action_mean = all_action_data.mean(dim=[0]).float()
action_std = all_action_data.std(dim=[0]).float()
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
# normalize qpos data
qpos_mean = all_qpos_data.mean(dim=[0]).float()
qpos_std = all_qpos_data.std(dim=[0]).float()
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
action_min = all_action_data.min(dim=0).values.float()
action_max = all_action_data.max(dim=0).values.float()
eps = 0.0001
stats = {"action_mean": action_mean.numpy(), "action_std": action_std.numpy(),
"action_min": action_min.numpy() - eps,"action_max": action_max.numpy() + eps,
"qpos_mean": qpos_mean.numpy(), "qpos_std": qpos_std.numpy(),
"example_qpos": qpos}
return stats, all_episode_len
# calculating the norm stats corresponding to each kind of task (e.g. folding shirt, clean table....)
def get_norm_stats_by_tasks(dataset_path_list):
data_tasks_dict = dict(
fold_shirt=[],
clean_table=[],
others=[],
)
for dataset_path in dataset_path_list:
if 'fold' in dataset_path or 'shirt' in dataset_path:
key = 'fold_shirt'
elif 'clean_table' in dataset_path and 'pick' not in dataset_path:
key = 'clean_table'
else:
key = 'others'
data_tasks_dict[key].append(dataset_path)
norm_stats_tasks = {k : None for k in data_tasks_dict.keys()}
for k,v in data_tasks_dict.items():
if len(v) > 0:
norm_stats_tasks[k], _ = get_norm_stats(v)
return norm_stats_tasks
def find_all_hdf5(dataset_dir, skip_mirrored_data, rank0_print=print):
hdf5_files = []
for root, dirs, files in os.walk(dataset_dir):
if 'pointcloud' in root: continue
for filename in fnmatch.filter(files, '*.hdf5'):
if 'features' in filename: continue
if skip_mirrored_data and 'mirror' in filename:
continue
hdf5_files.append(os.path.join(root, filename))
if len(hdf5_files) == 0:
rank0_print(f"{RED} Found 0 hdf5 datasets found in {dataset_dir} {RESET}")
exit(0)
rank0_print(f'Found {len(hdf5_files)} hdf5 files')
return hdf5_files
def BatchSampler(batch_size, episode_len_l, sample_weights):
sample_probs = np.array(sample_weights) / np.sum(sample_weights) if sample_weights is not None else None
sum_dataset_len_l = np.cumsum([0] + [np.sum(episode_len) for episode_len in episode_len_l])
while True:
batch = []
for _ in range(batch_size):
episode_idx = np.random.choice(len(episode_len_l), p=sample_probs)
step_idx = np.random.randint(sum_dataset_len_l[episode_idx], sum_dataset_len_l[episode_idx + 1])
batch.append(step_idx)
yield batch
def load_data(dataset_dir_l, name_filter, camera_names, batch_size_train, batch_size_val, chunk_size, config, rank0_print=print, skip_mirrored_data=False, policy_class=None, stats_dir_l=None, sample_weights=None, train_ratio=0.99, return_dataset=False, llava_pythia_process=None):
if type(dataset_dir_l) == str:
dataset_dir_l = [dataset_dir_l]
dataset_path_list_list = [find_all_hdf5(dataset_dir, skip_mirrored_data, rank0_print=rank0_print) for dataset_dir in dataset_dir_l]
for d,dpl in zip(dataset_dir_l, dataset_path_list_list):
if len(dpl) == 0:
rank0_print("#2"*20)
rank0_print(d)
num_episodes_0 = len(dataset_path_list_list[0])
dataset_path_list = flatten_list(dataset_path_list_list)
dataset_path_list = [n for n in dataset_path_list if name_filter(n)]
num_episodes_l = [len(dataset_path_list) for dataset_path_list in dataset_path_list_list]
num_episodes_cumsum = np.cumsum(num_episodes_l)
# obtain train test split on dataset_dir_l[0]
shuffled_episode_ids_0 = np.random.permutation(num_episodes_0)
train_episode_ids_0 = shuffled_episode_ids_0[:int(train_ratio * num_episodes_0)]
val_episode_ids_0 = shuffled_episode_ids_0[int(train_ratio * num_episodes_0):]
train_episode_ids_l = [train_episode_ids_0] + [np.arange(num_episodes) + num_episodes_cumsum[idx] for idx, num_episodes in enumerate(num_episodes_l[1:])]
val_episode_ids_l = [val_episode_ids_0]
#train_episode_ids_l = []
#val_episode_ids_l = []
#for idx, path_name in enumerate(dataset_path_list_list):
# num_episodes_i = len(dataset_path_list_list[idx])
# shuffled_episode_ids_i = np.random.permutation(num_episodes_i)
# train_episode_ids_i = shuffled_episode_ids_i[:int(train_ratio * num_episodes_i)]
# val_episode_ids_i = shuffled_episode_ids_i[int(train_ratio * num_episodes_i):]
# train_episode_ids_l.append(train_episode_ids_i)
# val_episode_ids_l.append(val_episode_ids_i)
train_episode_ids = np.concatenate(train_episode_ids_l)
val_episode_ids = np.concatenate(val_episode_ids_l)
rank0_print(f'\n\nData from: {dataset_dir_l}\n- Train on {[len(x) for x in train_episode_ids_l]} episodes\n- Test on {[len(x) for x in val_episode_ids_l]} episodes\n\n')
_, all_episode_len = get_norm_stats(dataset_path_list)
rank0_print(f"{RED}All images: {sum(all_episode_len)}, Trajectories: {len(all_episode_len)} {RESET}")
train_episode_len_l = [[all_episode_len[i] for i in train_episode_ids] for train_episode_ids in train_episode_ids_l]
val_episode_len_l = [[all_episode_len[i] for i in val_episode_ids] for val_episode_ids in val_episode_ids_l]
# if not 'co_train' in config['act_args'].task_name:
# val_episode_len_l = [[all_episode_len[i] for i in val_episode_ids] for val_episode_ids in val_episode_ids_l]
# else:
# _, all_episode_len2 = get_norm_stats(dataset_path_list_list[1])
# val_episode_len_l = [[all_episode_len2[i] for i in val_episode_ids] for val_episode_ids in val_episode_ids_l]
train_episode_len = flatten_list(train_episode_len_l)
val_episode_len = flatten_list(val_episode_len_l)
if stats_dir_l is None:
stats_dir_l = dataset_dir_l
elif type(stats_dir_l) == str:
stats_dir_l = [stats_dir_l]
# calculate norm stats across all episodes
norm_stats, _ = get_norm_stats(flatten_list([find_all_hdf5(stats_dir, skip_mirrored_data, rank0_print=rank0_print) for stats_dir in stats_dir_l]))
# calculate norm stats corresponding to each kind of task
# norm_stats = get_norm_stats_by_tasks(flatten_list([find_all_hdf5(stats_dir, skip_mirrored_data, rank0_print=rank0_print) for stats_dir in stats_dir_l]))
rank0_print(f'Norm stats from: {[each.split("/")[-1] for each in stats_dir_l]}')
rank0_print(f'train_episode_len_l: {train_episode_len_l}')
# print(f'train_episode_len: {train_episode_len}, val_episode_len: {val_episode_len}, train_episode_ids: {train_episode_ids}, val_episode_ids: {val_episode_ids}')
robot = 'aloha' if config['action_head_args'].action_dim == 14 or ('aloha' in config['training_args'].output_dir) else 'franka'
# construct dataset and dataloader
train_dataset = EpisodicDataset(dataset_path_list, camera_names, norm_stats, train_episode_ids, train_episode_len, chunk_size, policy_class, robot=robot, llava_pythia_process=llava_pythia_process, data_args=config['data_args'])
# val_dataset = EpisodicDataset(dataset_path_list, camera_names, norm_stats, val_episode_ids, val_episode_len, chunk_size, policy_class, robot=robot, llava_pythia_process=llava_pythia_process, data_args=config['data_args'])
sampler_params = {
'train': {"batch_size": batch_size_train, 'episode_len_l': train_episode_len_l, 'sample_weights':sample_weights, 'episode_first': config['data_args'].episode_first},
'eval': {"batch_size": batch_size_val, 'episode_len_l': val_episode_len_l, 'sample_weights': None, 'episode_first': config['data_args'].episode_first}
}
return train_dataset, None, norm_stats, sampler_params
def calibrate_linear_vel(base_action, c=None):
if c is None:
c = 0.0 # 0.19
v = base_action[..., 0]
w = base_action[..., 1]
base_action = base_action.copy()
base_action[..., 0] = v - c * w
return base_action
def smooth_base_action(base_action):
return np.stack([
np.convolve(base_action[:, i], np.ones(5)/5, mode='same') for i in range(base_action.shape[1])
], axis=-1).astype(np.float32)
def preprocess_base_action(base_action):
# base_action = calibrate_linear_vel(base_action)
base_action = smooth_base_action(base_action)
return base_action
def postprocess_base_action(base_action):
linear_vel, angular_vel = base_action
linear_vel *= 1.0
angular_vel *= 1.0
# angular_vel = 0
# if np.abs(linear_vel) < 0.05:
# linear_vel = 0
return np.array([linear_vel, angular_vel])
### env utils
def sample_box_pose():
x_range = [0.0, 0.2]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
cube_quat = np.array([1, 0, 0, 0])
return np.concatenate([cube_position, cube_quat])
def sample_insertion_pose():
# Peg
x_range = [0.1, 0.2]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
peg_quat = np.array([1, 0, 0, 0])
peg_pose = np.concatenate([peg_position, peg_quat])
# Socket
x_range = [-0.2, -0.1]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
socket_quat = np.array([1, 0, 0, 0])
socket_pose = np.concatenate([socket_position, socket_quat])
return peg_pose, socket_pose
### helper functions
def compute_dict_mean(epoch_dicts):
result = {k: None for k in epoch_dicts[0]}
num_items = len(epoch_dicts)
for k in result:
value_sum = 0
for epoch_dict in epoch_dicts:
value_sum += epoch_dict[k]
result[k] = value_sum / num_items
return result
def detach_dict(d):
new_d = dict()
for k, v in d.items():
new_d[k] = v.detach()
return new_d
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)