| |
| """ |
| Check internal consistency of Unity processed samples. |
| |
| This script compares: |
| - saved `smpl_params_w` vs `smpl_params_c` + `T_w2c`-derived world params |
| - saved `cam_angvel` (if present) vs recomputed from `T_w2c` |
| |
| If these disagree by non-trivial amounts, fine-tuning can "break global" because the |
| network is asked to fit mutually inconsistent targets/condition signals. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
| from typing import Dict, Any, Iterable |
|
|
| import torch |
|
|
| from genmo.utils.geo_transform import compute_cam_angvel, normalize_T_w2c |
| from genmo.utils.rotation_conversions import axis_angle_to_matrix, matrix_to_axis_angle |
|
|
|
|
| def _load_pt(path: Path) -> Dict[str, Any]: |
| return torch.load(path, map_location="cpu", weights_only=False) |
|
|
|
|
| def _axis_angle_angle_deg(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
| """Geodesic angle between rotations in axis-angle form. Shapes (..., 3). Returns (...,) degrees.""" |
| Ra = axis_angle_to_matrix(a.float()) |
| Rb = axis_angle_to_matrix(b.float()) |
| Rrel = Ra.transpose(-1, -2) @ Rb |
| aa = matrix_to_axis_angle(Rrel) |
| return aa.norm(dim=-1) * (180.0 / torch.pi) |
|
|
|
|
| def _derive_world_from_c_and_T( |
| smpl_params_c: Dict[str, torch.Tensor], T_w2c: torch.Tensor |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Derive world translation/orientation from camera params and T_w2c. |
| |
| Assumes: |
| p_w = R_c2w * p_c + t_c2w |
| R_w = R_c2w * R_c |
| """ |
| R_w2c = T_w2c[:, :3, :3].float() |
| t_w2c = T_w2c[:, :3, 3].float() |
| R_c2w = R_w2c.transpose(-1, -2) |
| t_c2w = -torch.einsum("fij,fj->fi", R_c2w, t_w2c) |
|
|
| transl_c = smpl_params_c["transl"].float() |
| transl_w = torch.einsum("fij,fj->fi", R_c2w, transl_c) + t_c2w |
|
|
| R_c = axis_angle_to_matrix(smpl_params_c["global_orient"].float()) |
| R_w = torch.einsum("fij,fjk->fik", R_c2w, R_c) |
| go_w = matrix_to_axis_angle(R_w) |
|
|
| return {"transl": transl_w, "global_orient": go_w} |
|
|
|
|
| def _maybe_slice(x: Any, start: int, end: int) -> Any: |
| if isinstance(x, dict): |
| return {k: _maybe_slice(v, start, end) for k, v in x.items()} |
| if isinstance(x, torch.Tensor): |
| return x[start:end] |
| return x |
|
|
|
|
| def _iter_paths(root: Path, glob_pat: str) -> Iterable[Path]: |
| if root.is_file(): |
| yield root |
| return |
| yield from sorted(root.glob(glob_pat)) |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser() |
| ap.add_argument( |
| "--root", |
| type=Path, |
| default=Path("processed_dataset/genmo_features"), |
| help="A single .pt file or a directory containing .pt files.", |
| ) |
| ap.add_argument("--glob", type=str, default="*.pt") |
| ap.add_argument("--max", type=int, default=10, help="Max files to check") |
| ap.add_argument("--frame0_only", action="store_true", help="Check only frame 0") |
| args = ap.parse_args() |
|
|
| paths = list(_iter_paths(args.root, args.glob)) |
| if not paths: |
| raise SystemExit(f"No files found under {args.root} with glob {args.glob!r}") |
| paths = paths[: max(1, int(args.max))] |
|
|
| all_stats = [] |
| for p in paths: |
| d = _load_pt(p) |
| smpl_c = d["smpl_params_c"] |
| smpl_w = d["smpl_params_w"] |
| T_w2c = d["T_w2c"] |
|
|
| if args.frame0_only: |
| smpl_c = _maybe_slice(smpl_c, 0, 1) |
| smpl_w = _maybe_slice(smpl_w, 0, 1) |
| T_w2c = _maybe_slice(T_w2c, 0, 1) |
|
|
| derived = _derive_world_from_c_and_T(smpl_c, T_w2c) |
|
|
| diff_t = smpl_w["transl"].float() - derived["transl"] |
| dt = diff_t.norm(dim=-1) |
| diff_t_abs_mean = diff_t.abs().mean(dim=0) |
| diff_t_abs_max = diff_t.abs().max(dim=0)[0] |
| dgo = _axis_angle_angle_deg(smpl_w["global_orient"].float(), derived["global_orient"]) |
|
|
| |
| cam_av_saved = d.get("cam_angvel", None) |
| cam_av_deg = None |
| if cam_av_saved is not None: |
| cam_av_saved = cam_av_saved[: T_w2c.shape[0]] |
| Tw = T_w2c.float() |
| if Tw.ndim == 3 and Tw.shape[-2:] == (4, 4) and Tw.shape[0] >= 2: |
| Tw = normalize_T_w2c(Tw) |
| cam_av_re = compute_cam_angvel(Tw[:, :3, :3]) |
| |
| cam_av_deg = (cam_av_saved.float() - cam_av_re.float()).abs().mean().item() |
|
|
| stat = { |
| "file": p.name, |
| "max_transl_err_m": float(dt.max().item()), |
| "mean_transl_err_m": float(dt.mean().item()), |
| "mean_abs_dx": float(diff_t_abs_mean[0].item()), |
| "mean_abs_dy": float(diff_t_abs_mean[1].item()), |
| "mean_abs_dz": float(diff_t_abs_mean[2].item()), |
| "max_abs_dx": float(diff_t_abs_max[0].item()), |
| "max_abs_dy": float(diff_t_abs_max[1].item()), |
| "max_abs_dz": float(diff_t_abs_max[2].item()), |
| "max_go_err_deg": float(dgo.max().item()), |
| "mean_go_err_deg": float(dgo.mean().item()), |
| "cam_angvel_absmean_diff": None if cam_av_deg is None else float(cam_av_deg), |
| } |
| all_stats.append(stat) |
|
|
| print( |
| f"{p.name}: " |
| f"transl_err mean={stat['mean_transl_err_m']:.4f}m max={stat['max_transl_err_m']:.4f}m | " |
| f"|Δt| mean_abs(x,y,z)=({stat['mean_abs_dx']:.4f},{stat['mean_abs_dy']:.4f},{stat['mean_abs_dz']:.4f}) " |
| f"max_abs(x,y,z)=({stat['max_abs_dx']:.4f},{stat['max_abs_dy']:.4f},{stat['max_abs_dz']:.4f}) | " |
| f"go_err mean={stat['mean_go_err_deg']:.2f}° max={stat['max_go_err_deg']:.2f}°" |
| + ( |
| "" |
| if stat["cam_angvel_absmean_diff"] is None |
| else f" | cam_angvel_absmean_diff={stat['cam_angvel_absmean_diff']:.6f}" |
| ) |
| ) |
|
|
| |
| max_transl = max(s["max_transl_err_m"] for s in all_stats) |
| max_go = max(s["max_go_err_deg"] for s in all_stats) |
| print(f"\nAggregate over {len(all_stats)} files: max_transl_err_m={max_transl:.4f}, max_go_err_deg={max_go:.2f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|