GR00T / scripts /verify_droid_rotation_correction.py
yqi19's picture
add: source files (batch 4)
b88b79e verified
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Verify that DROID demo data eef_9d uses the correct rotation convention.
Computes eef_9d from raw cartesian_position two ways (with and without
DROID_EEF_ROTATION_CORRECT) and compares against the pretrained model's
normalization statistics to determine which convention matches.
Usage:
python scripts/verify_droid_rotation_correction.py
python scripts/verify_droid_rotation_correction.py --dataset-path demo_data/droid_sample
"""
from __future__ import annotations
import argparse
import json
import logging
from pathlib import Path
import numpy as np
from scipy.spatial.transform import Rotation
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
DROID_EEF_ROTATION_CORRECT = np.array(
[[0, 0, -1], [-1, 0, 0], [0, 1, 0]],
dtype=np.float64,
)
EMBODIMENT_TAG = "oxe_droid_relative_eef_relative_joint"
def _euler_to_eef_9d(cartesian_position: np.ndarray, *, apply_correction: bool) -> np.ndarray:
"""Convert cartesian_position (XYZ + euler) to eef_9d (XYZ + rot6d)."""
cart = np.asarray(cartesian_position, dtype=np.float64)
xyz = cart[..., :3].reshape(-1, 3)
euler = cart[..., 3:].reshape(-1, 3)
rot = Rotation.from_euler("XYZ", euler).as_matrix()
if apply_correction:
rot = rot @ DROID_EEF_ROTATION_CORRECT
rot6d = rot[:, :2, :].reshape(-1, 6)
return np.concatenate([xyz, rot6d], axis=-1).astype(np.float32)
def _load_cartesian_positions(dataset_path: str) -> np.ndarray:
"""Load observation.state.cartesian_position from all episode parquets."""
import pandas as pd
all_cart = []
for pq in sorted((Path(dataset_path) / "data").rglob("*.parquet")):
df = pd.read_parquet(pq)
if "observation.state.cartesian_position" in df.columns:
all_cart.append(np.stack(df["observation.state.cartesian_position"].values))
if not all_cart:
raise RuntimeError("No cartesian_position found in any parquet file")
return np.concatenate(all_cart, axis=0)
def _download_eef_stats(hf_repo_id: str) -> dict | None:
"""Download statistics.json and extract eef_9d stats for DROID."""
try:
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id=hf_repo_id, filename="statistics.json")
with open(path) as f:
stats = json.load(f)
for tag_key in [EMBODIMENT_TAG, "default"]:
eef = stats.get(tag_key, {}).get("state", {}).get("eef_9d")
if eef:
return eef
except Exception as e:
logger.warning(f"Could not download statistics from {hf_repo_id}: {e}")
return None
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
denom = np.linalg.norm(a) * np.linalg.norm(b)
return float(np.dot(a, b) / denom) if denom > 0 else 0.0
def verify(dataset_path: str, hf_repo_id: str) -> bool:
"""Run the verification. Returns True if with_correction is the better match."""
logger.info(f"Loading cartesian_position from {dataset_path} ...")
cart = _load_cartesian_positions(dataset_path)
logger.info(f"Loaded {len(cart)} timesteps")
eef_no_corr = _euler_to_eef_9d(cart, apply_correction=False)
eef_with_corr = _euler_to_eef_9d(cart, apply_correction=True)
rot6d_diff = np.abs(eef_no_corr[:, 3:] - eef_with_corr[:, 3:])
if rot6d_diff.max() < 1e-6:
logger.error("Correction matrix has no effect — euler angles may be degenerate")
return False
logger.info(f"\nComparing against model: {hf_repo_id}")
model_stats = _download_eef_stats(hf_repo_id)
if not model_stats:
logger.error(f"No eef_9d stats found for {hf_repo_id} — cannot verify")
return False
# --- Cosine similarity of rot6d mean ---
model_mean = np.array(model_stats["mean"])
cos_no = _cosine_similarity(
np.array([np.mean(eef_no_corr[:, i]) for i in range(3, 9)]), model_mean[3:9]
)
cos_with = _cosine_similarity(
np.array([np.mean(eef_with_corr[:, i]) for i in range(3, 9)]), model_mean[3:9]
)
# --- Per-stat RMSE (rot6d dims only) ---
stat_fns = {"mean": np.mean, "std": np.std, "min": np.min, "max": np.max}
rmse_results: dict[str, tuple[float, float]] = {}
for stat_name, fn in stat_fns.items():
if stat_name not in model_stats:
continue
model_rot = np.array(model_stats[stat_name])[3:9]
vals_no = np.array([fn(eef_no_corr[:, i]) for i in range(3, 9)])
vals_with = np.array([fn(eef_with_corr[:, i]) for i in range(3, 9)])
rmse_results[stat_name] = (
float(np.sqrt(np.mean((vals_no - model_rot) ** 2))),
float(np.sqrt(np.mean((vals_with - model_rot) ** 2))),
)
# --- Print results ---
logger.info("")
logger.info(" Cosine similarity of rot6d mean vs pretrained model:")
logger.info(f" no_correction: {cos_no:+.6f}")
logger.info(f" with_correction: {cos_with:+.6f}")
logger.info("")
logger.info(" RMSE of rot6d stats vs pretrained model (lower = better):")
logger.info(f" {'stat':>5} {'no_correction':>15} {'with_correction':>15} {'winner':>15}")
with_wins = 0
for stat_name, (rmse_no, rmse_with) in rmse_results.items():
winner = "with_correction" if rmse_with < rmse_no else "no_correction"
if rmse_with < rmse_no:
with_wins += 1
logger.info(f" {stat_name:>5} {rmse_no:>15.6f} {rmse_with:>15.6f} {winner:>15}")
passed = cos_with > cos_no and with_wins >= len(rmse_results) // 2
logger.info("")
if passed:
logger.info(" RESULT: PASS — with_correction matches the pretrained model better")
else:
logger.info(" RESULT: FAIL — no_correction appears closer (unexpected)")
return passed
def main():
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument(
"--dataset-path", default="demo_data/droid_sample", help="Path to DROID demo dataset"
)
parser.add_argument(
"--hf-repo-id",
default="nvidia/GR00T-N1.7-3B",
help="HuggingFace model repo to compare against",
)
args = parser.parse_args()
passed = verify(args.dataset_path, args.hf_repo_id)
raise SystemExit(0 if passed else 1)
if __name__ == "__main__":
main()