File size: 7,097 Bytes
b88b79e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Verify that DROID demo data eef_9d uses the correct rotation convention.

Computes eef_9d from raw cartesian_position two ways (with and without
DROID_EEF_ROTATION_CORRECT) and compares against the pretrained model's
normalization statistics to determine which convention matches.

Usage:
  python scripts/verify_droid_rotation_correction.py
  python scripts/verify_droid_rotation_correction.py --dataset-path demo_data/droid_sample
"""

from __future__ import annotations

import argparse
import json
import logging
from pathlib import Path

import numpy as np
from scipy.spatial.transform import Rotation


logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)

DROID_EEF_ROTATION_CORRECT = np.array(
    [[0, 0, -1], [-1, 0, 0], [0, 1, 0]],
    dtype=np.float64,
)

EMBODIMENT_TAG = "oxe_droid_relative_eef_relative_joint"


def _euler_to_eef_9d(cartesian_position: np.ndarray, *, apply_correction: bool) -> np.ndarray:
    """Convert cartesian_position (XYZ + euler) to eef_9d (XYZ + rot6d)."""
    cart = np.asarray(cartesian_position, dtype=np.float64)
    xyz = cart[..., :3].reshape(-1, 3)
    euler = cart[..., 3:].reshape(-1, 3)
    rot = Rotation.from_euler("XYZ", euler).as_matrix()
    if apply_correction:
        rot = rot @ DROID_EEF_ROTATION_CORRECT
    rot6d = rot[:, :2, :].reshape(-1, 6)
    return np.concatenate([xyz, rot6d], axis=-1).astype(np.float32)


def _load_cartesian_positions(dataset_path: str) -> np.ndarray:
    """Load observation.state.cartesian_position from all episode parquets."""
    import pandas as pd

    all_cart = []
    for pq in sorted((Path(dataset_path) / "data").rglob("*.parquet")):
        df = pd.read_parquet(pq)
        if "observation.state.cartesian_position" in df.columns:
            all_cart.append(np.stack(df["observation.state.cartesian_position"].values))
    if not all_cart:
        raise RuntimeError("No cartesian_position found in any parquet file")
    return np.concatenate(all_cart, axis=0)


def _download_eef_stats(hf_repo_id: str) -> dict | None:
    """Download statistics.json and extract eef_9d stats for DROID."""
    try:
        from huggingface_hub import hf_hub_download

        path = hf_hub_download(repo_id=hf_repo_id, filename="statistics.json")
        with open(path) as f:
            stats = json.load(f)
        for tag_key in [EMBODIMENT_TAG, "default"]:
            eef = stats.get(tag_key, {}).get("state", {}).get("eef_9d")
            if eef:
                return eef
    except Exception as e:
        logger.warning(f"Could not download statistics from {hf_repo_id}: {e}")
    return None


def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
    denom = np.linalg.norm(a) * np.linalg.norm(b)
    return float(np.dot(a, b) / denom) if denom > 0 else 0.0


def verify(dataset_path: str, hf_repo_id: str) -> bool:
    """Run the verification. Returns True if with_correction is the better match."""
    logger.info(f"Loading cartesian_position from {dataset_path} ...")
    cart = _load_cartesian_positions(dataset_path)
    logger.info(f"Loaded {len(cart)} timesteps")

    eef_no_corr = _euler_to_eef_9d(cart, apply_correction=False)
    eef_with_corr = _euler_to_eef_9d(cart, apply_correction=True)

    rot6d_diff = np.abs(eef_no_corr[:, 3:] - eef_with_corr[:, 3:])
    if rot6d_diff.max() < 1e-6:
        logger.error("Correction matrix has no effect — euler angles may be degenerate")
        return False

    logger.info(f"\nComparing against model: {hf_repo_id}")
    model_stats = _download_eef_stats(hf_repo_id)
    if not model_stats:
        logger.error(f"No eef_9d stats found for {hf_repo_id} — cannot verify")
        return False

    # --- Cosine similarity of rot6d mean ---
    model_mean = np.array(model_stats["mean"])
    cos_no = _cosine_similarity(
        np.array([np.mean(eef_no_corr[:, i]) for i in range(3, 9)]), model_mean[3:9]
    )
    cos_with = _cosine_similarity(
        np.array([np.mean(eef_with_corr[:, i]) for i in range(3, 9)]), model_mean[3:9]
    )

    # --- Per-stat RMSE (rot6d dims only) ---
    stat_fns = {"mean": np.mean, "std": np.std, "min": np.min, "max": np.max}
    rmse_results: dict[str, tuple[float, float]] = {}
    for stat_name, fn in stat_fns.items():
        if stat_name not in model_stats:
            continue
        model_rot = np.array(model_stats[stat_name])[3:9]
        vals_no = np.array([fn(eef_no_corr[:, i]) for i in range(3, 9)])
        vals_with = np.array([fn(eef_with_corr[:, i]) for i in range(3, 9)])
        rmse_results[stat_name] = (
            float(np.sqrt(np.mean((vals_no - model_rot) ** 2))),
            float(np.sqrt(np.mean((vals_with - model_rot) ** 2))),
        )

    # --- Print results ---
    logger.info("")
    logger.info("  Cosine similarity of rot6d mean vs pretrained model:")
    logger.info(f"    no_correction:   {cos_no:+.6f}")
    logger.info(f"    with_correction: {cos_with:+.6f}")
    logger.info("")
    logger.info("  RMSE of rot6d stats vs pretrained model (lower = better):")
    logger.info(f"    {'stat':>5}  {'no_correction':>15}  {'with_correction':>15}  {'winner':>15}")
    with_wins = 0
    for stat_name, (rmse_no, rmse_with) in rmse_results.items():
        winner = "with_correction" if rmse_with < rmse_no else "no_correction"
        if rmse_with < rmse_no:
            with_wins += 1
        logger.info(f"    {stat_name:>5}  {rmse_no:>15.6f}  {rmse_with:>15.6f}  {winner:>15}")

    passed = cos_with > cos_no and with_wins >= len(rmse_results) // 2
    logger.info("")
    if passed:
        logger.info("  RESULT: PASS — with_correction matches the pretrained model better")
    else:
        logger.info("  RESULT: FAIL — no_correction appears closer (unexpected)")
    return passed


def main():
    parser = argparse.ArgumentParser(
        description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
    )
    parser.add_argument(
        "--dataset-path", default="demo_data/droid_sample", help="Path to DROID demo dataset"
    )
    parser.add_argument(
        "--hf-repo-id",
        default="nvidia/GR00T-N1.7-3B",
        help="HuggingFace model repo to compare against",
    )
    args = parser.parse_args()
    passed = verify(args.dataset_path, args.hf_repo_id)
    raise SystemExit(0 if passed else 1)


if __name__ == "__main__":
    main()