File size: 3,073 Bytes
7e120dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python3
"""
Flip Unity genmo_features .pt files by rotating existing smpl_params_w only.
This keeps the world/camera untouched and avoids re-running the pipeline.
"""
from __future__ import annotations

import argparse
import math
import sys
from pathlib import Path

import torch

REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from genmo.utils.rotation_conversions import axis_angle_to_matrix, matrix_to_axis_angle


def _squeeze_batch(x: torch.Tensor) -> torch.Tensor:
    if x.ndim >= 3 and x.shape[0] == 1:
        return x[0]
    return x


def _axis_angle_from_axis(axis: str, degrees: float) -> torch.Tensor:
    axis = axis.lower().strip()
    if axis == "x":
        vec = torch.tensor([1.0, 0.0, 0.0], dtype=torch.float32)
    elif axis == "y":
        vec = torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32)
    elif axis == "z":
        vec = torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32)
    else:
        raise ValueError(f"Unsupported axis: {axis}")
    return vec * float(math.radians(float(degrees)))


def _fix_one(path: Path, out_dir: Path | None) -> None:
    data = torch.load(path, map_location="cpu", weights_only=False)

    if "smpl_params_w" not in data:
        print(f"[skip] Missing smpl_params_w in {path.name}")
        return

    smpl_w = data["smpl_params_w"]
    global_orient_w = _squeeze_batch(smpl_w["global_orient"]).float()
    transl_w = _squeeze_batch(smpl_w["transl"]).float()

    if global_orient_w.ndim != 2 or transl_w.ndim != 2:
        print(f"[skip] Unexpected shapes in {path.name}")
        return

    # Hardcoded SMPLX-only flip to fix upside-down global orientation.
    aa_fix = _axis_angle_from_axis("x", 180.0).to(global_orient_w)
    R_fix = axis_angle_to_matrix(aa_fix[None])[0]
    R_w = torch.matmul(R_fix[None], axis_angle_to_matrix(global_orient_w))
    global_orient_w = matrix_to_axis_angle(R_w)
    transl_w = torch.einsum("ij,fj->fi", R_fix, transl_w)

    smpl_w["global_orient"] = global_orient_w
    smpl_w["transl"] = transl_w
    data["smpl_params_w"] = smpl_w

    out_path = (out_dir / path.name) if out_dir else path
    if out_dir:
        out_dir.mkdir(parents=True, exist_ok=True)
    torch.save(data, out_path)
    print(f"[ok] {path.name}")


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--features-dir", required=True, help="Path to genmo_features directory")
    parser.add_argument("--output-dir", default=None, help="Optional output dir (defaults to in-place)")
    args = parser.parse_args()

    features_dir = Path(args.features_dir)
    if not features_dir.is_dir():
        raise FileNotFoundError(f"features dir not found: {features_dir}")

    out_dir = Path(args.output_dir) if args.output_dir else None
    pt_files = sorted(features_dir.glob("*.pt"))
    if not pt_files:
        print(f"No .pt files found in {features_dir}")
        return

    for path in pt_files:
        _fix_one(path, out_dir)


if __name__ == "__main__":
    main()