#!/usr/bin/env python3 """ Flip Unity genmo_features .pt files by rotating existing smpl_params_w only. This keeps the world/camera untouched and avoids re-running the pipeline. """ from __future__ import annotations import argparse import math import sys from pathlib import Path import torch REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from genmo.utils.rotation_conversions import axis_angle_to_matrix, matrix_to_axis_angle def _squeeze_batch(x: torch.Tensor) -> torch.Tensor: if x.ndim >= 3 and x.shape[0] == 1: return x[0] return x def _axis_angle_from_axis(axis: str, degrees: float) -> torch.Tensor: axis = axis.lower().strip() if axis == "x": vec = torch.tensor([1.0, 0.0, 0.0], dtype=torch.float32) elif axis == "y": vec = torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32) elif axis == "z": vec = torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32) else: raise ValueError(f"Unsupported axis: {axis}") return vec * float(math.radians(float(degrees))) def _fix_one(path: Path, out_dir: Path | None) -> None: data = torch.load(path, map_location="cpu", weights_only=False) if "smpl_params_w" not in data: print(f"[skip] Missing smpl_params_w in {path.name}") return smpl_w = data["smpl_params_w"] global_orient_w = _squeeze_batch(smpl_w["global_orient"]).float() transl_w = _squeeze_batch(smpl_w["transl"]).float() if global_orient_w.ndim != 2 or transl_w.ndim != 2: print(f"[skip] Unexpected shapes in {path.name}") return # Hardcoded SMPLX-only flip to fix upside-down global orientation. aa_fix = _axis_angle_from_axis("x", 180.0).to(global_orient_w) R_fix = axis_angle_to_matrix(aa_fix[None])[0] R_w = torch.matmul(R_fix[None], axis_angle_to_matrix(global_orient_w)) global_orient_w = matrix_to_axis_angle(R_w) transl_w = torch.einsum("ij,fj->fi", R_fix, transl_w) smpl_w["global_orient"] = global_orient_w smpl_w["transl"] = transl_w data["smpl_params_w"] = smpl_w out_path = (out_dir / path.name) if out_dir else path if out_dir: out_dir.mkdir(parents=True, exist_ok=True) torch.save(data, out_path) print(f"[ok] {path.name}") def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--features-dir", required=True, help="Path to genmo_features directory") parser.add_argument("--output-dir", default=None, help="Optional output dir (defaults to in-place)") args = parser.parse_args() features_dir = Path(args.features_dir) if not features_dir.is_dir(): raise FileNotFoundError(f"features dir not found: {features_dir}") out_dir = Path(args.output_dir) if args.output_dir else None pt_files = sorted(features_dir.glob("*.pt")) if not pt_files: print(f"No .pt files found in {features_dir}") return for path in pt_files: _fix_one(path, out_dir) if __name__ == "__main__": main()