e-zorzi's picture
download
raw
6.02 kB
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
def load_target_jt(device, filename, offset=None, rng=None):
target_jt = np.load(f"data/{filename}", allow_pickle=True)
if len(target_jt.shape) == 2: # not 1 or 3
target_jt = target_jt[None, :]
target_jt = [torch.tensor(item.astype(np.float32)[:rng]).to(device) for item in target_jt]
size = torch.zeros(len(target_jt)).to(device)
for i, item in enumerate(target_jt):
size[i] = item.size(0)
padded_tensor = pad_sequence(target_jt, batch_first=True)
if offset is not None:
padded_tensor += offset
return padded_tensor, size
def sample_int_from_float(x):
if int(x) == x:
return int(x)
return int(x) if np.random.rand() < (x - int(x)) else int(x) + 1
def sample_wp(device, num_points, num_wp, ranges):
'''sample waypoints, relative to the starting point'''
# position
l_positions = torch.randn(num_points, 3) # left wrist positions
l_positions = l_positions / l_positions.norm(dim=-1, keepdim=True) * ranges.wrist_max_radius # within a sphere, [-radius, +radius]
l_positions = l_positions[l_positions[:,0] > ranges.l_wrist_pos_x[0]] # keep the ones that x > ranges.l_wrist_pos_x[0]
l_positions = l_positions[l_positions[:,0] < ranges.l_wrist_pos_x[1]] # keep the ones that x < ranges.l_wrist_pos_x[1]
l_positions = l_positions[l_positions[:,1] > ranges.l_wrist_pos_y[0]] # keep the ones that y > ranges.l_wrist_pos_y[0]
l_positions = l_positions[l_positions[:,1] < ranges.l_wrist_pos_y[1]] # keep the ones that y < ranges.l_wrist_pos_y[1]
l_positions = l_positions[l_positions[:,2] > ranges.l_wrist_pos_z[0]] # keep the ones that z > ranges.l_wrist_pos_z[0]
l_positions = l_positions[l_positions[:,2] < ranges.l_wrist_pos_z[1]] # keep the ones that z < ranges.l_wrist_pos_z[1]
r_positions = torch.randn(num_points, 3) # right wrist positions
r_positions = r_positions / r_positions.norm(dim=-1, keepdim=True) * ranges.wrist_max_radius # within a sphere, [-radius, +radius]
r_positions = r_positions[r_positions[:,0] > ranges.r_wrist_pos_x[0]] # keep the ones that x > ranges.r_wrist_pos_x[0]
r_positions = r_positions[r_positions[:,0] < ranges.r_wrist_pos_x[1]] # keep the ones that x < ranges.r_wrist_pos_x[1]
r_positions = r_positions[r_positions[:,1] > ranges.r_wrist_pos_y[0]] # keep the ones that y > ranges.r_wrist_pos_y[0]
r_positions = r_positions[r_positions[:,1] < ranges.r_wrist_pos_y[1]] # keep the ones that y < ranges.r_wrist_pos_y[1]
r_positions = r_positions[r_positions[:,2] > ranges.r_wrist_pos_z[0]] # keep the ones that z > ranges.r_wrist_pos_z[0]
r_positions = r_positions[r_positions[:,2] < ranges.r_wrist_pos_z[1]] # keep the ones that z < ranges.r_wrist_pos_z[1]
num_pairs = min(l_positions.size(0), r_positions.size(0))
positions = torch.stack([l_positions[:num_pairs], r_positions[:num_pairs]], dim=1) # (num_pairs, 2, 3)
# rotation (quaternion)
quaternions = torch.randn(num_pairs, 2, 4)
quaternions = quaternions / quaternions.norm(dim=-1, keepdim=True)
# concat
wp = torch.cat([positions, quaternions], dim=-1) # (num_pairs, 2, 7)
# repeat for num_wp
wp = wp.unsqueeze(1).repeat(1, num_wp, 1, 1) # (num_pairs, num_wp, 2, 7)
print("===> [sample_wp] return shape:", wp.shape)
return wp.to(device), num_pairs, num_wp
def sample_rp(device, num_points, num_wp, ranges):
"""sample reach points"""
wp, num_pairs, num_wp = sample_wp(device, num_points, num_wp, ranges)
center_positions = (torch.rand(num_pairs, 3) * ranges.max_center_distance).to(device)
center_positions[:, 0] = torch.clamp(center_positions[:, 0], ranges.center_offset_x[0], ranges.center_offset_x[1])
center_positions[:, 1] = torch.clamp(center_positions[:, 1], ranges.center_offset_y[0], ranges.center_offset_y[1])
center_positions[:, 2] = torch.clamp(center_positions[:, 2], ranges.center_offset_z[0], ranges.center_offset_z[1])
center_positions = center_positions.unsqueeze(1).repeat(1, num_wp, 1) # (num_pairs, num_wp, 3)
center_positions = center_positions.unsqueeze(2).repeat(1, 1, 2, 1)
wp[:, :, :, :3] += center_positions
print("===> [sample_rp] return shape:", wp.shape)
return wp.to(device), num_pairs, num_wp
def sample_fp(device, num_points, num_wp, ranges):
'''sample feet waypoints'''
# left foot still, right foot move, [num_points//2, 2]
l_positions_s = torch.zeros(num_points//2, 2) # left foot positions (xy)
r_positions_m = torch.randn(num_points//2, 2)
r_positions_m = r_positions_m / r_positions_m.norm(dim=-1, keepdim=True) * ranges.feet_max_radius # within a sphere, [-radius, +radius]
# right foot still, left foot move, [num_points//2, 2]
r_positions_s = torch.zeros(num_points//2, 2) # right foot positions (xy)
l_positions_m = torch.randn(num_points//2, 2)
l_positions_m = l_positions_m / l_positions_m.norm(dim=-1, keepdim=True) * ranges.feet_max_radius # within a sphere, [-radius, +radius]
# concat
l_positions = torch.cat([l_positions_s, l_positions_m], dim=0) # (num_points, 2)
r_positions = torch.cat([r_positions_m, r_positions_s], dim=0) # (num_points, 2)
wp = torch.stack([l_positions, r_positions], dim=1) # (num_points, 2, 2)
wp = wp.unsqueeze(1).repeat(1, num_wp, 1, 1) # (num_points, num_wp, 2, 2)
print("===> [sample_fp] return shape:", wp.shape)
return wp.to(device), num_points, num_wp
def sample_root_height(device, num_points, num_wp, ranges, base_height_target):
'''sample root height'''
root_height = torch.randn(num_points, 1) * ranges.root_height_std + base_height_target
root_height = torch.clamp(root_height, ranges.min_root_height, ranges.max_root_height)
root_height = root_height.unsqueeze(1).repeat(1, num_wp, 1) # (num_points, num_wp, 1)
print("===> [sample_root_height] return shape:", root_height.shape)
return root_height.to(device), num_points, num_wp

Xet Storage Details

Size:
6.02 kB
·
Xet hash:
97d04e82cc9fd7f3d4476399e34a076cb49f92b5bb947a16d2447b0cf257da9b

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.