| |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
|
|
| import torch |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
| GVHMR_ROOT = REPO_ROOT / "third_party" / "GVHMR" |
| if GVHMR_ROOT.is_dir() and str(GVHMR_ROOT) not in sys.path: |
| sys.path.insert(0, str(GVHMR_ROOT)) |
|
|
| from genmo.utils.eval_utils import compute_camcoord_metrics |
| from genmo.utils.geo_transform import compute_cam_angvel, compute_cam_tvel, normalize_T_w2c |
| from genmo.utils.rotation_conversions import axis_angle_to_matrix, matrix_to_axis_angle |
| from third_party.GVHMR.hmr4d.utils.smplx_utils import make_smplx |
|
|
|
|
| def _to_tensor(x): |
| if isinstance(x, torch.Tensor): |
| return x |
| return torch.as_tensor(x) |
|
|
|
|
| def _slice(x: torch.Tensor, n: int) -> torch.Tensor: |
| return x[:n].clone() |
|
|
|
|
| def _compare_arrays(a: torch.Tensor, b: torch.Tensor): |
| if a.shape != b.shape: |
| return {"shape_a": list(a.shape), "shape_b": list(b.shape)} |
| diff = a - b |
| return { |
| "mae": float(diff.abs().mean().item()), |
| "rmse": float((diff.pow(2).mean().sqrt()).item()), |
| } |
|
|
|
|
| def _load_smplx_tools(device: torch.device): |
| smplx = make_smplx("supermotion").to(device).eval() |
| smplx2smpl_path = ( |
| REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smplx2smpl_sparse.pt" |
| ) |
| j_reg_path = ( |
| REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smpl_neutral_J_regressor.pt" |
| ) |
| smplx2smpl = torch.load(smplx2smpl_path, map_location=device) |
| j_reg = torch.load(j_reg_path, map_location=device) |
| return smplx, smplx2smpl, j_reg |
|
|
|
|
| def _smplx_to_j3d(params: dict, smplx, smplx2smpl, j_reg): |
| out = smplx(**params) |
| verts = out.vertices if hasattr(out, "vertices") else out[0].vertices |
| verts = torch.stack([torch.matmul(smplx2smpl, v) for v in verts]) |
| j3d = torch.einsum("jv,fvi->fji", j_reg, verts) |
| return verts, j3d |
|
|
|
|
| def _prepare_params(params: dict, n: int, device: torch.device): |
| out = {} |
| for k, v in params.items(): |
| if not isinstance(v, torch.Tensor): |
| v = torch.as_tensor(v) |
| out[k] = _slice(v, n).to(device) |
| return out |
|
|
|
|
| def _min_len(*lens: int) -> int: |
| lens = [int(x) for x in lens if x is not None] |
| return min(lens) if lens else 0 |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--debug-io", required=True, help="Path to debug_io.pt") |
| parser.add_argument("--dataset-pt", required=True, help="Path to dataset .pt file") |
| parser.add_argument("--num-frames", type=int, default=60) |
| parser.add_argument("--out", default="outputs/debug_compare.json") |
| args = parser.parse_args() |
|
|
| debug = torch.load(args.debug_io, map_location="cpu", weights_only=False) |
| data = torch.load(args.dataset_pt, map_location="cpu", weights_only=False) |
|
|
| inputs = debug.get("inputs", {}) |
| outputs = debug.get("outputs", {}) |
|
|
| n = int(args.num_frames) |
| results = {"num_frames": n, "inputs": {}, "metrics": {}, "formatted_dataset": True} |
|
|
| dataset_inputs = dict(data) |
| if "T_w2c" in data: |
| T_w2c = _to_tensor(data["T_w2c"]).float() |
| if T_w2c.ndim == 3: |
| normed_T_w2c = normalize_T_w2c(T_w2c) |
| dataset_inputs["cam_angvel"] = compute_cam_angvel(normed_T_w2c[:, :3, :3]) |
| dataset_inputs["cam_tvel"] = compute_cam_tvel(normed_T_w2c[:, :3, 3]) |
|
|
| |
| input_keys = [ |
| "bbx_xys", |
| "kp2d", |
| "K_fullimg", |
| "cam_angvel", |
| "cam_tvel", |
| "T_w2c", |
| ] |
| for k in input_keys: |
| if k in inputs and k in dataset_inputs: |
| a_t = _to_tensor(inputs[k]) |
| b_t = _to_tensor(dataset_inputs[k]) |
| n_in = _min_len(n, a_t.shape[0], b_t.shape[0]) |
| a = _slice(a_t, n_in) |
| b = _slice(b_t, n_in) |
| results["inputs"][k] = _compare_arrays(a, b) |
|
|
| |
| pred_key = "smpl_params_incam" |
| if pred_key not in outputs: |
| pred_key = "pred_smpl_params_incam" |
|
|
| if pred_key in outputs and "smpl_params_c" in data: |
| device = torch.device("cpu") |
| smplx, smplx2smpl, j_reg = _load_smplx_tools(device) |
| pred_raw = outputs[pred_key] |
| gt_raw = data["smpl_params_c"] |
| pred_len = int(_to_tensor(pred_raw["global_orient"]).shape[0]) |
| gt_len = int(_to_tensor(gt_raw["global_orient"]).shape[0]) |
| n_eval = _min_len(n, pred_len, gt_len) |
| pred_params = _prepare_params(pred_raw, n_eval, device) |
| gt_params = _prepare_params(gt_raw, n_eval, device) |
|
|
| pred_verts, pred_j3d = _smplx_to_j3d(pred_params, smplx, smplx2smpl, j_reg) |
| gt_verts, gt_j3d = _smplx_to_j3d(gt_params, smplx, smplx2smpl, j_reg) |
|
|
| metrics = compute_camcoord_metrics( |
| { |
| "pred_j3d": pred_j3d, |
| "target_j3d": gt_j3d, |
| "pred_verts": pred_verts, |
| "target_verts": gt_verts, |
| } |
| ) |
| results["metrics"] = {k: float(v.mean()) for k, v in metrics.items()} |
|
|
| out_path = Path(args.out) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| with out_path.open("w") as f: |
| json.dump(results, f, indent=2) |
|
|
| print(f"Wrote {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|