File size: 5,473 Bytes
7e120dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | #!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
GVHMR_ROOT = REPO_ROOT / "third_party" / "GVHMR"
if GVHMR_ROOT.is_dir() and str(GVHMR_ROOT) not in sys.path:
sys.path.insert(0, str(GVHMR_ROOT))
from genmo.utils.eval_utils import compute_camcoord_metrics
from genmo.utils.geo_transform import compute_cam_angvel, compute_cam_tvel, normalize_T_w2c
from genmo.utils.rotation_conversions import axis_angle_to_matrix, matrix_to_axis_angle
from third_party.GVHMR.hmr4d.utils.smplx_utils import make_smplx
def _to_tensor(x):
if isinstance(x, torch.Tensor):
return x
return torch.as_tensor(x)
def _slice(x: torch.Tensor, n: int) -> torch.Tensor:
return x[:n].clone()
def _compare_arrays(a: torch.Tensor, b: torch.Tensor):
if a.shape != b.shape:
return {"shape_a": list(a.shape), "shape_b": list(b.shape)}
diff = a - b
return {
"mae": float(diff.abs().mean().item()),
"rmse": float((diff.pow(2).mean().sqrt()).item()),
}
def _load_smplx_tools(device: torch.device):
smplx = make_smplx("supermotion").to(device).eval()
smplx2smpl_path = (
REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smplx2smpl_sparse.pt"
)
j_reg_path = (
REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smpl_neutral_J_regressor.pt"
)
smplx2smpl = torch.load(smplx2smpl_path, map_location=device)
j_reg = torch.load(j_reg_path, map_location=device)
return smplx, smplx2smpl, j_reg
def _smplx_to_j3d(params: dict, smplx, smplx2smpl, j_reg):
out = smplx(**params)
verts = out.vertices if hasattr(out, "vertices") else out[0].vertices
verts = torch.stack([torch.matmul(smplx2smpl, v) for v in verts])
j3d = torch.einsum("jv,fvi->fji", j_reg, verts)
return verts, j3d
def _prepare_params(params: dict, n: int, device: torch.device):
out = {}
for k, v in params.items():
if not isinstance(v, torch.Tensor):
v = torch.as_tensor(v)
out[k] = _slice(v, n).to(device)
return out
def _min_len(*lens: int) -> int:
lens = [int(x) for x in lens if x is not None]
return min(lens) if lens else 0
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--debug-io", required=True, help="Path to debug_io.pt")
parser.add_argument("--dataset-pt", required=True, help="Path to dataset .pt file")
parser.add_argument("--num-frames", type=int, default=60)
parser.add_argument("--out", default="outputs/debug_compare.json")
args = parser.parse_args()
debug = torch.load(args.debug_io, map_location="cpu", weights_only=False)
data = torch.load(args.dataset_pt, map_location="cpu", weights_only=False)
inputs = debug.get("inputs", {})
outputs = debug.get("outputs", {})
n = int(args.num_frames)
results = {"num_frames": n, "inputs": {}, "metrics": {}, "formatted_dataset": True}
dataset_inputs = dict(data)
if "T_w2c" in data:
T_w2c = _to_tensor(data["T_w2c"]).float()
if T_w2c.ndim == 3:
normed_T_w2c = normalize_T_w2c(T_w2c)
dataset_inputs["cam_angvel"] = compute_cam_angvel(normed_T_w2c[:, :3, :3])
dataset_inputs["cam_tvel"] = compute_cam_tvel(normed_T_w2c[:, :3, 3])
# Compare input tensors (if present in both).
input_keys = [
"bbx_xys",
"kp2d",
"K_fullimg",
"cam_angvel",
"cam_tvel",
"T_w2c",
]
for k in input_keys:
if k in inputs and k in dataset_inputs:
a_t = _to_tensor(inputs[k])
b_t = _to_tensor(dataset_inputs[k])
n_in = _min_len(n, a_t.shape[0], b_t.shape[0])
a = _slice(a_t, n_in)
b = _slice(b_t, n_in)
results["inputs"][k] = _compare_arrays(a, b)
# Compare incam SMPLX (pred vs GT from dataset).
pred_key = "smpl_params_incam"
if pred_key not in outputs:
pred_key = "pred_smpl_params_incam"
if pred_key in outputs and "smpl_params_c" in data:
device = torch.device("cpu")
smplx, smplx2smpl, j_reg = _load_smplx_tools(device)
pred_raw = outputs[pred_key]
gt_raw = data["smpl_params_c"]
pred_len = int(_to_tensor(pred_raw["global_orient"]).shape[0])
gt_len = int(_to_tensor(gt_raw["global_orient"]).shape[0])
n_eval = _min_len(n, pred_len, gt_len)
pred_params = _prepare_params(pred_raw, n_eval, device)
gt_params = _prepare_params(gt_raw, n_eval, device)
pred_verts, pred_j3d = _smplx_to_j3d(pred_params, smplx, smplx2smpl, j_reg)
gt_verts, gt_j3d = _smplx_to_j3d(gt_params, smplx, smplx2smpl, j_reg)
metrics = compute_camcoord_metrics(
{
"pred_j3d": pred_j3d,
"target_j3d": gt_j3d,
"pred_verts": pred_verts,
"target_verts": gt_verts,
}
)
results["metrics"] = {k: float(v.mean()) for k, v in metrics.items()}
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w") as f:
json.dump(results, f, indent=2)
print(f"Wrote {out_path}")
if __name__ == "__main__":
main()
|