Spaces:
Sleeping
Sleeping
| # Some functions are borrowed from https://github.com/akanazawa/human_dynamics/blob/master/src/evaluation/eval_util.py | |
| # Adhere to their licence to use these functions | |
| from pathlib import Path | |
| import torch | |
| import numpy as np | |
| from matplotlib import pyplot as plt | |
| def compute_accel(joints): | |
| """ | |
| Computes acceleration of 3D joints. | |
| Args: | |
| joints (Nx25x3). | |
| Returns: | |
| Accelerations (N-2). | |
| """ | |
| velocities = joints[1:] - joints[:-1] | |
| acceleration = velocities[1:] - velocities[:-1] | |
| acceleration_normed = np.linalg.norm(acceleration, axis=2) | |
| return np.mean(acceleration_normed, axis=1) | |
| def compute_error_accel(joints_gt, joints_pred, vis=None): | |
| """ | |
| Computes acceleration error: | |
| 1/(n-2) \sum_{i=1}^{n-1} X_{i-1} - 2X_i + X_{i+1} | |
| Note that for each frame that is not visible, three entries in the | |
| acceleration error should be zero'd out. | |
| Args: | |
| joints_gt (Nx14x3). | |
| joints_pred (Nx14x3). | |
| vis (N). | |
| Returns: | |
| error_accel (N-2). | |
| """ | |
| # (N-2)x14x3 | |
| accel_gt = joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:] | |
| accel_pred = joints_pred[:-2] - 2 * joints_pred[1:-1] + joints_pred[2:] | |
| normed = np.linalg.norm(accel_pred - accel_gt, axis=2) | |
| if vis is None: | |
| new_vis = np.ones(len(normed), dtype=bool) | |
| else: | |
| invis = np.logical_not(vis) | |
| invis1 = np.roll(invis, -1) | |
| invis2 = np.roll(invis, -2) | |
| new_invis = np.logical_or(invis, np.logical_or(invis1, invis2))[:-2] | |
| new_vis = np.logical_not(new_invis) | |
| return np.mean(normed[new_vis], axis=1) | |
| def compute_error_verts(pred_verts, target_verts=None, target_theta=None): | |
| """ | |
| Computes MPJPE over 6890 surface vertices. | |
| Args: | |
| verts_gt (Nx6890x3). | |
| verts_pred (Nx6890x3). | |
| Returns: | |
| error_verts (N). | |
| """ | |
| if target_verts is None: | |
| from lib.models.smpl import SMPL_MODEL_DIR | |
| from lib.models.smpl import SMPL | |
| device = 'cpu' | |
| smpl = SMPL( | |
| SMPL_MODEL_DIR, | |
| batch_size=1, # target_theta.shape[0], | |
| ).to(device) | |
| betas = torch.from_numpy(target_theta[:,75:]).to(device) | |
| pose = torch.from_numpy(target_theta[:,3:75]).to(device) | |
| target_verts = [] | |
| b_ = torch.split(betas, 5000) | |
| p_ = torch.split(pose, 5000) | |
| for b,p in zip(b_,p_): | |
| output = smpl(betas=b, body_pose=p[:, 3:], global_orient=p[:, :3], pose2rot=True) | |
| target_verts.append(output.vertices.detach().cpu().numpy()) | |
| target_verts = np.concatenate(target_verts, axis=0) | |
| assert len(pred_verts) == len(target_verts) | |
| error_per_vert = np.sqrt(np.sum((target_verts - pred_verts) ** 2, axis=2)) | |
| return np.mean(error_per_vert, axis=1) | |
| def compute_similarity_transform(S1, S2): | |
| ''' | |
| Computes a similarity transform (sR, t) that takes | |
| a set of 3D points S1 (3 x N) closest to a set of 3D points S2, | |
| where R is an 3x3 rotation matrix, t 3x1 translation, s scale. | |
| i.e. solves the orthogonal Procrutes problem. | |
| ''' | |
| transposed = False | |
| if S1.shape[0] != 3 and S1.shape[0] != 2: | |
| S1 = S1.T | |
| S2 = S2.T | |
| transposed = True | |
| assert(S2.shape[1] == S1.shape[1]) | |
| # 1. Remove mean. | |
| mu1 = S1.mean(axis=1, keepdims=True) | |
| mu2 = S2.mean(axis=1, keepdims=True) | |
| X1 = S1 - mu1 | |
| X2 = S2 - mu2 | |
| # 2. Compute variance of X1 used for scale. | |
| var1 = np.sum(X1**2) | |
| # 3. The outer product of X1 and X2. | |
| K = X1.dot(X2.T) | |
| # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are | |
| # singular vectors of K. | |
| U, s, Vh = np.linalg.svd(K) | |
| V = Vh.T | |
| # Construct Z that fixes the orientation of R to get det(R)=1. | |
| Z = np.eye(U.shape[0]) | |
| Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) | |
| # Construct R. | |
| R = V.dot(Z.dot(U.T)) | |
| # 5. Recover scale. | |
| scale = np.trace(R.dot(K)) / var1 | |
| # 6. Recover translation. | |
| t = mu2 - scale*(R.dot(mu1)) | |
| # 7. Error: | |
| S1_hat = scale*R.dot(S1) + t | |
| if transposed: | |
| S1_hat = S1_hat.T | |
| return S1_hat | |
| def compute_similarity_transform_torch(S1, S2): | |
| ''' | |
| Computes a similarity transform (sR, t) that takes | |
| a set of 3D points S1 (3 x N) closest to a set of 3D points S2, | |
| where R is an 3x3 rotation matrix, t 3x1 translation, s scale. | |
| i.e. solves the orthogonal Procrutes problem. | |
| ''' | |
| transposed = False | |
| if S1.shape[0] != 3 and S1.shape[0] != 2: | |
| S1 = S1.T | |
| S2 = S2.T | |
| transposed = True | |
| assert (S2.shape[1] == S1.shape[1]) | |
| # 1. Remove mean. | |
| mu1 = S1.mean(axis=1, keepdims=True) | |
| mu2 = S2.mean(axis=1, keepdims=True) | |
| X1 = S1 - mu1 | |
| X2 = S2 - mu2 | |
| # print('X1', X1.shape) | |
| # 2. Compute variance of X1 used for scale. | |
| var1 = torch.sum(X1 ** 2) | |
| # print('var', var1.shape) | |
| # 3. The outer product of X1 and X2. | |
| K = X1.mm(X2.T) | |
| # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are | |
| # singular vectors of K. | |
| U, s, V = torch.svd(K) | |
| # V = Vh.T | |
| # Construct Z that fixes the orientation of R to get det(R)=1. | |
| Z = torch.eye(U.shape[0], device=S1.device) | |
| Z[-1, -1] *= torch.sign(torch.det(U @ V.T)) | |
| # Construct R. | |
| R = V.mm(Z.mm(U.T)) | |
| # print('R', X1.shape) | |
| # 5. Recover scale. | |
| scale = torch.trace(R.mm(K)) / var1 | |
| # print(R.shape, mu1.shape) | |
| # 6. Recover translation. | |
| t = mu2 - scale * (R.mm(mu1)) | |
| # print(t.shape) | |
| # 7. Error: | |
| S1_hat = scale * R.mm(S1) + t | |
| if transposed: | |
| S1_hat = S1_hat.T | |
| return S1_hat | |
| def batch_compute_similarity_transform_torch(S1, S2): | |
| ''' | |
| Computes a similarity transform (sR, t) that takes | |
| a set of 3D points S1 (3 x N) closest to a set of 3D points S2, | |
| where R is an 3x3 rotation matrix, t 3x1 translation, s scale. | |
| i.e. solves the orthogonal Procrutes problem. | |
| ''' | |
| transposed = False | |
| if S1.shape[0] != 3 and S1.shape[0] != 2: | |
| S1 = S1.permute(0,2,1) | |
| S2 = S2.permute(0,2,1) | |
| transposed = True | |
| assert(S2.shape[1] == S1.shape[1]) | |
| # 1. Remove mean. | |
| mu1 = S1.mean(axis=-1, keepdims=True) | |
| mu2 = S2.mean(axis=-1, keepdims=True) | |
| X1 = S1 - mu1 | |
| X2 = S2 - mu2 | |
| # 2. Compute variance of X1 used for scale. | |
| var1 = torch.sum(X1**2, dim=1).sum(dim=1) | |
| # 3. The outer product of X1 and X2. | |
| K = X1.bmm(X2.permute(0,2,1)) | |
| # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are | |
| # singular vectors of K. | |
| U, s, V = torch.svd(K) | |
| # Construct Z that fixes the orientation of R to get det(R)=1. | |
| Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) | |
| Z = Z.repeat(U.shape[0],1,1) | |
| Z[:,-1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0,2,1)))) | |
| # Construct R. | |
| R = V.bmm(Z.bmm(U.permute(0,2,1))) | |
| # 5. Recover scale. | |
| scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 | |
| # 6. Recover translation. | |
| t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) | |
| # 7. Error: | |
| S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t | |
| if transposed: | |
| S1_hat = S1_hat.permute(0,2,1) | |
| return S1_hat | |
| def align_by_pelvis(joints): | |
| """ | |
| Assumes joints is 14 x 3 in LSP order. | |
| Then hips are: [3, 2] | |
| Takes mid point of these points, then subtracts it. | |
| """ | |
| left_id = 2 | |
| right_id = 3 | |
| pelvis = (joints[left_id, :] + joints[right_id, :]) / 2.0 | |
| return joints - np.expand_dims(pelvis, axis=0) | |
| def compute_errors(gt3ds, preds): | |
| """ | |
| Gets MPJPE after pelvis alignment + MPJPE after Procrustes. | |
| Evaluates on the 14 common joints. | |
| Inputs: | |
| - gt3ds: N x 14 x 3 | |
| - preds: N x 14 x 3 | |
| """ | |
| errors, errors_pa = [], [] | |
| for i, (gt3d, pred) in enumerate(zip(gt3ds, preds)): | |
| gt3d = gt3d.reshape(-1, 3) | |
| # Root align. | |
| gt3d = align_by_pelvis(gt3d) | |
| pred3d = align_by_pelvis(pred) | |
| joint_error = np.sqrt(np.sum((gt3d - pred3d)**2, axis=1)) | |
| errors.append(np.mean(joint_error)) | |
| # Get PA error. | |
| pred3d_sym = compute_similarity_transform(pred3d, gt3d) | |
| pa_error = np.sqrt(np.sum((gt3d - pred3d_sym)**2, axis=1)) | |
| errors_pa.append(np.mean(pa_error)) | |
| return errors, errors_pa | |
| def batch_align_by_pelvis(data_list, pelvis_idxs): | |
| """ | |
| Assumes data is given as [pred_j3d, target_j3d, pred_verts, target_verts]. | |
| Each data is in shape of (frames, num_points, 3) | |
| Pelvis is notated as one / two joints indices. | |
| Align all data to the corresponding pelvis location. | |
| """ | |
| pred_j3d, target_j3d, pred_verts, target_verts = data_list | |
| pred_pelvis = pred_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone() | |
| target_pelvis = target_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone() | |
| # Align to the pelvis | |
| pred_j3d = pred_j3d - pred_pelvis | |
| target_j3d = target_j3d - target_pelvis | |
| pred_verts = pred_verts - pred_pelvis | |
| target_verts = target_verts - target_pelvis | |
| return (pred_j3d, target_j3d, pred_verts, target_verts) | |
| def compute_jpe(S1, S2): | |
| return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).mean(dim=-1).numpy() | |
| # The functions below are borrowed from SLAHMR official implementation. | |
| # Reference: https://github.com/vye16/slahmr/blob/main/slahmr/eval/tools.py | |
| def global_align_joints(gt_joints, pred_joints): | |
| """ | |
| :param gt_joints (T, J, 3) | |
| :param pred_joints (T, J, 3) | |
| """ | |
| s_glob, R_glob, t_glob = align_pcl( | |
| gt_joints.reshape(-1, 3), pred_joints.reshape(-1, 3) | |
| ) | |
| pred_glob = ( | |
| s_glob * torch.einsum("ij,tnj->tni", R_glob, pred_joints) + t_glob[None, None] | |
| ) | |
| return pred_glob | |
| def first_align_joints(gt_joints, pred_joints): | |
| """ | |
| align the first two frames | |
| :param gt_joints (T, J, 3) | |
| :param pred_joints (T, J, 3) | |
| """ | |
| # (1, 1), (1, 3, 3), (1, 3) | |
| s_first, R_first, t_first = align_pcl( | |
| gt_joints[:2].reshape(1, -1, 3), pred_joints[:2].reshape(1, -1, 3) | |
| ) | |
| pred_first = ( | |
| s_first * torch.einsum("tij,tnj->tni", R_first, pred_joints) + t_first[:, None] | |
| ) | |
| return pred_first | |
| def local_align_joints(gt_joints, pred_joints): | |
| """ | |
| :param gt_joints (T, J, 3) | |
| :param pred_joints (T, J, 3) | |
| """ | |
| s_loc, R_loc, t_loc = align_pcl(gt_joints, pred_joints) | |
| pred_loc = ( | |
| s_loc[:, None] * torch.einsum("tij,tnj->tni", R_loc, pred_joints) | |
| + t_loc[:, None] | |
| ) | |
| return pred_loc | |
| def align_pcl(Y, X, weight=None, fixed_scale=False): | |
| """align similarity transform to align X with Y using umeyama method | |
| X' = s * R * X + t is aligned with Y | |
| :param Y (*, N, 3) first trajectory | |
| :param X (*, N, 3) second trajectory | |
| :param weight (*, N, 1) optional weight of valid correspondences | |
| :returns s (*, 1), R (*, 3, 3), t (*, 3) | |
| """ | |
| *dims, N, _ = Y.shape | |
| N = torch.ones(*dims, 1, 1) * N | |
| if weight is not None: | |
| Y = Y * weight | |
| X = X * weight | |
| N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1) | |
| # subtract mean | |
| my = Y.sum(dim=-2) / N[..., 0] # (*, 3) | |
| mx = X.sum(dim=-2) / N[..., 0] | |
| y0 = Y - my[..., None, :] # (*, N, 3) | |
| x0 = X - mx[..., None, :] | |
| if weight is not None: | |
| y0 = y0 * weight | |
| x0 = x0 * weight | |
| # correlation | |
| C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3) | |
| U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3) | |
| S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1) | |
| neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0 | |
| S[neg, 2, 2] = -1 | |
| R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3) | |
| D = torch.diag_embed(D) # (*, 3, 3) | |
| if fixed_scale: | |
| s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32) | |
| else: | |
| var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1) | |
| s = ( | |
| torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum( | |
| dim=-1, keepdim=True | |
| ) | |
| / var[..., 0] | |
| ) # (*, 1) | |
| t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3) | |
| return s, R, t | |
| def compute_foot_sliding(target_output, pred_output, masks, thr=1e-2): | |
| """compute foot sliding error | |
| The foot ground contact label is computed by the threshold of 1 cm/frame | |
| Args: | |
| target_output (SMPL ModelOutput). | |
| pred_output (SMPL ModelOutput). | |
| masks (N). | |
| Returns: | |
| error (N frames in contact). | |
| """ | |
| # Foot vertices idxs | |
| foot_idxs = [3216, 3387, 6617, 6787] | |
| # Compute contact label | |
| foot_loc = target_output.vertices[masks][:, foot_idxs] | |
| foot_disp = (foot_loc[1:] - foot_loc[:-1]).norm(2, dim=-1) | |
| contact = foot_disp[:] < thr | |
| pred_feet_loc = pred_output.vertices[:, foot_idxs] | |
| pred_disp = (pred_feet_loc[1:] - pred_feet_loc[:-1]).norm(2, dim=-1) | |
| error = pred_disp[contact] | |
| return error.cpu().numpy() | |
| def compute_jitter(pred_output, fps=30): | |
| """compute jitter of the motion | |
| Args: | |
| pred_output (SMPL ModelOutput). | |
| fps (float). | |
| Returns: | |
| jitter (N-3). | |
| """ | |
| pred3d = pred_output.joints[:, :24] | |
| pred_jitter = torch.norm( | |
| (pred3d[3:] - 3 * pred3d[2:-1] + 3 * pred3d[1:-2] - pred3d[:-3]) * (fps**3), | |
| dim=2, | |
| ).mean(dim=-1) | |
| return pred_jitter.cpu().numpy() / 10.0 | |
| def compute_rte(target_trans, pred_trans): | |
| # Compute the global alignment | |
| _, rot, trans = align_pcl(target_trans[None, :], pred_trans[None, :], fixed_scale=True) | |
| pred_trans_hat = ( | |
| torch.einsum("tij,tnj->tni", rot, pred_trans[None, :]) + trans[None, :] | |
| )[0] | |
| # Compute the entire displacement of ground truth trajectory | |
| disps, disp = [], 0 | |
| for p1, p2 in zip(target_trans, target_trans[1:]): | |
| delta = (p2 - p1).norm(2, dim=-1) | |
| disp += delta | |
| disps.append(disp) | |
| # Compute absolute root-translation-error (RTE) | |
| rte = torch.norm(target_trans - pred_trans_hat, 2, dim=-1) | |
| # Normalize it to the displacement | |
| return (rte / disp).numpy() |