vr-hmr / scripts /fix_genmo_features_world.py
zirobtc's picture
Upload folder using huggingface_hub
7e120dd
#!/usr/bin/env python3
"""
Flip Unity genmo_features .pt files by rotating existing smpl_params_w only.
This keeps the world/camera untouched and avoids re-running the pipeline.
"""
from __future__ import annotations
import argparse
import math
import sys
from pathlib import Path
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from genmo.utils.rotation_conversions import axis_angle_to_matrix, matrix_to_axis_angle
def _squeeze_batch(x: torch.Tensor) -> torch.Tensor:
if x.ndim >= 3 and x.shape[0] == 1:
return x[0]
return x
def _axis_angle_from_axis(axis: str, degrees: float) -> torch.Tensor:
axis = axis.lower().strip()
if axis == "x":
vec = torch.tensor([1.0, 0.0, 0.0], dtype=torch.float32)
elif axis == "y":
vec = torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32)
elif axis == "z":
vec = torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32)
else:
raise ValueError(f"Unsupported axis: {axis}")
return vec * float(math.radians(float(degrees)))
def _fix_one(path: Path, out_dir: Path | None) -> None:
data = torch.load(path, map_location="cpu", weights_only=False)
if "smpl_params_w" not in data:
print(f"[skip] Missing smpl_params_w in {path.name}")
return
smpl_w = data["smpl_params_w"]
global_orient_w = _squeeze_batch(smpl_w["global_orient"]).float()
transl_w = _squeeze_batch(smpl_w["transl"]).float()
if global_orient_w.ndim != 2 or transl_w.ndim != 2:
print(f"[skip] Unexpected shapes in {path.name}")
return
# Hardcoded SMPLX-only flip to fix upside-down global orientation.
aa_fix = _axis_angle_from_axis("x", 180.0).to(global_orient_w)
R_fix = axis_angle_to_matrix(aa_fix[None])[0]
R_w = torch.matmul(R_fix[None], axis_angle_to_matrix(global_orient_w))
global_orient_w = matrix_to_axis_angle(R_w)
transl_w = torch.einsum("ij,fj->fi", R_fix, transl_w)
smpl_w["global_orient"] = global_orient_w
smpl_w["transl"] = transl_w
data["smpl_params_w"] = smpl_w
out_path = (out_dir / path.name) if out_dir else path
if out_dir:
out_dir.mkdir(parents=True, exist_ok=True)
torch.save(data, out_path)
print(f"[ok] {path.name}")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--features-dir", required=True, help="Path to genmo_features directory")
parser.add_argument("--output-dir", default=None, help="Optional output dir (defaults to in-place)")
args = parser.parse_args()
features_dir = Path(args.features_dir)
if not features_dir.is_dir():
raise FileNotFoundError(f"features dir not found: {features_dir}")
out_dir = Path(args.output_dir) if args.output_dir else None
pt_files = sorted(features_dir.glob("*.pt"))
if not pt_files:
print(f"No .pt files found in {features_dir}")
return
for path in pt_files:
_fix_one(path, out_dir)
if __name__ == "__main__":
main()