| |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| from pathlib import Path |
|
|
| import torch |
|
|
| |
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
| GVHMR_ROOT = REPO_ROOT / "third_party" / "GVHMR" |
| if GVHMR_ROOT.is_dir() and str(GVHMR_ROOT) not in sys.path: |
| sys.path.insert(0, str(GVHMR_ROOT)) |
|
|
| from hydra import compose, initialize_config_dir |
| from hydra.utils import instantiate |
|
|
| from genmo.datamodule.mocap_trainX_testY import collate_fn |
| from genmo.utils.eval_utils import compute_camcoord_metrics |
| from genmo.utils.net_utils import load_pretrained_model, to_cuda |
| from third_party.GVHMR.hmr4d.utils.smplx_utils import make_smplx |
|
|
|
|
| def _load_cfg(config_name: str, overrides: list[str]): |
| with initialize_config_dir(config_dir=str(REPO_ROOT / "configs"), version_base=None): |
| return compose(config_name=config_name, overrides=overrides) |
|
|
|
|
| def _select_sample(dataset, idx: int | None): |
| if idx is None: |
| idx = 0 |
| return dataset[int(idx)] |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--ckpt", required=True, help="Path to checkpoint") |
| parser.add_argument("--config-name", default="finetune_unity") |
| parser.add_argument("--root", default=None, help="Dataset root (processed_dataset)") |
| parser.add_argument("--sample-idx", type=int, default=0) |
| parser.add_argument("--out-dir", default="outputs/debug_unity_sample") |
| args = parser.parse_args() |
|
|
| overrides = [] |
| if args.root: |
| overrides.append(f"train_datasets.unity.root={args.root}") |
| overrides.append(f"test_datasets.unity_val.root={args.root}") |
|
|
| cfg = _load_cfg(args.config_name, overrides) |
|
|
| dataset = instantiate(cfg.test_datasets.unity_val, _recursive_=False) |
| sample = _select_sample(dataset, args.sample_idx) |
| batch = collate_fn([sample], mode="val", collate_cfg=cfg.data.collate_cfg) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = instantiate(cfg.model, _recursive_=False).to(device) |
| load_pretrained_model(model, args.ckpt) |
| model.eval() |
|
|
| if device.type == "cuda": |
| batch = to_cuda(batch) |
|
|
| with torch.inference_mode(): |
| outputs = model.validation(batch, test_mode="default", batch_idx=0, dataloader_idx=0) |
|
|
| smplx = make_smplx("supermotion").to(device) |
| j_reg_path = REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smpl_neutral_J_regressor.pt" |
| smplx2smpl_path = REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smplx2smpl_sparse.pt" |
| J_regressor = torch.load(j_reg_path, map_location=device) |
| smplx2smpl = torch.load(smplx2smpl_path, map_location=device) |
|
|
| gt_params_c = {k: v[0] for k, v in batch["smpl_params_c"].items()} |
| pred_params_c = outputs["pred_smpl_params_incam"] |
|
|
| gt_out = smplx(**gt_params_c) |
| pred_out = smplx(**pred_params_c) |
|
|
| gt_verts = torch.stack([torch.matmul(smplx2smpl, v) for v in gt_out.vertices]) |
| pred_verts = torch.stack([torch.matmul(smplx2smpl, v) for v in pred_out.vertices]) |
|
|
| gt_j3d = torch.einsum("jv,fvi->fji", J_regressor, gt_verts) |
| pred_j3d = torch.einsum("jv,fvi->fji", J_regressor, pred_verts) |
|
|
| mask = batch["mask"]["valid"][0] if "mask" in batch and "valid" in batch["mask"] else None |
| metrics = compute_camcoord_metrics( |
| { |
| "pred_j3d": pred_j3d, |
| "target_j3d": gt_j3d, |
| "pred_verts": pred_verts, |
| "target_verts": gt_verts, |
| }, |
| mask=mask, |
| ) |
|
|
| out_dir = Path(args.out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| with open(out_dir / "metrics.json", "w") as f: |
| json.dump({k: float(v.mean()) for k, v in metrics.items()}, f, indent=2) |
|
|
| torch.save( |
| { |
| "pred_smpl_params_incam": pred_params_c, |
| "gt_smpl_params_c": gt_params_c, |
| "pred_j3d": pred_j3d.detach().cpu(), |
| "gt_j3d": gt_j3d.detach().cpu(), |
| }, |
| out_dir / "pred_gt.pt", |
| ) |
|
|
| print("Wrote:", out_dir / "metrics.json") |
| print("Wrote:", out_dir / "pred_gt.pt") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|