hmr-dataset / genmo /utils /eval_utils.py
zirobtc's picture
Upload folder using huggingface_hub
fbb20ff verified
import numpy as np
import torch
from scipy.ndimage import gaussian_filter
from scipy.signal import argrelextrema
@torch.no_grad()
def compute_camcoord_metrics(batch, pelvis_idxs=[1, 2], fps=30, mask=None):
"""
Args:
batch (dict): {
"pred_j3d": (..., J, 3) tensor
"target_j3d":
"pred_verts":
"target_verts":
}
Returns:
cam_coord_metrics (dict): {
"pa_mpjpe": (..., ) numpy array
"mpjpe":
"pve":
"accel":
}
"""
# All data is in camera coordinates
pred_j3d = batch["pred_j3d"].cpu() # (..., J, 3)
target_j3d = batch["target_j3d"].cpu()
pred_verts = batch["pred_verts"].cpu()
target_verts = batch["target_verts"].cpu()
if mask is not None:
mask = mask.cpu()
pred_j3d = pred_j3d[mask].clone()
target_j3d = target_j3d[mask].clone()
pred_verts = pred_verts[mask].clone()
target_verts = target_verts[mask].clone()
assert "mask" not in batch
# Align by pelvis
pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis(
[pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs=pelvis_idxs
)
# Metrics
m2mm = 1000
S1_hat = batch_compute_similarity_transform_torch(pred_j3d, target_j3d)
pa_mpjpe = compute_jpe(S1_hat, target_j3d) * m2mm
mpjpe = compute_jpe(pred_j3d, target_j3d) * m2mm
pve = compute_jpe(pred_verts, target_verts) * m2mm
accel = compute_error_accel(joints_pred=pred_j3d, joints_gt=target_j3d, fps=fps)
camcoord_metrics = {
"pa_mpjpe": pa_mpjpe,
"mpjpe": mpjpe,
"pve": pve,
"accel": accel,
}
return camcoord_metrics
@torch.no_grad()
def compute_music_metrics(batch, mask=None):
"""
Args:
batch (dict): {
"pred_j3d": (..., J, 3) tensor
"target_j3d":
"music_beats": (T,) numpy array
}
Returns:
music_metrics (dict): {
"PFC":
}
"""
# All data is in global coordinates
pred_j3d_glob = batch["pred_j3d_glob"].cpu().numpy() # (..., J, 3)
# pred_j3d_glob = batch["target_j3d_glob"].cpu().numpy() # (..., J, 3)
up_dir = 1 # y is up
flat_dirs = [i for i in range(3) if i != up_dir]
DT = 1 / 30
assert pred_j3d_glob.ndim == 3
root_v = (
pred_j3d_glob[1:, 0, :] - pred_j3d_glob[:-1, 0, :]
) / DT # root velocity (T-1, 3)
root_a = (root_v[1:, :] - root_v[:-1, :]) / DT # root acceleration (T-2, 3)
# clamp the up-direction of root acceleration
root_a[:, up_dir] = np.maximum(root_a[:, up_dir], 0) # (T-2, 3)
# l2 norm
root_a = np.linalg.norm(root_a, axis=-1) # (T-2,)
scaling = root_a.max()
root_a = root_a / scaling
foot_idx = [7, 10, 8, 11]
feet = pred_j3d_glob[:, foot_idx, :] # (T, 4, 3)
foot_v = np.linalg.norm(
feet[2:, :, flat_dirs] - feet[1:-1, :, flat_dirs], axis=-1
) # horizontal velocity (T-2, 4)
foot_mins = np.zeros((len(foot_v), 2))
foot_mins[:, 0] = np.minimum(foot_v[:, 0], foot_v[:, 1])
foot_mins[:, 1] = np.minimum(foot_v[:, 2], foot_v[:, 3])
foot_v = np.maximum(foot_mins, 0)
foot_loss = (
foot_mins[:, 0] * foot_mins[:, 1] * root_a
) # min leftv * min rightv * root_a (T-2,)
pfc = foot_loss.mean() * 10000
# compute Beat Align Score
motion_beats = compute_motion_beats(pred_j3d_glob)[0]
music_beats = compute_music_beats(batch["music_beats"])
ba = 0
for bb in music_beats:
ba += np.exp(-np.min((motion_beats - bb) ** 2) / 2 / 9)
bas = ba / len(music_beats)
return {
"PFC": pfc,
"BAS": bas,
}
@torch.no_grad()
def compute_global_metrics(batch, mask=None):
"""Follow WHAM, the input has skipped invalid frames
Args:
batch (dict): {
"pred_j3d_glob": (F, J, 3) tensor
"target_j3d_glob":
"pred_verts_glob":
"target_verts_glob":
}
Returns:
global_metrics (dict): {
"wa2_mpjpe": (F, ) numpy array
"waa_mpjpe":
"rte":
"jitter":
"fs":
}
"""
# All data is in global coordinates
pred_j3d_glob = batch["pred_j3d_glob"].cpu() # (..., J, 3)
target_j3d_glob = batch["target_j3d_glob"].cpu()
pred_verts_glob = batch["pred_verts_glob"].cpu()
target_verts_glob = batch["target_verts_glob"].cpu()
if mask is not None:
mask = mask.cpu()
pred_j3d_glob = pred_j3d_glob[mask].clone()
target_j3d_glob = target_j3d_glob[mask].clone()
pred_verts_glob = pred_verts_glob[mask].clone()
target_verts_glob = target_verts_glob[mask].clone()
assert "mask" not in batch
seq_length = pred_j3d_glob.shape[0]
# Use chunk to compare
chunk_length = 100
wa2_mpjpe, waa_mpjpe = [], []
for start in range(0, seq_length, chunk_length):
end = min(seq_length, start + chunk_length)
target_j3d = target_j3d_glob[start:end].clone().cpu()
pred_j3d = pred_j3d_glob[start:end].clone().cpu()
w_j3d = first_align_joints(target_j3d, pred_j3d)
wa_j3d = global_align_joints(target_j3d, pred_j3d)
wa2_mpjpe.append(compute_jpe(target_j3d, w_j3d))
waa_mpjpe.append(compute_jpe(target_j3d, wa_j3d))
# Metrics
m2mm = 1000
wa2_mpjpe = np.concatenate(wa2_mpjpe) * m2mm
waa_mpjpe = np.concatenate(waa_mpjpe) * m2mm
# Additional Metrics
rte = compute_rte(target_j3d_glob[:, 0].cpu(), pred_j3d_glob[:, 0].cpu()) * 1e2
jitter = compute_jitter(pred_j3d_glob, fps=30)
foot_sliding = compute_foot_sliding(target_verts_glob, pred_verts_glob) * m2mm
global_metrics = {
"wa2_mpjpe": wa2_mpjpe,
"waa_mpjpe": waa_mpjpe,
"rte": rte,
"jitter": jitter,
"fs": foot_sliding,
}
return global_metrics
@torch.no_grad()
def compute_camcoord_perjoint_metrics(batch, pelvis_idxs=[1, 2]):
"""
Args:
batch (dict): {
"pred_j3d": (..., J, 3) tensor
"target_j3d":
}
Returns:
cam_coord_metrics (dict): {
"pa_mpjpe": (..., ) numpy array
"mpjpe":
"pve":
"accel":
}
"""
# All data is in camera coordinates
pred_j3d = batch["pred_j3d"].cpu() # (..., J, 3)
target_j3d = batch["target_j3d"].cpu()
pred_verts = batch["pred_verts"].cpu()
target_verts = batch["target_verts"].cpu()
# Align by pelvis
pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis(
[pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs=pelvis_idxs
)
# Metrics
m2mm = 1000
perjoint_mpjpe = compute_perjoint_jpe(pred_j3d, target_j3d) * m2mm
camcoord_perjoint_metrics = {
"mpjpe": perjoint_mpjpe,
}
return camcoord_perjoint_metrics
# ===== Utilities =====
def compute_jpe(S1, S2):
return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).mean(dim=-1).numpy()
def compute_perjoint_jpe(S1, S2):
return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).numpy()
def batch_align_by_pelvis(data_list, pelvis_idxs=[1, 2]):
"""
Assumes data is given as [pred_j3d, target_j3d, pred_verts, target_verts].
Each data is in shape of (frames, num_points, 3)
Pelvis is notated as one / two joints indices.
Align all data to the corresponding pelvis location.
"""
pred_j3d, target_j3d, pred_verts, target_verts = data_list
pred_pelvis = pred_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone()
target_pelvis = target_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone()
# Align to the pelvis
pred_j3d = pred_j3d - pred_pelvis
target_j3d = target_j3d - target_pelvis
pred_verts = pred_verts - pred_pelvis
target_verts = target_verts - target_pelvis
return (pred_j3d, target_j3d, pred_verts, target_verts)
def batch_compute_similarity_transform_torch(S1, S2):
"""
Computes a similarity transform (sR, t) that takes
a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
i.e. solves the orthogonal Procrutes problem.
"""
transposed = False
if S1.shape[0] != 3 and S1.shape[0] != 2:
S1 = S1.permute(0, 2, 1)
S2 = S2.permute(0, 2, 1)
transposed = True
assert S2.shape[1] == S1.shape[1]
# 1. Remove mean.
mu1 = S1.mean(axis=-1, keepdims=True)
mu2 = S2.mean(axis=-1, keepdims=True)
X1 = S1 - mu1
X2 = S2 - mu2
# 2. Compute variance of X1 used for scale.
var1 = torch.sum(X1**2, dim=1).sum(dim=1)
# 3. The outer product of X1 and X2.
K = X1.bmm(X2.permute(0, 2, 1))
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
# singular vectors of K.
U, s, V = torch.svd(K)
# Construct Z that fixes the orientation of R to get det(R)=1.
Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0)
Z = Z.repeat(U.shape[0], 1, 1)
Z[:, -1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0, 2, 1))))
# Construct R.
R = V.bmm(Z.bmm(U.permute(0, 2, 1)))
# 5. Recover scale.
scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1
# 6. Recover translation.
t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1)))
# 7. Error:
S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t
if transposed:
S1_hat = S1_hat.permute(0, 2, 1)
return S1_hat
def batch_compute_scale_trans_torch(S1, S2):
"""
Computes a similarity transform (sR, t) that takes
a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
i.e. solves the orthogonal Procrutes problem.
"""
transposed = False
if S1.shape[0] != 3 and S1.shape[0] != 2:
S1 = S1.permute(0, 2, 1)
S2 = S2.permute(0, 2, 1)
transposed = True
assert S2.shape[1] == S1.shape[1]
# 1. Remove mean.
mu1 = S1.mean(axis=-1, keepdims=True)
mu2 = S2.mean(axis=-1, keepdims=True)
X1 = S1 - mu1
X2 = S2 - mu2
# 2. Compute variance of X1 used for scale.
var1 = torch.sum(X1**2, dim=1).sum(dim=1)
# 3. The outer product of X1 and X2.
K = X1.bmm(X2.permute(0, 2, 1))
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
# singular vectors of K.
U, s, V = torch.svd(K)
# Construct Z that fixes the orientation of R to get det(R)=1.
Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0)
Z = Z.repeat(U.shape[0], 1, 1)
Z[:, -1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0, 2, 1))))
# Construct R.
R = V.bmm(Z.bmm(U.permute(0, 2, 1)))
# 5. Recover scale.
scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1
# 6. Recover translation.
t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1)))
return scale, t, R
def compute_error_accel(joints_gt, joints_pred, valid_mask=None, fps=None):
"""
Use [i-1, i, i+1] to compute acc at frame_i. The acceleration error:
1/(n-2) \sum_{i=1}^{n-1} X_{i-1} - 2X_i + X_{i+1}
Note that for each frame that is not visible, three entries(-1, 0, +1) in the
acceleration error will be zero'd out.
Args:
joints_gt : (F, J, 3)
joints_pred : (F, J, 3)
valid_mask : (F)
Returns:
error_accel (F-2) when valid_mask is None, else (F'), F' <= F-2
"""
# (F, J, 3) -> (F-2) per-joint
accel_gt = joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:]
accel_pred = joints_pred[:-2] - 2 * joints_pred[1:-1] + joints_pred[2:]
normed = np.linalg.norm(accel_pred - accel_gt, axis=-1).mean(axis=-1)
if fps is not None:
normed = normed * fps**2
if valid_mask is None:
new_vis = np.ones(len(normed), dtype=bool)
else:
invis = np.logical_not(valid_mask)
invis1 = np.roll(invis, -1)
invis2 = np.roll(invis, -2)
new_invis = np.logical_or(invis, np.logical_or(invis1, invis2))[:-2]
new_vis = np.logical_not(new_invis)
if new_vis.sum() == 0:
print("Warning!!! no valid acceleration error to compute.")
return normed[new_vis]
def compute_rte(target_trans, pred_trans):
# Compute the global alignment
_, rot, trans = align_pcl(
target_trans[None, :], pred_trans[None, :], fixed_scale=True
)
pred_trans_hat = (
torch.einsum("tij,tnj->tni", rot, pred_trans[None, :]) + trans[None, :]
)[0]
# Compute the entire displacement of ground truth trajectory
disps, disp = [], 0
for p1, p2 in zip(target_trans, target_trans[1:]):
delta = (p2 - p1).norm(2, dim=-1)
disp += delta
disps.append(disp)
# Compute absolute root-translation-error (RTE)
rte = torch.norm(target_trans - pred_trans_hat, 2, dim=-1)
# Normalize it to the displacement
return (rte / disp).numpy()
def compute_jitter(joints, fps=30):
"""compute jitter of the motion
Args:
joints (N, J, 3).
fps (float).
Returns:
jitter (N-3).
"""
pred_jitter = torch.norm(
(joints[3:] - 3 * joints[2:-1] + 3 * joints[1:-2] - joints[:-3]) * (fps**3),
dim=2,
).mean(dim=-1)
return pred_jitter.cpu().numpy() / 10.0
def compute_foot_sliding(target_verts, pred_verts, thr=1e-2):
"""compute foot sliding error
The foot ground contact label is computed by the threshold of 1 cm/frame
Args:
target_verts (N, 6890, 3).
pred_verts (N, 6890, 3).
Returns:
error (N frames in contact).
"""
assert target_verts.shape == pred_verts.shape
assert target_verts.shape[-2] == 6890
# Foot vertices idxs
foot_idxs = [3216, 3387, 6617, 6787]
# Compute contact label
foot_loc = target_verts[:, foot_idxs]
foot_disp = (foot_loc[1:] - foot_loc[:-1]).norm(2, dim=-1)
contact = foot_disp[:] < thr
pred_feet_loc = pred_verts[:, foot_idxs]
pred_disp = (pred_feet_loc[1:] - pred_feet_loc[:-1]).norm(2, dim=-1)
error = pred_disp[contact]
return error.cpu().numpy()
def convert_joints22_to_24(joints22, ratio2220=0.3438, ratio2321=0.3345):
joints24 = torch.zeros(*joints22.shape[:-2], 24, 3).to(joints22.device)
joints24[..., :22, :] = joints22
joints24[..., 22, :] = joints22[..., 20, :] + ratio2220 * (
joints22[..., 20, :] - joints22[..., 18, :]
)
joints24[..., 23, :] = joints22[..., 21, :] + ratio2321 * (
joints22[..., 21, :] - joints22[..., 19, :]
)
return joints24
def align_pcl(Y, X, weight=None, fixed_scale=False):
"""align similarity transform to align X with Y using umeyama method
X' = s * R * X + t is aligned with Y
:param Y (*, N, 3) first trajectory
:param X (*, N, 3) second trajectory
:param weight (*, N, 1) optional weight of valid correspondences
:returns s (*, 1), R (*, 3, 3), t (*, 3)
"""
*dims, N, _ = Y.shape
N = torch.ones(*dims, 1, 1) * N
if weight is not None:
Y = Y * weight
X = X * weight
N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1)
# subtract mean
my = Y.sum(dim=-2) / N[..., 0] # (*, 3)
mx = X.sum(dim=-2) / N[..., 0]
y0 = Y - my[..., None, :] # (*, N, 3)
x0 = X - mx[..., None, :]
if weight is not None:
y0 = y0 * weight
x0 = x0 * weight
# correlation
C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3)
U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3)
S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1)
neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0
S[neg, 2, 2] = -1
R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3)
D = torch.diag_embed(D) # (*, 3, 3)
if fixed_scale:
s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32)
else:
var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1)
s = (
torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum(
dim=-1, keepdim=True
)
/ var[..., 0]
) # (*, 1)
t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3)
return s, R, t
def global_align_joints(gt_joints, pred_joints):
"""
:param gt_joints (T, J, 3)
:param pred_joints (T, J, 3)
"""
s_glob, R_glob, t_glob = align_pcl(
gt_joints.reshape(-1, 3), pred_joints.reshape(-1, 3)
)
pred_glob = (
s_glob * torch.einsum("ij,tnj->tni", R_glob, pred_joints) + t_glob[None, None]
)
return pred_glob
def first_align_joints(gt_joints, pred_joints):
"""
align the first two frames
:param gt_joints (T, J, 3)
:param pred_joints (T, J, 3)
"""
# (1, 1), (1, 3, 3), (1, 3)
s_first, R_first, t_first = align_pcl(
gt_joints[:2].reshape(1, -1, 3), pred_joints[:2].reshape(1, -1, 3)
)
pred_first = (
s_first * torch.einsum("tij,tnj->tni", R_first, pred_joints) + t_first[:, None]
)
return pred_first
def rearrange_by_mask(x, mask):
"""
x (L, *)
mask (M,), M >= L
"""
M = mask.size(0)
L = x.size(0)
if M == L:
return x
assert M > L
assert mask.sum() == L
x_rearranged = torch.zeros((M, *x.size()[1:]), dtype=x.dtype, device=x.device)
x_rearranged[mask] = x
return x_rearranged
def as_np_array(d):
if isinstance(d, torch.Tensor):
return d.cpu().numpy()
elif isinstance(d, np.ndarray):
return d
else:
return np.array(d)
def compute_motion_beats(keypoints):
keypoints = keypoints.reshape(-1, 24, 3)
kinetic_vel = np.mean(
np.sqrt(np.sum((keypoints[1:] - keypoints[:-1]) ** 2, axis=2)), axis=1
)
kinetic_vel = gaussian_filter(kinetic_vel, sigma=5)
motion_beats = argrelextrema(kinetic_vel, np.less)
return motion_beats
def compute_music_beats(beats):
beats = beats.astype(bool)
beat_axis = np.arange(len(beats))
beat_axis = beat_axis[beats]
return beat_axis