#!/usr/bin/env python3 """ Check internal consistency of Unity processed samples. This script compares: - saved `smpl_params_w` vs `smpl_params_c` + `T_w2c`-derived world params - saved `cam_angvel` (if present) vs recomputed from `T_w2c` If these disagree by non-trivial amounts, fine-tuning can "break global" because the network is asked to fit mutually inconsistent targets/condition signals. """ from __future__ import annotations import argparse from pathlib import Path from typing import Dict, Any, Iterable import torch from genmo.utils.geo_transform import compute_cam_angvel, normalize_T_w2c from genmo.utils.rotation_conversions import axis_angle_to_matrix, matrix_to_axis_angle def _load_pt(path: Path) -> Dict[str, Any]: return torch.load(path, map_location="cpu", weights_only=False) def _axis_angle_angle_deg(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Geodesic angle between rotations in axis-angle form. Shapes (..., 3). Returns (...,) degrees.""" Ra = axis_angle_to_matrix(a.float()) Rb = axis_angle_to_matrix(b.float()) Rrel = Ra.transpose(-1, -2) @ Rb aa = matrix_to_axis_angle(Rrel) return aa.norm(dim=-1) * (180.0 / torch.pi) def _derive_world_from_c_and_T( smpl_params_c: Dict[str, torch.Tensor], T_w2c: torch.Tensor ) -> Dict[str, torch.Tensor]: """ Derive world translation/orientation from camera params and T_w2c. Assumes: p_w = R_c2w * p_c + t_c2w R_w = R_c2w * R_c """ R_w2c = T_w2c[:, :3, :3].float() # (L,3,3) t_w2c = T_w2c[:, :3, 3].float() # (L,3) R_c2w = R_w2c.transpose(-1, -2) t_c2w = -torch.einsum("fij,fj->fi", R_c2w, t_w2c) transl_c = smpl_params_c["transl"].float() transl_w = torch.einsum("fij,fj->fi", R_c2w, transl_c) + t_c2w R_c = axis_angle_to_matrix(smpl_params_c["global_orient"].float()) R_w = torch.einsum("fij,fjk->fik", R_c2w, R_c) go_w = matrix_to_axis_angle(R_w) return {"transl": transl_w, "global_orient": go_w} def _maybe_slice(x: Any, start: int, end: int) -> Any: if isinstance(x, dict): return {k: _maybe_slice(v, start, end) for k, v in x.items()} if isinstance(x, torch.Tensor): return x[start:end] return x def _iter_paths(root: Path, glob_pat: str) -> Iterable[Path]: if root.is_file(): yield root return yield from sorted(root.glob(glob_pat)) def main() -> None: ap = argparse.ArgumentParser() ap.add_argument( "--root", type=Path, default=Path("processed_dataset/genmo_features"), help="A single .pt file or a directory containing .pt files.", ) ap.add_argument("--glob", type=str, default="*.pt") ap.add_argument("--max", type=int, default=10, help="Max files to check") ap.add_argument("--frame0_only", action="store_true", help="Check only frame 0") args = ap.parse_args() paths = list(_iter_paths(args.root, args.glob)) if not paths: raise SystemExit(f"No files found under {args.root} with glob {args.glob!r}") paths = paths[: max(1, int(args.max))] all_stats = [] for p in paths: d = _load_pt(p) smpl_c = d["smpl_params_c"] smpl_w = d["smpl_params_w"] T_w2c = d["T_w2c"] if args.frame0_only: smpl_c = _maybe_slice(smpl_c, 0, 1) smpl_w = _maybe_slice(smpl_w, 0, 1) T_w2c = _maybe_slice(T_w2c, 0, 1) derived = _derive_world_from_c_and_T(smpl_c, T_w2c) diff_t = smpl_w["transl"].float() - derived["transl"] # (L, 3) dt = diff_t.norm(dim=-1) # (L,) diff_t_abs_mean = diff_t.abs().mean(dim=0) diff_t_abs_max = diff_t.abs().max(dim=0)[0] dgo = _axis_angle_angle_deg(smpl_w["global_orient"].float(), derived["global_orient"]) # (L,) # cam_angvel consistency cam_av_saved = d.get("cam_angvel", None) cam_av_deg = None if cam_av_saved is not None: cam_av_saved = cam_av_saved[: T_w2c.shape[0]] Tw = T_w2c.float() if Tw.ndim == 3 and Tw.shape[-2:] == (4, 4) and Tw.shape[0] >= 2: Tw = normalize_T_w2c(Tw) cam_av_re = compute_cam_angvel(Tw[:, :3, :3]) # (L,6) # matrix_to_rotation_6d convention: compare directly in 6D space cam_av_deg = (cam_av_saved.float() - cam_av_re.float()).abs().mean().item() stat = { "file": p.name, "max_transl_err_m": float(dt.max().item()), "mean_transl_err_m": float(dt.mean().item()), "mean_abs_dx": float(diff_t_abs_mean[0].item()), "mean_abs_dy": float(diff_t_abs_mean[1].item()), "mean_abs_dz": float(diff_t_abs_mean[2].item()), "max_abs_dx": float(diff_t_abs_max[0].item()), "max_abs_dy": float(diff_t_abs_max[1].item()), "max_abs_dz": float(diff_t_abs_max[2].item()), "max_go_err_deg": float(dgo.max().item()), "mean_go_err_deg": float(dgo.mean().item()), "cam_angvel_absmean_diff": None if cam_av_deg is None else float(cam_av_deg), } all_stats.append(stat) print( f"{p.name}: " f"transl_err mean={stat['mean_transl_err_m']:.4f}m max={stat['max_transl_err_m']:.4f}m | " f"|Δt| mean_abs(x,y,z)=({stat['mean_abs_dx']:.4f},{stat['mean_abs_dy']:.4f},{stat['mean_abs_dz']:.4f}) " f"max_abs(x,y,z)=({stat['max_abs_dx']:.4f},{stat['max_abs_dy']:.4f},{stat['max_abs_dz']:.4f}) | " f"go_err mean={stat['mean_go_err_deg']:.2f}° max={stat['max_go_err_deg']:.2f}°" + ( "" if stat["cam_angvel_absmean_diff"] is None else f" | cam_angvel_absmean_diff={stat['cam_angvel_absmean_diff']:.6f}" ) ) # Aggregate max_transl = max(s["max_transl_err_m"] for s in all_stats) max_go = max(s["max_go_err_deg"] for s in all_stats) print(f"\nAggregate over {len(all_stats)} files: max_transl_err_m={max_transl:.4f}, max_go_err_deg={max_go:.2f}") if __name__ == "__main__": main()