vr-hmr / scripts /check_unity_consistency.py
zirobtc's picture
Upload folder using huggingface_hub
7e120dd
#!/usr/bin/env python3
"""
Check internal consistency of Unity processed samples.
This script compares:
- saved `smpl_params_w` vs `smpl_params_c` + `T_w2c`-derived world params
- saved `cam_angvel` (if present) vs recomputed from `T_w2c`
If these disagree by non-trivial amounts, fine-tuning can "break global" because the
network is asked to fit mutually inconsistent targets/condition signals.
"""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Dict, Any, Iterable
import torch
from genmo.utils.geo_transform import compute_cam_angvel, normalize_T_w2c
from genmo.utils.rotation_conversions import axis_angle_to_matrix, matrix_to_axis_angle
def _load_pt(path: Path) -> Dict[str, Any]:
return torch.load(path, map_location="cpu", weights_only=False)
def _axis_angle_angle_deg(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Geodesic angle between rotations in axis-angle form. Shapes (..., 3). Returns (...,) degrees."""
Ra = axis_angle_to_matrix(a.float())
Rb = axis_angle_to_matrix(b.float())
Rrel = Ra.transpose(-1, -2) @ Rb
aa = matrix_to_axis_angle(Rrel)
return aa.norm(dim=-1) * (180.0 / torch.pi)
def _derive_world_from_c_and_T(
smpl_params_c: Dict[str, torch.Tensor], T_w2c: torch.Tensor
) -> Dict[str, torch.Tensor]:
"""
Derive world translation/orientation from camera params and T_w2c.
Assumes:
p_w = R_c2w * p_c + t_c2w
R_w = R_c2w * R_c
"""
R_w2c = T_w2c[:, :3, :3].float() # (L,3,3)
t_w2c = T_w2c[:, :3, 3].float() # (L,3)
R_c2w = R_w2c.transpose(-1, -2)
t_c2w = -torch.einsum("fij,fj->fi", R_c2w, t_w2c)
transl_c = smpl_params_c["transl"].float()
transl_w = torch.einsum("fij,fj->fi", R_c2w, transl_c) + t_c2w
R_c = axis_angle_to_matrix(smpl_params_c["global_orient"].float())
R_w = torch.einsum("fij,fjk->fik", R_c2w, R_c)
go_w = matrix_to_axis_angle(R_w)
return {"transl": transl_w, "global_orient": go_w}
def _maybe_slice(x: Any, start: int, end: int) -> Any:
if isinstance(x, dict):
return {k: _maybe_slice(v, start, end) for k, v in x.items()}
if isinstance(x, torch.Tensor):
return x[start:end]
return x
def _iter_paths(root: Path, glob_pat: str) -> Iterable[Path]:
if root.is_file():
yield root
return
yield from sorted(root.glob(glob_pat))
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument(
"--root",
type=Path,
default=Path("processed_dataset/genmo_features"),
help="A single .pt file or a directory containing .pt files.",
)
ap.add_argument("--glob", type=str, default="*.pt")
ap.add_argument("--max", type=int, default=10, help="Max files to check")
ap.add_argument("--frame0_only", action="store_true", help="Check only frame 0")
args = ap.parse_args()
paths = list(_iter_paths(args.root, args.glob))
if not paths:
raise SystemExit(f"No files found under {args.root} with glob {args.glob!r}")
paths = paths[: max(1, int(args.max))]
all_stats = []
for p in paths:
d = _load_pt(p)
smpl_c = d["smpl_params_c"]
smpl_w = d["smpl_params_w"]
T_w2c = d["T_w2c"]
if args.frame0_only:
smpl_c = _maybe_slice(smpl_c, 0, 1)
smpl_w = _maybe_slice(smpl_w, 0, 1)
T_w2c = _maybe_slice(T_w2c, 0, 1)
derived = _derive_world_from_c_and_T(smpl_c, T_w2c)
diff_t = smpl_w["transl"].float() - derived["transl"] # (L, 3)
dt = diff_t.norm(dim=-1) # (L,)
diff_t_abs_mean = diff_t.abs().mean(dim=0)
diff_t_abs_max = diff_t.abs().max(dim=0)[0]
dgo = _axis_angle_angle_deg(smpl_w["global_orient"].float(), derived["global_orient"]) # (L,)
# cam_angvel consistency
cam_av_saved = d.get("cam_angvel", None)
cam_av_deg = None
if cam_av_saved is not None:
cam_av_saved = cam_av_saved[: T_w2c.shape[0]]
Tw = T_w2c.float()
if Tw.ndim == 3 and Tw.shape[-2:] == (4, 4) and Tw.shape[0] >= 2:
Tw = normalize_T_w2c(Tw)
cam_av_re = compute_cam_angvel(Tw[:, :3, :3]) # (L,6)
# matrix_to_rotation_6d convention: compare directly in 6D space
cam_av_deg = (cam_av_saved.float() - cam_av_re.float()).abs().mean().item()
stat = {
"file": p.name,
"max_transl_err_m": float(dt.max().item()),
"mean_transl_err_m": float(dt.mean().item()),
"mean_abs_dx": float(diff_t_abs_mean[0].item()),
"mean_abs_dy": float(diff_t_abs_mean[1].item()),
"mean_abs_dz": float(diff_t_abs_mean[2].item()),
"max_abs_dx": float(diff_t_abs_max[0].item()),
"max_abs_dy": float(diff_t_abs_max[1].item()),
"max_abs_dz": float(diff_t_abs_max[2].item()),
"max_go_err_deg": float(dgo.max().item()),
"mean_go_err_deg": float(dgo.mean().item()),
"cam_angvel_absmean_diff": None if cam_av_deg is None else float(cam_av_deg),
}
all_stats.append(stat)
print(
f"{p.name}: "
f"transl_err mean={stat['mean_transl_err_m']:.4f}m max={stat['max_transl_err_m']:.4f}m | "
f"|Δt| mean_abs(x,y,z)=({stat['mean_abs_dx']:.4f},{stat['mean_abs_dy']:.4f},{stat['mean_abs_dz']:.4f}) "
f"max_abs(x,y,z)=({stat['max_abs_dx']:.4f},{stat['max_abs_dy']:.4f},{stat['max_abs_dz']:.4f}) | "
f"go_err mean={stat['mean_go_err_deg']:.2f}° max={stat['max_go_err_deg']:.2f}°"
+ (
""
if stat["cam_angvel_absmean_diff"] is None
else f" | cam_angvel_absmean_diff={stat['cam_angvel_absmean_diff']:.6f}"
)
)
# Aggregate
max_transl = max(s["max_transl_err_m"] for s in all_stats)
max_go = max(s["max_go_err_deg"] for s in all_stats)
print(f"\nAggregate over {len(all_stats)} files: max_transl_err_m={max_transl:.4f}, max_go_err_deg={max_go:.2f}")
if __name__ == "__main__":
main()