| import numpy as np |
| import torch |
| from scipy.ndimage import gaussian_filter |
| from scipy.signal import argrelextrema |
|
|
|
|
| @torch.no_grad() |
| def compute_camcoord_metrics(batch, pelvis_idxs=[1, 2], fps=30, mask=None): |
| """ |
| Args: |
| batch (dict): { |
| "pred_j3d": (..., J, 3) tensor |
| "target_j3d": |
| "pred_verts": |
| "target_verts": |
| } |
| Returns: |
| cam_coord_metrics (dict): { |
| "pa_mpjpe": (..., ) numpy array |
| "mpjpe": |
| "pve": |
| "accel": |
| } |
| """ |
| |
| pred_j3d = batch["pred_j3d"].cpu() |
| target_j3d = batch["target_j3d"].cpu() |
| pred_verts = batch["pred_verts"].cpu() |
| target_verts = batch["target_verts"].cpu() |
|
|
| if mask is not None: |
| mask = mask.cpu() |
| pred_j3d = pred_j3d[mask].clone() |
| target_j3d = target_j3d[mask].clone() |
| pred_verts = pred_verts[mask].clone() |
| target_verts = target_verts[mask].clone() |
| assert "mask" not in batch |
|
|
| |
| pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis( |
| [pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs=pelvis_idxs |
| ) |
|
|
| |
| m2mm = 1000 |
| S1_hat = batch_compute_similarity_transform_torch(pred_j3d, target_j3d) |
| pa_mpjpe = compute_jpe(S1_hat, target_j3d) * m2mm |
| mpjpe = compute_jpe(pred_j3d, target_j3d) * m2mm |
| pve = compute_jpe(pred_verts, target_verts) * m2mm |
| accel = compute_error_accel(joints_pred=pred_j3d, joints_gt=target_j3d, fps=fps) |
|
|
| camcoord_metrics = { |
| "pa_mpjpe": pa_mpjpe, |
| "mpjpe": mpjpe, |
| "pve": pve, |
| "accel": accel, |
| } |
| return camcoord_metrics |
|
|
|
|
| @torch.no_grad() |
| def compute_music_metrics(batch, mask=None): |
| """ |
| Args: |
| batch (dict): { |
| "pred_j3d": (..., J, 3) tensor |
| "target_j3d": |
| "music_beats": (T,) numpy array |
| } |
| Returns: |
| music_metrics (dict): { |
| "PFC": |
| } |
| """ |
| |
| pred_j3d_glob = batch["pred_j3d_glob"].cpu().numpy() |
| |
| up_dir = 1 |
| flat_dirs = [i for i in range(3) if i != up_dir] |
|
|
| DT = 1 / 30 |
| assert pred_j3d_glob.ndim == 3 |
|
|
| root_v = ( |
| pred_j3d_glob[1:, 0, :] - pred_j3d_glob[:-1, 0, :] |
| ) / DT |
| root_a = (root_v[1:, :] - root_v[:-1, :]) / DT |
|
|
| |
| root_a[:, up_dir] = np.maximum(root_a[:, up_dir], 0) |
| |
| root_a = np.linalg.norm(root_a, axis=-1) |
| scaling = root_a.max() |
| root_a = root_a / scaling |
|
|
| foot_idx = [7, 10, 8, 11] |
| feet = pred_j3d_glob[:, foot_idx, :] |
| foot_v = np.linalg.norm( |
| feet[2:, :, flat_dirs] - feet[1:-1, :, flat_dirs], axis=-1 |
| ) |
| foot_mins = np.zeros((len(foot_v), 2)) |
| foot_mins[:, 0] = np.minimum(foot_v[:, 0], foot_v[:, 1]) |
| foot_mins[:, 1] = np.minimum(foot_v[:, 2], foot_v[:, 3]) |
| foot_v = np.maximum(foot_mins, 0) |
|
|
| foot_loss = ( |
| foot_mins[:, 0] * foot_mins[:, 1] * root_a |
| ) |
| pfc = foot_loss.mean() * 10000 |
|
|
| |
| motion_beats = compute_motion_beats(pred_j3d_glob)[0] |
| music_beats = compute_music_beats(batch["music_beats"]) |
| ba = 0 |
| for bb in music_beats: |
| ba += np.exp(-np.min((motion_beats - bb) ** 2) / 2 / 9) |
| bas = ba / len(music_beats) |
| return { |
| "PFC": pfc, |
| "BAS": bas, |
| } |
|
|
|
|
| @torch.no_grad() |
| def compute_global_metrics(batch, mask=None): |
| """Follow WHAM, the input has skipped invalid frames |
| Args: |
| batch (dict): { |
| "pred_j3d_glob": (F, J, 3) tensor |
| "target_j3d_glob": |
| "pred_verts_glob": |
| "target_verts_glob": |
| } |
| Returns: |
| global_metrics (dict): { |
| "wa2_mpjpe": (F, ) numpy array |
| "waa_mpjpe": |
| "rte": |
| "jitter": |
| "fs": |
| } |
| """ |
| |
| pred_j3d_glob = batch["pred_j3d_glob"].cpu() |
| target_j3d_glob = batch["target_j3d_glob"].cpu() |
| pred_verts_glob = batch["pred_verts_glob"].cpu() |
| target_verts_glob = batch["target_verts_glob"].cpu() |
| if mask is not None: |
| mask = mask.cpu() |
| pred_j3d_glob = pred_j3d_glob[mask].clone() |
| target_j3d_glob = target_j3d_glob[mask].clone() |
| pred_verts_glob = pred_verts_glob[mask].clone() |
| target_verts_glob = target_verts_glob[mask].clone() |
| assert "mask" not in batch |
|
|
| seq_length = pred_j3d_glob.shape[0] |
|
|
| |
| chunk_length = 100 |
| wa2_mpjpe, waa_mpjpe = [], [] |
| for start in range(0, seq_length, chunk_length): |
| end = min(seq_length, start + chunk_length) |
|
|
| target_j3d = target_j3d_glob[start:end].clone().cpu() |
| pred_j3d = pred_j3d_glob[start:end].clone().cpu() |
|
|
| w_j3d = first_align_joints(target_j3d, pred_j3d) |
| wa_j3d = global_align_joints(target_j3d, pred_j3d) |
|
|
| wa2_mpjpe.append(compute_jpe(target_j3d, w_j3d)) |
| waa_mpjpe.append(compute_jpe(target_j3d, wa_j3d)) |
|
|
| |
| m2mm = 1000 |
| wa2_mpjpe = np.concatenate(wa2_mpjpe) * m2mm |
| waa_mpjpe = np.concatenate(waa_mpjpe) * m2mm |
|
|
| |
| rte = compute_rte(target_j3d_glob[:, 0].cpu(), pred_j3d_glob[:, 0].cpu()) * 1e2 |
| jitter = compute_jitter(pred_j3d_glob, fps=30) |
| foot_sliding = compute_foot_sliding(target_verts_glob, pred_verts_glob) * m2mm |
|
|
| global_metrics = { |
| "wa2_mpjpe": wa2_mpjpe, |
| "waa_mpjpe": waa_mpjpe, |
| "rte": rte, |
| "jitter": jitter, |
| "fs": foot_sliding, |
| } |
| return global_metrics |
|
|
|
|
| @torch.no_grad() |
| def compute_camcoord_perjoint_metrics(batch, pelvis_idxs=[1, 2]): |
| """ |
| Args: |
| batch (dict): { |
| "pred_j3d": (..., J, 3) tensor |
| "target_j3d": |
| } |
| Returns: |
| cam_coord_metrics (dict): { |
| "pa_mpjpe": (..., ) numpy array |
| "mpjpe": |
| "pve": |
| "accel": |
| } |
| """ |
| |
| pred_j3d = batch["pred_j3d"].cpu() |
| target_j3d = batch["target_j3d"].cpu() |
| pred_verts = batch["pred_verts"].cpu() |
| target_verts = batch["target_verts"].cpu() |
|
|
| |
| pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis( |
| [pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs=pelvis_idxs |
| ) |
| |
| m2mm = 1000 |
| perjoint_mpjpe = compute_perjoint_jpe(pred_j3d, target_j3d) * m2mm |
|
|
| camcoord_perjoint_metrics = { |
| "mpjpe": perjoint_mpjpe, |
| } |
| return camcoord_perjoint_metrics |
|
|
|
|
| |
|
|
|
|
| def compute_jpe(S1, S2): |
| return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).mean(dim=-1).numpy() |
|
|
|
|
| def compute_perjoint_jpe(S1, S2): |
| return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).numpy() |
|
|
|
|
| def batch_align_by_pelvis(data_list, pelvis_idxs=[1, 2]): |
| """ |
| Assumes data is given as [pred_j3d, target_j3d, pred_verts, target_verts]. |
| Each data is in shape of (frames, num_points, 3) |
| Pelvis is notated as one / two joints indices. |
| Align all data to the corresponding pelvis location. |
| """ |
|
|
| pred_j3d, target_j3d, pred_verts, target_verts = data_list |
|
|
| pred_pelvis = pred_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone() |
| target_pelvis = target_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone() |
|
|
| |
| pred_j3d = pred_j3d - pred_pelvis |
| target_j3d = target_j3d - target_pelvis |
| pred_verts = pred_verts - pred_pelvis |
| target_verts = target_verts - target_pelvis |
|
|
| return (pred_j3d, target_j3d, pred_verts, target_verts) |
|
|
|
|
| def batch_compute_similarity_transform_torch(S1, S2): |
| """ |
| Computes a similarity transform (sR, t) that takes |
| a set of 3D points S1 (3 x N) closest to a set of 3D points S2, |
| where R is an 3x3 rotation matrix, t 3x1 translation, s scale. |
| i.e. solves the orthogonal Procrutes problem. |
| """ |
| transposed = False |
| if S1.shape[0] != 3 and S1.shape[0] != 2: |
| S1 = S1.permute(0, 2, 1) |
| S2 = S2.permute(0, 2, 1) |
| transposed = True |
| assert S2.shape[1] == S1.shape[1] |
|
|
| |
| mu1 = S1.mean(axis=-1, keepdims=True) |
| mu2 = S2.mean(axis=-1, keepdims=True) |
|
|
| X1 = S1 - mu1 |
| X2 = S2 - mu2 |
|
|
| |
| var1 = torch.sum(X1**2, dim=1).sum(dim=1) |
|
|
| |
| K = X1.bmm(X2.permute(0, 2, 1)) |
|
|
| |
| |
| U, s, V = torch.svd(K) |
|
|
| |
| Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) |
| Z = Z.repeat(U.shape[0], 1, 1) |
| Z[:, -1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0, 2, 1)))) |
|
|
| |
| R = V.bmm(Z.bmm(U.permute(0, 2, 1))) |
|
|
| |
| scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 |
|
|
| |
| t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) |
|
|
| |
| S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t |
|
|
| if transposed: |
| S1_hat = S1_hat.permute(0, 2, 1) |
|
|
| return S1_hat |
|
|
|
|
| def batch_compute_scale_trans_torch(S1, S2): |
| """ |
| Computes a similarity transform (sR, t) that takes |
| a set of 3D points S1 (3 x N) closest to a set of 3D points S2, |
| where R is an 3x3 rotation matrix, t 3x1 translation, s scale. |
| i.e. solves the orthogonal Procrutes problem. |
| """ |
| transposed = False |
| if S1.shape[0] != 3 and S1.shape[0] != 2: |
| S1 = S1.permute(0, 2, 1) |
| S2 = S2.permute(0, 2, 1) |
| transposed = True |
| assert S2.shape[1] == S1.shape[1] |
|
|
| |
| mu1 = S1.mean(axis=-1, keepdims=True) |
| mu2 = S2.mean(axis=-1, keepdims=True) |
|
|
| X1 = S1 - mu1 |
| X2 = S2 - mu2 |
|
|
| |
| var1 = torch.sum(X1**2, dim=1).sum(dim=1) |
|
|
| |
| K = X1.bmm(X2.permute(0, 2, 1)) |
|
|
| |
| |
| U, s, V = torch.svd(K) |
|
|
| |
| Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) |
| Z = Z.repeat(U.shape[0], 1, 1) |
| Z[:, -1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0, 2, 1)))) |
|
|
| |
| R = V.bmm(Z.bmm(U.permute(0, 2, 1))) |
|
|
| |
| scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 |
|
|
| |
| t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) |
|
|
| return scale, t, R |
|
|
|
|
| def compute_error_accel(joints_gt, joints_pred, valid_mask=None, fps=None): |
| """ |
| Use [i-1, i, i+1] to compute acc at frame_i. The acceleration error: |
| 1/(n-2) \sum_{i=1}^{n-1} X_{i-1} - 2X_i + X_{i+1} |
| Note that for each frame that is not visible, three entries(-1, 0, +1) in the |
| acceleration error will be zero'd out. |
| Args: |
| joints_gt : (F, J, 3) |
| joints_pred : (F, J, 3) |
| valid_mask : (F) |
| Returns: |
| error_accel (F-2) when valid_mask is None, else (F'), F' <= F-2 |
| """ |
| |
| accel_gt = joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:] |
| accel_pred = joints_pred[:-2] - 2 * joints_pred[1:-1] + joints_pred[2:] |
| normed = np.linalg.norm(accel_pred - accel_gt, axis=-1).mean(axis=-1) |
| if fps is not None: |
| normed = normed * fps**2 |
|
|
| if valid_mask is None: |
| new_vis = np.ones(len(normed), dtype=bool) |
| else: |
| invis = np.logical_not(valid_mask) |
| invis1 = np.roll(invis, -1) |
| invis2 = np.roll(invis, -2) |
| new_invis = np.logical_or(invis, np.logical_or(invis1, invis2))[:-2] |
| new_vis = np.logical_not(new_invis) |
| if new_vis.sum() == 0: |
| print("Warning!!! no valid acceleration error to compute.") |
|
|
| return normed[new_vis] |
|
|
|
|
| def compute_rte(target_trans, pred_trans): |
| |
| _, rot, trans = align_pcl( |
| target_trans[None, :], pred_trans[None, :], fixed_scale=True |
| ) |
| pred_trans_hat = ( |
| torch.einsum("tij,tnj->tni", rot, pred_trans[None, :]) + trans[None, :] |
| )[0] |
|
|
| |
| disps, disp = [], 0 |
| for p1, p2 in zip(target_trans, target_trans[1:]): |
| delta = (p2 - p1).norm(2, dim=-1) |
| disp += delta |
| disps.append(disp) |
|
|
| |
| rte = torch.norm(target_trans - pred_trans_hat, 2, dim=-1) |
|
|
| |
| return (rte / disp).numpy() |
|
|
|
|
| def compute_jitter(joints, fps=30): |
| """compute jitter of the motion |
| Args: |
| joints (N, J, 3). |
| fps (float). |
| Returns: |
| jitter (N-3). |
| """ |
| pred_jitter = torch.norm( |
| (joints[3:] - 3 * joints[2:-1] + 3 * joints[1:-2] - joints[:-3]) * (fps**3), |
| dim=2, |
| ).mean(dim=-1) |
|
|
| return pred_jitter.cpu().numpy() / 10.0 |
|
|
|
|
| def compute_foot_sliding(target_verts, pred_verts, thr=1e-2): |
| """compute foot sliding error |
| The foot ground contact label is computed by the threshold of 1 cm/frame |
| Args: |
| target_verts (N, 6890, 3). |
| pred_verts (N, 6890, 3). |
| Returns: |
| error (N frames in contact). |
| """ |
| assert target_verts.shape == pred_verts.shape |
| assert target_verts.shape[-2] == 6890 |
|
|
| |
| foot_idxs = [3216, 3387, 6617, 6787] |
|
|
| |
| foot_loc = target_verts[:, foot_idxs] |
| foot_disp = (foot_loc[1:] - foot_loc[:-1]).norm(2, dim=-1) |
| contact = foot_disp[:] < thr |
|
|
| pred_feet_loc = pred_verts[:, foot_idxs] |
| pred_disp = (pred_feet_loc[1:] - pred_feet_loc[:-1]).norm(2, dim=-1) |
|
|
| error = pred_disp[contact] |
|
|
| return error.cpu().numpy() |
|
|
|
|
| def convert_joints22_to_24(joints22, ratio2220=0.3438, ratio2321=0.3345): |
| joints24 = torch.zeros(*joints22.shape[:-2], 24, 3).to(joints22.device) |
| joints24[..., :22, :] = joints22 |
| joints24[..., 22, :] = joints22[..., 20, :] + ratio2220 * ( |
| joints22[..., 20, :] - joints22[..., 18, :] |
| ) |
| joints24[..., 23, :] = joints22[..., 21, :] + ratio2321 * ( |
| joints22[..., 21, :] - joints22[..., 19, :] |
| ) |
| return joints24 |
|
|
|
|
| def align_pcl(Y, X, weight=None, fixed_scale=False): |
| """align similarity transform to align X with Y using umeyama method |
| X' = s * R * X + t is aligned with Y |
| :param Y (*, N, 3) first trajectory |
| :param X (*, N, 3) second trajectory |
| :param weight (*, N, 1) optional weight of valid correspondences |
| :returns s (*, 1), R (*, 3, 3), t (*, 3) |
| """ |
| *dims, N, _ = Y.shape |
| N = torch.ones(*dims, 1, 1) * N |
|
|
| if weight is not None: |
| Y = Y * weight |
| X = X * weight |
| N = weight.sum(dim=-2, keepdim=True) |
|
|
| |
| my = Y.sum(dim=-2) / N[..., 0] |
| mx = X.sum(dim=-2) / N[..., 0] |
| y0 = Y - my[..., None, :] |
| x0 = X - mx[..., None, :] |
|
|
| if weight is not None: |
| y0 = y0 * weight |
| x0 = x0 * weight |
|
|
| |
| C = torch.matmul(y0.transpose(-1, -2), x0) / N |
| U, D, Vh = torch.linalg.svd(C) |
|
|
| S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1) |
| neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0 |
| S[neg, 2, 2] = -1 |
|
|
| R = torch.matmul(U, torch.matmul(S, Vh)) |
|
|
| D = torch.diag_embed(D) |
| if fixed_scale: |
| s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32) |
| else: |
| var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N |
| s = ( |
| torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum( |
| dim=-1, keepdim=True |
| ) |
| / var[..., 0] |
| ) |
|
|
| t = my - s * torch.matmul(R, mx[..., None])[..., 0] |
|
|
| return s, R, t |
|
|
|
|
| def global_align_joints(gt_joints, pred_joints): |
| """ |
| :param gt_joints (T, J, 3) |
| :param pred_joints (T, J, 3) |
| """ |
| s_glob, R_glob, t_glob = align_pcl( |
| gt_joints.reshape(-1, 3), pred_joints.reshape(-1, 3) |
| ) |
| pred_glob = ( |
| s_glob * torch.einsum("ij,tnj->tni", R_glob, pred_joints) + t_glob[None, None] |
| ) |
| return pred_glob |
|
|
|
|
| def first_align_joints(gt_joints, pred_joints): |
| """ |
| align the first two frames |
| :param gt_joints (T, J, 3) |
| :param pred_joints (T, J, 3) |
| """ |
| |
| s_first, R_first, t_first = align_pcl( |
| gt_joints[:2].reshape(1, -1, 3), pred_joints[:2].reshape(1, -1, 3) |
| ) |
| pred_first = ( |
| s_first * torch.einsum("tij,tnj->tni", R_first, pred_joints) + t_first[:, None] |
| ) |
| return pred_first |
|
|
|
|
| def rearrange_by_mask(x, mask): |
| """ |
| x (L, *) |
| mask (M,), M >= L |
| """ |
| M = mask.size(0) |
| L = x.size(0) |
| if M == L: |
| return x |
| assert M > L |
| assert mask.sum() == L |
| x_rearranged = torch.zeros((M, *x.size()[1:]), dtype=x.dtype, device=x.device) |
| x_rearranged[mask] = x |
| return x_rearranged |
|
|
|
|
| def as_np_array(d): |
| if isinstance(d, torch.Tensor): |
| return d.cpu().numpy() |
| elif isinstance(d, np.ndarray): |
| return d |
| else: |
| return np.array(d) |
|
|
|
|
| def compute_motion_beats(keypoints): |
| keypoints = keypoints.reshape(-1, 24, 3) |
| kinetic_vel = np.mean( |
| np.sqrt(np.sum((keypoints[1:] - keypoints[:-1]) ** 2, axis=2)), axis=1 |
| ) |
| kinetic_vel = gaussian_filter(kinetic_vel, sigma=5) |
| motion_beats = argrelextrema(kinetic_vel, np.less) |
| return motion_beats |
|
|
|
|
| def compute_music_beats(beats): |
| beats = beats.astype(bool) |
| beat_axis = np.arange(len(beats)) |
| beat_axis = beat_axis[beats] |
|
|
| return beat_axis |
|
|