| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Verify that DROID demo data eef_9d uses the correct rotation convention. |
| |
| Computes eef_9d from raw cartesian_position two ways (with and without |
| DROID_EEF_ROTATION_CORRECT) and compares against the pretrained model's |
| normalization statistics to determine which convention matches. |
| |
| Usage: |
| python scripts/verify_droid_rotation_correction.py |
| python scripts/verify_droid_rotation_correction.py --dataset-path demo_data/droid_sample |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| from pathlib import Path |
|
|
| import numpy as np |
| from scipy.spatial.transform import Rotation |
|
|
|
|
| logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| DROID_EEF_ROTATION_CORRECT = np.array( |
| [[0, 0, -1], [-1, 0, 0], [0, 1, 0]], |
| dtype=np.float64, |
| ) |
|
|
| EMBODIMENT_TAG = "oxe_droid_relative_eef_relative_joint" |
|
|
|
|
| def _euler_to_eef_9d(cartesian_position: np.ndarray, *, apply_correction: bool) -> np.ndarray: |
| """Convert cartesian_position (XYZ + euler) to eef_9d (XYZ + rot6d).""" |
| cart = np.asarray(cartesian_position, dtype=np.float64) |
| xyz = cart[..., :3].reshape(-1, 3) |
| euler = cart[..., 3:].reshape(-1, 3) |
| rot = Rotation.from_euler("XYZ", euler).as_matrix() |
| if apply_correction: |
| rot = rot @ DROID_EEF_ROTATION_CORRECT |
| rot6d = rot[:, :2, :].reshape(-1, 6) |
| return np.concatenate([xyz, rot6d], axis=-1).astype(np.float32) |
|
|
|
|
| def _load_cartesian_positions(dataset_path: str) -> np.ndarray: |
| """Load observation.state.cartesian_position from all episode parquets.""" |
| import pandas as pd |
|
|
| all_cart = [] |
| for pq in sorted((Path(dataset_path) / "data").rglob("*.parquet")): |
| df = pd.read_parquet(pq) |
| if "observation.state.cartesian_position" in df.columns: |
| all_cart.append(np.stack(df["observation.state.cartesian_position"].values)) |
| if not all_cart: |
| raise RuntimeError("No cartesian_position found in any parquet file") |
| return np.concatenate(all_cart, axis=0) |
|
|
|
|
| def _download_eef_stats(hf_repo_id: str) -> dict | None: |
| """Download statistics.json and extract eef_9d stats for DROID.""" |
| try: |
| from huggingface_hub import hf_hub_download |
|
|
| path = hf_hub_download(repo_id=hf_repo_id, filename="statistics.json") |
| with open(path) as f: |
| stats = json.load(f) |
| for tag_key in [EMBODIMENT_TAG, "default"]: |
| eef = stats.get(tag_key, {}).get("state", {}).get("eef_9d") |
| if eef: |
| return eef |
| except Exception as e: |
| logger.warning(f"Could not download statistics from {hf_repo_id}: {e}") |
| return None |
|
|
|
|
| def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: |
| denom = np.linalg.norm(a) * np.linalg.norm(b) |
| return float(np.dot(a, b) / denom) if denom > 0 else 0.0 |
|
|
|
|
| def verify(dataset_path: str, hf_repo_id: str) -> bool: |
| """Run the verification. Returns True if with_correction is the better match.""" |
| logger.info(f"Loading cartesian_position from {dataset_path} ...") |
| cart = _load_cartesian_positions(dataset_path) |
| logger.info(f"Loaded {len(cart)} timesteps") |
|
|
| eef_no_corr = _euler_to_eef_9d(cart, apply_correction=False) |
| eef_with_corr = _euler_to_eef_9d(cart, apply_correction=True) |
|
|
| rot6d_diff = np.abs(eef_no_corr[:, 3:] - eef_with_corr[:, 3:]) |
| if rot6d_diff.max() < 1e-6: |
| logger.error("Correction matrix has no effect — euler angles may be degenerate") |
| return False |
|
|
| logger.info(f"\nComparing against model: {hf_repo_id}") |
| model_stats = _download_eef_stats(hf_repo_id) |
| if not model_stats: |
| logger.error(f"No eef_9d stats found for {hf_repo_id} — cannot verify") |
| return False |
|
|
| |
| model_mean = np.array(model_stats["mean"]) |
| cos_no = _cosine_similarity( |
| np.array([np.mean(eef_no_corr[:, i]) for i in range(3, 9)]), model_mean[3:9] |
| ) |
| cos_with = _cosine_similarity( |
| np.array([np.mean(eef_with_corr[:, i]) for i in range(3, 9)]), model_mean[3:9] |
| ) |
|
|
| |
| stat_fns = {"mean": np.mean, "std": np.std, "min": np.min, "max": np.max} |
| rmse_results: dict[str, tuple[float, float]] = {} |
| for stat_name, fn in stat_fns.items(): |
| if stat_name not in model_stats: |
| continue |
| model_rot = np.array(model_stats[stat_name])[3:9] |
| vals_no = np.array([fn(eef_no_corr[:, i]) for i in range(3, 9)]) |
| vals_with = np.array([fn(eef_with_corr[:, i]) for i in range(3, 9)]) |
| rmse_results[stat_name] = ( |
| float(np.sqrt(np.mean((vals_no - model_rot) ** 2))), |
| float(np.sqrt(np.mean((vals_with - model_rot) ** 2))), |
| ) |
|
|
| |
| logger.info("") |
| logger.info(" Cosine similarity of rot6d mean vs pretrained model:") |
| logger.info(f" no_correction: {cos_no:+.6f}") |
| logger.info(f" with_correction: {cos_with:+.6f}") |
| logger.info("") |
| logger.info(" RMSE of rot6d stats vs pretrained model (lower = better):") |
| logger.info(f" {'stat':>5} {'no_correction':>15} {'with_correction':>15} {'winner':>15}") |
| with_wins = 0 |
| for stat_name, (rmse_no, rmse_with) in rmse_results.items(): |
| winner = "with_correction" if rmse_with < rmse_no else "no_correction" |
| if rmse_with < rmse_no: |
| with_wins += 1 |
| logger.info(f" {stat_name:>5} {rmse_no:>15.6f} {rmse_with:>15.6f} {winner:>15}") |
|
|
| passed = cos_with > cos_no and with_wins >= len(rmse_results) // 2 |
| logger.info("") |
| if passed: |
| logger.info(" RESULT: PASS — with_correction matches the pretrained model better") |
| else: |
| logger.info(" RESULT: FAIL — no_correction appears closer (unexpected)") |
| return passed |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter |
| ) |
| parser.add_argument( |
| "--dataset-path", default="demo_data/droid_sample", help="Path to DROID demo dataset" |
| ) |
| parser.add_argument( |
| "--hf-repo-id", |
| default="nvidia/GR00T-N1.7-3B", |
| help="HuggingFace model repo to compare against", |
| ) |
| args = parser.parse_args() |
| passed = verify(args.dataset_path, args.hf_repo_id) |
| raise SystemExit(0 if passed else 1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|