File size: 4,285 Bytes
7e120dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import os
import sys
from pathlib import Path

import torch

# Ensure repo + GVHMR roots are importable.
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))
GVHMR_ROOT = REPO_ROOT / "third_party" / "GVHMR"
if GVHMR_ROOT.is_dir() and str(GVHMR_ROOT) not in sys.path:
    sys.path.insert(0, str(GVHMR_ROOT))

from hydra import compose, initialize_config_dir
from hydra.utils import instantiate

from genmo.datamodule.mocap_trainX_testY import collate_fn
from genmo.utils.eval_utils import compute_camcoord_metrics
from genmo.utils.net_utils import load_pretrained_model, to_cuda
from third_party.GVHMR.hmr4d.utils.smplx_utils import make_smplx


def _load_cfg(config_name: str, overrides: list[str]):
    with initialize_config_dir(config_dir=str(REPO_ROOT / "configs"), version_base=None):
        return compose(config_name=config_name, overrides=overrides)


def _select_sample(dataset, idx: int | None):
    if idx is None:
        idx = 0
    return dataset[int(idx)]


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt", required=True, help="Path to checkpoint")
    parser.add_argument("--config-name", default="finetune_unity")
    parser.add_argument("--root", default=None, help="Dataset root (processed_dataset)")
    parser.add_argument("--sample-idx", type=int, default=0)
    parser.add_argument("--out-dir", default="outputs/debug_unity_sample")
    args = parser.parse_args()

    overrides = []
    if args.root:
        overrides.append(f"train_datasets.unity.root={args.root}")
        overrides.append(f"test_datasets.unity_val.root={args.root}")

    cfg = _load_cfg(args.config_name, overrides)

    dataset = instantiate(cfg.test_datasets.unity_val, _recursive_=False)
    sample = _select_sample(dataset, args.sample_idx)
    batch = collate_fn([sample], mode="val", collate_cfg=cfg.data.collate_cfg)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = instantiate(cfg.model, _recursive_=False).to(device)
    load_pretrained_model(model, args.ckpt)
    model.eval()

    if device.type == "cuda":
        batch = to_cuda(batch)

    with torch.inference_mode():
        outputs = model.validation(batch, test_mode="default", batch_idx=0, dataloader_idx=0)

    smplx = make_smplx("supermotion").to(device)
    j_reg_path = REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smpl_neutral_J_regressor.pt"
    smplx2smpl_path = REPO_ROOT / "third_party" / "GVHMR" / "inputs" / "checkpoints" / "body_models" / "smplx2smpl_sparse.pt"
    J_regressor = torch.load(j_reg_path, map_location=device)
    smplx2smpl = torch.load(smplx2smpl_path, map_location=device)

    gt_params_c = {k: v[0] for k, v in batch["smpl_params_c"].items()}
    pred_params_c = outputs["pred_smpl_params_incam"]

    gt_out = smplx(**gt_params_c)
    pred_out = smplx(**pred_params_c)

    gt_verts = torch.stack([torch.matmul(smplx2smpl, v) for v in gt_out.vertices])
    pred_verts = torch.stack([torch.matmul(smplx2smpl, v) for v in pred_out.vertices])

    gt_j3d = torch.einsum("jv,fvi->fji", J_regressor, gt_verts)
    pred_j3d = torch.einsum("jv,fvi->fji", J_regressor, pred_verts)

    mask = batch["mask"]["valid"][0] if "mask" in batch and "valid" in batch["mask"] else None
    metrics = compute_camcoord_metrics(
        {
            "pred_j3d": pred_j3d,
            "target_j3d": gt_j3d,
            "pred_verts": pred_verts,
            "target_verts": gt_verts,
        },
        mask=mask,
    )

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    with open(out_dir / "metrics.json", "w") as f:
        json.dump({k: float(v.mean()) for k, v in metrics.items()}, f, indent=2)

    torch.save(
        {
            "pred_smpl_params_incam": pred_params_c,
            "gt_smpl_params_c": gt_params_c,
            "pred_j3d": pred_j3d.detach().cpu(),
            "gt_j3d": gt_j3d.detach().cpu(),
        },
        out_dir / "pred_gt.pt",
    )

    print("Wrote:", out_dir / "metrics.json")
    print("Wrote:", out_dir / "pred_gt.pt")


if __name__ == "__main__":
    main()