|
|
import numpy as np |
|
|
import torch |
|
|
from scipy.ndimage import gaussian_filter |
|
|
from scipy.signal import argrelextrema |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_camcoord_metrics(batch, pelvis_idxs=[1, 2], fps=30, mask=None): |
|
|
""" |
|
|
Args: |
|
|
batch (dict): { |
|
|
"pred_j3d": (..., J, 3) tensor |
|
|
"target_j3d": |
|
|
"pred_verts": |
|
|
"target_verts": |
|
|
} |
|
|
Returns: |
|
|
cam_coord_metrics (dict): { |
|
|
"pa_mpjpe": (..., ) numpy array |
|
|
"mpjpe": |
|
|
"pve": |
|
|
"accel": |
|
|
} |
|
|
""" |
|
|
|
|
|
pred_j3d = batch["pred_j3d"].cpu() |
|
|
target_j3d = batch["target_j3d"].cpu() |
|
|
pred_verts = batch["pred_verts"].cpu() |
|
|
target_verts = batch["target_verts"].cpu() |
|
|
|
|
|
if mask is not None: |
|
|
mask = mask.cpu() |
|
|
pred_j3d = pred_j3d[mask].clone() |
|
|
target_j3d = target_j3d[mask].clone() |
|
|
pred_verts = pred_verts[mask].clone() |
|
|
target_verts = target_verts[mask].clone() |
|
|
assert "mask" not in batch |
|
|
|
|
|
|
|
|
pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis( |
|
|
[pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs=pelvis_idxs |
|
|
) |
|
|
|
|
|
|
|
|
m2mm = 1000 |
|
|
S1_hat = batch_compute_similarity_transform_torch(pred_j3d, target_j3d) |
|
|
pa_mpjpe = compute_jpe(S1_hat, target_j3d) * m2mm |
|
|
mpjpe = compute_jpe(pred_j3d, target_j3d) * m2mm |
|
|
pve = compute_jpe(pred_verts, target_verts) * m2mm |
|
|
accel = compute_error_accel(joints_pred=pred_j3d, joints_gt=target_j3d, fps=fps) |
|
|
|
|
|
camcoord_metrics = { |
|
|
"pa_mpjpe": pa_mpjpe, |
|
|
"mpjpe": mpjpe, |
|
|
"pve": pve, |
|
|
"accel": accel, |
|
|
} |
|
|
return camcoord_metrics |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_music_metrics(batch, mask=None): |
|
|
""" |
|
|
Args: |
|
|
batch (dict): { |
|
|
"pred_j3d": (..., J, 3) tensor |
|
|
"target_j3d": |
|
|
"music_beats": (T,) numpy array |
|
|
} |
|
|
Returns: |
|
|
music_metrics (dict): { |
|
|
"PFC": |
|
|
} |
|
|
""" |
|
|
|
|
|
pred_j3d_glob = batch["pred_j3d_glob"].cpu().numpy() |
|
|
|
|
|
up_dir = 1 |
|
|
flat_dirs = [i for i in range(3) if i != up_dir] |
|
|
|
|
|
DT = 1 / 30 |
|
|
assert pred_j3d_glob.ndim == 3 |
|
|
|
|
|
root_v = ( |
|
|
pred_j3d_glob[1:, 0, :] - pred_j3d_glob[:-1, 0, :] |
|
|
) / DT |
|
|
root_a = (root_v[1:, :] - root_v[:-1, :]) / DT |
|
|
|
|
|
|
|
|
root_a[:, up_dir] = np.maximum(root_a[:, up_dir], 0) |
|
|
|
|
|
root_a = np.linalg.norm(root_a, axis=-1) |
|
|
scaling = root_a.max() |
|
|
root_a = root_a / scaling |
|
|
|
|
|
foot_idx = [7, 10, 8, 11] |
|
|
feet = pred_j3d_glob[:, foot_idx, :] |
|
|
foot_v = np.linalg.norm( |
|
|
feet[2:, :, flat_dirs] - feet[1:-1, :, flat_dirs], axis=-1 |
|
|
) |
|
|
foot_mins = np.zeros((len(foot_v), 2)) |
|
|
foot_mins[:, 0] = np.minimum(foot_v[:, 0], foot_v[:, 1]) |
|
|
foot_mins[:, 1] = np.minimum(foot_v[:, 2], foot_v[:, 3]) |
|
|
foot_v = np.maximum(foot_mins, 0) |
|
|
|
|
|
foot_loss = ( |
|
|
foot_mins[:, 0] * foot_mins[:, 1] * root_a |
|
|
) |
|
|
pfc = foot_loss.mean() * 10000 |
|
|
|
|
|
|
|
|
motion_beats = compute_motion_beats(pred_j3d_glob)[0] |
|
|
music_beats = compute_music_beats(batch["music_beats"]) |
|
|
ba = 0 |
|
|
for bb in music_beats: |
|
|
ba += np.exp(-np.min((motion_beats - bb) ** 2) / 2 / 9) |
|
|
bas = ba / len(music_beats) |
|
|
return { |
|
|
"PFC": pfc, |
|
|
"BAS": bas, |
|
|
} |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_global_metrics(batch, mask=None): |
|
|
"""Follow WHAM, the input has skipped invalid frames |
|
|
Args: |
|
|
batch (dict): { |
|
|
"pred_j3d_glob": (F, J, 3) tensor |
|
|
"target_j3d_glob": |
|
|
"pred_verts_glob": |
|
|
"target_verts_glob": |
|
|
} |
|
|
Returns: |
|
|
global_metrics (dict): { |
|
|
"wa2_mpjpe": (F, ) numpy array |
|
|
"waa_mpjpe": |
|
|
"rte": |
|
|
"jitter": |
|
|
"fs": |
|
|
} |
|
|
""" |
|
|
|
|
|
pred_j3d_glob = batch["pred_j3d_glob"].cpu() |
|
|
target_j3d_glob = batch["target_j3d_glob"].cpu() |
|
|
pred_verts_glob = batch["pred_verts_glob"].cpu() |
|
|
target_verts_glob = batch["target_verts_glob"].cpu() |
|
|
if mask is not None: |
|
|
mask = mask.cpu() |
|
|
pred_j3d_glob = pred_j3d_glob[mask].clone() |
|
|
target_j3d_glob = target_j3d_glob[mask].clone() |
|
|
pred_verts_glob = pred_verts_glob[mask].clone() |
|
|
target_verts_glob = target_verts_glob[mask].clone() |
|
|
assert "mask" not in batch |
|
|
|
|
|
seq_length = pred_j3d_glob.shape[0] |
|
|
|
|
|
|
|
|
chunk_length = 100 |
|
|
wa2_mpjpe, waa_mpjpe = [], [] |
|
|
for start in range(0, seq_length, chunk_length): |
|
|
end = min(seq_length, start + chunk_length) |
|
|
|
|
|
target_j3d = target_j3d_glob[start:end].clone().cpu() |
|
|
pred_j3d = pred_j3d_glob[start:end].clone().cpu() |
|
|
|
|
|
w_j3d = first_align_joints(target_j3d, pred_j3d) |
|
|
wa_j3d = global_align_joints(target_j3d, pred_j3d) |
|
|
|
|
|
wa2_mpjpe.append(compute_jpe(target_j3d, w_j3d)) |
|
|
waa_mpjpe.append(compute_jpe(target_j3d, wa_j3d)) |
|
|
|
|
|
|
|
|
m2mm = 1000 |
|
|
wa2_mpjpe = np.concatenate(wa2_mpjpe) * m2mm |
|
|
waa_mpjpe = np.concatenate(waa_mpjpe) * m2mm |
|
|
|
|
|
|
|
|
rte = compute_rte(target_j3d_glob[:, 0].cpu(), pred_j3d_glob[:, 0].cpu()) * 1e2 |
|
|
jitter = compute_jitter(pred_j3d_glob, fps=30) |
|
|
foot_sliding = compute_foot_sliding(target_verts_glob, pred_verts_glob) * m2mm |
|
|
|
|
|
global_metrics = { |
|
|
"wa2_mpjpe": wa2_mpjpe, |
|
|
"waa_mpjpe": waa_mpjpe, |
|
|
"rte": rte, |
|
|
"jitter": jitter, |
|
|
"fs": foot_sliding, |
|
|
} |
|
|
return global_metrics |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_camcoord_perjoint_metrics(batch, pelvis_idxs=[1, 2]): |
|
|
""" |
|
|
Args: |
|
|
batch (dict): { |
|
|
"pred_j3d": (..., J, 3) tensor |
|
|
"target_j3d": |
|
|
} |
|
|
Returns: |
|
|
cam_coord_metrics (dict): { |
|
|
"pa_mpjpe": (..., ) numpy array |
|
|
"mpjpe": |
|
|
"pve": |
|
|
"accel": |
|
|
} |
|
|
""" |
|
|
|
|
|
pred_j3d = batch["pred_j3d"].cpu() |
|
|
target_j3d = batch["target_j3d"].cpu() |
|
|
pred_verts = batch["pred_verts"].cpu() |
|
|
target_verts = batch["target_verts"].cpu() |
|
|
|
|
|
|
|
|
pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis( |
|
|
[pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs=pelvis_idxs |
|
|
) |
|
|
|
|
|
m2mm = 1000 |
|
|
perjoint_mpjpe = compute_perjoint_jpe(pred_j3d, target_j3d) * m2mm |
|
|
|
|
|
camcoord_perjoint_metrics = { |
|
|
"mpjpe": perjoint_mpjpe, |
|
|
} |
|
|
return camcoord_perjoint_metrics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_jpe(S1, S2): |
|
|
return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).mean(dim=-1).numpy() |
|
|
|
|
|
|
|
|
def compute_perjoint_jpe(S1, S2): |
|
|
return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).numpy() |
|
|
|
|
|
|
|
|
def batch_align_by_pelvis(data_list, pelvis_idxs=[1, 2]): |
|
|
""" |
|
|
Assumes data is given as [pred_j3d, target_j3d, pred_verts, target_verts]. |
|
|
Each data is in shape of (frames, num_points, 3) |
|
|
Pelvis is notated as one / two joints indices. |
|
|
Align all data to the corresponding pelvis location. |
|
|
""" |
|
|
|
|
|
pred_j3d, target_j3d, pred_verts, target_verts = data_list |
|
|
|
|
|
pred_pelvis = pred_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone() |
|
|
target_pelvis = target_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone() |
|
|
|
|
|
|
|
|
pred_j3d = pred_j3d - pred_pelvis |
|
|
target_j3d = target_j3d - target_pelvis |
|
|
pred_verts = pred_verts - pred_pelvis |
|
|
target_verts = target_verts - target_pelvis |
|
|
|
|
|
return (pred_j3d, target_j3d, pred_verts, target_verts) |
|
|
|
|
|
|
|
|
def batch_compute_similarity_transform_torch(S1, S2): |
|
|
""" |
|
|
Computes a similarity transform (sR, t) that takes |
|
|
a set of 3D points S1 (3 x N) closest to a set of 3D points S2, |
|
|
where R is an 3x3 rotation matrix, t 3x1 translation, s scale. |
|
|
i.e. solves the orthogonal Procrutes problem. |
|
|
""" |
|
|
transposed = False |
|
|
if S1.shape[0] != 3 and S1.shape[0] != 2: |
|
|
S1 = S1.permute(0, 2, 1) |
|
|
S2 = S2.permute(0, 2, 1) |
|
|
transposed = True |
|
|
assert S2.shape[1] == S1.shape[1] |
|
|
|
|
|
|
|
|
mu1 = S1.mean(axis=-1, keepdims=True) |
|
|
mu2 = S2.mean(axis=-1, keepdims=True) |
|
|
|
|
|
X1 = S1 - mu1 |
|
|
X2 = S2 - mu2 |
|
|
|
|
|
|
|
|
var1 = torch.sum(X1**2, dim=1).sum(dim=1) |
|
|
|
|
|
|
|
|
K = X1.bmm(X2.permute(0, 2, 1)) |
|
|
|
|
|
|
|
|
|
|
|
U, s, V = torch.svd(K) |
|
|
|
|
|
|
|
|
Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) |
|
|
Z = Z.repeat(U.shape[0], 1, 1) |
|
|
Z[:, -1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0, 2, 1)))) |
|
|
|
|
|
|
|
|
R = V.bmm(Z.bmm(U.permute(0, 2, 1))) |
|
|
|
|
|
|
|
|
scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 |
|
|
|
|
|
|
|
|
t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) |
|
|
|
|
|
|
|
|
S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t |
|
|
|
|
|
if transposed: |
|
|
S1_hat = S1_hat.permute(0, 2, 1) |
|
|
|
|
|
return S1_hat |
|
|
|
|
|
|
|
|
def batch_compute_scale_trans_torch(S1, S2): |
|
|
""" |
|
|
Computes a similarity transform (sR, t) that takes |
|
|
a set of 3D points S1 (3 x N) closest to a set of 3D points S2, |
|
|
where R is an 3x3 rotation matrix, t 3x1 translation, s scale. |
|
|
i.e. solves the orthogonal Procrutes problem. |
|
|
""" |
|
|
transposed = False |
|
|
if S1.shape[0] != 3 and S1.shape[0] != 2: |
|
|
S1 = S1.permute(0, 2, 1) |
|
|
S2 = S2.permute(0, 2, 1) |
|
|
transposed = True |
|
|
assert S2.shape[1] == S1.shape[1] |
|
|
|
|
|
|
|
|
mu1 = S1.mean(axis=-1, keepdims=True) |
|
|
mu2 = S2.mean(axis=-1, keepdims=True) |
|
|
|
|
|
X1 = S1 - mu1 |
|
|
X2 = S2 - mu2 |
|
|
|
|
|
|
|
|
var1 = torch.sum(X1**2, dim=1).sum(dim=1) |
|
|
|
|
|
|
|
|
K = X1.bmm(X2.permute(0, 2, 1)) |
|
|
|
|
|
|
|
|
|
|
|
U, s, V = torch.svd(K) |
|
|
|
|
|
|
|
|
Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) |
|
|
Z = Z.repeat(U.shape[0], 1, 1) |
|
|
Z[:, -1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0, 2, 1)))) |
|
|
|
|
|
|
|
|
R = V.bmm(Z.bmm(U.permute(0, 2, 1))) |
|
|
|
|
|
|
|
|
scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 |
|
|
|
|
|
|
|
|
t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) |
|
|
|
|
|
return scale, t, R |
|
|
|
|
|
|
|
|
def compute_error_accel(joints_gt, joints_pred, valid_mask=None, fps=None): |
|
|
""" |
|
|
Use [i-1, i, i+1] to compute acc at frame_i. The acceleration error: |
|
|
1/(n-2) \sum_{i=1}^{n-1} X_{i-1} - 2X_i + X_{i+1} |
|
|
Note that for each frame that is not visible, three entries(-1, 0, +1) in the |
|
|
acceleration error will be zero'd out. |
|
|
Args: |
|
|
joints_gt : (F, J, 3) |
|
|
joints_pred : (F, J, 3) |
|
|
valid_mask : (F) |
|
|
Returns: |
|
|
error_accel (F-2) when valid_mask is None, else (F'), F' <= F-2 |
|
|
""" |
|
|
|
|
|
accel_gt = joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:] |
|
|
accel_pred = joints_pred[:-2] - 2 * joints_pred[1:-1] + joints_pred[2:] |
|
|
normed = np.linalg.norm(accel_pred - accel_gt, axis=-1).mean(axis=-1) |
|
|
if fps is not None: |
|
|
normed = normed * fps**2 |
|
|
|
|
|
if valid_mask is None: |
|
|
new_vis = np.ones(len(normed), dtype=bool) |
|
|
else: |
|
|
invis = np.logical_not(valid_mask) |
|
|
invis1 = np.roll(invis, -1) |
|
|
invis2 = np.roll(invis, -2) |
|
|
new_invis = np.logical_or(invis, np.logical_or(invis1, invis2))[:-2] |
|
|
new_vis = np.logical_not(new_invis) |
|
|
if new_vis.sum() == 0: |
|
|
print("Warning!!! no valid acceleration error to compute.") |
|
|
|
|
|
return normed[new_vis] |
|
|
|
|
|
|
|
|
def compute_rte(target_trans, pred_trans): |
|
|
|
|
|
_, rot, trans = align_pcl( |
|
|
target_trans[None, :], pred_trans[None, :], fixed_scale=True |
|
|
) |
|
|
pred_trans_hat = ( |
|
|
torch.einsum("tij,tnj->tni", rot, pred_trans[None, :]) + trans[None, :] |
|
|
)[0] |
|
|
|
|
|
|
|
|
disps, disp = [], 0 |
|
|
for p1, p2 in zip(target_trans, target_trans[1:]): |
|
|
delta = (p2 - p1).norm(2, dim=-1) |
|
|
disp += delta |
|
|
disps.append(disp) |
|
|
|
|
|
|
|
|
rte = torch.norm(target_trans - pred_trans_hat, 2, dim=-1) |
|
|
|
|
|
|
|
|
return (rte / disp).numpy() |
|
|
|
|
|
|
|
|
def compute_jitter(joints, fps=30): |
|
|
"""compute jitter of the motion |
|
|
Args: |
|
|
joints (N, J, 3). |
|
|
fps (float). |
|
|
Returns: |
|
|
jitter (N-3). |
|
|
""" |
|
|
pred_jitter = torch.norm( |
|
|
(joints[3:] - 3 * joints[2:-1] + 3 * joints[1:-2] - joints[:-3]) * (fps**3), |
|
|
dim=2, |
|
|
).mean(dim=-1) |
|
|
|
|
|
return pred_jitter.cpu().numpy() / 10.0 |
|
|
|
|
|
|
|
|
def compute_foot_sliding(target_verts, pred_verts, thr=1e-2): |
|
|
"""compute foot sliding error |
|
|
The foot ground contact label is computed by the threshold of 1 cm/frame |
|
|
Args: |
|
|
target_verts (N, 6890, 3). |
|
|
pred_verts (N, 6890, 3). |
|
|
Returns: |
|
|
error (N frames in contact). |
|
|
""" |
|
|
assert target_verts.shape == pred_verts.shape |
|
|
assert target_verts.shape[-2] == 6890 |
|
|
|
|
|
|
|
|
foot_idxs = [3216, 3387, 6617, 6787] |
|
|
|
|
|
|
|
|
foot_loc = target_verts[:, foot_idxs] |
|
|
foot_disp = (foot_loc[1:] - foot_loc[:-1]).norm(2, dim=-1) |
|
|
contact = foot_disp[:] < thr |
|
|
|
|
|
pred_feet_loc = pred_verts[:, foot_idxs] |
|
|
pred_disp = (pred_feet_loc[1:] - pred_feet_loc[:-1]).norm(2, dim=-1) |
|
|
|
|
|
error = pred_disp[contact] |
|
|
|
|
|
return error.cpu().numpy() |
|
|
|
|
|
|
|
|
def convert_joints22_to_24(joints22, ratio2220=0.3438, ratio2321=0.3345): |
|
|
joints24 = torch.zeros(*joints22.shape[:-2], 24, 3).to(joints22.device) |
|
|
joints24[..., :22, :] = joints22 |
|
|
joints24[..., 22, :] = joints22[..., 20, :] + ratio2220 * ( |
|
|
joints22[..., 20, :] - joints22[..., 18, :] |
|
|
) |
|
|
joints24[..., 23, :] = joints22[..., 21, :] + ratio2321 * ( |
|
|
joints22[..., 21, :] - joints22[..., 19, :] |
|
|
) |
|
|
return joints24 |
|
|
|
|
|
|
|
|
def align_pcl(Y, X, weight=None, fixed_scale=False): |
|
|
"""align similarity transform to align X with Y using umeyama method |
|
|
X' = s * R * X + t is aligned with Y |
|
|
:param Y (*, N, 3) first trajectory |
|
|
:param X (*, N, 3) second trajectory |
|
|
:param weight (*, N, 1) optional weight of valid correspondences |
|
|
:returns s (*, 1), R (*, 3, 3), t (*, 3) |
|
|
""" |
|
|
*dims, N, _ = Y.shape |
|
|
N = torch.ones(*dims, 1, 1) * N |
|
|
|
|
|
if weight is not None: |
|
|
Y = Y * weight |
|
|
X = X * weight |
|
|
N = weight.sum(dim=-2, keepdim=True) |
|
|
|
|
|
|
|
|
my = Y.sum(dim=-2) / N[..., 0] |
|
|
mx = X.sum(dim=-2) / N[..., 0] |
|
|
y0 = Y - my[..., None, :] |
|
|
x0 = X - mx[..., None, :] |
|
|
|
|
|
if weight is not None: |
|
|
y0 = y0 * weight |
|
|
x0 = x0 * weight |
|
|
|
|
|
|
|
|
C = torch.matmul(y0.transpose(-1, -2), x0) / N |
|
|
U, D, Vh = torch.linalg.svd(C) |
|
|
|
|
|
S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1) |
|
|
neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0 |
|
|
S[neg, 2, 2] = -1 |
|
|
|
|
|
R = torch.matmul(U, torch.matmul(S, Vh)) |
|
|
|
|
|
D = torch.diag_embed(D) |
|
|
if fixed_scale: |
|
|
s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32) |
|
|
else: |
|
|
var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N |
|
|
s = ( |
|
|
torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum( |
|
|
dim=-1, keepdim=True |
|
|
) |
|
|
/ var[..., 0] |
|
|
) |
|
|
|
|
|
t = my - s * torch.matmul(R, mx[..., None])[..., 0] |
|
|
|
|
|
return s, R, t |
|
|
|
|
|
|
|
|
def global_align_joints(gt_joints, pred_joints): |
|
|
""" |
|
|
:param gt_joints (T, J, 3) |
|
|
:param pred_joints (T, J, 3) |
|
|
""" |
|
|
s_glob, R_glob, t_glob = align_pcl( |
|
|
gt_joints.reshape(-1, 3), pred_joints.reshape(-1, 3) |
|
|
) |
|
|
pred_glob = ( |
|
|
s_glob * torch.einsum("ij,tnj->tni", R_glob, pred_joints) + t_glob[None, None] |
|
|
) |
|
|
return pred_glob |
|
|
|
|
|
|
|
|
def first_align_joints(gt_joints, pred_joints): |
|
|
""" |
|
|
align the first two frames |
|
|
:param gt_joints (T, J, 3) |
|
|
:param pred_joints (T, J, 3) |
|
|
""" |
|
|
|
|
|
s_first, R_first, t_first = align_pcl( |
|
|
gt_joints[:2].reshape(1, -1, 3), pred_joints[:2].reshape(1, -1, 3) |
|
|
) |
|
|
pred_first = ( |
|
|
s_first * torch.einsum("tij,tnj->tni", R_first, pred_joints) + t_first[:, None] |
|
|
) |
|
|
return pred_first |
|
|
|
|
|
|
|
|
def rearrange_by_mask(x, mask): |
|
|
""" |
|
|
x (L, *) |
|
|
mask (M,), M >= L |
|
|
""" |
|
|
M = mask.size(0) |
|
|
L = x.size(0) |
|
|
if M == L: |
|
|
return x |
|
|
assert M > L |
|
|
assert mask.sum() == L |
|
|
x_rearranged = torch.zeros((M, *x.size()[1:]), dtype=x.dtype, device=x.device) |
|
|
x_rearranged[mask] = x |
|
|
return x_rearranged |
|
|
|
|
|
|
|
|
def as_np_array(d): |
|
|
if isinstance(d, torch.Tensor): |
|
|
return d.cpu().numpy() |
|
|
elif isinstance(d, np.ndarray): |
|
|
return d |
|
|
else: |
|
|
return np.array(d) |
|
|
|
|
|
|
|
|
def compute_motion_beats(keypoints): |
|
|
keypoints = keypoints.reshape(-1, 24, 3) |
|
|
kinetic_vel = np.mean( |
|
|
np.sqrt(np.sum((keypoints[1:] - keypoints[:-1]) ** 2, axis=2)), axis=1 |
|
|
) |
|
|
kinetic_vel = gaussian_filter(kinetic_vel, sigma=5) |
|
|
motion_beats = argrelextrema(kinetic_vel, np.less) |
|
|
return motion_beats |
|
|
|
|
|
|
|
|
def compute_music_beats(beats): |
|
|
beats = beats.astype(bool) |
|
|
beat_axis = np.arange(len(beats)) |
|
|
beat_axis = beat_axis[beats] |
|
|
|
|
|
return beat_axis |
|
|
|