#!/usr/bin/env python3 from __future__ import annotations import argparse import json import sys from pathlib import Path import torch REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) GVHMR_ROOT = REPO_ROOT / "third_party" / "GVHMR" if GVHMR_ROOT.is_dir() and str(GVHMR_ROOT) not in sys.path: sys.path.insert(0, str(GVHMR_ROOT)) from genmo.utils.eval_utils import compute_camcoord_metrics from genmo.utils.geo_transform import compute_cam_angvel, compute_cam_tvel, normalize_T_w2c from genmo.utils.rotation_conversions import axis_angle_to_matrix, matrix_to_axis_angle from third_party.GVHMR.hmr4d.utils.smplx_utils import make_smplx def _to_tensor(x): if isinstance(x, torch.Tensor): return x return torch.as_tensor(x) def _slice(x: torch.Tensor, n: int) -> torch.Tensor: return x[:n].clone() def _compare_arrays(a: torch.Tensor, b: torch.Tensor): if a.shape != b.shape: return {"shape_a": list(a.shape), "shape_b": list(b.shape)} diff = a - b return { "mae": float(diff.abs().mean().item()), "rmse": float((diff.pow(2).mean().sqrt()).item()), } def _load_smplx_tools(device: torch.device): smplx = make_smplx("supermotion").to(device).eval() smplx2smpl_path = ( REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smplx2smpl_sparse.pt" ) j_reg_path = ( REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smpl_neutral_J_regressor.pt" ) smplx2smpl = torch.load(smplx2smpl_path, map_location=device) j_reg = torch.load(j_reg_path, map_location=device) return smplx, smplx2smpl, j_reg def _smplx_to_j3d(params: dict, smplx, smplx2smpl, j_reg): out = smplx(**params) verts = out.vertices if hasattr(out, "vertices") else out[0].vertices verts = torch.stack([torch.matmul(smplx2smpl, v) for v in verts]) j3d = torch.einsum("jv,fvi->fji", j_reg, verts) return verts, j3d def _prepare_params(params: dict, n: int, device: torch.device): out = {} for k, v in params.items(): if not isinstance(v, torch.Tensor): v = torch.as_tensor(v) out[k] = _slice(v, n).to(device) return out def _min_len(*lens: int) -> int: lens = [int(x) for x in lens if x is not None] return min(lens) if lens else 0 def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--debug-io", required=True, help="Path to debug_io.pt") parser.add_argument("--dataset-pt", required=True, help="Path to dataset .pt file") parser.add_argument("--num-frames", type=int, default=60) parser.add_argument("--out", default="outputs/debug_compare.json") args = parser.parse_args() debug = torch.load(args.debug_io, map_location="cpu", weights_only=False) data = torch.load(args.dataset_pt, map_location="cpu", weights_only=False) inputs = debug.get("inputs", {}) outputs = debug.get("outputs", {}) n = int(args.num_frames) results = {"num_frames": n, "inputs": {}, "metrics": {}, "formatted_dataset": True} dataset_inputs = dict(data) if "T_w2c" in data: T_w2c = _to_tensor(data["T_w2c"]).float() if T_w2c.ndim == 3: normed_T_w2c = normalize_T_w2c(T_w2c) dataset_inputs["cam_angvel"] = compute_cam_angvel(normed_T_w2c[:, :3, :3]) dataset_inputs["cam_tvel"] = compute_cam_tvel(normed_T_w2c[:, :3, 3]) # Compare input tensors (if present in both). input_keys = [ "bbx_xys", "kp2d", "K_fullimg", "cam_angvel", "cam_tvel", "T_w2c", ] for k in input_keys: if k in inputs and k in dataset_inputs: a_t = _to_tensor(inputs[k]) b_t = _to_tensor(dataset_inputs[k]) n_in = _min_len(n, a_t.shape[0], b_t.shape[0]) a = _slice(a_t, n_in) b = _slice(b_t, n_in) results["inputs"][k] = _compare_arrays(a, b) # Compare incam SMPLX (pred vs GT from dataset). pred_key = "smpl_params_incam" if pred_key not in outputs: pred_key = "pred_smpl_params_incam" if pred_key in outputs and "smpl_params_c" in data: device = torch.device("cpu") smplx, smplx2smpl, j_reg = _load_smplx_tools(device) pred_raw = outputs[pred_key] gt_raw = data["smpl_params_c"] pred_len = int(_to_tensor(pred_raw["global_orient"]).shape[0]) gt_len = int(_to_tensor(gt_raw["global_orient"]).shape[0]) n_eval = _min_len(n, pred_len, gt_len) pred_params = _prepare_params(pred_raw, n_eval, device) gt_params = _prepare_params(gt_raw, n_eval, device) pred_verts, pred_j3d = _smplx_to_j3d(pred_params, smplx, smplx2smpl, j_reg) gt_verts, gt_j3d = _smplx_to_j3d(gt_params, smplx, smplx2smpl, j_reg) metrics = compute_camcoord_metrics( { "pred_j3d": pred_j3d, "target_j3d": gt_j3d, "pred_verts": pred_verts, "target_verts": gt_verts, } ) results["metrics"] = {k: float(v.mean()) for k, v in metrics.items()} out_path = Path(args.out) out_path.parent.mkdir(parents=True, exist_ok=True) with out_path.open("w") as f: json.dump(results, f, indent=2) print(f"Wrote {out_path}") if __name__ == "__main__": main()