File size: 4,285 Bytes
7e120dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | #!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
import sys
from pathlib import Path
import torch
# Ensure repo + GVHMR roots are importable.
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
GVHMR_ROOT = REPO_ROOT / "third_party" / "GVHMR"
if GVHMR_ROOT.is_dir() and str(GVHMR_ROOT) not in sys.path:
sys.path.insert(0, str(GVHMR_ROOT))
from hydra import compose, initialize_config_dir
from hydra.utils import instantiate
from genmo.datamodule.mocap_trainX_testY import collate_fn
from genmo.utils.eval_utils import compute_camcoord_metrics
from genmo.utils.net_utils import load_pretrained_model, to_cuda
from third_party.GVHMR.hmr4d.utils.smplx_utils import make_smplx
def _load_cfg(config_name: str, overrides: list[str]):
with initialize_config_dir(config_dir=str(REPO_ROOT / "configs"), version_base=None):
return compose(config_name=config_name, overrides=overrides)
def _select_sample(dataset, idx: int | None):
if idx is None:
idx = 0
return dataset[int(idx)]
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", required=True, help="Path to checkpoint")
parser.add_argument("--config-name", default="finetune_unity")
parser.add_argument("--root", default=None, help="Dataset root (processed_dataset)")
parser.add_argument("--sample-idx", type=int, default=0)
parser.add_argument("--out-dir", default="outputs/debug_unity_sample")
args = parser.parse_args()
overrides = []
if args.root:
overrides.append(f"train_datasets.unity.root={args.root}")
overrides.append(f"test_datasets.unity_val.root={args.root}")
cfg = _load_cfg(args.config_name, overrides)
dataset = instantiate(cfg.test_datasets.unity_val, _recursive_=False)
sample = _select_sample(dataset, args.sample_idx)
batch = collate_fn([sample], mode="val", collate_cfg=cfg.data.collate_cfg)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = instantiate(cfg.model, _recursive_=False).to(device)
load_pretrained_model(model, args.ckpt)
model.eval()
if device.type == "cuda":
batch = to_cuda(batch)
with torch.inference_mode():
outputs = model.validation(batch, test_mode="default", batch_idx=0, dataloader_idx=0)
smplx = make_smplx("supermotion").to(device)
j_reg_path = REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smpl_neutral_J_regressor.pt"
smplx2smpl_path = REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smplx2smpl_sparse.pt"
J_regressor = torch.load(j_reg_path, map_location=device)
smplx2smpl = torch.load(smplx2smpl_path, map_location=device)
gt_params_c = {k: v[0] for k, v in batch["smpl_params_c"].items()}
pred_params_c = outputs["pred_smpl_params_incam"]
gt_out = smplx(**gt_params_c)
pred_out = smplx(**pred_params_c)
gt_verts = torch.stack([torch.matmul(smplx2smpl, v) for v in gt_out.vertices])
pred_verts = torch.stack([torch.matmul(smplx2smpl, v) for v in pred_out.vertices])
gt_j3d = torch.einsum("jv,fvi->fji", J_regressor, gt_verts)
pred_j3d = torch.einsum("jv,fvi->fji", J_regressor, pred_verts)
mask = batch["mask"]["valid"][0] if "mask" in batch and "valid" in batch["mask"] else None
metrics = compute_camcoord_metrics(
{
"pred_j3d": pred_j3d,
"target_j3d": gt_j3d,
"pred_verts": pred_verts,
"target_verts": gt_verts,
},
mask=mask,
)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
with open(out_dir / "metrics.json", "w") as f:
json.dump({k: float(v.mean()) for k, v in metrics.items()}, f, indent=2)
torch.save(
{
"pred_smpl_params_incam": pred_params_c,
"gt_smpl_params_c": gt_params_c,
"pred_j3d": pred_j3d.detach().cpu(),
"gt_j3d": gt_j3d.detach().cpu(),
},
out_dir / "pred_gt.pt",
)
print("Wrote:", out_dir / "metrics.json")
print("Wrote:", out_dir / "pred_gt.pt")
if __name__ == "__main__":
main()
|