|
|
import torch |
|
|
|
|
|
|
|
|
def replace_constant(minibatch_pose_input, mask_start_frame): |
|
|
|
|
|
seq_len = minibatch_pose_input.size(1) |
|
|
interpolated = ( |
|
|
torch.ones_like(minibatch_pose_input, device=minibatch_pose_input.device) * 0.1 |
|
|
) |
|
|
|
|
|
if mask_start_frame == 0 or mask_start_frame == (seq_len - 1): |
|
|
interpolate_start = minibatch_pose_input[:, 0, :] |
|
|
interpolate_end = minibatch_pose_input[:, seq_len - 1, :] |
|
|
|
|
|
interpolated[:, 0, :] = interpolate_start |
|
|
interpolated[:, seq_len - 1, :] = interpolate_end |
|
|
|
|
|
assert torch.allclose(interpolated[:, 0, :], interpolate_start) |
|
|
assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end) |
|
|
|
|
|
else: |
|
|
interpolate_start1 = minibatch_pose_input[:, 0, :] |
|
|
interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :] |
|
|
|
|
|
interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :] |
|
|
interpolate_end2 = minibatch_pose_input[:, seq_len - 1, :] |
|
|
|
|
|
interpolated[:, 0, :] = interpolate_start1 |
|
|
interpolated[:, mask_start_frame, :] = interpolate_end1 |
|
|
|
|
|
interpolated[:, mask_start_frame, :] = interpolate_start2 |
|
|
interpolated[:, seq_len - 1, :] = interpolate_end2 |
|
|
|
|
|
assert torch.allclose(interpolated[:, 0, :], interpolate_start1) |
|
|
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1) |
|
|
|
|
|
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2) |
|
|
assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end2) |
|
|
return interpolated |
|
|
|
|
|
|
|
|
def slerp(x, y, a): |
|
|
""" |
|
|
Perfroms spherical linear interpolation (SLERP) between x and y, with proportion a |
|
|
|
|
|
:param x: quaternion tensor |
|
|
:param y: quaternion tensor |
|
|
:param a: indicator (between 0 and 1) of completion of the interpolation. |
|
|
:return: tensor of interpolation results |
|
|
""" |
|
|
device = x.device |
|
|
len = torch.sum(x * y, dim=-1) |
|
|
|
|
|
neg = len < 0.0 |
|
|
len[neg] = -len[neg] |
|
|
y[neg] = -y[neg] |
|
|
|
|
|
a = torch.zeros_like(x[..., 0]) + a |
|
|
amount0 = torch.zeros(a.shape, device=device) |
|
|
amount1 = torch.zeros(a.shape, device=device) |
|
|
|
|
|
linear = (1.0 - len) < 0.01 |
|
|
omegas = torch.arccos(len[~linear]) |
|
|
sinoms = torch.sin(omegas) |
|
|
|
|
|
amount0[linear] = 1.0 - a[linear] |
|
|
amount0[~linear] = torch.sin((1.0 - a[~linear]) * omegas) / sinoms |
|
|
|
|
|
amount1[linear] = a[linear] |
|
|
amount1[~linear] = torch.sin(a[~linear] * omegas) / sinoms |
|
|
|
|
|
res = amount0.unsqueeze(3) * x + amount1.unsqueeze(3) * y |
|
|
|
|
|
return res |
|
|
|
|
|
|
|
|
def slerp_input_repr(minibatch_pose_input, mask_start_frame): |
|
|
seq_len = minibatch_pose_input.size(1) |
|
|
minibatch_pose_input = minibatch_pose_input.reshape( |
|
|
minibatch_pose_input.size(0), seq_len, -1, 4 |
|
|
) |
|
|
interpolated = torch.zeros_like( |
|
|
minibatch_pose_input, device=minibatch_pose_input.device |
|
|
) |
|
|
|
|
|
if mask_start_frame == 0 or mask_start_frame == (seq_len - 1): |
|
|
interpolate_start = minibatch_pose_input[:, 0:1] |
|
|
interpolate_end = minibatch_pose_input[:, seq_len - 1 :] |
|
|
|
|
|
for i in range(seq_len): |
|
|
dt = 1 / (seq_len - 1) |
|
|
interpolated[:, i : i + 1, :] = slerp( |
|
|
interpolate_start, interpolate_end, dt * i |
|
|
) |
|
|
|
|
|
assert torch.allclose(interpolated[:, 0:1], interpolate_start) |
|
|
assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end) |
|
|
else: |
|
|
interpolate_start1 = minibatch_pose_input[:, 0:1] |
|
|
interpolate_end1 = minibatch_pose_input[ |
|
|
:, mask_start_frame : mask_start_frame + 1 |
|
|
] |
|
|
|
|
|
interpolate_start2 = minibatch_pose_input[ |
|
|
:, mask_start_frame : mask_start_frame + 1 |
|
|
] |
|
|
interpolate_end2 = minibatch_pose_input[:, seq_len - 1 :] |
|
|
|
|
|
for i in range(mask_start_frame + 1): |
|
|
dt = 1 / mask_start_frame |
|
|
interpolated[:, i : i + 1, :] = slerp( |
|
|
interpolate_start1, interpolate_end1, dt * i |
|
|
) |
|
|
|
|
|
assert torch.allclose(interpolated[:, 0:1], interpolate_start1) |
|
|
assert torch.allclose( |
|
|
interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_end1 |
|
|
) |
|
|
|
|
|
for i in range(mask_start_frame, seq_len): |
|
|
dt = 1 / (seq_len - mask_start_frame - 1) |
|
|
interpolated[:, i : i + 1, :] = slerp( |
|
|
interpolate_start2, interpolate_end2, dt * (i - mask_start_frame) |
|
|
) |
|
|
|
|
|
assert torch.allclose( |
|
|
interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_start2 |
|
|
) |
|
|
assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end2) |
|
|
|
|
|
interpolated = torch.nn.functional.normalize(interpolated, p=2.0, dim=3) |
|
|
return interpolated.reshape(minibatch_pose_input.size(0), seq_len, -1) |
|
|
|
|
|
|
|
|
def lerp_input_repr(minibatch_pose_input, mask_start_frame): |
|
|
seq_len = minibatch_pose_input.size(1) |
|
|
interpolated = torch.zeros_like( |
|
|
minibatch_pose_input, device=minibatch_pose_input.device |
|
|
) |
|
|
|
|
|
if mask_start_frame == 0 or mask_start_frame == (seq_len - 1): |
|
|
interpolate_start = minibatch_pose_input[:, 0, :] |
|
|
interpolate_end = minibatch_pose_input[:, seq_len - 1, :] |
|
|
|
|
|
for i in range(seq_len): |
|
|
dt = 1 / (seq_len - 1) |
|
|
interpolated[:, i, :] = torch.lerp( |
|
|
interpolate_start, interpolate_end, dt * i |
|
|
) |
|
|
|
|
|
assert torch.allclose(interpolated[:, 0, :], interpolate_start) |
|
|
assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end) |
|
|
else: |
|
|
interpolate_start1 = minibatch_pose_input[:, 0, :] |
|
|
interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :] |
|
|
|
|
|
interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :] |
|
|
interpolate_end2 = minibatch_pose_input[:, -1, :] |
|
|
|
|
|
for i in range(mask_start_frame + 1): |
|
|
dt = 1 / mask_start_frame |
|
|
interpolated[:, i, :] = torch.lerp( |
|
|
interpolate_start1, interpolate_end1, dt * i |
|
|
) |
|
|
|
|
|
assert torch.allclose(interpolated[:, 0, :], interpolate_start1) |
|
|
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1) |
|
|
|
|
|
for i in range(mask_start_frame, seq_len): |
|
|
dt = 1 / (seq_len - mask_start_frame - 1) |
|
|
interpolated[:, i, :] = torch.lerp( |
|
|
interpolate_start2, interpolate_end2, dt * (i - mask_start_frame) |
|
|
) |
|
|
|
|
|
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2) |
|
|
assert torch.allclose(interpolated[:, -1, :], interpolate_end2) |
|
|
return interpolated |
|
|
|
|
|
|
|
|
def vectorize_representation(global_position, global_rotation): |
|
|
|
|
|
batch_size = global_position.shape[0] |
|
|
seq_len = global_position.shape[1] |
|
|
|
|
|
global_pos_vec = global_position.reshape(batch_size, seq_len, -1).contiguous() |
|
|
global_rot_vec = global_rotation.reshape(batch_size, seq_len, -1).contiguous() |
|
|
|
|
|
global_pose_vec_gt = torch.cat([global_pos_vec, global_rot_vec], dim=2) |
|
|
return global_pose_vec_gt |
|
|
|