File size: 6,643 Bytes
7e120dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
#!/usr/bin/env python3
"""
Diagnostic script to check Unity dataset coordinate consistency.

Checks:
- Rotation consistency between `smpl_params_c`, `smpl_params_w`, and `T_w2c`
- `cam_angvel` matches the convention used in preprocessing
- SMPL forward-kinematics consistency between camera/world parameters

Run with the gvhmr env python:
  /root/miniconda3/envs/gvhmr/bin/python debug_unity_data.py
"""
import torch
import numpy as np
from pathlib import Path
from scipy.spatial.transform import Rotation as R

def axis_angle_to_matrix(aa):
    """Convert axis-angle to rotation matrix (numpy)."""
    return R.from_rotvec(aa).as_matrix()

def check_single_sequence(pt_path):
    """Check a single .pt file for coordinate consistency."""
    print(f"\n{'='*80}")
    print(f"Checking: {pt_path.name}")
    print(f"{'='*80}")

    data = torch.load(pt_path, map_location="cpu")

    # Extract key data
    smpl_c = data["smpl_params_c"]
    smpl_w = data["smpl_params_w"]
    T_w2c = data["T_w2c"].numpy()

    # Check first frame
    idx = 0
    print(f"\n[Frame {idx}]")

    # Ground truth
    go_c_gt = smpl_c["global_orient"][idx].numpy()  # (3,) axis-angle
    go_w_gt = smpl_w["global_orient"][idx].numpy()  # (3,) axis-angle

    # Convert to matrices
    R_c_gt = axis_angle_to_matrix(go_c_gt)  # Pelvis in camera frame
    R_w_gt = axis_angle_to_matrix(go_w_gt)  # Pelvis in world frame
    R_w2c = T_w2c[idx, :3, :3]  # World to camera
    R_c2w = R_w2c.T  # Camera to world

    # Verify: R_w = R_c2w @ R_c
    R_w_reconstructed = R_c2w @ R_c_gt

    # Compare
    R_diff = R_w_reconstructed @ R_w_gt.T
    angle_err_deg = np.linalg.norm(R.from_matrix(R_diff).as_rotvec()) * 180.0 / np.pi

    print(f"Ground truth global_orient_c (axis-angle): {go_c_gt}")
    print(f"Ground truth global_orient_w (axis-angle): {go_w_gt}")
    print(f"\nReconstruction test: R_w = R_c2w @ R_c")
    print(f"  Rotation error: {angle_err_deg:.4f}°")

    if angle_err_deg > 1.0:
        print(f"  ❌ ERROR: Rotation mismatch > 1°!")
        print(f"  R_w (ground truth):\n{R_w_gt}")
        print(f"  R_w (reconstructed):\n{R_w_reconstructed}")
    else:
        print(f"  ✅ OK: Rotations are consistent")

    # Check cam_angvel computation (should match preprocess convention)
    print(f"\n[Camera Angular Velocity Check]")
    cam_ok = True
    if "cam_angvel" in data:
        cam_angvel = data["cam_angvel"]  # (L, 6) - 6D rotation
        print(f"  cam_angvel shape: {cam_angvel.shape}")
        print(f"  cam_angvel[0]: {cam_angvel[0].numpy()}")

        # Manually compute cam_angvel and compare.
        # Convention (see `tools/demo/process_dataset.py:compute_velocity`):
        #   cam_angvel[0] = [1,0,0, 0,1,0]  (identity, rotation6d)
        #   cam_angvel[i] = rot6d(R_i @ R_{i-1}^T)
        from genmo.utils.rotation_conversions import matrix_to_rotation_6d
        R_w2c_t = torch.from_numpy(T_w2c[:, :3, :3]).float()
        L = int(R_w2c_t.shape[0])
        cam_angvel_manual = torch.zeros((L, 6), dtype=torch.float32)
        cam_angvel_manual[0] = cam_angvel_manual.new_tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0])
        if L > 1:
            R_diff_manual = R_w2c_t[1:] @ R_w2c_t[:-1].transpose(-1, -2)
            cam_angvel_manual[1:] = matrix_to_rotation_6d(R_diff_manual)

        diff = (cam_angvel - cam_angvel_manual).abs().max()
        print(f"  Manual vs stored cam_angvel max diff: {diff:.6f}")
        if diff > 1e-4:
            print(f"  ❌ WARNING: cam_angvel mismatch!")
            cam_ok = False
        else:
            print(f"  ✅ OK: cam_angvel matches manual computation")

    # Check SMPL forward kinematics consistency
    print(f"\n[SMPL FK Check]")
    fk_ok = True
    try:
        from third_party.GVHMR.hmr4d.utils.smplx_utils import make_smplx
        smplx_model = make_smplx("supermotion").eval()

        with torch.no_grad():
            # Incam SMPL
            out_c = smplx_model(
                global_orient=smpl_c["global_orient"][idx:idx+1],
                body_pose=smpl_c["body_pose"][idx:idx+1],
                betas=smpl_c["betas"][idx:idx+1],
                transl=smpl_c["transl"][idx:idx+1]
            )
            joints_c = out_c.joints[0, :22].numpy()  # (22, 3)

            # Global SMPL
            out_w = smplx_model(
                global_orient=smpl_w["global_orient"][idx:idx+1],
                body_pose=smpl_w["body_pose"][idx:idx+1],
                betas=smpl_w["betas"][idx:idx+1],
                transl=smpl_w["transl"][idx:idx+1]
            )
            joints_w = out_w.joints[0, :22].numpy()  # (22, 3)

        # Transform camera->world using T_w2c (world->camera):
        #   x_c = R_w2c x_w + t_w2c
        # => x_w = R_w2c^T (x_c - t_w2c)
        t_w2c = T_w2c[idx, :3, 3]
        joints_c2w = (R_c2w @ (joints_c - t_w2c).T).T

        # Compare
        joint_err = np.linalg.norm(joints_c2w - joints_w, axis=-1).mean()
        print(f"  Mean joint error (incam→world vs world GT): {joint_err:.4f}m")

        if joint_err > 0.05:
            print(f"  ❌ ERROR: Joint mismatch > 5cm!")
            fk_ok = False
        else:
            print(f"  ✅ OK: SMPL joints are consistent")

    except Exception as e:
        print(f"  ⚠️  Could not run SMPL FK check: {e}")
        fk_ok = False

    # Consider the clip consistent only if all checks are reasonable.
    ok_rot = angle_err_deg < 1.0
    return ok_rot and cam_ok and fk_ok

def main():
    dataset_root = Path("./processed_dataset")
    feat_dir = dataset_root / "genmo_features"

    if not feat_dir.exists():
        print(f"Error: {feat_dir} not found!")
        return

    pt_files = sorted(list(feat_dir.glob("*.pt")))
    print(f"Found {len(pt_files)} sequences")

    if len(pt_files) == 0:
        print("No .pt files found!")
        return

    # Check first 3 sequences
    num_check = min(3, len(pt_files))
    all_ok = True

    for i in range(num_check):
        ok = check_single_sequence(pt_files[i])
        all_ok = all_ok and ok

    print(f"\n{'='*80}")
    if all_ok:
        print("✅ All checks passed! Data appears consistent.")
        print("\nIf training still has high loss, the issue is likely:")
        print("  1. Model architecture/hyperparameters")
        print("  2. Normalization statistics mismatch")
        print("  3. Sequence length handling during training")
    else:
        print("❌ Data consistency issues found!")
        print("\nThis explains the high training loss.")
        print("You need to fix the coordinate system in process_dataset.py")
    print(f"{'='*80}\n")

if __name__ == "__main__":
    main()