vr-hmr / scripts /compare_debug_with_dataset.py
zirobtc's picture
Upload folder using huggingface_hub
7e120dd
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
GVHMR_ROOT = REPO_ROOT / "third_party" / "GVHMR"
if GVHMR_ROOT.is_dir() and str(GVHMR_ROOT) not in sys.path:
sys.path.insert(0, str(GVHMR_ROOT))
from genmo.utils.eval_utils import compute_camcoord_metrics
from genmo.utils.geo_transform import compute_cam_angvel, compute_cam_tvel, normalize_T_w2c
from genmo.utils.rotation_conversions import axis_angle_to_matrix, matrix_to_axis_angle
from third_party.GVHMR.hmr4d.utils.smplx_utils import make_smplx
def _to_tensor(x):
if isinstance(x, torch.Tensor):
return x
return torch.as_tensor(x)
def _slice(x: torch.Tensor, n: int) -> torch.Tensor:
return x[:n].clone()
def _compare_arrays(a: torch.Tensor, b: torch.Tensor):
if a.shape != b.shape:
return {"shape_a": list(a.shape), "shape_b": list(b.shape)}
diff = a - b
return {
"mae": float(diff.abs().mean().item()),
"rmse": float((diff.pow(2).mean().sqrt()).item()),
}
def _load_smplx_tools(device: torch.device):
smplx = make_smplx("supermotion").to(device).eval()
smplx2smpl_path = (
REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smplx2smpl_sparse.pt"
)
j_reg_path = (
REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smpl_neutral_J_regressor.pt"
)
smplx2smpl = torch.load(smplx2smpl_path, map_location=device)
j_reg = torch.load(j_reg_path, map_location=device)
return smplx, smplx2smpl, j_reg
def _smplx_to_j3d(params: dict, smplx, smplx2smpl, j_reg):
out = smplx(**params)
verts = out.vertices if hasattr(out, "vertices") else out[0].vertices
verts = torch.stack([torch.matmul(smplx2smpl, v) for v in verts])
j3d = torch.einsum("jv,fvi->fji", j_reg, verts)
return verts, j3d
def _prepare_params(params: dict, n: int, device: torch.device):
out = {}
for k, v in params.items():
if not isinstance(v, torch.Tensor):
v = torch.as_tensor(v)
out[k] = _slice(v, n).to(device)
return out
def _min_len(*lens: int) -> int:
lens = [int(x) for x in lens if x is not None]
return min(lens) if lens else 0
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--debug-io", required=True, help="Path to debug_io.pt")
parser.add_argument("--dataset-pt", required=True, help="Path to dataset .pt file")
parser.add_argument("--num-frames", type=int, default=60)
parser.add_argument("--out", default="outputs/debug_compare.json")
args = parser.parse_args()
debug = torch.load(args.debug_io, map_location="cpu", weights_only=False)
data = torch.load(args.dataset_pt, map_location="cpu", weights_only=False)
inputs = debug.get("inputs", {})
outputs = debug.get("outputs", {})
n = int(args.num_frames)
results = {"num_frames": n, "inputs": {}, "metrics": {}, "formatted_dataset": True}
dataset_inputs = dict(data)
if "T_w2c" in data:
T_w2c = _to_tensor(data["T_w2c"]).float()
if T_w2c.ndim == 3:
normed_T_w2c = normalize_T_w2c(T_w2c)
dataset_inputs["cam_angvel"] = compute_cam_angvel(normed_T_w2c[:, :3, :3])
dataset_inputs["cam_tvel"] = compute_cam_tvel(normed_T_w2c[:, :3, 3])
# Compare input tensors (if present in both).
input_keys = [
"bbx_xys",
"kp2d",
"K_fullimg",
"cam_angvel",
"cam_tvel",
"T_w2c",
]
for k in input_keys:
if k in inputs and k in dataset_inputs:
a_t = _to_tensor(inputs[k])
b_t = _to_tensor(dataset_inputs[k])
n_in = _min_len(n, a_t.shape[0], b_t.shape[0])
a = _slice(a_t, n_in)
b = _slice(b_t, n_in)
results["inputs"][k] = _compare_arrays(a, b)
# Compare incam SMPLX (pred vs GT from dataset).
pred_key = "smpl_params_incam"
if pred_key not in outputs:
pred_key = "pred_smpl_params_incam"
if pred_key in outputs and "smpl_params_c" in data:
device = torch.device("cpu")
smplx, smplx2smpl, j_reg = _load_smplx_tools(device)
pred_raw = outputs[pred_key]
gt_raw = data["smpl_params_c"]
pred_len = int(_to_tensor(pred_raw["global_orient"]).shape[0])
gt_len = int(_to_tensor(gt_raw["global_orient"]).shape[0])
n_eval = _min_len(n, pred_len, gt_len)
pred_params = _prepare_params(pred_raw, n_eval, device)
gt_params = _prepare_params(gt_raw, n_eval, device)
pred_verts, pred_j3d = _smplx_to_j3d(pred_params, smplx, smplx2smpl, j_reg)
gt_verts, gt_j3d = _smplx_to_j3d(gt_params, smplx, smplx2smpl, j_reg)
metrics = compute_camcoord_metrics(
{
"pred_j3d": pred_j3d,
"target_j3d": gt_j3d,
"pred_verts": pred_verts,
"target_verts": gt_verts,
}
)
results["metrics"] = {k: float(v.mean()) for k, v in metrics.items()}
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w") as f:
json.dump(results, f, indent=2)
print(f"Wrote {out_path}")
if __name__ == "__main__":
main()