| |
| """ |
| Flip Unity genmo_features .pt files by rotating existing smpl_params_w only. |
| This keeps the world/camera untouched and avoids re-running the pipeline. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import math |
| import sys |
| from pathlib import Path |
|
|
| import torch |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from genmo.utils.rotation_conversions import axis_angle_to_matrix, matrix_to_axis_angle |
|
|
|
|
| def _squeeze_batch(x: torch.Tensor) -> torch.Tensor: |
| if x.ndim >= 3 and x.shape[0] == 1: |
| return x[0] |
| return x |
|
|
|
|
| def _axis_angle_from_axis(axis: str, degrees: float) -> torch.Tensor: |
| axis = axis.lower().strip() |
| if axis == "x": |
| vec = torch.tensor([1.0, 0.0, 0.0], dtype=torch.float32) |
| elif axis == "y": |
| vec = torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32) |
| elif axis == "z": |
| vec = torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32) |
| else: |
| raise ValueError(f"Unsupported axis: {axis}") |
| return vec * float(math.radians(float(degrees))) |
|
|
|
|
| def _fix_one(path: Path, out_dir: Path | None) -> None: |
| data = torch.load(path, map_location="cpu", weights_only=False) |
|
|
| if "smpl_params_w" not in data: |
| print(f"[skip] Missing smpl_params_w in {path.name}") |
| return |
|
|
| smpl_w = data["smpl_params_w"] |
| global_orient_w = _squeeze_batch(smpl_w["global_orient"]).float() |
| transl_w = _squeeze_batch(smpl_w["transl"]).float() |
|
|
| if global_orient_w.ndim != 2 or transl_w.ndim != 2: |
| print(f"[skip] Unexpected shapes in {path.name}") |
| return |
|
|
| |
| aa_fix = _axis_angle_from_axis("x", 180.0).to(global_orient_w) |
| R_fix = axis_angle_to_matrix(aa_fix[None])[0] |
| R_w = torch.matmul(R_fix[None], axis_angle_to_matrix(global_orient_w)) |
| global_orient_w = matrix_to_axis_angle(R_w) |
| transl_w = torch.einsum("ij,fj->fi", R_fix, transl_w) |
|
|
| smpl_w["global_orient"] = global_orient_w |
| smpl_w["transl"] = transl_w |
| data["smpl_params_w"] = smpl_w |
|
|
| out_path = (out_dir / path.name) if out_dir else path |
| if out_dir: |
| out_dir.mkdir(parents=True, exist_ok=True) |
| torch.save(data, out_path) |
| print(f"[ok] {path.name}") |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--features-dir", required=True, help="Path to genmo_features directory") |
| parser.add_argument("--output-dir", default=None, help="Optional output dir (defaults to in-place)") |
| args = parser.parse_args() |
|
|
| features_dir = Path(args.features_dir) |
| if not features_dir.is_dir(): |
| raise FileNotFoundError(f"features dir not found: {features_dir}") |
|
|
| out_dir = Path(args.output_dir) if args.output_dir else None |
| pt_files = sorted(features_dir.glob("*.pt")) |
| if not pt_files: |
| print(f"No .pt files found in {features_dir}") |
| return |
|
|
| for path in pt_files: |
| _fix_one(path, out_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|