File size: 3,073 Bytes
7e120dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | #!/usr/bin/env python3
"""
Flip Unity genmo_features .pt files by rotating existing smpl_params_w only.
This keeps the world/camera untouched and avoids re-running the pipeline.
"""
from __future__ import annotations
import argparse
import math
import sys
from pathlib import Path
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from genmo.utils.rotation_conversions import axis_angle_to_matrix, matrix_to_axis_angle
def _squeeze_batch(x: torch.Tensor) -> torch.Tensor:
if x.ndim >= 3 and x.shape[0] == 1:
return x[0]
return x
def _axis_angle_from_axis(axis: str, degrees: float) -> torch.Tensor:
axis = axis.lower().strip()
if axis == "x":
vec = torch.tensor([1.0, 0.0, 0.0], dtype=torch.float32)
elif axis == "y":
vec = torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32)
elif axis == "z":
vec = torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32)
else:
raise ValueError(f"Unsupported axis: {axis}")
return vec * float(math.radians(float(degrees)))
def _fix_one(path: Path, out_dir: Path | None) -> None:
data = torch.load(path, map_location="cpu", weights_only=False)
if "smpl_params_w" not in data:
print(f"[skip] Missing smpl_params_w in {path.name}")
return
smpl_w = data["smpl_params_w"]
global_orient_w = _squeeze_batch(smpl_w["global_orient"]).float()
transl_w = _squeeze_batch(smpl_w["transl"]).float()
if global_orient_w.ndim != 2 or transl_w.ndim != 2:
print(f"[skip] Unexpected shapes in {path.name}")
return
# Hardcoded SMPLX-only flip to fix upside-down global orientation.
aa_fix = _axis_angle_from_axis("x", 180.0).to(global_orient_w)
R_fix = axis_angle_to_matrix(aa_fix[None])[0]
R_w = torch.matmul(R_fix[None], axis_angle_to_matrix(global_orient_w))
global_orient_w = matrix_to_axis_angle(R_w)
transl_w = torch.einsum("ij,fj->fi", R_fix, transl_w)
smpl_w["global_orient"] = global_orient_w
smpl_w["transl"] = transl_w
data["smpl_params_w"] = smpl_w
out_path = (out_dir / path.name) if out_dir else path
if out_dir:
out_dir.mkdir(parents=True, exist_ok=True)
torch.save(data, out_path)
print(f"[ok] {path.name}")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--features-dir", required=True, help="Path to genmo_features directory")
parser.add_argument("--output-dir", default=None, help="Optional output dir (defaults to in-place)")
args = parser.parse_args()
features_dir = Path(args.features_dir)
if not features_dir.is_dir():
raise FileNotFoundError(f"features dir not found: {features_dir}")
out_dir = Path(args.output_dir) if args.output_dir else None
pt_files = sorted(features_dir.glob("*.pt"))
if not pt_files:
print(f"No .pt files found in {features_dir}")
return
for path in pt_files:
_fix_one(path, out_dir)
if __name__ == "__main__":
main()
|