File size: 6,643 Bytes
7e120dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | #!/usr/bin/env python3
"""
Diagnostic script to check Unity dataset coordinate consistency.
Checks:
- Rotation consistency between `smpl_params_c`, `smpl_params_w`, and `T_w2c`
- `cam_angvel` matches the convention used in preprocessing
- SMPL forward-kinematics consistency between camera/world parameters
Run with the gvhmr env python:
/root/miniconda3/envs/gvhmr/bin/python debug_unity_data.py
"""
import torch
import numpy as np
from pathlib import Path
from scipy.spatial.transform import Rotation as R
def axis_angle_to_matrix(aa):
"""Convert axis-angle to rotation matrix (numpy)."""
return R.from_rotvec(aa).as_matrix()
def check_single_sequence(pt_path):
"""Check a single .pt file for coordinate consistency."""
print(f"\n{'='*80}")
print(f"Checking: {pt_path.name}")
print(f"{'='*80}")
data = torch.load(pt_path, map_location="cpu")
# Extract key data
smpl_c = data["smpl_params_c"]
smpl_w = data["smpl_params_w"]
T_w2c = data["T_w2c"].numpy()
# Check first frame
idx = 0
print(f"\n[Frame {idx}]")
# Ground truth
go_c_gt = smpl_c["global_orient"][idx].numpy() # (3,) axis-angle
go_w_gt = smpl_w["global_orient"][idx].numpy() # (3,) axis-angle
# Convert to matrices
R_c_gt = axis_angle_to_matrix(go_c_gt) # Pelvis in camera frame
R_w_gt = axis_angle_to_matrix(go_w_gt) # Pelvis in world frame
R_w2c = T_w2c[idx, :3, :3] # World to camera
R_c2w = R_w2c.T # Camera to world
# Verify: R_w = R_c2w @ R_c
R_w_reconstructed = R_c2w @ R_c_gt
# Compare
R_diff = R_w_reconstructed @ R_w_gt.T
angle_err_deg = np.linalg.norm(R.from_matrix(R_diff).as_rotvec()) * 180.0 / np.pi
print(f"Ground truth global_orient_c (axis-angle): {go_c_gt}")
print(f"Ground truth global_orient_w (axis-angle): {go_w_gt}")
print(f"\nReconstruction test: R_w = R_c2w @ R_c")
print(f" Rotation error: {angle_err_deg:.4f}°")
if angle_err_deg > 1.0:
print(f" ❌ ERROR: Rotation mismatch > 1°!")
print(f" R_w (ground truth):\n{R_w_gt}")
print(f" R_w (reconstructed):\n{R_w_reconstructed}")
else:
print(f" ✅ OK: Rotations are consistent")
# Check cam_angvel computation (should match preprocess convention)
print(f"\n[Camera Angular Velocity Check]")
cam_ok = True
if "cam_angvel" in data:
cam_angvel = data["cam_angvel"] # (L, 6) - 6D rotation
print(f" cam_angvel shape: {cam_angvel.shape}")
print(f" cam_angvel[0]: {cam_angvel[0].numpy()}")
# Manually compute cam_angvel and compare.
# Convention (see `tools/demo/process_dataset.py:compute_velocity`):
# cam_angvel[0] = [1,0,0, 0,1,0] (identity, rotation6d)
# cam_angvel[i] = rot6d(R_i @ R_{i-1}^T)
from genmo.utils.rotation_conversions import matrix_to_rotation_6d
R_w2c_t = torch.from_numpy(T_w2c[:, :3, :3]).float()
L = int(R_w2c_t.shape[0])
cam_angvel_manual = torch.zeros((L, 6), dtype=torch.float32)
cam_angvel_manual[0] = cam_angvel_manual.new_tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0])
if L > 1:
R_diff_manual = R_w2c_t[1:] @ R_w2c_t[:-1].transpose(-1, -2)
cam_angvel_manual[1:] = matrix_to_rotation_6d(R_diff_manual)
diff = (cam_angvel - cam_angvel_manual).abs().max()
print(f" Manual vs stored cam_angvel max diff: {diff:.6f}")
if diff > 1e-4:
print(f" ❌ WARNING: cam_angvel mismatch!")
cam_ok = False
else:
print(f" ✅ OK: cam_angvel matches manual computation")
# Check SMPL forward kinematics consistency
print(f"\n[SMPL FK Check]")
fk_ok = True
try:
from third_party.GVHMR.hmr4d.utils.smplx_utils import make_smplx
smplx_model = make_smplx("supermotion").eval()
with torch.no_grad():
# Incam SMPL
out_c = smplx_model(
global_orient=smpl_c["global_orient"][idx:idx+1],
body_pose=smpl_c["body_pose"][idx:idx+1],
betas=smpl_c["betas"][idx:idx+1],
transl=smpl_c["transl"][idx:idx+1]
)
joints_c = out_c.joints[0, :22].numpy() # (22, 3)
# Global SMPL
out_w = smplx_model(
global_orient=smpl_w["global_orient"][idx:idx+1],
body_pose=smpl_w["body_pose"][idx:idx+1],
betas=smpl_w["betas"][idx:idx+1],
transl=smpl_w["transl"][idx:idx+1]
)
joints_w = out_w.joints[0, :22].numpy() # (22, 3)
# Transform camera->world using T_w2c (world->camera):
# x_c = R_w2c x_w + t_w2c
# => x_w = R_w2c^T (x_c - t_w2c)
t_w2c = T_w2c[idx, :3, 3]
joints_c2w = (R_c2w @ (joints_c - t_w2c).T).T
# Compare
joint_err = np.linalg.norm(joints_c2w - joints_w, axis=-1).mean()
print(f" Mean joint error (incam→world vs world GT): {joint_err:.4f}m")
if joint_err > 0.05:
print(f" ❌ ERROR: Joint mismatch > 5cm!")
fk_ok = False
else:
print(f" ✅ OK: SMPL joints are consistent")
except Exception as e:
print(f" ⚠️ Could not run SMPL FK check: {e}")
fk_ok = False
# Consider the clip consistent only if all checks are reasonable.
ok_rot = angle_err_deg < 1.0
return ok_rot and cam_ok and fk_ok
def main():
dataset_root = Path("./processed_dataset")
feat_dir = dataset_root / "genmo_features"
if not feat_dir.exists():
print(f"Error: {feat_dir} not found!")
return
pt_files = sorted(list(feat_dir.glob("*.pt")))
print(f"Found {len(pt_files)} sequences")
if len(pt_files) == 0:
print("No .pt files found!")
return
# Check first 3 sequences
num_check = min(3, len(pt_files))
all_ok = True
for i in range(num_check):
ok = check_single_sequence(pt_files[i])
all_ok = all_ok and ok
print(f"\n{'='*80}")
if all_ok:
print("✅ All checks passed! Data appears consistent.")
print("\nIf training still has high loss, the issue is likely:")
print(" 1. Model architecture/hyperparameters")
print(" 2. Normalization statistics mismatch")
print(" 3. Sequence length handling during training")
else:
print("❌ Data consistency issues found!")
print("\nThis explains the high training loss.")
print("You need to fix the coordinate system in process_dataset.py")
print(f"{'='*80}\n")
if __name__ == "__main__":
main()
|