#!/usr/bin/env python3 from __future__ import annotations import argparse import json import os import sys from pathlib import Path import torch # Ensure repo + GVHMR roots are importable. REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) GVHMR_ROOT = REPO_ROOT / "third_party" / "GVHMR" if GVHMR_ROOT.is_dir() and str(GVHMR_ROOT) not in sys.path: sys.path.insert(0, str(GVHMR_ROOT)) from hydra import compose, initialize_config_dir from hydra.utils import instantiate from genmo.datamodule.mocap_trainX_testY import collate_fn from genmo.utils.eval_utils import compute_camcoord_metrics from genmo.utils.net_utils import load_pretrained_model, to_cuda from third_party.GVHMR.hmr4d.utils.smplx_utils import make_smplx def _load_cfg(config_name: str, overrides: list[str]): with initialize_config_dir(config_dir=str(REPO_ROOT / "configs"), version_base=None): return compose(config_name=config_name, overrides=overrides) def _select_sample(dataset, idx: int | None): if idx is None: idx = 0 return dataset[int(idx)] def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--ckpt", required=True, help="Path to checkpoint") parser.add_argument("--config-name", default="finetune_unity") parser.add_argument("--root", default=None, help="Dataset root (processed_dataset)") parser.add_argument("--sample-idx", type=int, default=0) parser.add_argument("--out-dir", default="outputs/debug_unity_sample") args = parser.parse_args() overrides = [] if args.root: overrides.append(f"train_datasets.unity.root={args.root}") overrides.append(f"test_datasets.unity_val.root={args.root}") cfg = _load_cfg(args.config_name, overrides) dataset = instantiate(cfg.test_datasets.unity_val, _recursive_=False) sample = _select_sample(dataset, args.sample_idx) batch = collate_fn([sample], mode="val", collate_cfg=cfg.data.collate_cfg) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = instantiate(cfg.model, _recursive_=False).to(device) load_pretrained_model(model, args.ckpt) model.eval() if device.type == "cuda": batch = to_cuda(batch) with torch.inference_mode(): outputs = model.validation(batch, test_mode="default", batch_idx=0, dataloader_idx=0) smplx = make_smplx("supermotion").to(device) j_reg_path = REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smpl_neutral_J_regressor.pt" smplx2smpl_path = REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smplx2smpl_sparse.pt" J_regressor = torch.load(j_reg_path, map_location=device) smplx2smpl = torch.load(smplx2smpl_path, map_location=device) gt_params_c = {k: v[0] for k, v in batch["smpl_params_c"].items()} pred_params_c = outputs["pred_smpl_params_incam"] gt_out = smplx(**gt_params_c) pred_out = smplx(**pred_params_c) gt_verts = torch.stack([torch.matmul(smplx2smpl, v) for v in gt_out.vertices]) pred_verts = torch.stack([torch.matmul(smplx2smpl, v) for v in pred_out.vertices]) gt_j3d = torch.einsum("jv,fvi->fji", J_regressor, gt_verts) pred_j3d = torch.einsum("jv,fvi->fji", J_regressor, pred_verts) mask = batch["mask"]["valid"][0] if "mask" in batch and "valid" in batch["mask"] else None metrics = compute_camcoord_metrics( { "pred_j3d": pred_j3d, "target_j3d": gt_j3d, "pred_verts": pred_verts, "target_verts": gt_verts, }, mask=mask, ) out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) with open(out_dir / "metrics.json", "w") as f: json.dump({k: float(v.mean()) for k, v in metrics.items()}, f, indent=2) torch.save( { "pred_smpl_params_incam": pred_params_c, "gt_smpl_params_c": gt_params_c, "pred_j3d": pred_j3d.detach().cpu(), "gt_j3d": gt_j3d.detach().cpu(), }, out_dir / "pred_gt.pt", ) print("Wrote:", out_dir / "metrics.json") print("Wrote:", out_dir / "pred_gt.pt") if __name__ == "__main__": main()