add: source files (batch 1)
Browse files- .github/ISSUE_TEMPLATE/bug_report.yml +46 -0
- .github/ISSUE_TEMPLATE/documentation.yml +21 -0
- .github/ISSUE_TEMPLATE/feature_request.yml +26 -0
- .github/actions/setup-venv/action.yml +55 -0
- .github/pull_request_template.md +17 -0
- .github/workflows/main.yml +68 -0
- AGENTS.md +80 -0
- ATTRIBUTIONS.md +0 -0
- CLAUDE.md +80 -0
- CONTRIBUTING.md +7 -0
- FAQ.md +81 -0
- README.md +555 -0
- examples/DROID/README.md +110 -0
- examples/DROID/main_gr00t.py +469 -0
- examples/DROID/server_client.py +365 -0
- examples/DROID/utils.py +81 -0
- examples/LIBERO/README.md +196 -0
- examples/LIBERO/modality.json +75 -0
- examples/SO100/README.md +87 -0
- examples/SO100/modality.json +35 -0
- examples/SO100/so100_config.py +70 -0
- examples/SimplerEnv/README.md +141 -0
- examples/SimplerEnv/bridge_modality.json +77 -0
- examples/SimplerEnv/convert_av1_to_h264.py +129 -0
- examples/SimplerEnv/fractal_modality.json +77 -0
- examples/finetune.sh +158 -0
- examples/mask-guided-background-suppression/README.md +203 -0
- examples/mask-guided-background-suppression/so101_config.py +62 -0
- examples/mask-guided-background-suppression/test_extra_augmentation.py +198 -0
- getting_started/data_config.md +331 -0
- getting_started/data_preparation.md +164 -0
- getting_started/finetune_new_embodiment.md +153 -0
- getting_started/hardware_recommendation.md +95 -0
- getting_started/policy.md +574 -0
- getting_started/real_world_deployment.md +459 -0
- gr00t/__init__.py +129 -0
- gr00t/configs/__init__.py +14 -0
- gr00t/configs/base_config.py +150 -0
- gr00t/configs/data/__init__.py +14 -0
- gr00t/configs/data/data_config.py +95 -0
- gr00t/configs/data/embodiment_configs.py +208 -0
- gr00t/configs/deepspeed/zero2_config.json +33 -0
- gr00t/configs/deepspeed/zero3_config.json +31 -0
- gr00t/configs/finetune_config.py +163 -0
- gr00t/configs/model/__init__.py +52 -0
- gr00t/configs/model/gr00t_n1d7.py +179 -0
- gr00t/configs/training/__init__.py +14 -0
- gr00t/configs/training/training_config.py +127 -0
- gr00t/data/__init__.py +14 -0
- gr00t/data/collator/__init__.py +16 -0
.github/ISSUE_TEMPLATE/bug_report.yml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: 🐛 Bug Report
|
| 2 |
+
description: Create a report to help us reproduce and fix the bug
|
| 3 |
+
labels: 'bug'
|
| 4 |
+
|
| 5 |
+
body:
|
| 6 |
+
- type: markdown
|
| 7 |
+
attributes:
|
| 8 |
+
value: >
|
| 9 |
+
#### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/Isaac-GR00T/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
| 10 |
+
- type: textarea
|
| 11 |
+
attributes:
|
| 12 |
+
label: 🐛 Describe the bug
|
| 13 |
+
description: |
|
| 14 |
+
Please provide a clear and concise description of what the bug is.
|
| 15 |
+
|
| 16 |
+
If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example:
|
| 17 |
+
|
| 18 |
+
```python
|
| 19 |
+
# All necessary imports at the beginning
|
| 20 |
+
import gr00t
|
| 21 |
+
|
| 22 |
+
# A succinct reproducing example trimmed down to the essential parts:
|
| 23 |
+
assert False is True, "Oh no!"
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
|
| 27 |
+
|
| 28 |
+
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
|
| 29 |
+
placeholder: |
|
| 30 |
+
A clear and concise description of what the bug is.
|
| 31 |
+
validations:
|
| 32 |
+
required: true
|
| 33 |
+
- type: textarea
|
| 34 |
+
attributes:
|
| 35 |
+
label: Versions
|
| 36 |
+
description: |
|
| 37 |
+
Please run the following and paste the output below.
|
| 38 |
+
```sh
|
| 39 |
+
python --version && pip freeze
|
| 40 |
+
```
|
| 41 |
+
validations:
|
| 42 |
+
required: true
|
| 43 |
+
- type: markdown
|
| 44 |
+
attributes:
|
| 45 |
+
value: >
|
| 46 |
+
Thanks for contributing 🎉!
|
.github/ISSUE_TEMPLATE/documentation.yml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: 📚 Documentation
|
| 2 |
+
description: Report an issue related to https://github.com/NVIDIA/Isaac-GR00T
|
| 3 |
+
labels: 'documentation'
|
| 4 |
+
|
| 5 |
+
body:
|
| 6 |
+
- type: textarea
|
| 7 |
+
attributes:
|
| 8 |
+
label: 📚 The doc issue
|
| 9 |
+
description: >
|
| 10 |
+
A clear and concise description of what content in https://github.com/NVIDIA/Isaac-GR00T is an issue.
|
| 11 |
+
validations:
|
| 12 |
+
required: true
|
| 13 |
+
- type: textarea
|
| 14 |
+
attributes:
|
| 15 |
+
label: Suggest a potential alternative/fix
|
| 16 |
+
description: >
|
| 17 |
+
Tell us how we could improve the documentation in this regard.
|
| 18 |
+
- type: markdown
|
| 19 |
+
attributes:
|
| 20 |
+
value: >
|
| 21 |
+
Thanks for contributing 🎉!
|
.github/ISSUE_TEMPLATE/feature_request.yml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: 🚀 Feature request
|
| 2 |
+
description: Submit a proposal/request for a new feature
|
| 3 |
+
labels: 'feature request'
|
| 4 |
+
|
| 5 |
+
body:
|
| 6 |
+
- type: textarea
|
| 7 |
+
attributes:
|
| 8 |
+
label: 🚀 The feature, motivation and pitch
|
| 9 |
+
description: >
|
| 10 |
+
A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too.
|
| 11 |
+
validations:
|
| 12 |
+
required: true
|
| 13 |
+
- type: textarea
|
| 14 |
+
attributes:
|
| 15 |
+
label: Alternatives
|
| 16 |
+
description: >
|
| 17 |
+
A description of any alternative solutions or features you've considered, if any.
|
| 18 |
+
- type: textarea
|
| 19 |
+
attributes:
|
| 20 |
+
label: Additional context
|
| 21 |
+
description: >
|
| 22 |
+
Add any other context or screenshots about the feature request.
|
| 23 |
+
- type: markdown
|
| 24 |
+
attributes:
|
| 25 |
+
value: >
|
| 26 |
+
Thanks for contributing 🎉!
|
.github/actions/setup-venv/action.yml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Python virtualenv
|
| 2 |
+
description: Set up a Python virtual environment with caching
|
| 3 |
+
inputs:
|
| 4 |
+
python-version:
|
| 5 |
+
description: The Python version to use
|
| 6 |
+
required: true
|
| 7 |
+
cache-prefix:
|
| 8 |
+
description: Update this to invalidate the cache
|
| 9 |
+
required: true
|
| 10 |
+
default: v4
|
| 11 |
+
runs:
|
| 12 |
+
using: composite
|
| 13 |
+
steps:
|
| 14 |
+
- name: Setup Python
|
| 15 |
+
uses: actions/setup-python@v4
|
| 16 |
+
with:
|
| 17 |
+
python-version: ${{ inputs.python-version }}
|
| 18 |
+
|
| 19 |
+
- shell: bash
|
| 20 |
+
run: |
|
| 21 |
+
# Install prerequisites.
|
| 22 |
+
pip install --upgrade pip setuptools wheel virtualenv
|
| 23 |
+
|
| 24 |
+
- shell: bash
|
| 25 |
+
run: |
|
| 26 |
+
# Get the exact Python version to use in the cache key.
|
| 27 |
+
echo "PYTHON_VERSION=$(python --version)" >> $GITHUB_ENV
|
| 28 |
+
|
| 29 |
+
- uses: actions/cache@v4
|
| 30 |
+
id: virtualenv-cache
|
| 31 |
+
with:
|
| 32 |
+
path: .venv
|
| 33 |
+
key: ${{ inputs.cache-prefix }}-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('pyproject.toml') }}
|
| 34 |
+
|
| 35 |
+
- if: steps.virtualenv-cache.outputs.cache-hit != 'true'
|
| 36 |
+
shell: bash
|
| 37 |
+
run: |
|
| 38 |
+
# Set up virtual environment without cache hit.
|
| 39 |
+
test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv
|
| 40 |
+
. .venv/bin/activate
|
| 41 |
+
pip install ruff
|
| 42 |
+
|
| 43 |
+
- if: steps.virtualenv-cache.outputs.cache-hit == 'true'
|
| 44 |
+
shell: bash
|
| 45 |
+
run: |
|
| 46 |
+
# Set up virtual environment from cache hit.
|
| 47 |
+
. .venv/bin/activate
|
| 48 |
+
|
| 49 |
+
- shell: bash
|
| 50 |
+
run: |
|
| 51 |
+
# Show environment info.
|
| 52 |
+
. .venv/bin/activate
|
| 53 |
+
echo "✓ Installed $(python --version) virtual environment to $(which python)"
|
| 54 |
+
echo "Packages:"
|
| 55 |
+
pip freeze
|
.github/pull_request_template.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- To ensure we can review your pull request promptly please complete this template entirely. -->
|
| 2 |
+
|
| 3 |
+
<!-- Please reference the issue number here. You can replace "Fixes" with "Closes" if it makes more sense. -->
|
| 4 |
+
Fixes #
|
| 5 |
+
|
| 6 |
+
Changes proposed in this pull request:
|
| 7 |
+
<!-- Please list all changes/additions here. -->
|
| 8 |
+
-
|
| 9 |
+
|
| 10 |
+
## Before submitting
|
| 11 |
+
|
| 12 |
+
<!-- Please complete this checklist BEFORE submitting your PR to speed along the review process. -->
|
| 13 |
+
- [ ] I've read and followed all steps in the [Making a pull request](https://github.com/NVIDIA/Isaac-GR00T/blob/main/CONTRIBUTING.md#making-a-pull-request)
|
| 14 |
+
section of the `CONTRIBUTING` docs.
|
| 15 |
+
- [ ] I've updated or added any relevant docstrings.
|
| 16 |
+
- [ ] If this PR fixes a bug, I've added a test that will fail without my fix.
|
| 17 |
+
- [ ] If this PR adds a new feature, I've added tests that sufficiently cover my new functionality.
|
.github/workflows/main.yml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Main
|
| 2 |
+
|
| 3 |
+
concurrency:
|
| 4 |
+
group: ${{ github.workflow }}-${{ github.ref }}
|
| 5 |
+
cancel-in-progress: true
|
| 6 |
+
|
| 7 |
+
on:
|
| 8 |
+
pull_request:
|
| 9 |
+
branches:
|
| 10 |
+
- main
|
| 11 |
+
push:
|
| 12 |
+
branches:
|
| 13 |
+
- main
|
| 14 |
+
tags:
|
| 15 |
+
- "v*.*.*"
|
| 16 |
+
|
| 17 |
+
env:
|
| 18 |
+
# Change this to invalidate existing cache.
|
| 19 |
+
CACHE_PREFIX: v4
|
| 20 |
+
PYTHONPATH: ./
|
| 21 |
+
|
| 22 |
+
jobs:
|
| 23 |
+
checks:
|
| 24 |
+
name: Python ${{ matrix.python }} - ${{ matrix.task.name }}
|
| 25 |
+
runs-on: [ubuntu-latest]
|
| 26 |
+
timeout-minutes: 15
|
| 27 |
+
strategy:
|
| 28 |
+
fail-fast: false
|
| 29 |
+
matrix:
|
| 30 |
+
include:
|
| 31 |
+
- python: "3.10"
|
| 32 |
+
task:
|
| 33 |
+
name: Lint
|
| 34 |
+
run: |
|
| 35 |
+
ruff check .
|
| 36 |
+
ruff format --check .
|
| 37 |
+
|
| 38 |
+
steps:
|
| 39 |
+
- uses: actions/checkout@v4
|
| 40 |
+
with:
|
| 41 |
+
lfs: true
|
| 42 |
+
|
| 43 |
+
- name: Pull LFS objects
|
| 44 |
+
run: git lfs pull
|
| 45 |
+
|
| 46 |
+
- name: Setup Python environment
|
| 47 |
+
uses: ./.github/actions/setup-venv
|
| 48 |
+
with:
|
| 49 |
+
python-version: ${{ matrix.python }}
|
| 50 |
+
cache-prefix: ${{ env.CACHE_PREFIX }}
|
| 51 |
+
|
| 52 |
+
- name: ${{ matrix.task.name }}
|
| 53 |
+
run: |
|
| 54 |
+
. .venv/bin/activate
|
| 55 |
+
${{ matrix.task.run }}
|
| 56 |
+
|
| 57 |
+
- name: Upload package distribution files
|
| 58 |
+
if: matrix.task.name == 'Build'
|
| 59 |
+
uses: actions/upload-artifact@v4
|
| 60 |
+
with:
|
| 61 |
+
name: package
|
| 62 |
+
path: dist
|
| 63 |
+
|
| 64 |
+
- name: Clean up
|
| 65 |
+
if: always()
|
| 66 |
+
run: |
|
| 67 |
+
. .venv/bin/activate
|
| 68 |
+
pip uninstall -y gr00t
|
AGENTS.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLAUDE.md — Isaac GR00T N1.7
|
| 2 |
+
|
| 3 |
+
## Project overview
|
| 4 |
+
|
| 5 |
+
Isaac GR00T N1.7 is an open vision-language-action (VLA) model for generalized humanoid robot skills.
|
| 6 |
+
The repo contains the model, training pipeline, evaluation harness, and deployment tooling.
|
| 7 |
+
|
| 8 |
+
- **Language:** Python 3.10 (dGPU, Orin); Python 3.12 (Thor, DGX Spark — see deployment dir)
|
| 9 |
+
- **Package manager:** [uv](https://docs.astral.sh/uv/)
|
| 10 |
+
- **Build system:** setuptools (see `pyproject.toml`)
|
| 11 |
+
- **CI:** internal GitLab CI (`.gitlab-ci.yml` + includes under `ci/`, not shipped to the public GitHub EA repo); public GitHub Actions (`.github/workflows/`)
|
| 12 |
+
|
| 13 |
+
## Quick-start commands
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
# Install (dev mode with all extras)
|
| 17 |
+
uv sync --all-extras
|
| 18 |
+
|
| 19 |
+
# Lint and format (uses ruff via pre-commit)
|
| 20 |
+
pre-commit run --all-files
|
| 21 |
+
|
| 22 |
+
# Run CPU tests
|
| 23 |
+
python -m pytest tests/ -m "not gpu" -v --timeout=300
|
| 24 |
+
|
| 25 |
+
# Run GPU tests
|
| 26 |
+
python -m pytest tests/ -m gpu -v --timeout=300
|
| 27 |
+
|
| 28 |
+
# Build package
|
| 29 |
+
uv build
|
| 30 |
+
|
| 31 |
+
# Validate lockfile
|
| 32 |
+
uv lock --locked
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Code style
|
| 36 |
+
|
| 37 |
+
- Formatter: `ruff format` (double quotes, spaces, line-length 100)
|
| 38 |
+
- Linter: `ruff check` with rules E, F, I (ignores E501)
|
| 39 |
+
- Config lives in `pyproject.toml` under `[tool.ruff]`
|
| 40 |
+
- Run `pre-commit run --all-files` before committing
|
| 41 |
+
|
| 42 |
+
## Directory layout
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
gr00t/ # Main package
|
| 46 |
+
configs/ # Training, data, and model configs
|
| 47 |
+
data/ # Data loading, embodiment tags, dataset processing
|
| 48 |
+
eval/ # Evaluation (run_gr00t_server.py)
|
| 49 |
+
experiment/ # Training pipeline (launch_finetune.py, trainer.py)
|
| 50 |
+
model/ # Model architecture (N1.7, base, modules)
|
| 51 |
+
policy/ # Policy inference (Gr00tPolicy, server/client)
|
| 52 |
+
examples/ # Per-embodiment example configs and READMEs
|
| 53 |
+
scripts/ # Deployment, conversion, and utility scripts
|
| 54 |
+
deployment/ # Platform install scripts (dgpu, orin, thor, spark)
|
| 55 |
+
tests/ # pytest suite (markers: gpu, not gpu)
|
| 56 |
+
getting_started/ # User-facing guides and notebooks
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## Key entry points
|
| 60 |
+
|
| 61 |
+
- **Fine-tune:** `bash examples/finetune.sh --base-model-path <path> --dataset-path <path> --embodiment-tag <tag> --output-dir <dir>`
|
| 62 |
+
- **Inference server:** `python gr00t/eval/run_gr00t_server.py --model-path <path> --embodiment-tag <tag>`
|
| 63 |
+
- **ONNX export:** `python scripts/deployment/export_onnx_n1d7.py`
|
| 64 |
+
- **TensorRT build:** `python scripts/deployment/build_trt_pipeline.py`
|
| 65 |
+
- **Benchmark:** `python scripts/deployment/benchmark_inference.py`
|
| 66 |
+
|
| 67 |
+
## Testing
|
| 68 |
+
|
| 69 |
+
- Test markers: `gpu` (requires GPU), default is CPU-safe
|
| 70 |
+
- Fixtures live in `tests/fixtures/` and `demo_data/`
|
| 71 |
+
- CI runs CPU and GPU tests in separate jobs with 300s timeout
|
| 72 |
+
|
| 73 |
+
## Deployment platforms
|
| 74 |
+
|
| 75 |
+
- **dGPU (H100, A100, RTX):** CUDA 12.8 — install via `scripts/deployment/dgpu/install_deps.sh`, container via top-level `docker/Dockerfile` (supports x86_64 and aarch64)
|
| 76 |
+
- **Jetson Orin:** CUDA 12.6 — install via `scripts/deployment/orin/install_deps.sh`, container via `scripts/deployment/orin/Dockerfile`
|
| 77 |
+
- **Jetson Thor:** CUDA 13.0 — install via `scripts/deployment/thor/install_deps.sh`, container via `scripts/deployment/thor/Dockerfile`
|
| 78 |
+
- **DGX Spark:** CUDA 13.0 — install via `scripts/deployment/spark/install_deps.sh`, container via `scripts/deployment/spark/Dockerfile`
|
| 79 |
+
|
| 80 |
+
Each Jetson/Spark platform ships an `activate_*.sh` helper (`scripts/activate_orin.sh`, `scripts/activate_spark.sh`, `scripts/activate_thor.sh`) that exports platform-specific library paths. For dGPU, the standard `source .venv/bin/activate` is sufficient.
|
ATTRIBUTIONS.md
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
CLAUDE.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLAUDE.md — Isaac GR00T N1.7
|
| 2 |
+
|
| 3 |
+
## Project overview
|
| 4 |
+
|
| 5 |
+
Isaac GR00T N1.7 is an open vision-language-action (VLA) model for generalized humanoid robot skills.
|
| 6 |
+
The repo contains the model, training pipeline, evaluation harness, and deployment tooling.
|
| 7 |
+
|
| 8 |
+
- **Language:** Python 3.10 (dGPU, Orin); Python 3.12 (Thor, DGX Spark — see deployment dir)
|
| 9 |
+
- **Package manager:** [uv](https://docs.astral.sh/uv/)
|
| 10 |
+
- **Build system:** setuptools (see `pyproject.toml`)
|
| 11 |
+
- **CI:** internal GitLab CI (`.gitlab-ci.yml` + includes under `ci/`, not shipped to the public GitHub EA repo); public GitHub Actions (`.github/workflows/`)
|
| 12 |
+
|
| 13 |
+
## Quick-start commands
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
# Install (dev mode with all extras)
|
| 17 |
+
uv sync --all-extras
|
| 18 |
+
|
| 19 |
+
# Lint and format (uses ruff via pre-commit)
|
| 20 |
+
pre-commit run --all-files
|
| 21 |
+
|
| 22 |
+
# Run CPU tests
|
| 23 |
+
python -m pytest tests/ -m "not gpu" -v --timeout=300
|
| 24 |
+
|
| 25 |
+
# Run GPU tests
|
| 26 |
+
python -m pytest tests/ -m gpu -v --timeout=300
|
| 27 |
+
|
| 28 |
+
# Build package
|
| 29 |
+
uv build
|
| 30 |
+
|
| 31 |
+
# Validate lockfile
|
| 32 |
+
uv lock --locked
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Code style
|
| 36 |
+
|
| 37 |
+
- Formatter: `ruff format` (double quotes, spaces, line-length 100)
|
| 38 |
+
- Linter: `ruff check` with rules E, F, I (ignores E501)
|
| 39 |
+
- Config lives in `pyproject.toml` under `[tool.ruff]`
|
| 40 |
+
- Run `pre-commit run --all-files` before committing
|
| 41 |
+
|
| 42 |
+
## Directory layout
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
gr00t/ # Main package
|
| 46 |
+
configs/ # Training, data, and model configs
|
| 47 |
+
data/ # Data loading, embodiment tags, dataset processing
|
| 48 |
+
eval/ # Evaluation (run_gr00t_server.py)
|
| 49 |
+
experiment/ # Training pipeline (launch_finetune.py, trainer.py)
|
| 50 |
+
model/ # Model architecture (N1.7, base, modules)
|
| 51 |
+
policy/ # Policy inference (Gr00tPolicy, server/client)
|
| 52 |
+
examples/ # Per-embodiment example configs and READMEs
|
| 53 |
+
scripts/ # Deployment, conversion, and utility scripts
|
| 54 |
+
deployment/ # Platform install scripts (dgpu, orin, thor, spark)
|
| 55 |
+
tests/ # pytest suite (markers: gpu, not gpu)
|
| 56 |
+
getting_started/ # User-facing guides and notebooks
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## Key entry points
|
| 60 |
+
|
| 61 |
+
- **Fine-tune:** `bash examples/finetune.sh --base-model-path <path> --dataset-path <path> --embodiment-tag <tag> --output-dir <dir>`
|
| 62 |
+
- **Inference server:** `python gr00t/eval/run_gr00t_server.py --model-path <path> --embodiment-tag <tag>`
|
| 63 |
+
- **ONNX export:** `python scripts/deployment/export_onnx_n1d7.py`
|
| 64 |
+
- **TensorRT build:** `python scripts/deployment/build_trt_pipeline.py`
|
| 65 |
+
- **Benchmark:** `python scripts/deployment/benchmark_inference.py`
|
| 66 |
+
|
| 67 |
+
## Testing
|
| 68 |
+
|
| 69 |
+
- Test markers: `gpu` (requires GPU), default is CPU-safe
|
| 70 |
+
- Fixtures live in `tests/fixtures/` and `demo_data/`
|
| 71 |
+
- CI runs CPU and GPU tests in separate jobs with 300s timeout
|
| 72 |
+
|
| 73 |
+
## Deployment platforms
|
| 74 |
+
|
| 75 |
+
- **dGPU (H100, A100, RTX):** CUDA 12.8 — install via `scripts/deployment/dgpu/install_deps.sh`, container via top-level `docker/Dockerfile` (supports x86_64 and aarch64)
|
| 76 |
+
- **Jetson Orin:** CUDA 12.6 — install via `scripts/deployment/orin/install_deps.sh`, container via `scripts/deployment/orin/Dockerfile`
|
| 77 |
+
- **Jetson Thor:** CUDA 13.0 — install via `scripts/deployment/thor/install_deps.sh`, container via `scripts/deployment/thor/Dockerfile`
|
| 78 |
+
- **DGX Spark:** CUDA 13.0 — install via `scripts/deployment/spark/install_deps.sh`, container via `scripts/deployment/spark/Dockerfile`
|
| 79 |
+
|
| 80 |
+
Each Jetson/Spark platform ships an `activate_*.sh` helper (`scripts/activate_orin.sh`, `scripts/activate_spark.sh`, `scripts/activate_thor.sh`) that exports platform-specific library paths. For dGPU, the standard `source .venv/bin/activate` is sufficient.
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributions
|
| 2 |
+
|
| 3 |
+
During Early Access we are not accepting pull requests while the codebase stabilizes. If you encounter issues or have suggestions, please open an [Issue](https://github.com/NVIDIA/Isaac-GR00T/issues) in this repository.
|
| 4 |
+
|
| 5 |
+
## Support
|
| 6 |
+
|
| 7 |
+
Support during Early Access is best-effort. We will continue iterating toward a more stable General Availability (GA) release.
|
FAQ.md
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GR00T N1.7 FAQ
|
| 2 |
+
|
| 3 |
+
## Infrastructure & Hardware
|
| 4 |
+
|
| 5 |
+
### Is the data loader GPU-accelerated?
|
| 6 |
+
|
| 7 |
+
No, the current data loader is CPU-based. However, it has been heavily optimized for multimodal data to ensure it does not become a training bottleneck. We validated this on various configurations, including GB200, H100, and local desktops with RTX 4090 GPUs. We are actively exploring GPU-accelerated approaches for future releases.
|
| 8 |
+
|
| 9 |
+
### Is the same data loader used for both pre-training and post-training?
|
| 10 |
+
|
| 11 |
+
Yes, the data loading pipeline is unified across both training stages.
|
| 12 |
+
|
| 13 |
+
### What is the role of the Policy Remote Server in the deployment diagram?
|
| 14 |
+
|
| 15 |
+
The Policy Remote Server decouples inference from the physical robot. This allows users to run the policy on a high-compute cluster (e.g., H100s) for faster inference while the robot operates in a separate environment. It separates dependencies and enables scaling beyond the robot's onboard compute.
|
| 16 |
+
|
| 17 |
+
## Workflow & Architecture
|
| 18 |
+
|
| 19 |
+
### Why retain only specific LLM layers (e.g., 16 layers) during fine-tuning?
|
| 20 |
+
|
| 21 |
+
This configuration was empirically tuned for the backbone (e.g., Eagle/Cosmos-Reason). Research suggests early layers capture grammatical structure, while middle-to-late layers are highly expressive. However, the very last layers are often over-optimized for next-token prediction; pruning or freezing them can sometimes yield better representations for vision-language-action alignment.
|
| 22 |
+
|
| 23 |
+
### How do you verify if the language model is successfully aligned with the action space?
|
| 24 |
+
|
| 25 |
+
We evaluate this end-to-end via downstream task success. We design evaluation tasks that are ambiguous without language instructions (e.g., "pick the pear" from a bowl of mixed fruit). If the robot succeeds, it confirms the model is correctly grounding language commands into physical actions.
|
| 26 |
+
|
| 27 |
+
## Data Strategy & Volume
|
| 28 |
+
|
| 29 |
+
### How much data is required for post-training on a new embodiment or task?
|
| 30 |
+
|
| 31 |
+
Data requirements depend heavily on task complexity and scene variation. Typical guidelines include:
|
| 32 |
+
|
| 33 |
+
- **Simple, fixed-location tasks (Pick & Place):** ~100 trajectories.
|
| 34 |
+
- **Complex scenes or multi-step tasks:** ~500+ trajectories.
|
| 35 |
+
- **High-DoF humanoid tasks:** ~2,000+ trajectories (e.g., shelf-picking with G1).
|
| 36 |
+
- **Fine manipulation:** ~100–500 episodes, ideally with human motion pre-training.
|
| 37 |
+
|
| 38 |
+
### What is the recommended strategy for improving success rates on hard tasks?
|
| 39 |
+
|
| 40 |
+
We recommend an iterative approach: start with ~100 teleoperated demonstrations, train a policy, and then use HG-DAgger (Human Gated Dataset Aggregation). Run the policy, intervene when it fails, and add the corrections from those trajectories to the dataset. This helps the model cover out-of-distribution states that pure behavior cloning (BC) might miss, and recover from partial failure states (e.g., a grip slipping or imprecise item placement).
|
| 41 |
+
|
| 42 |
+
### Does including real-robot data from other embodiments help if I only care about one robot?
|
| 43 |
+
|
| 44 |
+
Yes. Even if cross-embodiment generalization is not your goal, including diverse real-robot data adds visual diversity and robustness to the VLA's backbone, improving performance on your specific target robot.
|
| 45 |
+
|
| 46 |
+
### Does GR00T N1.7 support synthetic data generation via Cosmos?
|
| 47 |
+
|
| 48 |
+
While research models (like DreamGen) show promise, a robust, product-ready pipeline for generating synthetic training data via Cosmos is currently in development and not yet part of the standard release.
|
| 49 |
+
|
| 50 |
+
## Model Capabilities
|
| 51 |
+
|
| 52 |
+
### Can the model handle lighting changes or different object colors?
|
| 53 |
+
|
| 54 |
+
VLMs can struggle with drastic appearance changes (e.g., hard shadows or significant hue shifts). While we haven't released specific lighting ablations, we strongly recommend using color jitter augmentation during training and collecting diverse data (20–50 episodes) under different lighting conditions to prevent overfitting.
|
| 55 |
+
|
| 56 |
+
### Can GR00T models perform reasoning or Visual Question Answering (VQA)?
|
| 57 |
+
|
| 58 |
+
The GR00T N1.x series is optimized specifically for action generation, not open-ended reasoning or VQA. Capabilities requiring complex semantic reasoning are targeted for future N2 releases.
|
| 59 |
+
|
| 60 |
+
### Can the model learn "retry" behaviors?
|
| 61 |
+
|
| 62 |
+
The current architecture is stateless and does not inherently "know" if a previous attempt failed. While some retry behavior may emerge from high-quality data, explicit recovery strategies are best achieved through DAgger (collecting data on recovery from failure) or Reinforcement Learning (RL), rather than pure Imitation Learning.
|
| 63 |
+
|
| 64 |
+
### Does the model distinguish between left and right arms in bimanual tasks?
|
| 65 |
+
|
| 66 |
+
Yes, provided the training data is distinct or annotated (e.g., instructions specifying "left arm" vs. "right arm"). If the dataset contains mixed, unannotated data where both arms perform identical tasks indiscriminately, the model may struggle to distinguish them.
|
| 67 |
+
|
| 68 |
+
### Is there a zero-shot cross-embodiment VLA model?
|
| 69 |
+
|
| 70 |
+
No. While cross-embodiment data improves generalization, a true "zero-shot" model (one that works perfectly on a new robot without *any* fine-tuning) does not currently exist in the open VLA landscape.
|
| 71 |
+
|
| 72 |
+
### Will differences in object shape between training and deployment cause the success rate to drop?
|
| 73 |
+
|
| 74 |
+
It depends on the degree of deviation. If the target object's shape differs drastically from the training data, performance will likely drop significantly. However, if the shape variation is minor and shares a similar grasping affordance (e.g., a slightly different bottle shape that is still grasped from the side), the model may still succeed, though with potentially lower reliability than on the original objects.
|
| 75 |
+
|
| 76 |
+
### Has the impact of large viewpoint changes (e.g., head movement) on task difficulty been studied?
|
| 77 |
+
|
| 78 |
+
Yes. Large viewpoint changes effectively change the observation distribution, which can complicate simple tasks. For example, a "simple" handover becomes complex if the robot's head moves significantly, altering the camera's perspective of its own hands.
|
| 79 |
+
|
| 80 |
+
- **Current Status:** Most public GR00T demos feature a relatively fixed head position to stabilize observations.
|
| 81 |
+
- **Mitigation:** To handle natural head movement, we recommend training with aggressive camera pose augmentation or collecting data that explicitly includes head motion to ensure the policy becomes robust to viewpoint shifts.
|
README.md
ADDED
|
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
<img src="media/header_compress.png" width="800" alt="NVIDIA Isaac GR00T N1.7 Header">
|
| 4 |
+
|
| 5 |
+
<!-- --- -->
|
| 6 |
+
|
| 7 |
+
<p style="font-size: 1.2em;">
|
| 8 |
+
<a href="https://developer.nvidia.com/isaac/gr00t"><strong>Website</strong></a> |
|
| 9 |
+
<a href="https://huggingface.co/collections/nvidia/gr00t-n17"><strong>Model</strong></a> |
|
| 10 |
+
<a href="https://huggingface.co/collections/nvidia/physical-ai"><strong>Dataset</strong></a> |
|
| 11 |
+
<a href="https://arxiv.org/abs/2503.14734"><strong>Paper</strong></a> |
|
| 12 |
+
<a href="https://developer.nvidia.com/isaac"><strong>NVIDIA Isaac</strong></a> |
|
| 13 |
+
<a href="FAQ.md"><strong>FAQ</strong></a>
|
| 14 |
+
</p>
|
| 15 |
+
</div>
|
| 16 |
+
|
| 17 |
+
## Table of Contents
|
| 18 |
+
|
| 19 |
+
- [NVIDIA Isaac GR00T](#nvidia-isaac-gr00t)
|
| 20 |
+
- [What's New in GR00T N1.7](#whats-new-in-gr00t-n17)
|
| 21 |
+
- [Installation](#installation)
|
| 22 |
+
- [Model Checkpoints & Embodiment Tags](#model-checkpoints--embodiment-tags)
|
| 23 |
+
- [Data Format](#data-format)
|
| 24 |
+
- [Inference](#inference)
|
| 25 |
+
- [Fine-tuning](#fine-tuning)
|
| 26 |
+
- [Evaluation](#evaluation)
|
| 27 |
+
- [Contributions](#contributions)
|
| 28 |
+
- [License](#license)
|
| 29 |
+
- [Citation](#citation)
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## NVIDIA Isaac GR00T
|
| 34 |
+
|
| 35 |
+
<table style="width:100%; table-layout:fixed;">
|
| 36 |
+
<tr>
|
| 37 |
+
<td style="width:33.33%; text-align:center;">
|
| 38 |
+
<img src="media/unitree_g1.gif" style="max-width:100%; height:auto;">
|
| 39 |
+
</td>
|
| 40 |
+
<td style="width:33.33%; text-align:center;">
|
| 41 |
+
<img src="media/agibot_g1.gif" style="max-width:100%; height:auto;">
|
| 42 |
+
</td>
|
| 43 |
+
<td style="width:33.33%; text-align:center;">
|
| 44 |
+
<img src="media/yam.gif" style="max-width:100%; height:auto;">
|
| 45 |
+
</td>
|
| 46 |
+
</tr>
|
| 47 |
+
</table>
|
| 48 |
+
|
| 49 |
+
> We just released GR00T N1.7 Early Access, the latest version of GR00T N1 with a new VLM backbone (Cosmos-Reason2-2B / Qwen3-VL) and improved performance.
|
| 50 |
+
|
| 51 |
+
> **This is an Early Access (EA) release.** You are welcome to download the model, explore the codebase, and begin building on the stack, with the understanding that support and stability guarantees are limited until the GA release.
|
| 52 |
+
>
|
| 53 |
+
> **What's available:**
|
| 54 |
+
> - Pre-trained GR00T N1.7 model weights and reference code
|
| 55 |
+
> - Fine-tuning and inference with custom robot data or demonstrations
|
| 56 |
+
> - Experimentation, prototyping, and research use cases
|
| 57 |
+
>
|
| 58 |
+
> **Available at GA:**
|
| 59 |
+
> - Production deployment with commercial support
|
| 60 |
+
> - Complete benchmarks and a fully validated, stable feature set
|
| 61 |
+
> - Pull request contributions
|
| 62 |
+
>
|
| 63 |
+
> We welcome feedback - please feel free to raise issues in this repository.
|
| 64 |
+
|
| 65 |
+
> To use older versions: [N1.6](https://github.com/NVIDIA/Isaac-GR00T/releases/tag/n1.6-release) | [N1.5](https://github.com/NVIDIA/Isaac-GR00T/tree/n1.5-release)
|
| 66 |
+
|
| 67 |
+
NVIDIA Isaac GR00T N1.7 is an open vision-language-action (VLA) model for generalized humanoid robot skills. This cross-embodiment model takes multimodal input, including language and images, to perform manipulation tasks in diverse environments.
|
| 68 |
+
|
| 69 |
+
GR00T N1.7 is trained on a diverse mixture of robot data including bimanual, semi-humanoid and an expansive humanoid dataset. It is adaptable through post-training for specific embodiments, tasks and environments.
|
| 70 |
+
|
| 71 |
+
GR00T N1.7 is fully commercially licensable under Apache 2.0. It delivers comparable performance to N1.6, with improved generalization and language-following capabilities driven by the inclusion of 20K hours of EgoScale human video data in pretraining.
|
| 72 |
+
|
| 73 |
+
The neural network architecture of GR00T N1.7 is a combination of vision-language foundation model and diffusion transformer head that denoises continuous actions. Here is a schematic diagram of the architecture:
|
| 74 |
+
|
| 75 |
+
<div align="center">
|
| 76 |
+
<img src="media/model-architecture.png" width="800" alt="model-architecture">
|
| 77 |
+
</div>
|
| 78 |
+
|
| 79 |
+
### Workflow Overview
|
| 80 |
+
|
| 81 |
+
1. **Prepare data** — Collect robot demonstrations (video, state, action) and convert them to the [GR00T LeRobot format](#data-format). Demo datasets are included for quick testing.
|
| 82 |
+
2. **Run inference** — Try zero-shot inference with the base model on [pretrain embodiments](#embodiment-tags), or use a [finetuned checkpoint](#checkpoints) for benchmark tasks.
|
| 83 |
+
3. **Fine-tune** — Adapt the model to your robot using [`launch_finetune.py`](#fine-tuning) with your own data and modality config.
|
| 84 |
+
4. **Evaluate** — Validate with [open-loop evaluation](#open-loop-evaluation), then test in [simulation benchmarks](#benchmark-examples) or on real hardware via the [Policy API](getting_started/policy.md).
|
| 85 |
+
5. **Deploy** — Connect `Gr00tPolicy` to your robot controller, optionally accelerated with [TensorRT](scripts/deployment/README.md).
|
| 86 |
+
|
| 87 |
+
## What's New in GR00T N1.7
|
| 88 |
+
|
| 89 |
+
GR00T N1.7 builds on N1.6 with a new VLM backbone and code-level improvements.
|
| 90 |
+
|
| 91 |
+
1. **Relative EEF Action Space** — N1.7 adopts a relative end-effector action space shared across robot and human embodiments. Representing actions as deltas from the current pose (rather than absolute targets) improves generalization and is a key factor in the model's cross-embodiment performance. See [`getting_started/finetune_new_embodiment.md`](getting_started/finetune_new_embodiment.md) for guidance on configuring relative EEF for your own robot.
|
| 92 |
+
|
| 93 |
+
2. **Human Video Pretraining** — N1.7 is pretrained on 20K hours of EgoScale human video data alongside diverse robot demonstrations. Because the relative EEF action representation is consistent across both human and robot data, the model can transfer manipulation priors learned from human video directly to robot control.
|
| 94 |
+
|
| 95 |
+
### Key Changes from N1.6
|
| 96 |
+
|
| 97 |
+
- **New VLM backbone:** Cosmos-Reason2-2B (Qwen3-VL architecture), replacing the Eagle backbone used in N1.6. Supports flexible resolution and encodes images in their native aspect ratio without padding.
|
| 98 |
+
- Simplified data processing pipeline (`processing_gr00t_n1d7.py`).
|
| 99 |
+
- Added full pipeline export to ONNX and TensorRT with improved frequency.
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
## Installation
|
| 104 |
+
|
| 105 |
+
### Hardware Requirements
|
| 106 |
+
|
| 107 |
+
**Inference:** 1 GPU with 16 GB+ VRAM (e.g., RTX 4090, L40, H100, Jetson AGX Thor/Orin, DGX Spark).
|
| 108 |
+
|
| 109 |
+
**Fine-tuning:** 1 or more GPUs with 40 GB+ VRAM recommended. We recommend H100 or L40 nodes for optimal performance. Other hardware (e.g., A6000) works but may require longer training time. See the [Hardware Recommendation Guide](getting_started/hardware_recommendation.md) for detailed specs.
|
| 110 |
+
|
| 111 |
+
**CUDA / Python per platform:** dGPU on CUDA 12.8 with Python 3.10; Jetson Orin on CUDA 12.6 with Python 3.10; Jetson Thor and DGX Spark on CUDA 13.0 with Python 3.12. The per-platform install scripts and Dockerfiles live under `scripts/deployment/`; see the [Deployment & Inference Guide](scripts/deployment/README.md) for the full matrix.
|
| 112 |
+
|
| 113 |
+
### Clone the Repository
|
| 114 |
+
|
| 115 |
+
GR00T relies on submodules for certain dependencies. Include them when cloning:
|
| 116 |
+
|
| 117 |
+
**Note:** `git-lfs` is **required** to download parquet data files in `/demo_data`. Install it before cloning: `sudo apt install git-lfs && git lfs install`.
|
| 118 |
+
```sh
|
| 119 |
+
git clone --recurse-submodules https://github.com/NVIDIA/Isaac-GR00T
|
| 120 |
+
cd Isaac-GR00T
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
If you've already cloned without submodules, initialize them separately:
|
| 124 |
+
|
| 125 |
+
```sh
|
| 126 |
+
git submodule update --init --recursive
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
### Set Up the Environment
|
| 130 |
+
|
| 131 |
+
GR00T uses [uv](https://github.com/astral-sh/uv) for fast, reproducible dependency management. Install uv first:
|
| 132 |
+
|
| 133 |
+
```sh
|
| 134 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
#### dGPU (x86_64) — Default
|
| 138 |
+
|
| 139 |
+
Install FFmpeg (required by `torchcodec`, the default video backend):
|
| 140 |
+
```sh
|
| 141 |
+
sudo apt-get update && sudo apt-get install -y ffmpeg
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
Create the environment and install GR00T:
|
| 145 |
+
```sh
|
| 146 |
+
uv sync --python 3.10
|
| 147 |
+
```
|
| 148 |
+
GPU dependencies (flash-attn, TensorRT, etc.) are included in the default install.
|
| 149 |
+
|
| 150 |
+
Verify the installation:
|
| 151 |
+
```sh
|
| 152 |
+
uv run python -c "import gr00t; print('GR00T installed successfully')"
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
> **`flash-attn` message on every `uv run`:** You may see `Installing flash-attn...` each time you run `uv run`. This is a known `uv` behavior with URL-pinned wheel sources — `uv` re-validates the cached wheel against the source URL on each invocation. It is **not** rebuilding from source; the wheel is already cached locally and the operation takes 2-3 seconds. This only affects x86_64 platforms.
|
| 156 |
+
> To suppress it, remove the `flash-attn` entries under `[tool.uv.sources]` in your local `pyproject.toml` after the initial install. But that will break `uv lock` and cause flash-attn to build from source on next lock regeneration.
|
| 157 |
+
|
| 158 |
+
<details>
|
| 159 |
+
<summary><strong>Alternative: pip install (without uv)</strong></summary>
|
| 160 |
+
|
| 161 |
+
If you prefer pip/conda over uv, create a Python 3.10 virtualenv and install:
|
| 162 |
+
```sh
|
| 163 |
+
python3.10 -m venv .venv && source .venv/bin/activate
|
| 164 |
+
pip install -e .
|
| 165 |
+
```
|
| 166 |
+
Note: GPU dependencies (flash-attn, TensorRT) may require manual installation with pip. The `uv` workflow handles these automatically.
|
| 167 |
+
</details>
|
| 168 |
+
|
| 169 |
+
> **If fine-tuning fails with `CUDA_HOME is unset`:** Run `bash scripts/deployment/dgpu/install_deps.sh` once to configure CUDA paths, or manually `export CUDA_HOME=/usr/local/cuda`.
|
| 170 |
+
|
| 171 |
+
> **CUDA 13.x Users (Thor, Spark, and other CUDA 13+ platforms):** PyTorch 2.7 pins Triton to 3.3.1, which does not recognize CUDA major version 13+. This causes a `RuntimeError` in Triton's `ptx_get_version()`. Run the patch script to fix:
|
| 172 |
+
> ```sh
|
| 173 |
+
> uv run bash scripts/patch_triton_cuda13.sh
|
| 174 |
+
> ```
|
| 175 |
+
|
| 176 |
+
> **GB300 (sm_103) Users:** Triton 3.3.1 (pinned by PyTorch 2.7) does not support the GB300 GPU architecture (sm_103). `torch.compile` will fail on GB300. Use PyTorch eager mode or TensorRT inference instead. Triton 3.5.1+ adds sm_103 support but is not yet compatible with the pinned PyTorch version.
|
| 177 |
+
|
| 178 |
+
> **aarch64 Video Backend:** On aarch64 platforms (Thor, Orin, Spark), `torchcodec` is the required video backend. `install_deps.sh` prefers the prebuilt aarch64 wheel under `scripts/deployment/dgpu/wheels/` (shared by Thor/Spark against FFmpeg 6; Orin uses a matching build against FFmpeg 4) and falls back to a source build only if the wheel is missing. If you encounter `NotImplementedError` from the video backend, ensure `torchcodec` was installed successfully during setup. Other backends (decord, pyav) are not supported on aarch64.
|
| 179 |
+
|
| 180 |
+
<details>
|
| 181 |
+
<summary><strong>DGX Spark</strong> (tested with DGX Spark GB10)</summary>
|
| 182 |
+
|
| 183 |
+
```bash
|
| 184 |
+
bash scripts/deployment/spark/install_deps.sh
|
| 185 |
+
source .venv/bin/activate
|
| 186 |
+
source scripts/activate_spark.sh
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
See the [Spark setup guide](scripts/deployment/README.md#dgx-spark-setup) for Docker and bare metal details.
|
| 190 |
+
</details>
|
| 191 |
+
|
| 192 |
+
<details>
|
| 193 |
+
<summary><strong>Jetson AGX Thor</strong> (tested with JetPack 7.1)</summary>
|
| 194 |
+
|
| 195 |
+
> **flash-attn on older systems (e.g., Ubuntu 20.04 with glibc < 2.35):** The pre-built `flash-attn` wheel may fail with `ImportError: glibc_compat.so: cannot open shared object file`. To fix this, build from source:
|
| 196 |
+
> ```sh
|
| 197 |
+
> uv pip install flash-attn==2.7.4.post1 --no-binary flash-attn --no-cache
|
| 198 |
+
> ```
|
| 199 |
+
> This compiles locally (~10-30 minutes) and avoids the glibc compatibility issue.
|
| 200 |
+
|
| 201 |
+
```bash
|
| 202 |
+
bash scripts/deployment/thor/install_deps.sh
|
| 203 |
+
source .venv/bin/activate
|
| 204 |
+
source scripts/activate_thor.sh
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
See the [Thor setup guide](scripts/deployment/README.md#jetson-thor-setup) for Docker and bare metal details.
|
| 208 |
+
</details>
|
| 209 |
+
|
| 210 |
+
<details>
|
| 211 |
+
<summary><strong>Jetson Orin</strong> (tested with JetPack 6.2)</summary>
|
| 212 |
+
|
| 213 |
+
```bash
|
| 214 |
+
bash scripts/deployment/orin/install_deps.sh
|
| 215 |
+
source .venv/bin/activate
|
| 216 |
+
source scripts/activate_orin.sh
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
See the [Orin setup guide](scripts/deployment/README.md#jetson-orin-setup) for Docker and bare metal details.
|
| 220 |
+
</details>
|
| 221 |
+
|
| 222 |
+
For a containerized setup that avoids system-level dependency conflicts, see our [Docker Setup Guide](docker/README.md).
|
| 223 |
+
|
| 224 |
+
---
|
| 225 |
+
|
| 226 |
+
## Model Checkpoints & Embodiment Tags
|
| 227 |
+
|
| 228 |
+
### Checkpoints
|
| 229 |
+
|
| 230 |
+
| Checkpoint | Type | Embodiment Tag | Description |
|
| 231 |
+
|------------|------|---------------|-------------|
|
| 232 |
+
| [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B) | Base | See [pretrain tags](getting_started/policy.md#--embodiment-tag) | Base model (3B params) — zero-shot inference on pretrain embodiments, or finetune for new tasks |
|
| 233 |
+
| [`nvidia/GR00T-N1.7-LIBERO`](https://huggingface.co/nvidia/GR00T-N1.7-LIBERO) | Finetuned | `LIBERO_PANDA` | Finetuned on [LIBERO](https://libero-project.github.io/) benchmark (Franka Panda) |
|
| 234 |
+
| [`nvidia/GR00T-N1.7-DROID`](https://huggingface.co/nvidia/GR00T-N1.7-DROID) | Finetuned | `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` | Finetuned on [DROID](https://droid-dataset.github.io/) dataset |
|
| 235 |
+
| [`nvidia/GR00T-N1.7-SimplerEnv-Bridge`](https://huggingface.co/nvidia/GR00T-N1.7-SimplerEnv-Bridge) | Finetuned | `SIMPLER_ENV_WIDOWX` | Finetuned on SimplerEnv Bridge (WidowX) |
|
| 236 |
+
| [`nvidia/GR00T-N1.7-SimplerEnv-Fractal`](https://huggingface.co/nvidia/GR00T-N1.7-SimplerEnv-Fractal) | Finetuned | `SIMPLER_ENV_GOOGLE` | Finetuned on SimplerEnv Fractal (Google Robot) |
|
| 237 |
+
|
| 238 |
+
> Older versions: [N1.6 checkpoints](https://github.com/NVIDIA/Isaac-GR00T/tree/n1.6-release) | [N1.5 checkpoints](https://github.com/NVIDIA/Isaac-GR00T/tree/n1.5-release)
|
| 239 |
+
|
| 240 |
+
### Embodiment Tags
|
| 241 |
+
|
| 242 |
+
Every inference or finetuning command requires an `--embodiment-tag`. The tag determines which modality config (state/action keys, normalization) the model uses. Tags are **case-insensitive**.
|
| 243 |
+
|
| 244 |
+
For the full list of pretrain and posttrain tags, see the [Policy API Guide — Embodiment Tags](getting_started/policy.md#--embodiment-tag).
|
| 245 |
+
|
| 246 |
+
---
|
| 247 |
+
|
| 248 |
+
## Data Format
|
| 249 |
+
|
| 250 |
+
GR00T uses a flavor of the [LeRobot v2 dataset format](https://github.com/huggingface/lerobot) with an additional `meta/modality.json` file that describes state/action/video structure. A dataset looks like:
|
| 251 |
+
|
| 252 |
+
```
|
| 253 |
+
my_dataset/
|
| 254 |
+
meta/
|
| 255 |
+
info.json # dataset metadata
|
| 256 |
+
episodes.jsonl # episode index and lengths
|
| 257 |
+
tasks.jsonl # language task descriptions
|
| 258 |
+
modality.json # state/action/video key mapping (GR00T-specific)
|
| 259 |
+
data/chunk-000/ # parquet files (state, action per timestep)
|
| 260 |
+
videos/chunk-000/ # mp4 video files per episode
|
| 261 |
+
```
|
| 262 |
+
|
| 263 |
+
The `modality.json` maps how the concatenated state/action arrays split into named fields (e.g., `x`, `y`, `z`, `gripper`) and which video keys are available. This is what the embodiment tag uses to interpret the data.
|
| 264 |
+
|
| 265 |
+
**Included demo datasets** (ready to use, no download needed):
|
| 266 |
+
|
| 267 |
+
| Dataset | Robot | Embodiment Tag | Use Case |
|
| 268 |
+
|---------|-------|---------------|----------|
|
| 269 |
+
| `demo_data/droid_sample` | DROID (3 episodes) | `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` | Zero-shot or finetuned inference (DROID) |
|
| 270 |
+
| `demo_data/libero_demo` | LIBERO Panda (5 episodes) | `LIBERO_PANDA` | Inference with finetuned checkpoint |
|
| 271 |
+
| `demo_data/simplerenv_bridge_sample` | WidowX (SimplerEnv Bridge) | `SIMPLER_ENV_WIDOWX` | Inference with finetuned SimplerEnv Bridge checkpoint |
|
| 272 |
+
| `demo_data/simplerenv_fractal_sample` | Google Robot (SimplerEnv Fractal) | `SIMPLER_ENV_GOOGLE` | Inference with finetuned SimplerEnv Fractal checkpoint |
|
| 273 |
+
| `demo_data/cube_to_bowl_5` | SO100 arm (5 episodes) | `NEW_EMBODIMENT` | Fine-tuning custom embodiment example |
|
| 274 |
+
| `demo_data/cube_to_bowl_5_with_mask` | SO100 arm + per-frame masks | `NEW_EMBODIMENT` | [Mask-guided background suppression](examples/mask-guided-background-suppression/README.md) example |
|
| 275 |
+
|
| 276 |
+
> To generate more DROID episodes: `python scripts/download_droid_sample.py --num-episodes 10`
|
| 277 |
+
|
| 278 |
+
**Using your own data:** Convert your demonstrations to the format above. If coming from LeRobot v3, use the conversion script: `python scripts/lerobot_conversion/convert_v3_to_v2.py`. See the full [Data Preparation Guide](getting_started/data_preparation.md) for schema details and examples.
|
| 279 |
+
|
| 280 |
+
---
|
| 281 |
+
|
| 282 |
+
## Inference
|
| 283 |
+
|
| 284 |
+
### Zero-Shot Inference (Base Model)
|
| 285 |
+
|
| 286 |
+
The included `demo_data/droid_sample` dataset works with the base model out of the box — no finetuning or checkpoint download needed:
|
| 287 |
+
|
| 288 |
+
```bash
|
| 289 |
+
uv run python scripts/deployment/standalone_inference_script.py \
|
| 290 |
+
--model-path nvidia/GR00T-N1.7-3B \
|
| 291 |
+
--dataset-path demo_data/droid_sample \
|
| 292 |
+
--embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT \
|
| 293 |
+
--traj-ids 1 2 \
|
| 294 |
+
--inference-mode pytorch \
|
| 295 |
+
--action-horizon 8
|
| 296 |
+
```
|
| 297 |
+
|
| 298 |
+
This runs open-loop inference on 2 DROID episodes, comparing predicted actions against ground truth. The base model downloads automatically from HuggingFace on first run (~6 GB).
|
| 299 |
+
|
| 300 |
+
### Finetuned Inference
|
| 301 |
+
|
| 302 |
+
For posttrain embodiments, use a finetuned checkpoint. Most finetuned checkpoints (e.g., DROID, SimplerEnv) have a flat file structure and can be passed directly as a HuggingFace model ID — no manual download needed:
|
| 303 |
+
|
| 304 |
+
```bash
|
| 305 |
+
uv run python scripts/deployment/standalone_inference_script.py \
|
| 306 |
+
--model-path nvidia/GR00T-N1.7-DROID \
|
| 307 |
+
--dataset-path demo_data/droid_sample \
|
| 308 |
+
--embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT \
|
| 309 |
+
--traj-ids 1 2 \
|
| 310 |
+
--inference-mode pytorch \
|
| 311 |
+
--action-horizon 8
|
| 312 |
+
```
|
| 313 |
+
|
| 314 |
+
Some checkpoints (e.g., LIBERO) use a nested folder structure with model files under a subfolder. HuggingFace does not support nested repo paths in `--model-path`, so you must download first:
|
| 315 |
+
|
| 316 |
+
```bash
|
| 317 |
+
uv run hf download nvidia/GR00T-N1.7-LIBERO \
|
| 318 |
+
--include "libero_10/config.json" "libero_10/embodiment_id.json" \
|
| 319 |
+
"libero_10/model-*.safetensors" "libero_10/model.safetensors.index.json" \
|
| 320 |
+
"libero_10/processor_config.json" "libero_10/statistics.json" \
|
| 321 |
+
--local-dir checkpoints/GR00T-N1.7-LIBERO
|
| 322 |
+
```
|
| 323 |
+
|
| 324 |
+
```bash
|
| 325 |
+
uv run python scripts/deployment/standalone_inference_script.py \
|
| 326 |
+
--model-path checkpoints/GR00T-N1.7-LIBERO/libero_10 \
|
| 327 |
+
--dataset-path demo_data/libero_demo \
|
| 328 |
+
--embodiment-tag LIBERO_PANDA \
|
| 329 |
+
--traj-ids 0 1 2 \
|
| 330 |
+
--inference-mode pytorch \
|
| 331 |
+
--action-horizon 8
|
| 332 |
+
```
|
| 333 |
+
|
| 334 |
+
### Server-Client Inference (for Deployment)
|
| 335 |
+
|
| 336 |
+
For real-world deployment or simulation evaluation, use the server-client architecture. The policy runs on a GPU server; a lightweight client sends observations and receives actions over ZMQ.
|
| 337 |
+
|
| 338 |
+
**Terminal 1 — Start the policy server:**
|
| 339 |
+
```bash
|
| 340 |
+
uv run python gr00t/eval/run_gr00t_server.py \
|
| 341 |
+
--model-path nvidia/GR00T-N1.7-3B \
|
| 342 |
+
--embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT \
|
| 343 |
+
--device cuda:0
|
| 344 |
+
```
|
| 345 |
+
|
| 346 |
+
**Terminal 2 — Run open-loop evaluation as a client:**
|
| 347 |
+
```bash
|
| 348 |
+
uv run python gr00t/eval/open_loop_eval.py \
|
| 349 |
+
--dataset-path demo_data/droid_sample \
|
| 350 |
+
--embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT \
|
| 351 |
+
--host 127.0.0.1 \
|
| 352 |
+
--port 5555 \
|
| 353 |
+
--traj-ids 1 2 \
|
| 354 |
+
--action-horizon 8
|
| 355 |
+
```
|
| 356 |
+
|
| 357 |
+
> **Tip:** If you get `ZMQError: Address already in use`, the default port 5555 is occupied. Use `--port <other_port>`.
|
| 358 |
+
|
| 359 |
+
For connecting to a real robot (e.g., DROID hardware), see [examples/DROID/README.md](examples/DROID/README.md). For faster inference with TensorRT, see the [Deployment & Inference Guide](scripts/deployment/README.md).
|
| 360 |
+
|
| 361 |
+
See the complete [Policy API Guide](getting_started/policy.md) for documentation on observation/action formats, batched inference, and troubleshooting.
|
| 362 |
+
|
| 363 |
+
---
|
| 364 |
+
|
| 365 |
+
## Fine-tuning
|
| 366 |
+
|
| 367 |
+
### Reproducing Benchmark Results
|
| 368 |
+
|
| 369 |
+
Each benchmark has a self-contained README with dataset download, finetune, and evaluation commands:
|
| 370 |
+
|
| 371 |
+
| Benchmark | Embodiment | Guide |
|
| 372 |
+
|-----------|-----------|-------|
|
| 373 |
+
| LIBERO | `LIBERO_PANDA` | [examples/LIBERO/README.md](examples/LIBERO/README.md) |
|
| 374 |
+
| SimplerEnv (Fractal) | `SIMPLER_ENV_GOOGLE` | [examples/SimplerEnv/README.md](examples/SimplerEnv/README.md) |
|
| 375 |
+
| SimplerEnv (Bridge) | `SIMPLER_ENV_WIDOWX` | [examples/SimplerEnv/README.md](examples/SimplerEnv/README.md) |
|
| 376 |
+
| SO100 | `NEW_EMBODIMENT` | [examples/SO100/README.md](examples/SO100/README.md) |
|
| 377 |
+
|
| 378 |
+
### Fine-tune on Your Own Robot ("NEW_EMBODIMENT")
|
| 379 |
+
|
| 380 |
+
To finetune GR00T on your own robot data and configuration, follow the detailed tutorial at [`getting_started/finetune_new_embodiment.md`](getting_started/finetune_new_embodiment.md).
|
| 381 |
+
|
| 382 |
+
Ensure your input data follows the [GR00T LeRobot format](#data-format), and specify your modality configuration via `--modality-config-path`.
|
| 383 |
+
|
| 384 |
+
**Single GPU:**
|
| 385 |
+
```bash
|
| 386 |
+
CUDA_VISIBLE_DEVICES=0 uv run python \
|
| 387 |
+
gr00t/experiment/launch_finetune.py \
|
| 388 |
+
--base-model-path nvidia/GR00T-N1.7-3B \
|
| 389 |
+
--dataset-path demo_data/cube_to_bowl_5 \
|
| 390 |
+
--embodiment-tag NEW_EMBODIMENT \
|
| 391 |
+
--modality-config-path examples/SO100/so100_config.py \
|
| 392 |
+
--num-gpus 1 \
|
| 393 |
+
--output-dir /tmp/test_finetune \
|
| 394 |
+
--max-steps 2000 \
|
| 395 |
+
--global-batch-size 32 \
|
| 396 |
+
--dataloader-num-workers 4
|
| 397 |
+
```
|
| 398 |
+
|
| 399 |
+
**Multi-GPU (e.g., 8xH100):**
|
| 400 |
+
```bash
|
| 401 |
+
uv run torchrun --nproc_per_node=8 --master_port=29500 \
|
| 402 |
+
gr00t/experiment/launch_finetune.py \
|
| 403 |
+
--base-model-path nvidia/GR00T-N1.7-3B \
|
| 404 |
+
--dataset-path demo_data/cube_to_bowl_5 \
|
| 405 |
+
--embodiment-tag NEW_EMBODIMENT \
|
| 406 |
+
--modality-config-path examples/SO100/so100_config.py \
|
| 407 |
+
--num-gpus 8 \
|
| 408 |
+
--output-dir /tmp/test_finetune_8gpu \
|
| 409 |
+
--max-steps 2000 \
|
| 410 |
+
--global-batch-size 32 \
|
| 411 |
+
--dataloader-num-workers 4
|
| 412 |
+
```
|
| 413 |
+
|
| 414 |
+
Replace `demo_data/cube_to_bowl_5` and `examples/SO100/so100_config.py` with your own dataset and modality config. See [`examples/SO100`](examples/SO100/README.md) for a complete walkthrough.
|
| 415 |
+
|
| 416 |
+
> **Note:** Use `uv run torchrun` (not bare `torchrun`) to ensure the correct virtual environment is used. Add `--use-wandb` to enable Weights & Biases logging. For more extensive configuration, use `gr00t/experiment/launch_train.py`.
|
| 417 |
+
|
| 418 |
+
### Training Tips
|
| 419 |
+
|
| 420 |
+
- Maximize batch size for your hardware and train for a few thousand steps.
|
| 421 |
+
- Users may observe 5-6% variance between runs due to non-deterministic image augmentations. Keep this in mind when comparing to reported benchmarks.
|
| 422 |
+
- **`--state_dropout_prob`** (model config default: 0.8; finetune CLI default: 0.2; see `gr00t/configs/finetune_config.py`): Randomly drops state inputs during training to improve generalization and reduce state-dependency. The shipped benchmark scripts override the CLI default per suite: LIBERO 10-Long uses 0.2 (the CLI default), SimplerEnv Bridge uses 0.8, SimplerEnv Fractal uses 0.5. If your task relies heavily on proprioceptive state, lower this value.
|
| 423 |
+
|
| 424 |
+
---
|
| 425 |
+
|
| 426 |
+
## Evaluation
|
| 427 |
+
|
| 428 |
+
### Open-Loop Evaluation
|
| 429 |
+
|
| 430 |
+
Compare predicted actions against ground truth from your dataset:
|
| 431 |
+
|
| 432 |
+
```bash
|
| 433 |
+
uv run python gr00t/eval/open_loop_eval.py \
|
| 434 |
+
--dataset-path <DATASET_PATH> \
|
| 435 |
+
--embodiment-tag NEW_EMBODIMENT \
|
| 436 |
+
--model-path <CHECKPOINT_PATH> \
|
| 437 |
+
--traj-ids 0 \
|
| 438 |
+
--action-horizon 16
|
| 439 |
+
```
|
| 440 |
+
|
| 441 |
+
This generates a visualization at `/tmp/open_loop_eval/traj_{traj_id}.jpeg` with ground truth vs. predicted actions and MSE metrics. Use `--save-plot-path <dir>` to save plots to a custom location.
|
| 442 |
+
|
| 443 |
+
### Closed-Loop Evaluation
|
| 444 |
+
|
| 445 |
+
Test your model in simulation or on real hardware using the server-client architecture:
|
| 446 |
+
|
| 447 |
+
```bash
|
| 448 |
+
# Start the policy server
|
| 449 |
+
uv run python gr00t/eval/run_gr00t_server.py \
|
| 450 |
+
--embodiment-tag NEW_EMBODIMENT \
|
| 451 |
+
--model-path <CHECKPOINT_PATH> \
|
| 452 |
+
--device cuda:0 \
|
| 453 |
+
--host 0.0.0.0 --port 5555
|
| 454 |
+
```
|
| 455 |
+
|
| 456 |
+
```python
|
| 457 |
+
from gr00t.policy.server_client import PolicyClient
|
| 458 |
+
|
| 459 |
+
policy = PolicyClient(host="localhost", port=5555)
|
| 460 |
+
env = YourEnvironment()
|
| 461 |
+
obs, info = env.reset()
|
| 462 |
+
action, info = policy.get_action(obs)
|
| 463 |
+
obs, reward, done, truncated, info = env.step(action)
|
| 464 |
+
```
|
| 465 |
+
|
| 466 |
+
**Debugging with ReplayPolicy:** To verify your environment setup without a trained model, start the server with `--dataset-path <DATASET_PATH>` (omit `--model-path`) to replay recorded actions from the dataset.
|
| 467 |
+
|
| 468 |
+
See the complete [Policy API Guide](getting_started/policy.md) for observation/action formats, batched inference, and troubleshooting.
|
| 469 |
+
|
| 470 |
+
### Benchmark Examples
|
| 471 |
+
|
| 472 |
+
We support evaluation on public benchmarks using a server-client architecture. The policy server reuses the project root's uv environment; simulation clients have individual setup scripts.
|
| 473 |
+
|
| 474 |
+
You can use [the verification script](scripts/eval/check_sim_eval_ready.py) to verify that all dependencies are properly configured.
|
| 475 |
+
|
| 476 |
+
**Zero-shot** (evaluate with the base model, no finetuning):
|
| 477 |
+
- [DROID](examples/DROID/README.md) — real-world DROID robot (also available as the finetuned `nvidia/GR00T-N1.7-DROID` checkpoint; `examples/DROID/README.md` covers both paths)
|
| 478 |
+
|
| 479 |
+
**Finetuned** (evaluate with finetuned checkpoints):
|
| 480 |
+
- [DROID](examples/DROID/README.md) — real-world DROID robot via `nvidia/GR00T-N1.7-DROID`
|
| 481 |
+
- [LIBERO](examples/LIBERO/README.md) — LIBERO benchmark (Franka Panda)
|
| 482 |
+
- [SimplerEnv](examples/SimplerEnv/README.md) — Google Robot (Fractal) and WidowX (Bridge)
|
| 483 |
+
- [SO100](examples/SO100/README.md) — SO100 custom embodiment workflow
|
| 484 |
+
|
| 485 |
+
<details>
|
| 486 |
+
<summary><strong>Adding a New Sim Benchmark</strong></summary>
|
| 487 |
+
|
| 488 |
+
Each sim benchmark registers its environments under a gym env_name with the format `{prefix}/{task_name}` (e.g., `libero_sim/LIVING_ROOM_SCENE2_put_soup_in_basket`). The evaluation framework uses the prefix to look up the corresponding `EmbodimentTag` via a mapping in [`gr00t/eval/sim/env_utils.py`](gr00t/eval/sim/env_utils.py).
|
| 489 |
+
|
| 490 |
+
> **Important:** The env_name prefix and the `EmbodimentTag` value are often different. For example, `libero_sim` maps to `EmbodimentTag.LIBERO_PANDA` (`"libero_sim"`). Do not assume they match.
|
| 491 |
+
|
| 492 |
+
To add a new benchmark:
|
| 493 |
+
|
| 494 |
+
1. Add an entry to `ENV_PREFIX_TO_EMBODIMENT_TAG` in `gr00t/eval/sim/env_utils.py`:
|
| 495 |
+
```python
|
| 496 |
+
ENV_PREFIX_TO_EMBODIMENT_TAG = {
|
| 497 |
+
...
|
| 498 |
+
"my_new_benchmark": EmbodimentTag.MY_ROBOT,
|
| 499 |
+
}
|
| 500 |
+
```
|
| 501 |
+
2. If the benchmark has multiple env_name prefixes (e.g., `my_benchmark_v1`, `my_benchmark_v2`), all related prefixes **must** map to the same `EmbodimentTag`.
|
| 502 |
+
3. Add corresponding test cases in `tests/gr00t/eval/sim/test_env_utils.py` and update the `test_all_known_prefixes_present` test.
|
| 503 |
+
</details>
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
---
|
| 508 |
+
|
| 509 |
+
# Contributions
|
| 510 |
+
|
| 511 |
+
During Early Access we are not accepting pull requests while the codebase stabilizes. If you encounter issues or have suggestions, please open an [Issue](https://github.com/NVIDIA/Isaac-GR00T/issues) in this repository.
|
| 512 |
+
|
| 513 |
+
# Support
|
| 514 |
+
|
| 515 |
+
Support during Early Access is best-effort. We will continue iterating toward a more stable General Availability (GA) release.
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
## License
|
| 519 |
+
|
| 520 |
+
- **Code:** Apache 2.0 — see [LICENSE](LICENSE)
|
| 521 |
+
- **Model weights:** [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/)
|
| 522 |
+
|
| 523 |
+
```
|
| 524 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 525 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 526 |
+
#
|
| 527 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 528 |
+
# you may not use this file except in compliance with the License.
|
| 529 |
+
# You may obtain a copy of the License at
|
| 530 |
+
#
|
| 531 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 532 |
+
#
|
| 533 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 534 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 535 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 536 |
+
# See the License for the specific language governing permissions and
|
| 537 |
+
# limitations under the License.
|
| 538 |
+
```
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
## Citation
|
| 542 |
+
|
| 543 |
+
[Paper Site](https://research.nvidia.com/labs/lpr/publication/gr00tn1_2025/)
|
| 544 |
+
```bibtex
|
| 545 |
+
@inproceedings{gr00tn1_2025,
|
| 546 |
+
archivePrefix = {arxiv},
|
| 547 |
+
eprint = {2503.14734},
|
| 548 |
+
title = {{GR00T} {N1}: An Open Foundation Model for Generalist Humanoid Robots},
|
| 549 |
+
author = {NVIDIA and Johan Bjorck and Fernando Castañeda, Nikita Cherniadev and Xingye Da and Runyu Ding and Linxi "Jim" Fan and Yu Fang and Dieter Fox and Fengyuan Hu and Spencer Huang and Joel Jang and Zhenyu Jiang and Jan Kautz and Kaushil Kundalia and Lawrence Lao and Zhiqi Li and Zongyu Lin and Kevin Lin and Guilin Liu and Edith Llontop and Loic Magne and Ajay Mandlekar and Avnish Narayan and Soroush Nasiriany and Scott Reed and You Liang Tan and Guanzhi Wang and Zu Wang and Jing Wang and Qi Wang and Jiannan Xiang and Yuqi Xie and Yinzhen Xu and Zhenjia Xu and Seonghyeon Ye and Zhiding Yu and Ao Zhang and Hao Zhang and Yizhou Zhao and Ruijie Zheng and Yuke Zhu},
|
| 550 |
+
month = {March},
|
| 551 |
+
year = {2025},
|
| 552 |
+
booktitle = {ArXiv Preprint},
|
| 553 |
+
}
|
| 554 |
+
```
|
| 555 |
+
|
examples/DROID/README.md
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GR00T DROID
|
| 2 |
+
|
| 3 |
+
The N1.7 base model supports DROID inference out of the box via the `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` pretrain tag. A finetuned checkpoint is also available at [`nvidia/GR00T-N1.7-DROID`](https://huggingface.co/nvidia/GR00T-N1.7-DROID).
|
| 4 |
+
|
| 5 |
+
> **Note:** The DROID dataset contains multiple language instruction paraphrases per episode (`language_instruction`, `language_instruction_2`, `language_instruction_3`). These are used for language augmentation during training. At inference time, only the first language key is used.
|
| 6 |
+
|
| 7 |
+
## Data Format
|
| 8 |
+
|
| 9 |
+
The DROID embodiment expects the following modality structure:
|
| 10 |
+
|
| 11 |
+
| Modality | Keys | Dimensions |
|
| 12 |
+
|----------|------|------------|
|
| 13 |
+
| Video | `exterior_image_1_left`, `wrist_image_left` | 2 cameras |
|
| 14 |
+
| State | `eef_9d`, `gripper_position`, `joint_position` | 9D + 1D + 7D = 17D |
|
| 15 |
+
| Action | `eef_9d`, `gripper_position`, `joint_position` | 9D + 1D + 7D = 17D |
|
| 16 |
+
| Language | `annotation.language.language_instruction` | text |
|
| 17 |
+
|
| 18 |
+
Action representations:
|
| 19 |
+
- `eef_9d`: relative end-effector (XYZ + rotation 6D)
|
| 20 |
+
- `gripper_position`: absolute (1D)
|
| 21 |
+
- `joint_position`: relative joint positions (7D)
|
| 22 |
+
|
| 23 |
+
### Preparing DROID Demo Data
|
| 24 |
+
|
| 25 |
+
The full DROID dataset ([lerobot/droid_1.0.1](https://huggingface.co/datasets/lerobot/droid_1.0.1)) is ~358 GB with 95k+ episodes in LeRobot v3.0 format. To create a small sample for testing:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
uv pip install jsonlines # one-time dependency
|
| 29 |
+
python scripts/download_droid_sample.py
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
This downloads the first data/video chunk (~170 MB) and extracts 3 episodes into `demo_data/droid_sample/` in GR00T LeRobot v2.0 format.
|
| 33 |
+
|
| 34 |
+
**Key conversion notes:**
|
| 35 |
+
- Source is LeRobot v3.0 (consolidated parquet + concatenated videos) — the script converts to v2.0 (per-episode parquet + per-episode mp4).
|
| 36 |
+
- Video keys in the raw dataset (`exterior_1_left`, `wrist_left`) differ from the model config keys (`exterior_image_1_left`, `wrist_image_left`). The data loader auto-maps by position — no manual renaming needed.
|
| 37 |
+
- Language instructions are loaded via the `task_index` column mapped through `tasks.jsonl`.
|
| 38 |
+
|
| 39 |
+
## 1. Standalone Inference (with demo data)
|
| 40 |
+
|
| 41 |
+
After preparing demo data, run inference directly (no server needed):
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
uv run python scripts/deployment/standalone_inference_script.py \
|
| 45 |
+
--model-path nvidia/GR00T-N1.7-3B \
|
| 46 |
+
--dataset-path demo_data/droid_sample \
|
| 47 |
+
--embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT \
|
| 48 |
+
--traj-ids 0 1 \
|
| 49 |
+
--inference-mode pytorch \
|
| 50 |
+
--action-horizon 8
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
> **Note:** Episode 0 may have an empty language instruction. If inference fails on episode 0, try `--traj-ids 1 2`.
|
| 54 |
+
|
| 55 |
+
Expected zero-shot performance on the base model (not finetuned):
|
| 56 |
+
|
| 57 |
+
| Metric | Value |
|
| 58 |
+
|--------|-------|
|
| 59 |
+
| Average MSE | ~0.0149 |
|
| 60 |
+
| Average MAE | ~0.0753 |
|
| 61 |
+
| Inference per step (base) | ~262 ms (H100) |
|
| 62 |
+
| Inference per step (finetuned) | ~253 ms (H100) |
|
| 63 |
+
|
| 64 |
+
## 2. Inference Server (for real-world deployment)
|
| 65 |
+
|
| 66 |
+
### Using the base model (zero-shot):
|
| 67 |
+
|
| 68 |
+
```bash
|
| 69 |
+
uv run python gr00t/eval/run_gr00t_server.py \
|
| 70 |
+
--model-path nvidia/GR00T-N1.7-3B \
|
| 71 |
+
--embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
### Using the finetuned model:
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
uv run python gr00t/eval/run_gr00t_server.py \
|
| 78 |
+
--model-path nvidia/GR00T-N1.7-DROID \
|
| 79 |
+
--embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
## 3. Fine-tuning
|
| 83 |
+
|
| 84 |
+
Fine-tune the base model on DROID data using the shared launcher:
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
NUM_GPUS=8 MAX_STEPS=20000 GLOBAL_BATCH_SIZE=640 SAVE_STEPS=1000 uv run bash examples/finetune.sh \
|
| 88 |
+
--base-model-path nvidia/GR00T-N1.7-3B \
|
| 89 |
+
--dataset-path demo_data/droid_sample \
|
| 90 |
+
--embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT \
|
| 91 |
+
--output-dir /tmp/droid_finetune
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
> **Note:** The above uses the small `demo_data/droid_sample` (3 episodes) for quick validation. For production training, replace `--dataset-path` with the full DROID dataset.
|
| 95 |
+
|
| 96 |
+
## 4. Robot Control Script
|
| 97 |
+
|
| 98 |
+
1. Install the DROID package on the robot control laptop/workstation — [instructions](https://droid-dataset.github.io/droid/software-setup/host-installation.html#configuring-the-laptopworkstation)
|
| 99 |
+
|
| 100 |
+
2. Install dependencies for the GR00T control script in the environment from step 1:
|
| 101 |
+
```bash
|
| 102 |
+
pip install tyro moviepy==1.0.3 pydantic numpy==1.26.4
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
3. Enter the camera IDs for your ZED cameras in `examples/DROID/main_gr00t.py`.
|
| 106 |
+
|
| 107 |
+
4. Start the control script:
|
| 108 |
+
```bash
|
| 109 |
+
python examples/DROID/main_gr00t.py --external-camera="left" # or "right"
|
| 110 |
+
```
|
examples/DROID/main_gr00t.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
# ruff: noqa
|
| 17 |
+
# NOTE: this requires installation of the droid repo.
|
| 18 |
+
# Adapted from https://github.com/Physical-Intelligence/openpi/blob/main/examples/droid/main.py
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import contextlib
|
| 23 |
+
import dataclasses
|
| 24 |
+
import datetime
|
| 25 |
+
import faulthandler
|
| 26 |
+
import os
|
| 27 |
+
import signal
|
| 28 |
+
import time
|
| 29 |
+
from collections import deque
|
| 30 |
+
|
| 31 |
+
import cv2
|
| 32 |
+
import imageio
|
| 33 |
+
import numpy as np
|
| 34 |
+
import pandas as pd
|
| 35 |
+
import tqdm
|
| 36 |
+
import tyro
|
| 37 |
+
from moviepy.editor import ImageSequenceClip
|
| 38 |
+
from PIL import Image
|
| 39 |
+
|
| 40 |
+
from droid.robot_env import RobotEnv
|
| 41 |
+
from server_client import PolicyClient
|
| 42 |
+
from utils import resize_with_pad
|
| 43 |
+
from scipy.spatial.transform import Rotation
|
| 44 |
+
|
| 45 |
+
faulthandler.enable()
|
| 46 |
+
|
| 47 |
+
# DROID data collection frequency -- we slow down execution to match this frequency
|
| 48 |
+
DROID_CONTROL_FREQUENCY = 15
|
| 49 |
+
RESOLUTION = (180, 320) # resize images to this resolution before sending to the policy server
|
| 50 |
+
|
| 51 |
+
# Egocentric frame correction: R_euler is post-multiplied by this matrix
|
| 52 |
+
# to match the OXE DROID training pipeline (TFG convention).
|
| 53 |
+
DROID_EEF_ROTATION_CORRECT = np.array(
|
| 54 |
+
[[0, 0, -1], [-1, 0, 0], [0, 1, 0]],
|
| 55 |
+
dtype=np.float64,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def compute_eef_9d(cartesian_position: np.ndarray) -> np.ndarray:
|
| 60 |
+
"""Convert cartesian_position (XYZ + euler 3D) to eef_9d (XYZ + rot6d).
|
| 61 |
+
|
| 62 |
+
Uses extrinsic XYZ Euler convention (scipy ``"XYZ"``, equivalent to
|
| 63 |
+
``tfg.rotation_matrix_3d.from_euler``) and post-multiplies by
|
| 64 |
+
``DROID_EEF_ROTATION_CORRECT`` to match the pretrained model.
|
| 65 |
+
"""
|
| 66 |
+
c = np.asarray(cartesian_position, dtype=np.float64).reshape(6)
|
| 67 |
+
xyz = c[:3]
|
| 68 |
+
euler = c[3:6]
|
| 69 |
+
rot_robot = Rotation.from_euler("XYZ", euler).as_matrix()
|
| 70 |
+
rot_mat = rot_robot @ DROID_EEF_ROTATION_CORRECT
|
| 71 |
+
rot6d = rot_mat[:2, :].reshape(6)
|
| 72 |
+
return np.concatenate([xyz, rot6d]).astype(np.float32)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclasses.dataclass
|
| 76 |
+
class Args:
|
| 77 |
+
# Hardware parameters
|
| 78 |
+
|
| 79 |
+
left_camera_id: str = "<SET THIS>" # e.g., "24259877"
|
| 80 |
+
right_camera_id: str = "<SET THIS>" # e.g., "24514023"
|
| 81 |
+
wrist_camera_id: str = "<SET THIS>" # e.g., "13062452"
|
| 82 |
+
|
| 83 |
+
# Policy parameters
|
| 84 |
+
policy_host: str = "localhost"
|
| 85 |
+
policy_port: int = 5555
|
| 86 |
+
policy_api_token: str = None
|
| 87 |
+
|
| 88 |
+
results_dir: str = None # if None, will use the current timestamp as the results directory
|
| 89 |
+
|
| 90 |
+
# Rollout parameters
|
| 91 |
+
max_timesteps: int = 600 # how many steps to run each rollout
|
| 92 |
+
|
| 93 |
+
# How many actions to execute from a predicted action chunk before querying policy server again
|
| 94 |
+
open_loop_horizon: int = 15
|
| 95 |
+
external_camera: str = (
|
| 96 |
+
"left" # which exterior camera to use for the policy server, choose from ["left", "right"]
|
| 97 |
+
)
|
| 98 |
+
render_camera: str = "left" # which camera to render saved video from
|
| 99 |
+
render_fps: int = 50
|
| 100 |
+
|
| 101 |
+
debug: bool = False
|
| 102 |
+
vis_cameras: bool = False
|
| 103 |
+
|
| 104 |
+
delay_seconds: int = 5
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
|
| 108 |
+
# waiting for a new action chunk, it will raise an exception and the server connection dies.
|
| 109 |
+
# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
|
| 110 |
+
@contextlib.contextmanager
|
| 111 |
+
def prevent_keyboard_interrupt():
|
| 112 |
+
"""Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
|
| 113 |
+
interrupted = False
|
| 114 |
+
original_handler = signal.getsignal(signal.SIGINT)
|
| 115 |
+
|
| 116 |
+
def handler(signum, frame):
|
| 117 |
+
nonlocal interrupted
|
| 118 |
+
interrupted = True
|
| 119 |
+
|
| 120 |
+
signal.signal(signal.SIGINT, handler)
|
| 121 |
+
try:
|
| 122 |
+
yield
|
| 123 |
+
finally:
|
| 124 |
+
signal.signal(signal.SIGINT, original_handler)
|
| 125 |
+
if interrupted:
|
| 126 |
+
raise KeyboardInterrupt
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def main(args: Args):
|
| 130 |
+
assert args.external_camera in ["left", "right"], (
|
| 131 |
+
f"Invalid exterior camera: {args.exterior_camera}"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
if args.results_dir is None:
|
| 135 |
+
results_dir = f"results_gr00t_{datetime.datetime.now().strftime('%Y_%m_%d')}"
|
| 136 |
+
else:
|
| 137 |
+
results_dir = args.results_dir
|
| 138 |
+
|
| 139 |
+
# Initialize the Panda environment.
|
| 140 |
+
env = RobotEnv(action_space="joint_position", gripper_action_space="position")
|
| 141 |
+
print("Created the droid env!")
|
| 142 |
+
|
| 143 |
+
os.makedirs(results_dir, exist_ok=True)
|
| 144 |
+
|
| 145 |
+
policy_client = PolicyClient(
|
| 146 |
+
host=args.policy_host, port=args.policy_port, api_token=args.policy_api_token
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
modality_config = policy_client.get_modality_config()
|
| 150 |
+
video_delta = modality_config["video"].delta_indices
|
| 151 |
+
video_T = len(video_delta)
|
| 152 |
+
video_history_len = max(-min(video_delta), 0) + 1 if video_delta else 1
|
| 153 |
+
video_keys = modality_config["video"].modality_keys
|
| 154 |
+
state_keys = modality_config["state"].modality_keys
|
| 155 |
+
state_T = len(modality_config["state"].delta_indices)
|
| 156 |
+
print(
|
| 157 |
+
f"Model config — video T={video_T} (delta={video_delta}), "
|
| 158 |
+
f"state T={state_T}, keys: video={video_keys}, state={state_keys}"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
df = pd.DataFrame(columns=["success", "duration", "video_filename"])
|
| 162 |
+
|
| 163 |
+
if args.debug:
|
| 164 |
+
debug_dir = os.path.join(results_dir, "debug_data")
|
| 165 |
+
os.makedirs(debug_dir, exist_ok=True)
|
| 166 |
+
os.makedirs(os.path.join(debug_dir, "videos/wrist_image/"), exist_ok=True)
|
| 167 |
+
os.makedirs(os.path.join(debug_dir, "videos/exterior_image_1_left/"), exist_ok=True)
|
| 168 |
+
|
| 169 |
+
instruction = None
|
| 170 |
+
while True:
|
| 171 |
+
if instruction is None:
|
| 172 |
+
instruction = input("Enter instruction: ")
|
| 173 |
+
else:
|
| 174 |
+
if input("Change instruction? (enter y or n) ").lower() == "y":
|
| 175 |
+
instruction = input("Enter instruction: ")
|
| 176 |
+
|
| 177 |
+
time.sleep(args.delay_seconds)
|
| 178 |
+
|
| 179 |
+
# Rollout parameters
|
| 180 |
+
actions_from_chunk_completed = 0
|
| 181 |
+
pred_action_chunk = None
|
| 182 |
+
|
| 183 |
+
# Prepare to save video of rollout
|
| 184 |
+
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
|
| 185 |
+
video = []
|
| 186 |
+
if args.debug:
|
| 187 |
+
model_wrist_image_writer = imageio.get_writer(
|
| 188 |
+
os.path.join(
|
| 189 |
+
debug_dir, "videos/wrist_image/", f"model_wrist_image_{timestamp}.mp4"
|
| 190 |
+
),
|
| 191 |
+
fps=5,
|
| 192 |
+
)
|
| 193 |
+
model_exterior_image_1_left_writer = imageio.get_writer(
|
| 194 |
+
os.path.join(
|
| 195 |
+
debug_dir,
|
| 196 |
+
"videos/exterior_image_1_left/",
|
| 197 |
+
f"model_exterior_image_1_left_{timestamp}.mp4",
|
| 198 |
+
),
|
| 199 |
+
fps=5,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
bar = tqdm.tqdm(range(args.max_timesteps))
|
| 203 |
+
print("Running rollout... press Ctrl+C to stop early.")
|
| 204 |
+
|
| 205 |
+
# Profiling variables (reset for each rollout)
|
| 206 |
+
rollout_start_time = time.time()
|
| 207 |
+
obs_times = deque(maxlen=50) # Track observation collection times
|
| 208 |
+
server_times = deque(maxlen=50) # Track server response times
|
| 209 |
+
action_count = 0
|
| 210 |
+
frame_buffer = deque(maxlen=video_history_len)
|
| 211 |
+
|
| 212 |
+
for t_step in bar:
|
| 213 |
+
step_start_time = time.time()
|
| 214 |
+
try:
|
| 215 |
+
# Get the current observation
|
| 216 |
+
obs_start_time = time.time()
|
| 217 |
+
curr_obs = _extract_observation(
|
| 218 |
+
args,
|
| 219 |
+
env.get_observation(),
|
| 220 |
+
# Save the first observation to disk
|
| 221 |
+
save_to_disk=t_step == 0,
|
| 222 |
+
)
|
| 223 |
+
obs_time = time.time() - obs_start_time
|
| 224 |
+
obs_times.append(obs_time)
|
| 225 |
+
|
| 226 |
+
video.append(curr_obs[f"{args.render_camera}_image"])
|
| 227 |
+
|
| 228 |
+
# Resize every step so the rolling frame buffer stays current.
|
| 229 |
+
left_image = resize_with_pad(curr_obs["left_image"], RESOLUTION[0], RESOLUTION[1])
|
| 230 |
+
right_image = resize_with_pad(curr_obs["right_image"], RESOLUTION[0], RESOLUTION[1])
|
| 231 |
+
wrist_image = resize_with_pad(curr_obs["wrist_image"], RESOLUTION[0], RESOLUTION[1])
|
| 232 |
+
|
| 233 |
+
if args.external_camera == "left":
|
| 234 |
+
ext_image = left_image
|
| 235 |
+
elif args.external_camera == "right":
|
| 236 |
+
ext_image = right_image
|
| 237 |
+
|
| 238 |
+
frame_buffer.append({"ext": ext_image, "wrist": wrist_image})
|
| 239 |
+
|
| 240 |
+
# Send websocket request to policy server if it's time to predict a new chunk
|
| 241 |
+
if (
|
| 242 |
+
actions_from_chunk_completed == 0
|
| 243 |
+
or actions_from_chunk_completed >= args.open_loop_horizon
|
| 244 |
+
):
|
| 245 |
+
actions_from_chunk_completed = 0
|
| 246 |
+
|
| 247 |
+
if args.debug:
|
| 248 |
+
model_wrist_image_writer.append_data(wrist_image)
|
| 249 |
+
model_exterior_image_1_left_writer.append_data(ext_image)
|
| 250 |
+
|
| 251 |
+
# Build video tensor with T frames derived from the model's
|
| 252 |
+
# delta_indices (e.g. [-15, 0] -> T=2, [0] -> T=1).
|
| 253 |
+
if video_T == 1:
|
| 254 |
+
video_dict = {
|
| 255 |
+
"exterior_image_1_left": ext_image[None, None, ...],
|
| 256 |
+
"wrist_image_left": wrist_image[None, None, ...],
|
| 257 |
+
} # (B=1, T=1, H, W, C)
|
| 258 |
+
else:
|
| 259 |
+
hist_frame = frame_buffer[0]
|
| 260 |
+
cur_frame = frame_buffer[-1]
|
| 261 |
+
video_dict = {
|
| 262 |
+
"exterior_image_1_left": np.stack(
|
| 263 |
+
[hist_frame["ext"], cur_frame["ext"]]
|
| 264 |
+
)[None, ...],
|
| 265 |
+
"wrist_image_left": np.stack([hist_frame["wrist"], cur_frame["wrist"]])[
|
| 266 |
+
None, ...
|
| 267 |
+
],
|
| 268 |
+
} # (B=1, T=video_T, H, W, C)
|
| 269 |
+
|
| 270 |
+
# Build state dict from the model's reported state keys.
|
| 271 |
+
state_dict = {}
|
| 272 |
+
state_source = {
|
| 273 |
+
"eef_9d": curr_obs["eef_9d"],
|
| 274 |
+
"gripper_position": curr_obs["gripper_position"],
|
| 275 |
+
"joint_position": curr_obs["joint_position"],
|
| 276 |
+
}
|
| 277 |
+
for key in state_keys:
|
| 278 |
+
state_dict[key] = state_source[key][None, None, ...].astype(
|
| 279 |
+
np.float32
|
| 280 |
+
) # (B=1, T=1, D)
|
| 281 |
+
|
| 282 |
+
lang_key = modality_config["language"].modality_keys[0]
|
| 283 |
+
request_data = {
|
| 284 |
+
"video": video_dict,
|
| 285 |
+
"state": state_dict,
|
| 286 |
+
"language": {lang_key: [[instruction]]},
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
if args.vis_cameras:
|
| 290 |
+
# viz the left image 1 and wrist image and use cv2 to display them side by side
|
| 291 |
+
left_image_display = cv2.resize(
|
| 292 |
+
left_image, (wrist_image.shape[1], wrist_image.shape[0])
|
| 293 |
+
)
|
| 294 |
+
combined_display = np.concatenate([left_image_display, wrist_image], axis=1)
|
| 295 |
+
# convert to bgr
|
| 296 |
+
combined_display = combined_display[..., ::-1]
|
| 297 |
+
cv2.imshow("Camera Views", combined_display)
|
| 298 |
+
cv2.waitKey(1)
|
| 299 |
+
|
| 300 |
+
# Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
|
| 301 |
+
# Ctrl+C will be handled after the server call is complete
|
| 302 |
+
server_start_time = time.time()
|
| 303 |
+
with prevent_keyboard_interrupt():
|
| 304 |
+
# this returns action chunk [N, 8] of joint position actions (7) + gripper position (1)
|
| 305 |
+
response = policy_client.get_action(request_data)
|
| 306 |
+
server_time = time.time() - server_start_time
|
| 307 |
+
server_times.append(server_time)
|
| 308 |
+
|
| 309 |
+
pred_action_chunk = np.concatenate(
|
| 310 |
+
(
|
| 311 |
+
response[0]["joint_position"][0],
|
| 312 |
+
response[0]["gripper_position"][0],
|
| 313 |
+
),
|
| 314 |
+
axis=1,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
# Select current action to execute from chunk
|
| 318 |
+
action = pred_action_chunk[actions_from_chunk_completed]
|
| 319 |
+
actions_from_chunk_completed += 1
|
| 320 |
+
|
| 321 |
+
# Binarize gripper action
|
| 322 |
+
if action[-1].item() > 0.5:
|
| 323 |
+
action = np.concatenate([action[:-1], np.ones((1,))])
|
| 324 |
+
else:
|
| 325 |
+
action = np.concatenate([action[:-1], np.zeros((1,))])
|
| 326 |
+
|
| 327 |
+
env.step(action)
|
| 328 |
+
action_count += 1
|
| 329 |
+
|
| 330 |
+
# Sleep to match DROID data collection frequency
|
| 331 |
+
elapsed_time = time.time() - step_start_time
|
| 332 |
+
if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
|
| 333 |
+
time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
|
| 334 |
+
|
| 335 |
+
# profiling stats
|
| 336 |
+
if obs_times:
|
| 337 |
+
avg_obs_time = np.mean(obs_times) * 1000
|
| 338 |
+
min_obs_time = np.min(obs_times) * 1000
|
| 339 |
+
max_obs_time = np.max(obs_times) * 1000
|
| 340 |
+
else:
|
| 341 |
+
avg_obs_time = min_obs_time = max_obs_time = 0
|
| 342 |
+
|
| 343 |
+
if server_times:
|
| 344 |
+
avg_server_time = np.mean(server_times) * 1000
|
| 345 |
+
min_server_time = np.min(server_times) * 1000
|
| 346 |
+
max_server_time = np.max(server_times) * 1000
|
| 347 |
+
else:
|
| 348 |
+
avg_server_time = min_server_time = max_server_time = 0
|
| 349 |
+
|
| 350 |
+
total_elapsed = time.time() - rollout_start_time
|
| 351 |
+
actions_per_sec = action_count / total_elapsed if total_elapsed > 0 else 0
|
| 352 |
+
|
| 353 |
+
bar.set_description(
|
| 354 |
+
f"Obs: {avg_obs_time:.1f}ms [{min_obs_time:.1f}-{max_obs_time:.1f}] | "
|
| 355 |
+
f"Server: {avg_server_time:.1f}ms [{min_server_time:.1f}-{max_server_time:.1f}] | "
|
| 356 |
+
f"Actions/sec: {actions_per_sec:.2f}"
|
| 357 |
+
)
|
| 358 |
+
except KeyboardInterrupt:
|
| 359 |
+
break
|
| 360 |
+
|
| 361 |
+
os.makedirs(os.path.join(results_dir, "videos"), exist_ok=True)
|
| 362 |
+
video = np.stack(video)
|
| 363 |
+
# replace whitespace with underscores in instruction
|
| 364 |
+
sanitized_instruction = instruction.replace(" ", "_")
|
| 365 |
+
save_filename = os.path.join(
|
| 366 |
+
results_dir, "videos", f"{sanitized_instruction}_video_" + timestamp
|
| 367 |
+
)
|
| 368 |
+
ImageSequenceClip(list(video), fps=args.render_fps).write_videofile(
|
| 369 |
+
save_filename + ".mp4", codec="libx264"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
if args.debug:
|
| 373 |
+
model_wrist_image_writer.close()
|
| 374 |
+
model_exterior_image_1_left_writer.close()
|
| 375 |
+
|
| 376 |
+
success: str | float | None = None
|
| 377 |
+
while not isinstance(success, float):
|
| 378 |
+
success = input(
|
| 379 |
+
"Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
|
| 380 |
+
)
|
| 381 |
+
if success == "y":
|
| 382 |
+
success = 1.0
|
| 383 |
+
elif success == "n":
|
| 384 |
+
success = 0.0
|
| 385 |
+
|
| 386 |
+
success = float(success) / 100
|
| 387 |
+
if not (0 <= success <= 1):
|
| 388 |
+
print(f"Success must be a number in [0, 100] but got: {success * 100}")
|
| 389 |
+
|
| 390 |
+
new_row = {
|
| 391 |
+
"success": success,
|
| 392 |
+
"duration": t_step,
|
| 393 |
+
"video_filename": save_filename,
|
| 394 |
+
}
|
| 395 |
+
new_index = len(df)
|
| 396 |
+
df.loc[new_index] = new_row
|
| 397 |
+
|
| 398 |
+
if input("Do one more eval? (enter y or n) ").lower() != "y":
|
| 399 |
+
break
|
| 400 |
+
env.reset(randomize=False)
|
| 401 |
+
|
| 402 |
+
timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
|
| 403 |
+
csv_filename = os.path.join(results_dir, f"eval_{timestamp}.csv")
|
| 404 |
+
df.to_csv(csv_filename)
|
| 405 |
+
print(f"Results saved to {csv_filename}")
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def _extract_observation(args: Args, obs_dict, *, stereo_camera="left", save_to_disk=False):
|
| 409 |
+
image_observations = obs_dict["image"]
|
| 410 |
+
key_left = f"{args.left_camera_id}_{stereo_camera}"
|
| 411 |
+
key_right = f"{args.right_camera_id}_{stereo_camera}"
|
| 412 |
+
key_wrist = f"{args.wrist_camera_id}_{stereo_camera}"
|
| 413 |
+
|
| 414 |
+
left_image = image_observations.get(key_left)
|
| 415 |
+
right_image = image_observations.get(key_right)
|
| 416 |
+
wrist_image = image_observations.get(key_wrist)
|
| 417 |
+
|
| 418 |
+
available = list(image_observations.keys())
|
| 419 |
+
assert left_image is not None, (
|
| 420 |
+
f"Left camera not found for key {key_left!r}. Available keys: {available}. "
|
| 421 |
+
"Set --left-camera-id to the ZED serial used in observation keys."
|
| 422 |
+
)
|
| 423 |
+
assert right_image is not None, (
|
| 424 |
+
f"Right camera not found for key {key_right!r}. Available keys: {available}. "
|
| 425 |
+
"Set --right-camera-id to the ZED serial used in observation keys."
|
| 426 |
+
)
|
| 427 |
+
assert wrist_image is not None, (
|
| 428 |
+
f"Wrist camera not found for key {key_wrist!r}. Available keys: {available}. "
|
| 429 |
+
"Set --wrist-camera-id to the ZED serial used in observation keys."
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
# Drop the alpha dimension
|
| 433 |
+
left_image = left_image[..., :3]
|
| 434 |
+
right_image = right_image[..., :3]
|
| 435 |
+
wrist_image = wrist_image[..., :3]
|
| 436 |
+
|
| 437 |
+
# Convert to RGB
|
| 438 |
+
left_image = left_image[..., ::-1]
|
| 439 |
+
right_image = right_image[..., ::-1]
|
| 440 |
+
wrist_image = wrist_image[..., ::-1]
|
| 441 |
+
|
| 442 |
+
# In addition to image observations, also capture the proprioceptive state
|
| 443 |
+
robot_state = obs_dict["robot_state"]
|
| 444 |
+
cartesian_position = np.array(robot_state["cartesian_position"])
|
| 445 |
+
joint_position = np.array(robot_state["joint_positions"])
|
| 446 |
+
gripper_position = np.array([robot_state["gripper_position"]])
|
| 447 |
+
eef_9d = compute_eef_9d(cartesian_position)
|
| 448 |
+
|
| 449 |
+
# Save the images to disk so that they can be viewed live while the robot is running
|
| 450 |
+
# Create one combined image to make live viewing easy
|
| 451 |
+
if save_to_disk:
|
| 452 |
+
combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
|
| 453 |
+
combined_image = Image.fromarray(combined_image)
|
| 454 |
+
combined_image.save("robot_camera_views.png")
|
| 455 |
+
|
| 456 |
+
return {
|
| 457 |
+
"left_image": left_image,
|
| 458 |
+
"right_image": right_image,
|
| 459 |
+
"wrist_image": wrist_image,
|
| 460 |
+
"cartesian_position": cartesian_position,
|
| 461 |
+
"eef_9d": eef_9d,
|
| 462 |
+
"joint_position": joint_position,
|
| 463 |
+
"gripper_position": gripper_position,
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
if __name__ == "__main__":
|
| 468 |
+
args: Args = tyro.cli(Args)
|
| 469 |
+
main(args)
|
examples/DROID/server_client.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from abc import ABC, abstractmethod
|
| 19 |
+
from dataclasses import asdict, dataclass, is_dataclass
|
| 20 |
+
from enum import Enum
|
| 21 |
+
import io
|
| 22 |
+
from typing import Any
|
| 23 |
+
|
| 24 |
+
import msgpack
|
| 25 |
+
import numpy as np
|
| 26 |
+
import zmq
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def to_json_serializable(obj: Any) -> Any:
|
| 30 |
+
"""
|
| 31 |
+
Recursively convert dataclasses and numpy arrays to JSON-serializable format.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
obj: Object to convert (can be dataclass, numpy array, dict, list, etc.)
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
JSON-serializable representation of the object
|
| 38 |
+
"""
|
| 39 |
+
if is_dataclass(obj) and not isinstance(obj, type):
|
| 40 |
+
# Convert dataclass to dict, then recursively process the dict
|
| 41 |
+
return to_json_serializable(asdict(obj))
|
| 42 |
+
elif isinstance(obj, np.ndarray):
|
| 43 |
+
# Convert numpy array to list
|
| 44 |
+
return obj.tolist()
|
| 45 |
+
elif isinstance(obj, np.integer):
|
| 46 |
+
# Convert numpy integers to Python int
|
| 47 |
+
return int(obj)
|
| 48 |
+
elif isinstance(obj, np.floating):
|
| 49 |
+
# Convert numpy floats to Python float
|
| 50 |
+
return float(obj)
|
| 51 |
+
elif isinstance(obj, np.bool_):
|
| 52 |
+
# Convert numpy bool to Python bool
|
| 53 |
+
return bool(obj)
|
| 54 |
+
elif isinstance(obj, dict):
|
| 55 |
+
# Recursively process dictionary values
|
| 56 |
+
return {key: to_json_serializable(value) for key, value in obj.items()}
|
| 57 |
+
elif isinstance(obj, (list, tuple)):
|
| 58 |
+
# Recursively process list/tuple elements
|
| 59 |
+
return [to_json_serializable(item) for item in obj]
|
| 60 |
+
elif isinstance(obj, set):
|
| 61 |
+
# Convert set to list
|
| 62 |
+
return [to_json_serializable(item) for item in obj]
|
| 63 |
+
elif isinstance(obj, (str, int, float, bool, type(None))):
|
| 64 |
+
# Already JSON-serializable
|
| 65 |
+
return obj
|
| 66 |
+
elif isinstance(obj, Enum):
|
| 67 |
+
return obj.name
|
| 68 |
+
else:
|
| 69 |
+
# For other types, try to convert to string as fallback
|
| 70 |
+
# You might want to handle specific types differently
|
| 71 |
+
return str(obj)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class MessageType(Enum):
|
| 75 |
+
START_OF_EPISODE = "start_of_episode"
|
| 76 |
+
END_OF_EPISODE = "end_of_episode"
|
| 77 |
+
EPISODE_STEP = "episode_step"
|
| 78 |
+
IMAGE = "image"
|
| 79 |
+
TEXT = "text"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ActionRepresentation(Enum):
|
| 83 |
+
RELATIVE = "relative"
|
| 84 |
+
DELTA = "delta"
|
| 85 |
+
ABSOLUTE = "absolute"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class ActionType(Enum):
|
| 89 |
+
EEF = "eef"
|
| 90 |
+
NON_EEF = "non_eef"
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class ActionFormat(Enum):
|
| 94 |
+
DEFAULT = "default"
|
| 95 |
+
XYZ_ROT6D = "xyz+rot6d"
|
| 96 |
+
XYZ_ROTVEC = "xyz+rotvec"
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@dataclass
|
| 100 |
+
class ActionConfig:
|
| 101 |
+
rep: ActionRepresentation
|
| 102 |
+
type: ActionType
|
| 103 |
+
format: ActionFormat
|
| 104 |
+
state_key: str | None = None
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@dataclass
|
| 108 |
+
class ModalityConfig:
|
| 109 |
+
"""Configuration for a modality defining how data should be sampled and loaded.
|
| 110 |
+
|
| 111 |
+
This class specifies which indices to sample relative to a base index and which
|
| 112 |
+
keys to load for a particular modality (e.g., video, state, action).
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
delta_indices: list[int]
|
| 116 |
+
"""Delta indices to sample relative to the current index. The returned data will correspond to the original data at a sampled base index + delta indices."""
|
| 117 |
+
modality_keys: list[str]
|
| 118 |
+
"""The keys to load for the modality in the dataset."""
|
| 119 |
+
sin_cos_embedding_keys: list[str] | None = None
|
| 120 |
+
"""Optional list of keys to apply sin/cos encoding. If None or empty, use min/max normalization for all keys."""
|
| 121 |
+
mean_std_embedding_keys: list[str] | None = None
|
| 122 |
+
"""Optional list of keys to apply mean/std normalization. If None or empty, use min/max normalization for all keys."""
|
| 123 |
+
action_configs: list[ActionConfig] | None = None
|
| 124 |
+
|
| 125 |
+
def __post_init__(self):
|
| 126 |
+
"""Set default values for action-related fields if not specified."""
|
| 127 |
+
if self.action_configs is not None:
|
| 128 |
+
assert len(self.action_configs) == len(self.modality_keys), (
|
| 129 |
+
f"Number of action configs ({len(self.action_configs)}) must match number of modality keys ({len(self.modality_keys)})"
|
| 130 |
+
)
|
| 131 |
+
parsed_action_configs = []
|
| 132 |
+
for action_config in self.action_configs:
|
| 133 |
+
if isinstance(action_config, dict):
|
| 134 |
+
action_config = ActionConfig(
|
| 135 |
+
rep=ActionRepresentation[action_config["rep"]],
|
| 136 |
+
type=ActionType[action_config["type"]],
|
| 137 |
+
format=ActionFormat[action_config["format"]],
|
| 138 |
+
state_key=action_config.get("state_key", None),
|
| 139 |
+
)
|
| 140 |
+
parsed_action_configs.append(action_config)
|
| 141 |
+
self.action_configs = parsed_action_configs
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class MsgSerializer:
|
| 145 |
+
@staticmethod
|
| 146 |
+
def to_bytes(data: Any) -> bytes:
|
| 147 |
+
return msgpack.packb(data, default=MsgSerializer.encode_custom_classes)
|
| 148 |
+
|
| 149 |
+
@staticmethod
|
| 150 |
+
def from_bytes(data: bytes) -> Any:
|
| 151 |
+
return msgpack.unpackb(data, object_hook=MsgSerializer.decode_custom_classes)
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def decode_custom_classes(obj):
|
| 155 |
+
if not isinstance(obj, dict):
|
| 156 |
+
return obj
|
| 157 |
+
if "__ModalityConfig_class__" in obj:
|
| 158 |
+
return ModalityConfig(**obj["as_json"])
|
| 159 |
+
if "__ndarray_class__" in obj:
|
| 160 |
+
return np.load(io.BytesIO(obj["as_npy"]), allow_pickle=False)
|
| 161 |
+
return obj
|
| 162 |
+
|
| 163 |
+
@staticmethod
|
| 164 |
+
def encode_custom_classes(obj):
|
| 165 |
+
if isinstance(obj, ModalityConfig):
|
| 166 |
+
# Convert to dict and let msgpack recursively handle nested objects
|
| 167 |
+
return {"__ModalityConfig_class__": True, "as_json": to_json_serializable(obj)}
|
| 168 |
+
if isinstance(obj, np.ndarray):
|
| 169 |
+
output = io.BytesIO()
|
| 170 |
+
np.save(output, obj, allow_pickle=False)
|
| 171 |
+
return {"__ndarray_class__": True, "as_npy": output.getvalue()}
|
| 172 |
+
return obj
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class BasePolicy(ABC):
|
| 176 |
+
"""Abstract base class for robotic control policies.
|
| 177 |
+
|
| 178 |
+
This class defines the interface that all policies must implement, including
|
| 179 |
+
methods for action computation, input/output validation, and state management.
|
| 180 |
+
|
| 181 |
+
Subclasses must implement:
|
| 182 |
+
- check_observation(): Validate observation format
|
| 183 |
+
- check_action(): Validate action format
|
| 184 |
+
- _get_action(): Core action computation logic
|
| 185 |
+
- reset(): Reset policy to initial state
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
def __init__(self, *, strict: bool = True):
|
| 189 |
+
self.strict = strict
|
| 190 |
+
|
| 191 |
+
@abstractmethod
|
| 192 |
+
def check_observation(self, observation: dict[str, Any]) -> None:
|
| 193 |
+
"""Check if the observation is valid.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
observation: Dictionary containing the current state/observation of the environment
|
| 197 |
+
|
| 198 |
+
Raises:
|
| 199 |
+
AssertionError: If the observation is invalid.
|
| 200 |
+
"""
|
| 201 |
+
pass
|
| 202 |
+
|
| 203 |
+
@abstractmethod
|
| 204 |
+
def check_action(self, action: dict[str, Any]) -> None:
|
| 205 |
+
"""Check if the action is valid.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
action: Dictionary containing the action to be executed
|
| 209 |
+
|
| 210 |
+
Raises:
|
| 211 |
+
AssertionError: If the action is invalid.
|
| 212 |
+
"""
|
| 213 |
+
pass
|
| 214 |
+
|
| 215 |
+
@abstractmethod
|
| 216 |
+
def _get_action(
|
| 217 |
+
self, observation: dict[str, Any], options: dict[str, Any] | None = None
|
| 218 |
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
| 219 |
+
"""Compute and return the next action based on current observation.
|
| 220 |
+
|
| 221 |
+
This method should be overridden by subclasses to implement policy-specific
|
| 222 |
+
action computation. Input validation is handled by the public get_action() method.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
observation: Dictionary containing the current state/observation
|
| 226 |
+
options: Optional configuration dict for action computation
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Tuple of (action, info):
|
| 230 |
+
- action: Dictionary containing the action to be executed
|
| 231 |
+
- info: Dictionary containing additional metadata (e.g., confidence scores)
|
| 232 |
+
"""
|
| 233 |
+
pass
|
| 234 |
+
|
| 235 |
+
def get_action(
|
| 236 |
+
self, observation: dict[str, Any], options: dict[str, Any] | None = None
|
| 237 |
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
| 238 |
+
"""Compute and return the next action based on current observation with validation.
|
| 239 |
+
|
| 240 |
+
This is the main public interface. It validates the observation, calls
|
| 241 |
+
the internal _get_action(), and validates the resulting action.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
observation: Dictionary containing the current state/observation
|
| 245 |
+
options: Optional configuration dict for action computation
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
Tuple of (action, info):
|
| 249 |
+
- action: Dictionary containing the validated action
|
| 250 |
+
- info: Dictionary containing additional metadata
|
| 251 |
+
|
| 252 |
+
Raises:
|
| 253 |
+
AssertionError/ValueError: If observation or action validation fails
|
| 254 |
+
"""
|
| 255 |
+
if self.strict:
|
| 256 |
+
self.check_observation(observation)
|
| 257 |
+
action, info = self._get_action(observation, options)
|
| 258 |
+
if self.strict:
|
| 259 |
+
self.check_action(action)
|
| 260 |
+
return action, info
|
| 261 |
+
|
| 262 |
+
@abstractmethod
|
| 263 |
+
def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
| 264 |
+
"""Reset the policy to its initial state.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
options: Dictionary containing the options for the reset
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
Dictionary containing the info after resetting the policy
|
| 271 |
+
"""
|
| 272 |
+
pass
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class PolicyClient(BasePolicy):
|
| 276 |
+
def __init__(
|
| 277 |
+
self,
|
| 278 |
+
host: str = "localhost",
|
| 279 |
+
port: int = 5555,
|
| 280 |
+
timeout_ms: int = 15000,
|
| 281 |
+
api_token: str = None,
|
| 282 |
+
strict: bool = False,
|
| 283 |
+
):
|
| 284 |
+
super().__init__(strict=strict)
|
| 285 |
+
self.context = zmq.Context()
|
| 286 |
+
self.host = host
|
| 287 |
+
self.port = port
|
| 288 |
+
self.timeout_ms = timeout_ms
|
| 289 |
+
self.api_token = api_token
|
| 290 |
+
self._init_socket()
|
| 291 |
+
|
| 292 |
+
def _init_socket(self):
|
| 293 |
+
"""Initialize or reinitialize the socket with current settings"""
|
| 294 |
+
self.socket = self.context.socket(zmq.REQ)
|
| 295 |
+
self.socket.connect(f"tcp://{self.host}:{self.port}")
|
| 296 |
+
|
| 297 |
+
def ping(self) -> bool:
|
| 298 |
+
try:
|
| 299 |
+
self.call_endpoint("ping", requires_input=False)
|
| 300 |
+
return True
|
| 301 |
+
except zmq.error.ZMQError:
|
| 302 |
+
self._init_socket() # Recreate socket for next attempt
|
| 303 |
+
return False
|
| 304 |
+
|
| 305 |
+
def kill_server(self):
|
| 306 |
+
"""
|
| 307 |
+
Kill the server.
|
| 308 |
+
"""
|
| 309 |
+
self.call_endpoint("kill", requires_input=False)
|
| 310 |
+
|
| 311 |
+
def call_endpoint(
|
| 312 |
+
self, endpoint: str, data: dict | None = None, requires_input: bool = True
|
| 313 |
+
) -> Any:
|
| 314 |
+
"""
|
| 315 |
+
Call an endpoint on the server.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
endpoint: The name of the endpoint.
|
| 319 |
+
data: The input data for the endpoint.
|
| 320 |
+
requires_input: Whether the endpoint requires input data.
|
| 321 |
+
"""
|
| 322 |
+
request: dict = {"endpoint": endpoint}
|
| 323 |
+
if requires_input:
|
| 324 |
+
request["data"] = data
|
| 325 |
+
if self.api_token:
|
| 326 |
+
request["api_token"] = self.api_token
|
| 327 |
+
|
| 328 |
+
self.socket.send(MsgSerializer.to_bytes(request))
|
| 329 |
+
message = self.socket.recv()
|
| 330 |
+
if message == b"ERROR":
|
| 331 |
+
raise RuntimeError("Server error. Make sure we are running the correct policy server.")
|
| 332 |
+
response = MsgSerializer.from_bytes(message)
|
| 333 |
+
|
| 334 |
+
if isinstance(response, dict) and "error" in response:
|
| 335 |
+
raise RuntimeError(f"Server error: {response['error']}")
|
| 336 |
+
return response
|
| 337 |
+
|
| 338 |
+
def __del__(self):
|
| 339 |
+
"""Cleanup resources on destruction"""
|
| 340 |
+
self.socket.close()
|
| 341 |
+
self.context.term()
|
| 342 |
+
|
| 343 |
+
def _get_action(
|
| 344 |
+
self, observation: dict[str, Any], options: dict[str, Any] | None = None
|
| 345 |
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
| 346 |
+
response = self.call_endpoint(
|
| 347 |
+
"get_action", {"observation": observation, "options": options}
|
| 348 |
+
)
|
| 349 |
+
return tuple(response) # Convert list (from msgpack) to tuple of (action, info)
|
| 350 |
+
|
| 351 |
+
def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
| 352 |
+
return self.call_endpoint("reset", {"options": options})
|
| 353 |
+
|
| 354 |
+
def get_modality_config(self) -> dict[str, ModalityConfig]:
|
| 355 |
+
return self.call_endpoint("get_modality_config", requires_input=False)
|
| 356 |
+
|
| 357 |
+
def check_observation(self, observation: dict[str, Any]) -> None:
|
| 358 |
+
raise NotImplementedError(
|
| 359 |
+
"check_observation is not implemented. Please use `strict=False` to disable strict mode or implement this method in the subclass."
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
def check_action(self, action: dict[str, Any]) -> None:
|
| 363 |
+
raise NotImplementedError(
|
| 364 |
+
"check_action is not implemented. Please use `strict=False` to disable strict mode or implement this method in the subclass."
|
| 365 |
+
)
|
examples/DROID/utils.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
Taken from https://github.com/Physical-Intelligence/openpi/tree/main/packages/openpi-client/src/openpi_client
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
from PIL import Image
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def convert_to_uint8(img: np.ndarray) -> np.ndarray:
|
| 25 |
+
"""Converts an image to uint8 if it is a float image.
|
| 26 |
+
|
| 27 |
+
This is important for reducing the size of the image when sending it over the network.
|
| 28 |
+
"""
|
| 29 |
+
if np.issubdtype(img.dtype, np.floating):
|
| 30 |
+
img = (255 * img).astype(np.uint8)
|
| 31 |
+
return img
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def resize_with_pad(
|
| 35 |
+
images: np.ndarray, height: int, width: int, method=Image.BILINEAR
|
| 36 |
+
) -> np.ndarray:
|
| 37 |
+
"""Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
images: A batch of images in [..., height, width, channel] format.
|
| 41 |
+
height: The target height of the image.
|
| 42 |
+
width: The target width of the image.
|
| 43 |
+
method: The interpolation method to use. Default is bilinear.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
The resized images in [..., height, width, channel].
|
| 47 |
+
"""
|
| 48 |
+
# If the images are already the correct size, return them as is.
|
| 49 |
+
if images.shape[-3:-1] == (height, width):
|
| 50 |
+
return images
|
| 51 |
+
|
| 52 |
+
original_shape = images.shape
|
| 53 |
+
|
| 54 |
+
images = images.reshape(-1, *original_shape[-3:])
|
| 55 |
+
resized = np.stack(
|
| 56 |
+
[_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images]
|
| 57 |
+
)
|
| 58 |
+
return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
|
| 62 |
+
"""Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
|
| 63 |
+
width without distortion by padding with zeros.
|
| 64 |
+
|
| 65 |
+
Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
|
| 66 |
+
"""
|
| 67 |
+
cur_width, cur_height = image.size
|
| 68 |
+
if cur_width == width and cur_height == height:
|
| 69 |
+
return image # No need to resize if the image is already the correct size.
|
| 70 |
+
|
| 71 |
+
ratio = max(cur_width / width, cur_height / height)
|
| 72 |
+
resized_height = int(cur_height / ratio)
|
| 73 |
+
resized_width = int(cur_width / ratio)
|
| 74 |
+
resized_image = image.resize((resized_width, resized_height), resample=method)
|
| 75 |
+
|
| 76 |
+
zero_image = Image.new(resized_image.mode, (width, height), 0)
|
| 77 |
+
pad_height = max(0, int((height - resized_height) / 2))
|
| 78 |
+
pad_width = max(0, int((width - resized_width) / 2))
|
| 79 |
+
zero_image.paste(resized_image, (pad_width, pad_height))
|
| 80 |
+
assert zero_image.size == (width, height)
|
| 81 |
+
return zero_image
|
examples/LIBERO/README.md
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LIBERO
|
| 2 |
+
|
| 3 |
+
Benchmark for studying knowledge transfer in lifelong robot learning. Includes multiple suites: **Spatial** (spatial reasoning), **Object** (object generalization), **Goal** (goal-conditioned learning), and **10 Long** (long-horizon multi-step tasks). Provides RGB images, proprioception data, and language task specifications.
|
| 4 |
+
|
| 5 |
+
For more information, see the [official website](https://libero-project.github.io/main.html).
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# LIBERO evaluation benchmark result
|
| 10 |
+
|
| 11 |
+
> **Note:** The full task list is attached at the end of this document.
|
| 12 |
+
|
| 13 |
+
All four suites were finetuned with the same hyper-parameters, including
|
| 14 |
+
`--state-dropout-prob 0.2` (the finetune CLI default from
|
| 15 |
+
`gr00t/configs/finetune_config.py`).
|
| 16 |
+
|
| 17 |
+
| Task | Success rate | max_steps | grad_accum_steps | batch_size |
|
| 18 |
+
|-----------|--------------------|-----------|------------------|------------|
|
| 19 |
+
| Spatial | 195/200 (97.65%) | 20K | 1 | 640 |
|
| 20 |
+
| Goal | 195/200 (97.5%) | 20K | 1 | 640 |
|
| 21 |
+
| Object | 197/200 (98.45%) | 20K | 1 | 640 |
|
| 22 |
+
| 10 (Long) | 189/200 (94.35%) | 20K | 1 | 640 |
|
| 23 |
+
|
| 24 |
+
# Fine-tune LIBERO 10 (long)
|
| 25 |
+
|
| 26 |
+
To reproduce our finetune results, use the following commands to setup dataset and launch finetune experiments. Please remember to set `WANDB_API_KEY` since `--use-wandb` is turned on by default. If you don't have a WANDB account, please remove this argument:
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
uv run hf download \
|
| 30 |
+
--repo-type dataset IPEC-COMMUNITY/libero_10_no_noops_1.0.0_lerobot \
|
| 31 |
+
--local-dir examples/LIBERO/libero_10_no_noops_1.0.0_lerobot/
|
| 32 |
+
|
| 33 |
+
# Copy the patches and run the finetune script
|
| 34 |
+
cp -r examples/LIBERO/modality.json examples/LIBERO/libero_10_no_noops_1.0.0_lerobot/meta/
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
Run the shared finetune launcher:
|
| 38 |
+
```bash
|
| 39 |
+
NUM_GPUS=8 MAX_STEPS=20000 GLOBAL_BATCH_SIZE=640 SAVE_STEPS=1000 uv run bash examples/finetune.sh \
|
| 40 |
+
--base-model-path nvidia/GR00T-N1.7-3B \
|
| 41 |
+
--dataset-path examples/LIBERO/libero_10_no_noops_1.0.0_lerobot/ \
|
| 42 |
+
--embodiment-tag LIBERO_PANDA \
|
| 43 |
+
--output-dir /tmp/libero_10 \
|
| 44 |
+
--state-dropout-prob 0.2
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
# Fine-tune LIBERO goal
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
uv run hf download \
|
| 51 |
+
--repo-type dataset IPEC-COMMUNITY/libero_goal_no_noops_1.0.0_lerobot \
|
| 52 |
+
--local-dir examples/LIBERO/libero_goal_no_noops_1.0.0_lerobot/
|
| 53 |
+
|
| 54 |
+
# Copy the patches and run the finetune script
|
| 55 |
+
cp -r examples/LIBERO/modality.json examples/LIBERO/libero_goal_no_noops_1.0.0_lerobot/meta/
|
| 56 |
+
## This is a patch for one of the episode where the image seems to be corrupted.
|
| 57 |
+
cp examples/LIBERO/patches/episode_000082.mp4 examples/LIBERO/libero_goal_no_noops_1.0.0_lerobot/videos/chunk-000/observation.images.wrist_image/
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
Run the shared finetune launcher:
|
| 61 |
+
```bash
|
| 62 |
+
NUM_GPUS=8 MAX_STEPS=20000 GLOBAL_BATCH_SIZE=640 SAVE_STEPS=1000 uv run bash examples/finetune.sh \
|
| 63 |
+
--base-model-path nvidia/GR00T-N1.7-3B \
|
| 64 |
+
--dataset-path examples/LIBERO/libero_goal_no_noops_1.0.0_lerobot/ \
|
| 65 |
+
--embodiment-tag LIBERO_PANDA \
|
| 66 |
+
--output-dir /tmp/libero_goal
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
# Fine-tune LIBERO object
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
uv run hf download \
|
| 73 |
+
--repo-type dataset IPEC-COMMUNITY/libero_object_no_noops_1.0.0_lerobot \
|
| 74 |
+
--local-dir examples/LIBERO/libero_object_no_noops_1.0.0_lerobot/
|
| 75 |
+
|
| 76 |
+
# Copy the patches and run the finetune script
|
| 77 |
+
cp -r examples/LIBERO/modality.json examples/LIBERO/libero_object_no_noops_1.0.0_lerobot/meta/
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
Run the shared finetune launcher:
|
| 81 |
+
```bash
|
| 82 |
+
NUM_GPUS=8 MAX_STEPS=20000 GLOBAL_BATCH_SIZE=640 SAVE_STEPS=1000 uv run bash examples/finetune.sh \
|
| 83 |
+
--base-model-path nvidia/GR00T-N1.7-3B \
|
| 84 |
+
--dataset-path examples/LIBERO/libero_object_no_noops_1.0.0_lerobot/ \
|
| 85 |
+
--embodiment-tag LIBERO_PANDA \
|
| 86 |
+
--output-dir /tmp/libero_object
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
# Fine-tune LIBERO spatial
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
uv run hf download \
|
| 93 |
+
--repo-type dataset IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot \
|
| 94 |
+
--local-dir examples/LIBERO/libero_spatial_no_noops_1.0.0_lerobot/
|
| 95 |
+
|
| 96 |
+
# Copy the patches and run the finetune script
|
| 97 |
+
cp -r examples/LIBERO/modality.json examples/LIBERO/libero_spatial_no_noops_1.0.0_lerobot/meta/
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
Run the shared finetune launcher:
|
| 101 |
+
```bash
|
| 102 |
+
NUM_GPUS=8 MAX_STEPS=20000 GLOBAL_BATCH_SIZE=640 SAVE_STEPS=1000 uv run bash examples/finetune.sh \
|
| 103 |
+
--base-model-path nvidia/GR00T-N1.7-3B \
|
| 104 |
+
--dataset-path examples/LIBERO/libero_spatial_no_noops_1.0.0_lerobot/ \
|
| 105 |
+
--embodiment-tag LIBERO_PANDA \
|
| 106 |
+
--output-dir /tmp/libero_spatial
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
# Evaluate checkpoint
|
| 110 |
+
|
| 111 |
+
First, setup the evaluation simulation environment. This only needs to run once for each simulation benchmark. After it's done, we only need to launch server and client.
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
sudo apt update
|
| 115 |
+
sudo apt install libegl1-mesa-dev libglu1-mesa
|
| 116 |
+
bash gr00t/eval/sim/LIBERO/setup_libero.sh
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
Then, download the finetuned model to a local directory (HuggingFace does not support nested repo paths directly):
|
| 120 |
+
```bash
|
| 121 |
+
uv run hf download nvidia/GR00T-N1.7-LIBERO --include "libero_10/config.json" "libero_10/embodiment_id.json" "libero_10/model-*.safetensors" "libero_10/model.safetensors.index.json" "libero_10/processor_config.json" "libero_10/statistics.json" --local-dir checkpoints/GR00T-N1.7-LIBERO
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
Run client server evaluation under the project root directory in separate terminals:
|
| 125 |
+
|
| 126 |
+
**Terminal 1 - Server:**
|
| 127 |
+
```bash
|
| 128 |
+
uv run python gr00t/eval/run_gr00t_server.py \
|
| 129 |
+
--model-path checkpoints/GR00T-N1.7-LIBERO/libero_10 \
|
| 130 |
+
--embodiment-tag LIBERO_PANDA \
|
| 131 |
+
--use-sim-policy-wrapper
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
> **Note:** Replace `checkpoints/GR00T-N1.7-LIBERO/libero_10` with your own checkpoint path (e.g., `/tmp/libero_10/checkpoint-20000/`) if evaluating a locally finetuned model.
|
| 135 |
+
|
| 136 |
+
**Terminal 2 - Client:**
|
| 137 |
+
```bash
|
| 138 |
+
gr00t/eval/sim/LIBERO/libero_uv/.venv/bin/python gr00t/eval/rollout_policy.py \
|
| 139 |
+
--n-episodes 10 \
|
| 140 |
+
--policy-client-host 127.0.0.1 \
|
| 141 |
+
--policy-client-port 5555 \
|
| 142 |
+
--max-episode-steps 720 \
|
| 143 |
+
--env-name libero_sim/KITCHEN_SCENE3_turn_on_the_stove_and_put_the_moka_pot_on_it \
|
| 144 |
+
--n-action-steps 8 \
|
| 145 |
+
--n-envs 5
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
# Full task list
|
| 149 |
+
|
| 150 |
+
## Libero 10 (Long)
|
| 151 |
+
- `libero_sim/LIVING_ROOM_SCENE2_put_both_the_alphabet_soup_and_the_tomato_sauce_in_the_basket`
|
| 152 |
+
- `libero_sim/LIVING_ROOM_SCENE2_put_both_the_cream_cheese_box_and_the_butter_in_the_basket`
|
| 153 |
+
- `libero_sim/KITCHEN_SCENE3_turn_on_the_stove_and_put_the_moka_pot_on_it`
|
| 154 |
+
- `libero_sim/KITCHEN_SCENE4_put_the_black_bowl_in_the_bottom_drawer_of_the_cabinet_and_close_it`
|
| 155 |
+
- `libero_sim/LIVING_ROOM_SCENE5_put_the_white_mug_on_the_left_plate_and_put_the_yellow_and_white_mug_on_the_right_plate`
|
| 156 |
+
- `libero_sim/STUDY_SCENE1_pick_up_the_book_and_place_it_in_the_back_compartment_of_the_caddy`
|
| 157 |
+
- `libero_sim/LIVING_ROOM_SCENE6_put_the_white_mug_on_the_plate_and_put_the_chocolate_pudding_to_the_right_of_the_plate`
|
| 158 |
+
- `libero_sim/LIVING_ROOM_SCENE1_put_both_the_alphabet_soup_and_the_cream_cheese_box_in_the_basket`
|
| 159 |
+
- `libero_sim/KITCHEN_SCENE8_put_both_moka_pots_on_the_stove`
|
| 160 |
+
- `libero_sim/KITCHEN_SCENE6_put_the_yellow_and_white_mug_in_the_microwave_and_close_it`
|
| 161 |
+
|
| 162 |
+
## Libero Goal
|
| 163 |
+
- `libero_sim/open_the_middle_drawer_of_the_cabinet`
|
| 164 |
+
- `libero_sim/put_the_bowl_on_the_stove`
|
| 165 |
+
- `libero_sim/put_the_wine_bottle_on_top_of_the_cabinet`
|
| 166 |
+
- `libero_sim/open_the_top_drawer_and_put_the_bowl_inside`
|
| 167 |
+
- `libero_sim/put_the_bowl_on_top_of_the_cabinet`
|
| 168 |
+
- `libero_sim/push_the_plate_to_the_front_of_the_stove`
|
| 169 |
+
- `libero_sim/put_the_cream_cheese_in_the_bowl`
|
| 170 |
+
- `libero_sim/turn_on_the_stove`
|
| 171 |
+
- `libero_sim/put_the_bowl_on_the_plate`
|
| 172 |
+
- `libero_sim/put_the_wine_bottle_on_the_rack`
|
| 173 |
+
|
| 174 |
+
## Libero Object
|
| 175 |
+
- `libero_sim/pick_up_the_alphabet_soup_and_place_it_in_the_basket`
|
| 176 |
+
- `libero_sim/pick_up_the_cream_cheese_and_place_it_in_the_basket`
|
| 177 |
+
- `libero_sim/pick_up_the_salad_dressing_and_place_it_in_the_basket`
|
| 178 |
+
- `libero_sim/pick_up_the_bbq_sauce_and_place_it_in_the_basket`
|
| 179 |
+
- `libero_sim/pick_up_the_ketchup_and_place_it_in_the_basket`
|
| 180 |
+
- `libero_sim/pick_up_the_tomato_sauce_and_place_it_in_the_basket`
|
| 181 |
+
- `libero_sim/pick_up_the_butter_and_place_it_in_the_basket`
|
| 182 |
+
- `libero_sim/pick_up_the_milk_and_place_it_in_the_basket`
|
| 183 |
+
- `libero_sim/pick_up_the_chocolate_pudding_and_place_it_in_the_basket`
|
| 184 |
+
- `libero_sim/pick_up_the_orange_juice_and_place_it_in_the_basket`
|
| 185 |
+
|
| 186 |
+
## Libero Spatial
|
| 187 |
+
- `libero_sim/pick_up_the_black_bowl_between_the_plate_and_the_ramekin_and_place_it_on_the_plate`
|
| 188 |
+
- `libero_sim/pick_up_the_black_bowl_next_to_the_ramekin_and_place_it_on_the_plate`
|
| 189 |
+
- `libero_sim/pick_up_the_black_bowl_from_table_center_and_place_it_on_the_plate`
|
| 190 |
+
- `libero_sim/pick_up_the_black_bowl_on_the_cookie_box_and_place_it_on_the_plate`
|
| 191 |
+
- `libero_sim/pick_up_the_black_bowl_in_the_top_drawer_of_the_wooden_cabinet_and_place_it_on_the_plate`
|
| 192 |
+
- `libero_sim/pick_up_the_black_bowl_on_the_ramekin_and_place_it_on_the_plate`
|
| 193 |
+
- `libero_sim/pick_up_the_black_bowl_next_to_the_cookie_box_and_place_it_on_the_plate`
|
| 194 |
+
- `libero_sim/pick_up_the_black_bowl_on_the_stove_and_place_it_on_the_plate`
|
| 195 |
+
- `libero_sim/pick_up_the_black_bowl_next_to_the_plate_and_place_it_on_the_plate`
|
| 196 |
+
- `libero_sim/pick_up_the_black_bowl_on_the_wooden_cabinet_and_place_it_on_the_plate`
|
examples/LIBERO/modality.json
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"state": {
|
| 3 |
+
"x": {
|
| 4 |
+
"start": 0,
|
| 5 |
+
"end": 1
|
| 6 |
+
},
|
| 7 |
+
"y": {
|
| 8 |
+
"start": 1,
|
| 9 |
+
"end": 2
|
| 10 |
+
},
|
| 11 |
+
"z": {
|
| 12 |
+
"start": 2,
|
| 13 |
+
"end": 3
|
| 14 |
+
},
|
| 15 |
+
"roll": {
|
| 16 |
+
"start": 3,
|
| 17 |
+
"end": 4
|
| 18 |
+
},
|
| 19 |
+
"pitch": {
|
| 20 |
+
"start": 4,
|
| 21 |
+
"end": 5
|
| 22 |
+
},
|
| 23 |
+
"yaw": {
|
| 24 |
+
"start": 5,
|
| 25 |
+
"end": 6
|
| 26 |
+
},
|
| 27 |
+
"gripper": {
|
| 28 |
+
"start": 6,
|
| 29 |
+
"end": 8
|
| 30 |
+
}
|
| 31 |
+
},
|
| 32 |
+
"action": {
|
| 33 |
+
"x": {
|
| 34 |
+
"start": 0,
|
| 35 |
+
"end": 1
|
| 36 |
+
},
|
| 37 |
+
"y": {
|
| 38 |
+
"start": 1,
|
| 39 |
+
"end": 2
|
| 40 |
+
},
|
| 41 |
+
"z": {
|
| 42 |
+
"start": 2,
|
| 43 |
+
"end": 3
|
| 44 |
+
},
|
| 45 |
+
"roll": {
|
| 46 |
+
"start": 3,
|
| 47 |
+
"end": 4
|
| 48 |
+
},
|
| 49 |
+
"pitch": {
|
| 50 |
+
"start": 4,
|
| 51 |
+
"end": 5
|
| 52 |
+
},
|
| 53 |
+
"yaw": {
|
| 54 |
+
"start": 5,
|
| 55 |
+
"end": 6
|
| 56 |
+
},
|
| 57 |
+
"gripper": {
|
| 58 |
+
"start": 6,
|
| 59 |
+
"end": 7
|
| 60 |
+
}
|
| 61 |
+
},
|
| 62 |
+
"video": {
|
| 63 |
+
"image": {
|
| 64 |
+
"original_key": "observation.images.image"
|
| 65 |
+
},
|
| 66 |
+
"wrist_image": {
|
| 67 |
+
"original_key": "observation.images.wrist_image"
|
| 68 |
+
}
|
| 69 |
+
},
|
| 70 |
+
"annotation": {
|
| 71 |
+
"human.action.task_description": {
|
| 72 |
+
"original_key": "task_index"
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
}
|
examples/SO100/README.md
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Finetuning SO100 Model
|
| 2 |
+
|
| 3 |
+
This guide shows how to finetune dataset collected from [SO100](https://huggingface.co/docs/lerobot/en/so101) robot, and evaluate the model on the real robot.
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
## Dataset
|
| 7 |
+
|
| 8 |
+
To collect the dataset via teleoperation, please refer to the official documentation in lerobot: https://huggingface.co/docs/lerobot/il_robots?teleoperate_so101=Command
|
| 9 |
+
|
| 10 |
+
**Dataset Path:** [izuluaga/finish_sandwich](https://huggingface.co/datasets/izuluaga/finish_sandwich)
|
| 11 |
+
|
| 12 |
+
Visualize it with this [link](https://huggingface.co/spaces/lerobot/visualize_dataset?path=%2Fizuluaga%2Ffinish_sandwich%2Fepisode_0)
|
| 13 |
+
|
| 14 |
+
## Handling the dataset
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
uv run --project scripts/lerobot_conversion \
|
| 18 |
+
python scripts/lerobot_conversion/convert_v3_to_v2.py \
|
| 19 |
+
--repo-id izuluaga/finish_sandwich \
|
| 20 |
+
--root examples/SO100/finish_sandwich_lerobot
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
Then move the `modality.json` file to the root of the dataset.
|
| 24 |
+
```bash
|
| 25 |
+
cp examples/SO100/modality.json examples/SO100/finish_sandwich_lerobot/izuluaga/finish_sandwich/meta/modality.json
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Finetuning
|
| 29 |
+
|
| 30 |
+
Run the shared finetune launcher directly, using absolute joint positions (feel free to experiment with relative positions):
|
| 31 |
+
```bash
|
| 32 |
+
CUDA_VISIBLE_DEVICES=0 NUM_GPUS=1 uv run bash examples/finetune.sh \
|
| 33 |
+
--base-model-path nvidia/GR00T-N1.7-3B \
|
| 34 |
+
--dataset-path examples/SO100/finish_sandwich_lerobot/izuluaga/finish_sandwich \
|
| 35 |
+
--modality-config-path examples/SO100/so100_config.py \
|
| 36 |
+
--embodiment-tag NEW_EMBODIMENT \
|
| 37 |
+
--output-dir /tmp/so100_finetune
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Open-Loop Evaluation
|
| 41 |
+
|
| 42 |
+
Evaluate the finetuned model with the following command:
|
| 43 |
+
```bash
|
| 44 |
+
uv run python gr00t/eval/open_loop_eval.py \
|
| 45 |
+
--dataset-path examples/SO100/finish_sandwich_lerobot/izuluaga/finish_sandwich/ \
|
| 46 |
+
--embodiment-tag NEW_EMBODIMENT \
|
| 47 |
+
--model-path /tmp/so100_finetune/checkpoint-10000 \
|
| 48 |
+
--traj-ids 0 \
|
| 49 |
+
--action-horizon 16 \
|
| 50 |
+
--steps 400
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
### Evaluation Results
|
| 54 |
+
|
| 55 |
+
The evaluation produces visualizations comparing predicted actions against ground truth trajectories:
|
| 56 |
+
|
| 57 |
+
<img src="../../media/open_loop_eval_so100.jpg" width="800" alt="Open-loop evaluation results showing predicted vs ground truth trajectories" />
|
| 58 |
+
|
| 59 |
+
## Closed-Loop Evaluation
|
| 60 |
+
|
| 61 |
+
Please refer to [eval_so100.py](../../gr00t/eval/real_robot/SO100/eval_so100.py) for how to write SO100 deployment code using Policy API.
|
| 62 |
+
|
| 63 |
+
1. set up client side deps
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
cd gr00t/eval/real_robot/SO100
|
| 67 |
+
uv venv
|
| 68 |
+
source .venv/bin/activate
|
| 69 |
+
uv pip install -e . --verbose
|
| 70 |
+
uv pip install --no-deps -e ../../../../
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
2. Start policy server
|
| 74 |
+
```bash
|
| 75 |
+
uv run python gr00t/eval/run_gr00t_server.py \
|
| 76 |
+
--model-path /tmp/so100_finetune/checkpoint-10000 \
|
| 77 |
+
--embodiment-tag NEW_EMBODIMENT
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
3. Run the eval script, as client.
|
| 81 |
+
```bash
|
| 82 |
+
uv run python gr00t/eval/real_robot/SO100/eval_so100.py \
|
| 83 |
+
--robot.type=so101_follower --robot.port=/dev/ttyACM2 \
|
| 84 |
+
--robot.id=orange_follower \
|
| 85 |
+
--robot.cameras="{ wrist: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 6, width: 640, height: 480, fps: 30}}" \
|
| 86 |
+
--policy-host=localhost --policy-port=5555 --lang-instruction="finish the ham cheese olives sandwich"
|
| 87 |
+
```
|
examples/SO100/modality.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"state": {
|
| 3 |
+
"single_arm": {
|
| 4 |
+
"start": 0,
|
| 5 |
+
"end": 5
|
| 6 |
+
},
|
| 7 |
+
"gripper": {
|
| 8 |
+
"start": 5,
|
| 9 |
+
"end": 6
|
| 10 |
+
}
|
| 11 |
+
},
|
| 12 |
+
"action": {
|
| 13 |
+
"single_arm": {
|
| 14 |
+
"start": 0,
|
| 15 |
+
"end": 5
|
| 16 |
+
},
|
| 17 |
+
"gripper": {
|
| 18 |
+
"start": 5,
|
| 19 |
+
"end": 6
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"video": {
|
| 23 |
+
"front": {
|
| 24 |
+
"original_key": "observation.images.front"
|
| 25 |
+
},
|
| 26 |
+
"wrist": {
|
| 27 |
+
"original_key": "observation.images.wrist"
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"annotation": {
|
| 31 |
+
"human.task_description": {
|
| 32 |
+
"original_key": "task_index"
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
}
|
examples/SO100/so100_config.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from gr00t.configs.data.embodiment_configs import register_modality_config
|
| 17 |
+
from gr00t.data.embodiment_tags import EmbodimentTag
|
| 18 |
+
from gr00t.data.types import (
|
| 19 |
+
ActionConfig,
|
| 20 |
+
ActionFormat,
|
| 21 |
+
ActionRepresentation,
|
| 22 |
+
ActionType,
|
| 23 |
+
ModalityConfig,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
so100_config = {
|
| 28 |
+
# Video: current frame only; keys must match "video" entries in meta/modality.json
|
| 29 |
+
"video": ModalityConfig(
|
| 30 |
+
delta_indices=[0],
|
| 31 |
+
modality_keys=["front", "wrist"], # front third-person view + wrist egocentric
|
| 32 |
+
),
|
| 33 |
+
# State: current proprioceptive reading; keys must match "state" entries in meta/modality.json
|
| 34 |
+
"state": ModalityConfig(
|
| 35 |
+
delta_indices=[0],
|
| 36 |
+
modality_keys=[
|
| 37 |
+
"single_arm", # joint positions
|
| 38 |
+
"gripper", # gripper state
|
| 39 |
+
],
|
| 40 |
+
),
|
| 41 |
+
# Action: 16-step prediction horizon; one ActionConfig per modality key
|
| 42 |
+
"action": ModalityConfig(
|
| 43 |
+
delta_indices=list(range(0, 16)), # predict 16 future steps
|
| 44 |
+
modality_keys=[
|
| 45 |
+
"single_arm",
|
| 46 |
+
"gripper",
|
| 47 |
+
],
|
| 48 |
+
action_configs=[
|
| 49 |
+
# single_arm: RELATIVE = delta from current state (better generalization)
|
| 50 |
+
ActionConfig(
|
| 51 |
+
rep=ActionRepresentation.RELATIVE,
|
| 52 |
+
type=ActionType.NON_EEF, # joint-space, not end-effector
|
| 53 |
+
format=ActionFormat.DEFAULT,
|
| 54 |
+
),
|
| 55 |
+
# gripper: ABSOLUTE = target position (binary open/close works better absolute)
|
| 56 |
+
ActionConfig(
|
| 57 |
+
rep=ActionRepresentation.ABSOLUTE,
|
| 58 |
+
type=ActionType.NON_EEF,
|
| 59 |
+
format=ActionFormat.DEFAULT,
|
| 60 |
+
),
|
| 61 |
+
],
|
| 62 |
+
),
|
| 63 |
+
# Language: task instruction from annotation field in the dataset
|
| 64 |
+
"language": ModalityConfig(
|
| 65 |
+
delta_indices=[0],
|
| 66 |
+
modality_keys=["annotation.human.task_description"],
|
| 67 |
+
),
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
register_modality_config(so100_config, embodiment_tag=EmbodimentTag.NEW_EMBODIMENT)
|
examples/SimplerEnv/README.md
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SimplerEnv
|
| 2 |
+
|
| 3 |
+
Framework for evaluating real-world robot manipulation policies (RT-1, RT-1-X, Octo) in simulation. Replicates common setups like Google Robot and WidowX+Bridge, with GPU-accelerated simulations (10-15x speedup). Offers visual matching and variant aggregation evaluation methods for robust policy assessment.
|
| 4 |
+
|
| 5 |
+
For more information, see the [official repository](https://github.com/simpler-env/SimplerEnv).
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# Fine-tune Simpler Env bridge dataset (WidowX robot)
|
| 10 |
+
|
| 11 |
+
To reproduce our finetune results, use the following commands to setup dataset and launch finetune experiments. Please remember to set `WANDB_API_KEY` since `--use-wandb` is turned on by default. If you don't have a WANDB account, please remove this argument:
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
uv run hf download \
|
| 15 |
+
--repo-type dataset IPEC-COMMUNITY/bridge_orig_lerobot \
|
| 16 |
+
--local-dir examples/SimplerEnv/bridge_orig_lerobot/
|
| 17 |
+
|
| 18 |
+
# Copy the patches and run the finetune script
|
| 19 |
+
cp examples/SimplerEnv/bridge_modality.json examples/SimplerEnv/bridge_orig_lerobot/meta/modality.json
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
NUM_GPUS=8 MAX_STEPS=20000 GLOBAL_BATCH_SIZE=1024 SAVE_STEPS=1000 uv run bash examples/finetune.sh \
|
| 24 |
+
--base-model-path nvidia/GR00T-N1.7-3B \
|
| 25 |
+
--dataset-path examples/SimplerEnv/bridge_orig_lerobot/ \
|
| 26 |
+
--embodiment-tag SIMPLER_ENV_WIDOWX \
|
| 27 |
+
--output-dir /tmp/bridge_finetune \
|
| 28 |
+
--state-dropout-prob 0.8
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
# Fine-tune Simpler Env fractal dataset (Google robot)
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
uv run hf download \
|
| 35 |
+
--repo-type dataset IPEC-COMMUNITY/fractal20220817_data_lerobot \
|
| 36 |
+
--local-dir examples/SimplerEnv/fractal20220817_data_lerobot/
|
| 37 |
+
|
| 38 |
+
# Copy the patches and run the finetune script
|
| 39 |
+
cp -r examples/SimplerEnv/fractal_modality.json examples/SimplerEnv/fractal20220817_data_lerobot/meta/modality.json
|
| 40 |
+
uv run python examples/SimplerEnv/convert_av1_to_h264.py examples/SimplerEnv/fractal20220817_data_lerobot --jobs 16 # (Optional) if AV1 doesn't work on your machine
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
NUM_GPUS=8 MAX_STEPS=20000 GLOBAL_BATCH_SIZE=1024 SAVE_STEPS=1000 uv run bash examples/finetune.sh \
|
| 45 |
+
--base-model-path nvidia/GR00T-N1.7-3B \
|
| 46 |
+
--dataset-path examples/SimplerEnv/fractal20220817_data_lerobot/ \
|
| 47 |
+
--embodiment-tag SIMPLER_ENV_GOOGLE \
|
| 48 |
+
--output-dir /tmp/fractal_finetune \
|
| 49 |
+
--state-dropout-prob 0.5
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
# Evaluate checkpoint
|
| 53 |
+
|
| 54 |
+
First, setup the evaluation simulation environment. This only needs to run once for each simulation benchmark. After it's done, we only need to launch server and client.
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
sudo apt update
|
| 58 |
+
sudo apt install libegl1-mesa-dev libglu1-mesa
|
| 59 |
+
bash gr00t/eval/sim/SimplerEnv/setup_SimplerEnv.sh
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
Then, run client server evaluation under the project root directory in separate terminals:
|
| 63 |
+
|
| 64 |
+
## Fractal (Google Robot) Evaluation
|
| 65 |
+
|
| 66 |
+
**Terminal 1 - Server:**
|
| 67 |
+
|
| 68 |
+
You can use either a local finetuned checkpoint path or the remote finetuned checkpoint (provided by us):
|
| 69 |
+
|
| 70 |
+
**Option 1: Local finetuned checkpoint**
|
| 71 |
+
```bash
|
| 72 |
+
uv run python gr00t/eval/run_gr00t_server.py \
|
| 73 |
+
--model-path /tmp/fractal_finetune/checkpoint-30000 \
|
| 74 |
+
--embodiment-tag SIMPLER_ENV_GOOGLE \
|
| 75 |
+
--use-sim-policy-wrapper
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
**Option 2: Remote finetuned checkpoint (directly runnable)**
|
| 79 |
+
```bash
|
| 80 |
+
uv run python gr00t/eval/run_gr00t_server.py \
|
| 81 |
+
--model-path nvidia/GR00T-N1.7-SimplerEnv-Fractal \
|
| 82 |
+
--embodiment-tag SIMPLER_ENV_GOOGLE \
|
| 83 |
+
--use-sim-policy-wrapper
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
**Terminal 2 - Client:**
|
| 87 |
+
```bash
|
| 88 |
+
gr00t/eval/sim/SimplerEnv/simpler_uv/.venv/bin/python gr00t/eval/rollout_policy.py \
|
| 89 |
+
--n-episodes 10 \
|
| 90 |
+
--policy-client-host 127.0.0.1 \
|
| 91 |
+
--policy-client-port 5555 \
|
| 92 |
+
--max-episode-steps 300 \
|
| 93 |
+
--env-name simpler_env_google/google_robot_pick_coke_can \
|
| 94 |
+
--n-action-steps 1 \
|
| 95 |
+
--n-envs 5
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
## Bridge (WidowX) Evaluation
|
| 99 |
+
|
| 100 |
+
**Terminal 1 - Server:**
|
| 101 |
+
|
| 102 |
+
**Option 1: Local finetuned checkpoint**
|
| 103 |
+
```bash
|
| 104 |
+
uv run python gr00t/eval/run_gr00t_server.py \
|
| 105 |
+
--model-path /tmp/bridge_finetune/checkpoint-30000 \
|
| 106 |
+
--embodiment-tag SIMPLER_ENV_WIDOWX \
|
| 107 |
+
--use-sim-policy-wrapper
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
**Option 2: Remote finetuned checkpoint (directly runnable)**
|
| 111 |
+
```bash
|
| 112 |
+
uv run python gr00t/eval/run_gr00t_server.py \
|
| 113 |
+
--model-path nvidia/GR00T-N1.7-SimplerEnv-Bridge \
|
| 114 |
+
--embodiment-tag SIMPLER_ENV_WIDOWX \
|
| 115 |
+
--use-sim-policy-wrapper
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
**Terminal 2 - Client:**
|
| 119 |
+
```bash
|
| 120 |
+
gr00t/eval/sim/SimplerEnv/simpler_uv/.venv/bin/python gr00t/eval/rollout_policy.py \
|
| 121 |
+
--n-episodes 10 \
|
| 122 |
+
--policy-client-host 127.0.0.1 \
|
| 123 |
+
--policy-client-port 5555 \
|
| 124 |
+
--max-episode-steps 300 \
|
| 125 |
+
--env-name simpler_env_widowx/widowx_spoon_on_towel \
|
| 126 |
+
--n-action-steps 4 \
|
| 127 |
+
--n-envs 5
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
Other supported tasks are:
|
| 131 |
+
```
|
| 132 |
+
simpler_env_google/google_robot_pick_object
|
| 133 |
+
simpler_env_google/google_robot_move_near
|
| 134 |
+
simpler_env_google/google_robot_open_drawer
|
| 135 |
+
...
|
| 136 |
+
simpler_env_widowx/widowx_spoon_on_towel
|
| 137 |
+
simpler_env_widowx/widowx_carrot_on_plate
|
| 138 |
+
simpler_env_widowx/widowx_stack_cube
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
you can replace the env_name with the corresponding tasks listed in the SimplerEnv fork this repo pins at `external_dependencies/SimplerEnv` (see `.gitmodules`).
|
examples/SimplerEnv/bridge_modality.json
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"state": {
|
| 3 |
+
"x": {
|
| 4 |
+
"start": 0,
|
| 5 |
+
"end": 1
|
| 6 |
+
},
|
| 7 |
+
"y": {
|
| 8 |
+
"start": 1,
|
| 9 |
+
"end": 2
|
| 10 |
+
},
|
| 11 |
+
"z": {
|
| 12 |
+
"start": 2,
|
| 13 |
+
"end": 3
|
| 14 |
+
},
|
| 15 |
+
"roll": {
|
| 16 |
+
"start": 3,
|
| 17 |
+
"end": 4
|
| 18 |
+
},
|
| 19 |
+
"pitch": {
|
| 20 |
+
"start": 4,
|
| 21 |
+
"end": 5
|
| 22 |
+
},
|
| 23 |
+
"yaw": {
|
| 24 |
+
"start": 5,
|
| 25 |
+
"end": 6
|
| 26 |
+
},
|
| 27 |
+
"pad": {
|
| 28 |
+
"start": 6,
|
| 29 |
+
"end": 7
|
| 30 |
+
},
|
| 31 |
+
"gripper": {
|
| 32 |
+
"start": 7,
|
| 33 |
+
"end": 8
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
"action": {
|
| 37 |
+
"x": {
|
| 38 |
+
"start": 0,
|
| 39 |
+
"end": 1
|
| 40 |
+
},
|
| 41 |
+
"y": {
|
| 42 |
+
"start": 1,
|
| 43 |
+
"end": 2
|
| 44 |
+
},
|
| 45 |
+
"z": {
|
| 46 |
+
"start": 2,
|
| 47 |
+
"end": 3
|
| 48 |
+
},
|
| 49 |
+
"roll": {
|
| 50 |
+
"start": 3,
|
| 51 |
+
"end": 4
|
| 52 |
+
},
|
| 53 |
+
"pitch": {
|
| 54 |
+
"start": 4,
|
| 55 |
+
"end": 5
|
| 56 |
+
},
|
| 57 |
+
"yaw": {
|
| 58 |
+
"start": 5,
|
| 59 |
+
"end": 6
|
| 60 |
+
},
|
| 61 |
+
"gripper": {
|
| 62 |
+
"start": 6,
|
| 63 |
+
"end": 7
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"video": {
|
| 67 |
+
"image_0": {
|
| 68 |
+
"original_key": "observation.images.image_0"
|
| 69 |
+
}
|
| 70 |
+
},
|
| 71 |
+
"annotation": {
|
| 72 |
+
"human.action.task_description": {
|
| 73 |
+
"original_key": "task_index"
|
| 74 |
+
},
|
| 75 |
+
"human.validity": {}
|
| 76 |
+
}
|
| 77 |
+
}
|
examples/SimplerEnv/convert_av1_to_h264.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 20 |
+
import os
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
import subprocess
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
VIDEO_EXTS = {".mp4", ".mov", ".mkv"}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def run(cmd):
|
| 29 |
+
return subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def is_av1(path: Path) -> bool:
|
| 33 |
+
cmd = [
|
| 34 |
+
"ffprobe",
|
| 35 |
+
"-v",
|
| 36 |
+
"error",
|
| 37 |
+
"-select_streams",
|
| 38 |
+
"v:0",
|
| 39 |
+
"-show_entries",
|
| 40 |
+
"stream=codec_name",
|
| 41 |
+
"-of",
|
| 42 |
+
"default=nw=1:nk=1",
|
| 43 |
+
str(path),
|
| 44 |
+
]
|
| 45 |
+
proc = run(cmd)
|
| 46 |
+
if proc.returncode != 0:
|
| 47 |
+
print(f"[ffprobe FAIL] {path}: {proc.stderr.strip()}")
|
| 48 |
+
return False
|
| 49 |
+
codec = proc.stdout.strip()
|
| 50 |
+
return codec in ("av01", "av1")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def convert_file(path: Path):
|
| 54 |
+
if not is_av1(path):
|
| 55 |
+
print(f"[SKIP] {path}")
|
| 56 |
+
return
|
| 57 |
+
|
| 58 |
+
tmp = path.with_suffix(path.suffix + ".mp4")
|
| 59 |
+
print(f"[CONVERT] {path} -> {tmp}")
|
| 60 |
+
|
| 61 |
+
cmd = [
|
| 62 |
+
"ffmpeg",
|
| 63 |
+
"-y",
|
| 64 |
+
"-i",
|
| 65 |
+
str(path),
|
| 66 |
+
"-c:v",
|
| 67 |
+
"libx264",
|
| 68 |
+
"-qp",
|
| 69 |
+
"0",
|
| 70 |
+
"-pix_fmt",
|
| 71 |
+
"yuv420p",
|
| 72 |
+
"-c:a",
|
| 73 |
+
"copy",
|
| 74 |
+
"-vsync",
|
| 75 |
+
"passthrough",
|
| 76 |
+
"-copyts",
|
| 77 |
+
"-muxdelay",
|
| 78 |
+
"0",
|
| 79 |
+
"-muxpreload",
|
| 80 |
+
"0",
|
| 81 |
+
str(tmp),
|
| 82 |
+
]
|
| 83 |
+
proc = run(cmd)
|
| 84 |
+
if proc.returncode != 0:
|
| 85 |
+
print(f"[ffmpeg FAIL] {path}: {proc.stderr.strip()}")
|
| 86 |
+
if tmp.exists():
|
| 87 |
+
tmp.unlink()
|
| 88 |
+
return
|
| 89 |
+
|
| 90 |
+
tmp.replace(path)
|
| 91 |
+
print(f"[DONE] {path}")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def find_videos(root: Path):
|
| 95 |
+
for dirpath, _, filenames in os.walk(root):
|
| 96 |
+
for name in filenames:
|
| 97 |
+
p = Path(dirpath) / name
|
| 98 |
+
if p.suffix.lower() in VIDEO_EXTS:
|
| 99 |
+
yield p
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def main():
|
| 103 |
+
ap = argparse.ArgumentParser(
|
| 104 |
+
description="Recursively convert AV1 videos to H.264 (lossless-ish) in place."
|
| 105 |
+
)
|
| 106 |
+
ap.add_argument("root", nargs="?", default=".", help="Root directory (default: .)")
|
| 107 |
+
ap.add_argument(
|
| 108 |
+
"-j",
|
| 109 |
+
"--jobs",
|
| 110 |
+
type=int,
|
| 111 |
+
default=os.cpu_count() or 4,
|
| 112 |
+
help="Number of parallel workers (default: CPU count)",
|
| 113 |
+
)
|
| 114 |
+
args = ap.parse_args()
|
| 115 |
+
|
| 116 |
+
root = Path(args.root).resolve()
|
| 117 |
+
files = list(find_videos(root))
|
| 118 |
+
print(f"Scanning {root}, found {len(files)} candidate video files")
|
| 119 |
+
|
| 120 |
+
if not files:
|
| 121 |
+
return
|
| 122 |
+
|
| 123 |
+
with ThreadPoolExecutor(max_workers=args.jobs) as ex:
|
| 124 |
+
for p in files:
|
| 125 |
+
ex.submit(convert_file, p)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
main()
|
examples/SimplerEnv/fractal_modality.json
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"state": {
|
| 3 |
+
"x": {
|
| 4 |
+
"start": 0,
|
| 5 |
+
"end": 1
|
| 6 |
+
},
|
| 7 |
+
"y": {
|
| 8 |
+
"start": 1,
|
| 9 |
+
"end": 2
|
| 10 |
+
},
|
| 11 |
+
"z": {
|
| 12 |
+
"start": 2,
|
| 13 |
+
"end": 3
|
| 14 |
+
},
|
| 15 |
+
"rx": {
|
| 16 |
+
"start": 3,
|
| 17 |
+
"end": 4
|
| 18 |
+
},
|
| 19 |
+
"ry": {
|
| 20 |
+
"start": 4,
|
| 21 |
+
"end": 5
|
| 22 |
+
},
|
| 23 |
+
"rz": {
|
| 24 |
+
"start": 5,
|
| 25 |
+
"end": 6
|
| 26 |
+
},
|
| 27 |
+
"rw": {
|
| 28 |
+
"start": 6,
|
| 29 |
+
"end": 7
|
| 30 |
+
},
|
| 31 |
+
"gripper": {
|
| 32 |
+
"start": 7,
|
| 33 |
+
"end": 8
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
"action": {
|
| 37 |
+
"x": {
|
| 38 |
+
"start": 0,
|
| 39 |
+
"end": 1
|
| 40 |
+
},
|
| 41 |
+
"y": {
|
| 42 |
+
"start": 1,
|
| 43 |
+
"end": 2
|
| 44 |
+
},
|
| 45 |
+
"z": {
|
| 46 |
+
"start": 2,
|
| 47 |
+
"end": 3
|
| 48 |
+
},
|
| 49 |
+
"roll": {
|
| 50 |
+
"start": 3,
|
| 51 |
+
"end": 4
|
| 52 |
+
},
|
| 53 |
+
"pitch": {
|
| 54 |
+
"start": 4,
|
| 55 |
+
"end": 5
|
| 56 |
+
},
|
| 57 |
+
"yaw": {
|
| 58 |
+
"start": 5,
|
| 59 |
+
"end": 6
|
| 60 |
+
},
|
| 61 |
+
"gripper": {
|
| 62 |
+
"start": 6,
|
| 63 |
+
"end": 7
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"video": {
|
| 67 |
+
"image": {
|
| 68 |
+
"original_key": "observation.images.image"
|
| 69 |
+
}
|
| 70 |
+
},
|
| 71 |
+
"annotation": {
|
| 72 |
+
"human.action.task_description": {
|
| 73 |
+
"original_key": "task_index"
|
| 74 |
+
},
|
| 75 |
+
"human.validity": {}
|
| 76 |
+
}
|
| 77 |
+
}
|
examples/finetune.sh
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
set -x -euo pipefail
|
| 4 |
+
|
| 5 |
+
NUM_GPUS="${NUM_GPUS:-1}"
|
| 6 |
+
MASTER_PORT="${MASTER_PORT:-29500}"
|
| 7 |
+
SAVE_STEPS="${SAVE_STEPS:-1000}"
|
| 8 |
+
MAX_STEPS="${MAX_STEPS:-10000}"
|
| 9 |
+
USE_WANDB="${USE_WANDB:-1}"
|
| 10 |
+
DATALOADER_NUM_WORKERS="${DATALOADER_NUM_WORKERS:-4}"
|
| 11 |
+
GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-32}"
|
| 12 |
+
SHARD_SIZE="${SHARD_SIZE:-1024}"
|
| 13 |
+
NUM_SHARDS_PER_EPOCH="${NUM_SHARDS_PER_EPOCH:-100000}"
|
| 14 |
+
EPISODE_SAMPLING_RATE="${EPISODE_SAMPLING_RATE:-0.1}"
|
| 15 |
+
|
| 16 |
+
BASE_MODEL_PATH=""
|
| 17 |
+
DATASET_PATH=""
|
| 18 |
+
MODALITY_CONFIG_PATH=""
|
| 19 |
+
EMBODIMENT_TAG=""
|
| 20 |
+
OUTPUT_DIR=""
|
| 21 |
+
EXPERIMENT_NAME=""
|
| 22 |
+
WANDB_PROJECT=""
|
| 23 |
+
STATE_DROPOUT_PROB=""
|
| 24 |
+
EXTRA_ARGS=()
|
| 25 |
+
|
| 26 |
+
usage() {
|
| 27 |
+
cat <<'EOF'
|
| 28 |
+
Usage: bash examples/finetune.sh \
|
| 29 |
+
--base-model-path <path> \
|
| 30 |
+
--dataset-path <path> \
|
| 31 |
+
--embodiment-tag <tag> \
|
| 32 |
+
--output-dir <path> \
|
| 33 |
+
[--modality-config-path <path>] \
|
| 34 |
+
[--state-dropout-prob <value>] \
|
| 35 |
+
[--save-only-model] \
|
| 36 |
+
[-- <extra launch_finetune.py args>...]
|
| 37 |
+
EOF
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
while [ "$#" -gt 0 ]; do
|
| 41 |
+
case "$1" in
|
| 42 |
+
--base-model-path)
|
| 43 |
+
BASE_MODEL_PATH="$2"
|
| 44 |
+
shift 2
|
| 45 |
+
;;
|
| 46 |
+
--dataset-path)
|
| 47 |
+
DATASET_PATH="$2"
|
| 48 |
+
shift 2
|
| 49 |
+
;;
|
| 50 |
+
--modality-config-path)
|
| 51 |
+
MODALITY_CONFIG_PATH="$2"
|
| 52 |
+
shift 2
|
| 53 |
+
;;
|
| 54 |
+
--embodiment-tag)
|
| 55 |
+
EMBODIMENT_TAG="$2"
|
| 56 |
+
shift 2
|
| 57 |
+
;;
|
| 58 |
+
--output-dir)
|
| 59 |
+
OUTPUT_DIR="$2"
|
| 60 |
+
shift 2
|
| 61 |
+
;;
|
| 62 |
+
--experiment-name)
|
| 63 |
+
EXPERIMENT_NAME="$2"
|
| 64 |
+
shift 2
|
| 65 |
+
;;
|
| 66 |
+
--wandb-project)
|
| 67 |
+
WANDB_PROJECT="$2"
|
| 68 |
+
shift 2
|
| 69 |
+
;;
|
| 70 |
+
--state-dropout-prob)
|
| 71 |
+
STATE_DROPOUT_PROB="$2"
|
| 72 |
+
shift 2
|
| 73 |
+
;;
|
| 74 |
+
--save-only-model)
|
| 75 |
+
SAVE_ONLY_MODEL=1
|
| 76 |
+
shift
|
| 77 |
+
;;
|
| 78 |
+
--help|-h)
|
| 79 |
+
usage
|
| 80 |
+
exit 0
|
| 81 |
+
;;
|
| 82 |
+
--)
|
| 83 |
+
shift
|
| 84 |
+
EXTRA_ARGS=("$@")
|
| 85 |
+
break
|
| 86 |
+
;;
|
| 87 |
+
*)
|
| 88 |
+
echo "Unknown argument: $1" >&2
|
| 89 |
+
usage >&2
|
| 90 |
+
exit 1
|
| 91 |
+
;;
|
| 92 |
+
esac
|
| 93 |
+
done
|
| 94 |
+
|
| 95 |
+
for required_var in BASE_MODEL_PATH DATASET_PATH EMBODIMENT_TAG OUTPUT_DIR; do
|
| 96 |
+
if [ -z "${!required_var}" ]; then
|
| 97 |
+
echo "Missing required argument: ${required_var}" >&2
|
| 98 |
+
usage >&2
|
| 99 |
+
exit 1
|
| 100 |
+
fi
|
| 101 |
+
done
|
| 102 |
+
|
| 103 |
+
WANDB_FLAG=()
|
| 104 |
+
if [ "$USE_WANDB" = "1" ]; then
|
| 105 |
+
WANDB_FLAG+=(--use_wandb)
|
| 106 |
+
fi
|
| 107 |
+
|
| 108 |
+
LAUNCH_CMD=(
|
| 109 |
+
gr00t/experiment/launch_finetune.py
|
| 110 |
+
--base_model_path "$BASE_MODEL_PATH"
|
| 111 |
+
--dataset_path "$DATASET_PATH"
|
| 112 |
+
--embodiment_tag "$EMBODIMENT_TAG"
|
| 113 |
+
--num_gpus "$NUM_GPUS"
|
| 114 |
+
--output_dir "$OUTPUT_DIR"
|
| 115 |
+
--save_steps "$SAVE_STEPS"
|
| 116 |
+
--save_total_limit 5
|
| 117 |
+
--max_steps "$MAX_STEPS"
|
| 118 |
+
--warmup_ratio 0.05
|
| 119 |
+
--weight_decay 1e-5
|
| 120 |
+
--learning_rate 1e-4
|
| 121 |
+
"${WANDB_FLAG[@]}"
|
| 122 |
+
--global_batch_size "$GLOBAL_BATCH_SIZE"
|
| 123 |
+
--color_jitter_params brightness 0.3 contrast 0.4 saturation 0.5 hue 0.08
|
| 124 |
+
--dataloader_num_workers "$DATALOADER_NUM_WORKERS"
|
| 125 |
+
--shard_size "$SHARD_SIZE"
|
| 126 |
+
--num_shards_per_epoch "$NUM_SHARDS_PER_EPOCH"
|
| 127 |
+
--episode_sampling_rate "$EPISODE_SAMPLING_RATE"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
if [ -n "$MODALITY_CONFIG_PATH" ]; then
|
| 131 |
+
LAUNCH_CMD+=(--modality_config_path "$MODALITY_CONFIG_PATH")
|
| 132 |
+
fi
|
| 133 |
+
if [ -n "$EXPERIMENT_NAME" ]; then
|
| 134 |
+
LAUNCH_CMD+=(--experiment_name "$EXPERIMENT_NAME")
|
| 135 |
+
fi
|
| 136 |
+
if [ -n "$WANDB_PROJECT" ]; then
|
| 137 |
+
LAUNCH_CMD+=(--wandb_project "$WANDB_PROJECT")
|
| 138 |
+
fi
|
| 139 |
+
|
| 140 |
+
if [ -n "$STATE_DROPOUT_PROB" ]; then
|
| 141 |
+
LAUNCH_CMD+=(--state_dropout_prob "$STATE_DROPOUT_PROB")
|
| 142 |
+
fi
|
| 143 |
+
if [ -n "${SAVE_ONLY_MODEL:-}" ]; then
|
| 144 |
+
LAUNCH_CMD+=(--save_only_model)
|
| 145 |
+
fi
|
| 146 |
+
|
| 147 |
+
if [ "${#EXTRA_ARGS[@]}" -gt 0 ]; then
|
| 148 |
+
LAUNCH_CMD+=("${EXTRA_ARGS[@]}")
|
| 149 |
+
fi
|
| 150 |
+
|
| 151 |
+
if [ "$NUM_GPUS" = "1" ]; then
|
| 152 |
+
# Restrict to a single GPU so HF Trainer doesn't wrap the model in DataParallel,
|
| 153 |
+
# which crashes with a StopIteration error in the model's device property.
|
| 154 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
|
| 155 |
+
exec python "${LAUNCH_CMD[@]}"
|
| 156 |
+
fi
|
| 157 |
+
|
| 158 |
+
exec torchrun --nproc_per_node="$NUM_GPUS" --master_port="$MASTER_PORT" "${LAUNCH_CMD[@]}"
|
examples/mask-guided-background-suppression/README.md
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mask-Guided Background Suppression
|
| 2 |
+
|
| 3 |
+
Mask-guided augmentations leverage per-frame segmentation masks to apply targeted image transformations during training. This enables **domain randomization** on specific regions (e.g., replacing backgrounds with noise, tinting foreground objects) without affecting the rest of the image.
|
| 4 |
+
|
| 5 |
+
This feature is controlled via the `--extra_augmentation_config` argument, which accepts a JSON string specifying which mask regions to augment and how.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## Prerequisites
|
| 10 |
+
|
| 11 |
+
1. **Segmentation masks** must be pre-generated and stored alongside your dataset. The dataset's `info.json` must include a `mask_path` template, and `modality.json` must define a `"mask"` section mapping camera views.
|
| 12 |
+
|
| 13 |
+
2. **Albumentations transforms** are enabled by default in N1.7 (`use_albumentations_transforms=True` in model config). No extra flag is needed.
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## Supported Augmentation Types
|
| 18 |
+
|
| 19 |
+
### 1. Background Noise Transform
|
| 20 |
+
|
| 21 |
+
Replaces pixels in specified mask regions with **random RGB noise**. Useful for sim-to-real transfer or preventing the model from overfitting to static backgrounds.
|
| 22 |
+
|
| 23 |
+
| Parameter | Type | Description |
|
| 24 |
+
|-----------|------|-------------|
|
| 25 |
+
| `target_mask_values` | `list[int]` | Mask label values to replace with noise (e.g., `[0]` for background) |
|
| 26 |
+
| `p` | `float` | Probability of applying the transform per frame (0.0 to 1.0) |
|
| 27 |
+
|
| 28 |
+
### 2. Masked Region Color Transform
|
| 29 |
+
|
| 30 |
+
Applies a **random color tint** to pixels in specified mask regions. Useful for augmenting the appearance of specific objects (e.g., tables, tools) to improve color generalization.
|
| 31 |
+
|
| 32 |
+
| Parameter | Type | Description |
|
| 33 |
+
|-----------|------|-------------|
|
| 34 |
+
| `target_mask_values` | `list[int]` | Mask label values to apply the tint to (e.g., `[4]`, `[5]`) |
|
| 35 |
+
| `p` | `float` | Probability of applying the transform per frame (0.0 to 1.0) |
|
| 36 |
+
| `alpha_range` | `[min, max]` | Range for blending intensity between original and tint color (default: `[0.3, 1.0]`) |
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## Configuration Format
|
| 41 |
+
|
| 42 |
+
The `--extra_augmentation_config` argument takes a JSON string with two optional keys:
|
| 43 |
+
|
| 44 |
+
```json
|
| 45 |
+
{
|
| 46 |
+
"background_noise_transforms": [
|
| 47 |
+
{"target_mask_values": [0], "p": 0.9}
|
| 48 |
+
],
|
| 49 |
+
"masked_region_transforms": [
|
| 50 |
+
{"target_mask_values": [4], "p": 1.0, "alpha_range": [0.0, 1.0]}
|
| 51 |
+
]
|
| 52 |
+
}
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
Multiple transforms of each type can be specified (e.g., different mask values with different probabilities).
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## Quick Start with Demo Data
|
| 60 |
+
|
| 61 |
+
The included demo dataset `demo_data/cube_to_bowl_5_with_mask` contains a single episode with front and wrist camera views, along with pre-generated segmentation masks. The masks were generated using [SAM 3](https://github.com/facebookresearch/sam3) with the text prompt `"background"`, then converted so that background pixels = `0` and foreground pixels = `1` (see [Generating mask files](#generating-mask-files) below).
|
| 62 |
+
|
| 63 |
+
### 1. Background noise only
|
| 64 |
+
|
| 65 |
+
Replace background (mask=0) with random noise:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
uv run python test_extra_augmentation.py \
|
| 69 |
+
--dataset_path ../../demo_data/cube_to_bowl_5_with_mask \
|
| 70 |
+
--embodiment_tag NEW_EMBODIMENT \
|
| 71 |
+
--modality_config_path so101_config.py \
|
| 72 |
+
--extra_augmentation_config '{"background_noise_transforms": [{"target_mask_values": [0], "p": 1.0}]}'
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
### 2. Background noise + foreground color tint
|
| 76 |
+
|
| 77 |
+
Apply both transforms together:
|
| 78 |
+
|
| 79 |
+
```bash
|
| 80 |
+
uv run python test_extra_augmentation.py \
|
| 81 |
+
--dataset_path ../../demo_data/cube_to_bowl_5_with_mask \
|
| 82 |
+
--embodiment_tag NEW_EMBODIMENT \
|
| 83 |
+
--modality_config_path so101_config.py \
|
| 84 |
+
--extra_augmentation_config '{"background_noise_transforms": [{"target_mask_values": [0], "p": 1.0}], "masked_region_transforms": [{"target_mask_values": [1], "p": 1.0, "alpha_range": [0.3, 1.0]}]}' \
|
| 85 |
+
--output_dir /tmp/augmentation_vis --num_frames 5
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
Both commands save side-by-side comparison images (**Original | Augmented | Mask**) under `output_dir/<view_name>/`, with frames sampled evenly across the episode.
|
| 89 |
+
|
| 90 |
+
### 3. Fine-tune with mask-guided augmentation
|
| 91 |
+
|
| 92 |
+
```bash
|
| 93 |
+
export NUM_GPUS=8
|
| 94 |
+
|
| 95 |
+
torchrun --nproc_per_node=$NUM_GPUS --master_port=29500 \
|
| 96 |
+
gr00t/experiment/launch_finetune.py \
|
| 97 |
+
--base-model-path nvidia/GR00T-N1.7-3B \
|
| 98 |
+
--dataset-path <YOUR_DATASET_WITH_MASKS> \
|
| 99 |
+
--embodiment-tag <YOUR_EMBODIMENT_TAG> \
|
| 100 |
+
--num-gpus $NUM_GPUS \
|
| 101 |
+
--output-dir /tmp/mask_augmentation_run \
|
| 102 |
+
--save-steps 1000 \
|
| 103 |
+
--save-total-limit 5 \
|
| 104 |
+
--max-steps 20000 \
|
| 105 |
+
--warmup-ratio 0.05 \
|
| 106 |
+
--weight-decay 1e-5 \
|
| 107 |
+
--learning-rate 1e-4 \
|
| 108 |
+
--use-wandb \
|
| 109 |
+
--global-batch-size 640 \
|
| 110 |
+
--dataloader-num-workers 4 \
|
| 111 |
+
--extra-augmentation-config '{"background_noise_transforms": [{"target_mask_values": [0], "p": 0.9}], "masked_region_transforms": [{"target_mask_values": [4], "p": 1.0, "alpha_range": [0, 1]}]}'
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
---
|
| 115 |
+
|
| 116 |
+
## Dataset Setup
|
| 117 |
+
|
| 118 |
+
To use mask-guided augmentation with your own dataset, ensure:
|
| 119 |
+
|
| 120 |
+
1. **Mask files** are stored as `.npz` files under a `masks/` directory, following the same chunk/episode structure as videos. Each `.npz` contains a single `uint8` array of shape `(num_frames, H, W)` where each pixel holds an integer semantic label (e.g., `0` = background, `1` = object A, `2` = object B).
|
| 121 |
+
|
| 122 |
+
```
|
| 123 |
+
masks/
|
| 124 |
+
└── chunk-000/
|
| 125 |
+
└── observation.images.front/
|
| 126 |
+
└── episode_000000_masks.npz
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
See [Generating mask files](#generating-mask-files) below for how to produce these files.
|
| 130 |
+
|
| 131 |
+
2. **`info.json`** includes a `mask_path` template:
|
| 132 |
+
|
| 133 |
+
```json
|
| 134 |
+
{
|
| 135 |
+
"mask_path": "masks/chunk-{episode_chunk:03d}/{mask_key}/episode_{episode_index:06d}_masks.npz"
|
| 136 |
+
}
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
3. **`modality.json`** includes a `"mask"` section mapping view names to their original keys. The keys should match the actual camera view names in your dataset:
|
| 140 |
+
|
| 141 |
+
```json
|
| 142 |
+
{
|
| 143 |
+
"mask": {
|
| 144 |
+
"<view_name>": {
|
| 145 |
+
"original_key": "<observation.images.xxx>"
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
For example, if your dataset has `front` and `wrist` cameras:
|
| 152 |
+
|
| 153 |
+
```json
|
| 154 |
+
{
|
| 155 |
+
"mask": {
|
| 156 |
+
"front": {
|
| 157 |
+
"original_key": "observation.images.front"
|
| 158 |
+
},
|
| 159 |
+
"wrist": {
|
| 160 |
+
"original_key": "observation.images.wrist"
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
---
|
| 167 |
+
|
| 168 |
+
## Generating Mask Files
|
| 169 |
+
|
| 170 |
+
You can generate mask files using any video segmentation model that produces per-pixel labels. The demo masks in this example were created with [SAM 3](https://github.com/facebookresearch/sam3) (see the [video predictor example](https://github.com/facebookresearch/sam3/blob/main/examples/sam3_video_predictor_example.ipynb) for SAM 3 usage). The workflow was:
|
| 171 |
+
|
| 172 |
+
1. Run SAM 3 on each episode video with a text prompt such as `"background"`. SAM 3 returns per-frame binary masks via `propagate_in_video`.
|
| 173 |
+
2. Convert the binary masks into the label format expected by this pipeline (`0` = background, non-zero = foreground categories) and save as `.npz`:
|
| 174 |
+
|
| 175 |
+
```python
|
| 176 |
+
import numpy as np
|
| 177 |
+
|
| 178 |
+
# sam3_binary_masks: (num_frames, H, W) bool array from SAM 3 (True where prompt matched)
|
| 179 |
+
# For a "background" prompt, invert so that background=0 and foreground=1:
|
| 180 |
+
label_masks = (~sam3_binary_masks).astype(np.uint8)
|
| 181 |
+
|
| 182 |
+
# For multiple prompts, merge into one label array instead:
|
| 183 |
+
# label_masks = np.zeros((num_frames, H, W), dtype=np.uint8)
|
| 184 |
+
# label_masks[prompt_0_masks] = 1
|
| 185 |
+
# label_masks[prompt_1_masks] = 2
|
| 186 |
+
|
| 187 |
+
np.savez_compressed("episode_000000_masks.npz", label_masks)
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
The pipeline loads the array from the `.npz` file (it expects the key `arr_0`, which is the default for `np.savez_compressed`). A single `.npy` file containing the `(num_frames, H, W)` array also works.
|
| 191 |
+
|
| 192 |
+
---
|
| 193 |
+
|
| 194 |
+
## How It Works
|
| 195 |
+
|
| 196 |
+
The augmentation pipeline applies mask-based transforms **per-frame** before the standard augmentations (crop, resize, color jitter, etc.):
|
| 197 |
+
|
| 198 |
+
1. For each frame, the corresponding segmentation mask is loaded.
|
| 199 |
+
2. `BackgroundNoiseTransform` replaces all pixels where `mask == target_value` with random RGB noise.
|
| 200 |
+
3. `MaskedColorTransform` blends a random color into all pixels where `mask == target_value`, controlled by `alpha_range`.
|
| 201 |
+
4. Standard augmentations (shared across views via replay) are then applied on top.
|
| 202 |
+
|
| 203 |
+
This ordering ensures that mask-guided augmentations are applied independently per frame, while standard augmentations remain consistent across camera views within the same timestep.
|
examples/mask-guided-background-suppression/so101_config.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from gr00t.configs.data.embodiment_configs import register_modality_config
|
| 17 |
+
from gr00t.data.embodiment_tags import EmbodimentTag
|
| 18 |
+
from gr00t.data.types import (
|
| 19 |
+
ActionConfig,
|
| 20 |
+
ActionFormat,
|
| 21 |
+
ActionRepresentation,
|
| 22 |
+
ActionType,
|
| 23 |
+
ModalityConfig,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
so101_config = {
|
| 28 |
+
"video": ModalityConfig(
|
| 29 |
+
delta_indices=[0],
|
| 30 |
+
modality_keys=["front", "wrist"],
|
| 31 |
+
),
|
| 32 |
+
"mask": ModalityConfig(
|
| 33 |
+
delta_indices=[0],
|
| 34 |
+
modality_keys=["front", "wrist"],
|
| 35 |
+
),
|
| 36 |
+
"state": ModalityConfig(
|
| 37 |
+
delta_indices=[0],
|
| 38 |
+
modality_keys=["single_arm", "gripper"],
|
| 39 |
+
),
|
| 40 |
+
"action": ModalityConfig(
|
| 41 |
+
delta_indices=list(range(16)),
|
| 42 |
+
modality_keys=["single_arm", "gripper"],
|
| 43 |
+
action_configs=[
|
| 44 |
+
ActionConfig(
|
| 45 |
+
rep=ActionRepresentation.RELATIVE,
|
| 46 |
+
type=ActionType.NON_EEF,
|
| 47 |
+
format=ActionFormat.DEFAULT,
|
| 48 |
+
),
|
| 49 |
+
ActionConfig(
|
| 50 |
+
rep=ActionRepresentation.ABSOLUTE,
|
| 51 |
+
type=ActionType.NON_EEF,
|
| 52 |
+
format=ActionFormat.DEFAULT,
|
| 53 |
+
),
|
| 54 |
+
],
|
| 55 |
+
),
|
| 56 |
+
"language": ModalityConfig(
|
| 57 |
+
delta_indices=[0],
|
| 58 |
+
modality_keys=["annotation.human.task_description"],
|
| 59 |
+
),
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
register_modality_config(so101_config, embodiment_tag=EmbodimentTag.NEW_EMBODIMENT)
|
examples/mask-guided-background-suppression/test_extra_augmentation.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Smoke test: apply extra_augmentation_config to raw frames and save comparison images."""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import importlib
|
| 24 |
+
import json
|
| 25 |
+
import os
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
import sys
|
| 28 |
+
|
| 29 |
+
from gr00t.configs.data.embodiment_configs import MODALITY_CONFIGS
|
| 30 |
+
from gr00t.data.dataset.lerobot_episode_loader import LeRobotEpisodeLoader
|
| 31 |
+
from gr00t.data.embodiment_tags import EmbodimentTag
|
| 32 |
+
from gr00t.model.gr00t_n1d7.image_augmentations import (
|
| 33 |
+
apply_with_replay,
|
| 34 |
+
build_image_transformations_albumentations,
|
| 35 |
+
)
|
| 36 |
+
import numpy as np
|
| 37 |
+
from PIL import Image
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def save_comparison(original, augmented, mask, output_path):
|
| 41 |
+
orig_arr = np.array(original)
|
| 42 |
+
aug_arr = augmented.transpose(1, 2, 0) if augmented.shape[0] == 3 else augmented
|
| 43 |
+
|
| 44 |
+
panels = [orig_arr, aug_arr]
|
| 45 |
+
if mask is not None:
|
| 46 |
+
mask_vis = np.where(mask[..., None] > 0, 255, 0).astype(np.uint8)
|
| 47 |
+
mask_vis = np.broadcast_to(mask_vis, (*mask.shape[:2], 3)).copy()
|
| 48 |
+
if mask_vis.shape[:2] != orig_arr.shape[:2]:
|
| 49 |
+
mask_vis = np.array(
|
| 50 |
+
Image.fromarray(mask_vis).resize(
|
| 51 |
+
(orig_arr.shape[1], orig_arr.shape[0]), Image.NEAREST
|
| 52 |
+
)
|
| 53 |
+
)
|
| 54 |
+
panels.append(mask_vis)
|
| 55 |
+
|
| 56 |
+
h = panels[0].shape[0]
|
| 57 |
+
resized = []
|
| 58 |
+
for p in panels:
|
| 59 |
+
if p.shape[0] != h:
|
| 60 |
+
new_w = int(p.shape[1] * h / p.shape[0])
|
| 61 |
+
p = np.array(Image.fromarray(p).resize((new_w, h), Image.BILINEAR))
|
| 62 |
+
resized.append(p)
|
| 63 |
+
|
| 64 |
+
Image.fromarray(np.concatenate(resized, axis=1)).save(output_path)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def main():
|
| 68 |
+
parser = argparse.ArgumentParser()
|
| 69 |
+
parser.add_argument("--dataset_path", required=True)
|
| 70 |
+
parser.add_argument("--embodiment_tag", required=True)
|
| 71 |
+
parser.add_argument("--modality_config_path", default=None)
|
| 72 |
+
parser.add_argument("--extra_augmentation_config", type=str, required=True)
|
| 73 |
+
parser.add_argument("--output_dir", type=str, default="/tmp/augmentation_vis")
|
| 74 |
+
parser.add_argument("--num_frames", type=int, default=5)
|
| 75 |
+
parser.add_argument("--video_backend", type=str, default="torchcodec")
|
| 76 |
+
args = parser.parse_args()
|
| 77 |
+
|
| 78 |
+
if args.modality_config_path:
|
| 79 |
+
path = Path(args.modality_config_path)
|
| 80 |
+
sys.path.append(str(path.parent))
|
| 81 |
+
importlib.import_module(path.stem)
|
| 82 |
+
|
| 83 |
+
embodiment_tag = EmbodimentTag[args.embodiment_tag].value
|
| 84 |
+
modality_configs = MODALITY_CONFIGS[embodiment_tag]
|
| 85 |
+
extra_aug_config = json.loads(args.extra_augmentation_config)
|
| 86 |
+
|
| 87 |
+
train_transform, _ = build_image_transformations_albumentations(
|
| 88 |
+
image_target_size=[224, 224],
|
| 89 |
+
image_crop_size=[224, 224],
|
| 90 |
+
random_rotation_angle=0,
|
| 91 |
+
color_jitter_params=None,
|
| 92 |
+
shortest_image_edge=512,
|
| 93 |
+
crop_fraction=0.95,
|
| 94 |
+
extra_augmentation_config=extra_aug_config,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
loader = LeRobotEpisodeLoader(
|
| 98 |
+
dataset_path=args.dataset_path,
|
| 99 |
+
modality_configs=modality_configs,
|
| 100 |
+
video_backend=args.video_backend,
|
| 101 |
+
)
|
| 102 |
+
episode_df = loader[0]
|
| 103 |
+
|
| 104 |
+
video_cols = [c for c in episode_df.columns if c.startswith("video.")]
|
| 105 |
+
mask_cols = [c for c in episode_df.columns if c.startswith("mask.")]
|
| 106 |
+
print(f"Video columns: {video_cols}")
|
| 107 |
+
print(f"Mask columns: {mask_cols}")
|
| 108 |
+
|
| 109 |
+
num_frames = min(args.num_frames, len(episode_df))
|
| 110 |
+
frame_indices = np.linspace(0, len(episode_df) - 1, num_frames, dtype=int)
|
| 111 |
+
|
| 112 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 113 |
+
|
| 114 |
+
for vcol in video_cols:
|
| 115 |
+
view_name = vcol.replace("video.", "")
|
| 116 |
+
mcol = f"mask.{view_name}"
|
| 117 |
+
has_mask = mcol in mask_cols
|
| 118 |
+
|
| 119 |
+
view_dir = os.path.join(args.output_dir, view_name.replace(".", "_"))
|
| 120 |
+
os.makedirs(view_dir, exist_ok=True)
|
| 121 |
+
|
| 122 |
+
for fidx in frame_indices:
|
| 123 |
+
orig_img = episode_df[vcol].iloc[fidx]
|
| 124 |
+
mask_arr = np.array(episode_df[mcol].iloc[fidx]) if has_mask else None
|
| 125 |
+
masks_list = [mask_arr] if mask_arr is not None else None
|
| 126 |
+
|
| 127 |
+
transformed, _ = apply_with_replay(train_transform, [orig_img], masks_list)
|
| 128 |
+
aug_arr = transformed[0].numpy()
|
| 129 |
+
|
| 130 |
+
out_path = os.path.join(view_dir, f"frame_{fidx:04d}.png")
|
| 131 |
+
save_comparison(orig_img, aug_arr, mask_arr, out_path)
|
| 132 |
+
print(f" Saved: {out_path}")
|
| 133 |
+
|
| 134 |
+
print(f"\nDone! {num_frames} frames x {len(video_cols)} views saved to {args.output_dir}")
|
| 135 |
+
|
| 136 |
+
print("\n" + "=" * 60)
|
| 137 |
+
print("Testing full training pipeline (processor + dataloader) ...")
|
| 138 |
+
print("=" * 60)
|
| 139 |
+
|
| 140 |
+
from gr00t.configs.base_config import get_default_config
|
| 141 |
+
from gr00t.data.dataset.factory import DatasetFactory
|
| 142 |
+
from gr00t.model.gr00t_n1d7.processing_gr00t_n1d7 import Gr00tN1d7Processor
|
| 143 |
+
|
| 144 |
+
config = get_default_config()
|
| 145 |
+
config = config.load_dict(
|
| 146 |
+
{
|
| 147 |
+
"data": {
|
| 148 |
+
"download_cache": False,
|
| 149 |
+
"video_backend": args.video_backend,
|
| 150 |
+
"datasets": [
|
| 151 |
+
{
|
| 152 |
+
"dataset_paths": [args.dataset_path],
|
| 153 |
+
"mix_ratio": 1.0,
|
| 154 |
+
"embodiment_tag": embodiment_tag,
|
| 155 |
+
}
|
| 156 |
+
],
|
| 157 |
+
}
|
| 158 |
+
}
|
| 159 |
+
)
|
| 160 |
+
config.model.extra_augmentation_config = extra_aug_config
|
| 161 |
+
config.model.use_albumentations_transforms = True
|
| 162 |
+
|
| 163 |
+
processor = Gr00tN1d7Processor(
|
| 164 |
+
modality_configs=config.data.modality_configs,
|
| 165 |
+
statistics=None,
|
| 166 |
+
image_crop_size=config.model.image_crop_size,
|
| 167 |
+
image_target_size=config.model.image_target_size,
|
| 168 |
+
random_rotation_angle=config.model.random_rotation_angle,
|
| 169 |
+
color_jitter_params=config.model.color_jitter_params,
|
| 170 |
+
model_name=config.model.model_name,
|
| 171 |
+
model_type=config.model.backbone_model_type,
|
| 172 |
+
formalize_language=config.model.formalize_language,
|
| 173 |
+
max_state_dim=config.model.max_state_dim,
|
| 174 |
+
max_action_dim=config.model.max_action_dim,
|
| 175 |
+
apply_sincos_state_encoding=config.model.apply_sincos_state_encoding,
|
| 176 |
+
max_action_horizon=config.model.action_horizon,
|
| 177 |
+
use_albumentations=config.model.use_albumentations_transforms,
|
| 178 |
+
extra_augmentation_config=config.model.extra_augmentation_config,
|
| 179 |
+
shortest_image_edge=config.model.shortest_image_edge,
|
| 180 |
+
crop_fraction=config.model.crop_fraction,
|
| 181 |
+
use_relative_action=config.model.use_relative_action,
|
| 182 |
+
)
|
| 183 |
+
processor.train()
|
| 184 |
+
|
| 185 |
+
dataset_factory = DatasetFactory(config=config)
|
| 186 |
+
train_dataset, _ = dataset_factory.build(processor=processor)
|
| 187 |
+
sample = next(iter(train_dataset))
|
| 188 |
+
|
| 189 |
+
print(f"Sample keys: {list(sample.keys())}")
|
| 190 |
+
print(f"VLM keys: {list(sample['vlm_content'].keys())}")
|
| 191 |
+
for k, v in sample.items():
|
| 192 |
+
if hasattr(v, "shape"):
|
| 193 |
+
print(f" {k}: shape={v.shape}, dtype={v.dtype}")
|
| 194 |
+
print("\nPipeline test PASSED!")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
if __name__ == "__main__":
|
| 198 |
+
main()
|
getting_started/data_config.md
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# How to prepare your modality configuration
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
The modality configuration defines how your robot's data should be loaded, processed, and interpreted by the model. This configuration bridges your dataset's physical structure (defined in `meta/modality.json`) and the model's data processing pipeline.
|
| 6 |
+
|
| 7 |
+
Each embodiment requires a Python configuration file that specifies:
|
| 8 |
+
- Which observations to use (video cameras, proprioceptive states)
|
| 9 |
+
- How to sample data temporally (current frame, historical frames, future action horizons)
|
| 10 |
+
- How actions should be interpreted and transformed
|
| 11 |
+
- Which language annotations to use
|
| 12 |
+
|
| 13 |
+
## Configuration Structure
|
| 14 |
+
|
| 15 |
+
A modality configuration is a Python dictionary containing four top-level keys: `"video"`, `"state"`, `"action"`, and `"language"`. Each key maps to a `ModalityConfig` object.
|
| 16 |
+
|
| 17 |
+
Here's the [SO-100 example](../examples/SO100/so100_config.py):
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
from gr00t.configs.data.embodiment_configs import register_modality_config
|
| 21 |
+
from gr00t.data.types import ModalityConfig, ActionConfig, ActionRepresentation, ActionType, ActionFormat
|
| 22 |
+
|
| 23 |
+
so100_config = {
|
| 24 |
+
"video": ModalityConfig(...),
|
| 25 |
+
"state": ModalityConfig(...),
|
| 26 |
+
"action": ModalityConfig(...),
|
| 27 |
+
"language": ModalityConfig(...),
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
register_modality_config(so100_config, embodiment_tag=EmbodimentTag.NEW_EMBODIMENT)
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
## Understanding `ModalityConfig`
|
| 34 |
+
|
| 35 |
+
Each `ModalityConfig` specifies two required fields and several optional ones:
|
| 36 |
+
|
| 37 |
+
### Required Fields
|
| 38 |
+
|
| 39 |
+
**1. `delta_indices` (list[int])**
|
| 40 |
+
|
| 41 |
+
Defines which temporal offsets to sample relative to the current timestep:
|
| 42 |
+
- Current observation: Use [0] for the current timestep (recommended for video and state)
|
| 43 |
+
- Future actions: Use positive indices (e.g., list(range(0, 16))) for action prediction horizons
|
| 44 |
+
|
| 45 |
+
> **Note:** Negative indices (e.g., [-2, -1, 0]) are supported by the data loader for historical context, but no current N1.7 embodiment config uses them. Stick with [0] for video and state unless you have a specific reason to stack frames.
|
| 46 |
+
|
| 47 |
+
Examples:
|
| 48 |
+
```python
|
| 49 |
+
# Single current frame for video
|
| 50 |
+
delta_indices=[0]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# 16-step action prediction horizon
|
| 54 |
+
delta_indices=list(range(0, 16))
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
> **Note:** If you modify `delta_indices` for the action modality (e.g., changing the action horizon from 16 to 8), you **must** regenerate the dataset statistics by re-running `python gr00t/data/stats.py --dataset-path <dataset_path> --embodiment-tag <embodiment_tag>`. The normalization statistics (especially `meta/relative_stats.json`) are computed based on the original `delta_indices` length, and a mismatch will cause errors during training.
|
| 58 |
+
|
| 59 |
+
<details>
|
| 60 |
+
<summary>Example: What happens if you change <code>delta_indices</code> without regenerating stats?</summary>
|
| 61 |
+
|
| 62 |
+
Suppose your action config originally uses a 16-step horizon:
|
| 63 |
+
|
| 64 |
+
```python
|
| 65 |
+
"action": ModalityConfig(
|
| 66 |
+
delta_indices=list(range(0, 16)), # 16 steps
|
| 67 |
+
...
|
| 68 |
+
)
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
Running `python gr00t/data/stats.py` generates `meta/relative_stats.json` with per-step statistics of shape `(16, D)`, where `D` is the action dimension.
|
| 72 |
+
|
| 73 |
+
If you later change the horizon to 8 steps:
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
"action": ModalityConfig(
|
| 77 |
+
delta_indices=list(range(0, 8)), # 8 steps
|
| 78 |
+
...
|
| 79 |
+
)
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
The training data will now have shape `(8, D)`, but the normalization parameters from `relative_stats.json` still have shape `(16, D)`. This dimension mismatch causes an `IndexError` during normalization:
|
| 83 |
+
|
| 84 |
+
```
|
| 85 |
+
IndexError: boolean index did not match indexed array along dimension 0;
|
| 86 |
+
dimension is 8 but corresponding boolean dimension is 16
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
**Fix:** Re-run `python gr00t/data/stats.py --dataset-path <dataset_path> --embodiment-tag <embodiment_tag>` after changing `delta_indices` to regenerate matching statistics.
|
| 90 |
+
|
| 91 |
+
</details>
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
**2. `modality_keys` (list[str])**
|
| 95 |
+
|
| 96 |
+
Specifies which keys to load from your dataset. These keys **must match** the keys defined in your `meta/modality.json` file.
|
| 97 |
+
|
| 98 |
+
For the SO-100 example:
|
| 99 |
+
- **Video keys**: Must match keys in `meta/modality.json` under `"video"` (e.g., `"front"`, `"wrist"`)
|
| 100 |
+
- **State keys**: Must match keys in `meta/modality.json` under `"state"` (e.g., `"single_arm"`, `"gripper"`)
|
| 101 |
+
- **Action keys**: Must match keys in `meta/modality.json` under `"action"` (e.g., `"single_arm"`, `"gripper"`)
|
| 102 |
+
- **Language keys**: Must match keys in `meta/modality.json` under `"annotation"` (e.g., `"annotation.human.task_description"` for SO-100)
|
| 103 |
+
|
| 104 |
+
### Optional Fields
|
| 105 |
+
|
| 106 |
+
**3. `sin_cos_embedding_keys` (list[str] | None)**
|
| 107 |
+
|
| 108 |
+
Specifies which state keys should use sine/cosine encoding. Best for dimensions that are in radians (e.g., joint angles). If not specified, min-max normalization is used. Note that this will duplicate the number of dimensions by 2, and is only recommended for proprioceptive states.
|
| 109 |
+
|
| 110 |
+
```python
|
| 111 |
+
"state": ModalityConfig(
|
| 112 |
+
delta_indices=[0],
|
| 113 |
+
modality_keys=["single_arm", "gripper"],
|
| 114 |
+
sin_cos_embedding_keys=["single_arm"], # Apply sin/cos to joint angles
|
| 115 |
+
)
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
**4. `mean_std_embedding_keys` (list[str] | None)**
|
| 119 |
+
|
| 120 |
+
Specifies which keys should use mean/standard deviation normalization instead of min-max normalization.
|
| 121 |
+
|
| 122 |
+
**5. `action_configs` (list[ActionConfig] | None)**
|
| 123 |
+
|
| 124 |
+
Required for the `"action"` modality. Defines how each action modality should be interpreted and transformed. The list must have the **same length and same order** as `modality_keys` — `action_configs[0]` applies to `modality_keys[0]`, `action_configs[1]` to `modality_keys[1]`, etc. A mismatch in ordering will silently apply the wrong representation (e.g., RELATIVE to a gripper that should be ABSOLUTE). See more details in the [Action Modality](#understanding-actionconfig) section.
|
| 125 |
+
|
| 126 |
+
## Configuring Each Modality
|
| 127 |
+
|
| 128 |
+
### Video Modality
|
| 129 |
+
|
| 130 |
+
Defines which camera views to use:
|
| 131 |
+
|
| 132 |
+
```python
|
| 133 |
+
"video": ModalityConfig(
|
| 134 |
+
delta_indices=[0], # Current frame only
|
| 135 |
+
modality_keys=[
|
| 136 |
+
"front", # Must match a key in meta/modality.json under "video"
|
| 137 |
+
],
|
| 138 |
+
)
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
For multiple cameras:
|
| 142 |
+
```python
|
| 143 |
+
"video": ModalityConfig(
|
| 144 |
+
delta_indices=[0],
|
| 145 |
+
modality_keys=["front", "wrist"],
|
| 146 |
+
)
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
### State Modality
|
| 150 |
+
|
| 151 |
+
Defines proprioceptive observations (joint positions, gripper states, etc.):
|
| 152 |
+
|
| 153 |
+
```python
|
| 154 |
+
"state": ModalityConfig(
|
| 155 |
+
delta_indices=[0], # Current state
|
| 156 |
+
modality_keys=[
|
| 157 |
+
"single_arm", # Must match keys in meta/modality.json under "state"
|
| 158 |
+
"gripper",
|
| 159 |
+
],
|
| 160 |
+
)
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
### Action Modality
|
| 164 |
+
|
| 165 |
+
Defines the action space and prediction horizon:
|
| 166 |
+
|
| 167 |
+
```python
|
| 168 |
+
"action": ModalityConfig(
|
| 169 |
+
delta_indices=list(range(0, 16)), # Predict 16 steps into the future
|
| 170 |
+
modality_keys=[
|
| 171 |
+
"single_arm", # Must match keys in meta/modality.json under "action"
|
| 172 |
+
"gripper",
|
| 173 |
+
],
|
| 174 |
+
action_configs=[
|
| 175 |
+
# One ActionConfig per modality_key
|
| 176 |
+
# single_arm
|
| 177 |
+
ActionConfig(
|
| 178 |
+
rep=ActionRepresentation.RELATIVE, # relative control of the single arm
|
| 179 |
+
type=ActionType.NON_EEF,
|
| 180 |
+
format=ActionFormat.DEFAULT,
|
| 181 |
+
),
|
| 182 |
+
# gripper
|
| 183 |
+
ActionConfig(
|
| 184 |
+
rep=ActionRepresentation.ABSOLUTE, # absolute control of the gripper
|
| 185 |
+
type=ActionType.NON_EEF,
|
| 186 |
+
format=ActionFormat.DEFAULT,
|
| 187 |
+
),
|
| 188 |
+
],
|
| 189 |
+
)
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
#### Understanding `ActionConfig`
|
| 193 |
+
|
| 194 |
+
Each `ActionConfig` has three required fields and one optional field:
|
| 195 |
+
|
| 196 |
+
**1. `rep` (ActionRepresentation)**
|
| 197 |
+
|
| 198 |
+
Defines how actions should be interpreted:
|
| 199 |
+
- `RELATIVE`: Actions are deltas from the current state (introduced in the UMI paper)
|
| 200 |
+
- `ABSOLUTE`: Actions are target positions
|
| 201 |
+
|
| 202 |
+
Using relative actions will lead to smoother actions, but might suffer from drifting. If you want to use relative actions, please make sure the state and action stored in the dataset are absolute, and the absolute to relative will be handled in the processor.
|
| 203 |
+
|
| 204 |
+
**2. `type` (ActionType)**
|
| 205 |
+
|
| 206 |
+
Specifies the control space:
|
| 207 |
+
- `EEF`: End-effector/Cartesian space control (Expecting a 9-dimensional vector: x, y, z positions + rotation 6D)
|
| 208 |
+
- `NON_EEF`: Joint space control and other non-EEF control spaces (joint angles, positions, gripper positions, etc.)
|
| 209 |
+
|
| 210 |
+
**3. `format` (ActionFormat)**
|
| 211 |
+
|
| 212 |
+
Defines the action representation format:
|
| 213 |
+
- `DEFAULT`: Standard format (e.g., joint angles, gripper positions)
|
| 214 |
+
- `XYZ_ROT6D`: 3D position + 6D rotation representation for end-effector control
|
| 215 |
+
- `XYZ_ROTVEC`: 3D position + rotation vector for end-effector control
|
| 216 |
+
|
| 217 |
+
**4. `state_key` (str | None)**
|
| 218 |
+
|
| 219 |
+
Optional. Specifies the corresponding reference state key for computing relative actions when `rep=RELATIVE`. If not provided, the system will use the action key as the reference state key.
|
| 220 |
+
|
| 221 |
+
Example with `state_key`:
|
| 222 |
+
```python
|
| 223 |
+
"joint_pos_action_left": ActionConfig(
|
| 224 |
+
rep=ActionRepresentation.RELATIVE,
|
| 225 |
+
type=ActionType.NON_EEF,
|
| 226 |
+
format=ActionFormat.DEFAULT,
|
| 227 |
+
state_key="joint_pos_obs_left", # Use this state to compute relative action
|
| 228 |
+
)
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
### Language Modality
|
| 232 |
+
|
| 233 |
+
Defines which language annotations to use:
|
| 234 |
+
|
| 235 |
+
```python
|
| 236 |
+
"language": ModalityConfig(
|
| 237 |
+
delta_indices=[0],
|
| 238 |
+
modality_keys=["annotation.human.task_description"], # Must match annotation keys in meta/modality.json
|
| 239 |
+
)
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
## Complete Example: SO-100
|
| 243 |
+
|
| 244 |
+
Here's the complete SO-100 configuration with explanations:
|
| 245 |
+
|
| 246 |
+
```python
|
| 247 |
+
so100_config = {
|
| 248 |
+
"video": ModalityConfig(
|
| 249 |
+
delta_indices=[0],
|
| 250 |
+
modality_keys=["front", "wrist"],
|
| 251 |
+
),
|
| 252 |
+
"state": ModalityConfig(
|
| 253 |
+
delta_indices=[0],
|
| 254 |
+
modality_keys=[
|
| 255 |
+
"single_arm",
|
| 256 |
+
"gripper",
|
| 257 |
+
],
|
| 258 |
+
),
|
| 259 |
+
"action": ModalityConfig(
|
| 260 |
+
delta_indices=list(range(0, 16)),
|
| 261 |
+
modality_keys=[
|
| 262 |
+
"single_arm",
|
| 263 |
+
"gripper",
|
| 264 |
+
],
|
| 265 |
+
action_configs=[
|
| 266 |
+
ActionConfig(
|
| 267 |
+
rep=ActionRepresentation.RELATIVE,
|
| 268 |
+
type=ActionType.NON_EEF,
|
| 269 |
+
format=ActionFormat.DEFAULT,
|
| 270 |
+
),
|
| 271 |
+
ActionConfig(
|
| 272 |
+
rep=ActionRepresentation.ABSOLUTE,
|
| 273 |
+
type=ActionType.NON_EEF,
|
| 274 |
+
format=ActionFormat.DEFAULT,
|
| 275 |
+
),
|
| 276 |
+
],
|
| 277 |
+
),
|
| 278 |
+
"language": ModalityConfig(
|
| 279 |
+
delta_indices=[0],
|
| 280 |
+
modality_keys=["annotation.human.task_description"],
|
| 281 |
+
),
|
| 282 |
+
}
|
| 283 |
+
```
|
| 284 |
+
|
| 285 |
+
## Key Relationships with `meta/modality.json`
|
| 286 |
+
|
| 287 |
+
The modality configuration's `modality_keys` must reference keys that exist in your dataset's `meta/modality.json`:
|
| 288 |
+
|
| 289 |
+
**Example `meta/modality.json`:**
|
| 290 |
+
```json
|
| 291 |
+
{
|
| 292 |
+
"state": {
|
| 293 |
+
"single_arm": {"start": 0, "end": 5},
|
| 294 |
+
"gripper": {"start": 5, "end": 6},
|
| 295 |
+
},
|
| 296 |
+
"action": {
|
| 297 |
+
"single_arm": {"start": 0, "end": 5},
|
| 298 |
+
"gripper": {"start": 5, "end": 6},
|
| 299 |
+
},
|
| 300 |
+
"video": {
|
| 301 |
+
"front": {"original_key": "observation.images.front"},
|
| 302 |
+
"wrist": {"original_key": "observation.images.wrist"},
|
| 303 |
+
},
|
| 304 |
+
"annotation": {
|
| 305 |
+
"human.task_description": {
|
| 306 |
+
"original_key": "task_index"
|
| 307 |
+
}
|
| 308 |
+
}
|
| 309 |
+
}
|
| 310 |
+
```
|
| 311 |
+
|
| 312 |
+
The system will:
|
| 313 |
+
1. Use `modality_keys` to look up the corresponding entries in `meta/modality.json`
|
| 314 |
+
2. Extract the correct slices from the concatenated state/action arrays
|
| 315 |
+
3. Apply the specified transformations (normalization, action representation conversion)
|
| 316 |
+
|
| 317 |
+
## Registering Your Configuration
|
| 318 |
+
|
| 319 |
+
After defining your configuration, register it so it's available to the training and inference pipelines:
|
| 320 |
+
|
| 321 |
+
```python
|
| 322 |
+
from gr00t.configs.data.embodiment_configs import register_modality_config
|
| 323 |
+
|
| 324 |
+
your_modality_config = {
|
| 325 |
+
...
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
register_modality_config(your_modality_config, embodiment_tag=EmbodimentTag.NEW_EMBODIMENT)
|
| 329 |
+
```
|
| 330 |
+
|
| 331 |
+
Save your configuration to a Python file and pass the path to the `modality_config_path` argument when running the finetuning script.
|
getting_started/data_preparation.md
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Robot Data Preparation Guide
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
This guide shows how to convert your robot data to work with our flavor of the [LeRobot dataset V2 format](https://github.com/huggingface/lerobot?tab=readme-ov-file#the-lerobotdataset-format) ([LeRobot docs](https://huggingface.co/docs/lerobot)) -- `GR00T LeRobot`. While we have added additional structure, our schema maintains full compatibility with the upstream LeRobot v2. The additional metadata and structure allow for more detailed specification and language annotations for your robot data.
|
| 6 |
+
|
| 7 |
+
> The TLDR: Add a `meta/modality.json` file to your LeRobot v2 dataset and follow the schema below.
|
| 8 |
+
|
| 9 |
+
## LeRobot v2 Requirements
|
| 10 |
+
|
| 11 |
+
If you already have a dataset in the LeRobot v2 format, you can skip this section.
|
| 12 |
+
|
| 13 |
+
If you have a dataset in the LeRobot v3.0 format, please use [this script](../scripts/lerobot_conversion/convert_v3_to_v2.py) to convert it to the LeRobot v2 format.
|
| 14 |
+
|
| 15 |
+
> **Why LeRobot v2?** GR00T currently uses the LeRobot v2 data format because many upstream datasets (DROID, LIBERO, Bridge, etc.) are published in v2. We plan to support both v2 and v3 formats natively in a future release. For now, please convert v3 datasets to v2 using the script above.
|
| 16 |
+
|
| 17 |
+
If you have a dataset in another format, please convert it to the LeRobot v2 format satisfying the following requirements.
|
| 18 |
+
|
| 19 |
+
### Structure Requirements
|
| 20 |
+
|
| 21 |
+
The folder should follow a similar structure as below and contain these core folders and files:
|
| 22 |
+
|
| 23 |
+
```
|
| 24 |
+
.
|
| 25 |
+
├─meta
|
| 26 |
+
│ ├─episodes.jsonl
|
| 27 |
+
│ ├─modality.json # -> GR00T LeRobot specific
|
| 28 |
+
│ ├─info.json
|
| 29 |
+
│ └─tasks.jsonl
|
| 30 |
+
├─videos
|
| 31 |
+
│ └─chunk-000
|
| 32 |
+
│ └─observation.images.ego_view
|
| 33 |
+
│ └─episode_000001.mp4
|
| 34 |
+
│ └─episode_000000.mp4
|
| 35 |
+
└─data
|
| 36 |
+
└─chunk-000
|
| 37 |
+
├─episode_000001.parquet
|
| 38 |
+
└─episode_000000.parquet
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Video Observations (video/chunk-*)
|
| 42 |
+
The videos folder will contain the mp4 files associated with each episode following episode_00000X.mp4 naming format where X indicates the episode number.
|
| 43 |
+
**Requirements**:
|
| 44 |
+
- Must be stored as MP4 files.
|
| 45 |
+
- Should be named using the format: `observation.images.<video_name>`
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
### Data (data/chunk-*)
|
| 49 |
+
The data folder will contain all of the parquet files associated with each episode following episode_00000X.parquet naming format where X indicates the episode number.
|
| 50 |
+
Each parquet file will contain:
|
| 51 |
+
- State information: stored as observation.state which is a 1D concatenated array of all state modalities.
|
| 52 |
+
- Action: stored as action which is a 1D concatenated array of all action modalities.
|
| 53 |
+
- Timestamp: stored as timestamp which is a float point number of the starting time.
|
| 54 |
+
- Annotations: stored as annotation.<annotation_source>.<annotation_type>(.<annotation_name>) (see the annotation field in the example configuration for example naming.). No other columns should have the annotation prefix, see the (multiple-annotation-support) if interested in adding multiple annotations.
|
| 55 |
+
|
| 56 |
+
#### Example Parquet File
|
| 57 |
+
Here is a sample of the `cube_to_bowl` dataset that is present in the [demo_data](../demo_data/cube_to_bowl_5/) directory.
|
| 58 |
+
```
|
| 59 |
+
{
|
| 60 |
+
"observation.state":[-0.01,...,0], // 1D array: all state modalities concatenated per modality.json order
|
| 61 |
+
"action":[-0.010,...,0], // 1D array: all action modalities concatenated per modality.json order
|
| 62 |
+
"timestamp":0.049, // float: wall-clock time of this observation (seconds)
|
| 63 |
+
"annotation.human.action.task_description":0, // int: index into meta/tasks.jsonl for the language instruction
|
| 64 |
+
"task_index":0, // int: task identifier (same as annotation index for single-task)
|
| 65 |
+
"annotation.human.validity":1, // int: index into meta/tasks.jsonl for validity label
|
| 66 |
+
"episode_index":0, // int: which episode this frame belongs to
|
| 67 |
+
"index":0, // int: global frame index across all episodes in the dataset
|
| 68 |
+
"next.reward":0, // float: reward at the next timestep (0 if unused)
|
| 69 |
+
"next.done":false // bool: true if this is the last frame of the episode
|
| 70 |
+
}
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
### Meta
|
| 74 |
+
|
| 75 |
+
- `episodes.jsonl` contains a list of all the episodes in the entire dataset. Each episode contains a list of tasks and the length of the episode.
|
| 76 |
+
- `tasks.jsonl` contains a list of all the tasks in the entire dataset.
|
| 77 |
+
- `info.json` contains the dataset information.
|
| 78 |
+
|
| 79 |
+
#### meta/tasks.jsonl
|
| 80 |
+
Here is a sample of the `meta/tasks.jsonl` file that contains the task descriptions.
|
| 81 |
+
```
|
| 82 |
+
{"task_index": 0, "task": "pick the squash from the counter and place it in the plate"}
|
| 83 |
+
{"task_index": 1, "task": "valid"}
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
You can refer the task index in the parquet file to get the task description. So in this case, the `annotation.human.action.task_description` for the first observation is "pick the squash from the counter and place it in the plate" and `annotation.human.validity` is "valid".
|
| 87 |
+
|
| 88 |
+
`tasks.jsonl` contains a list of all the tasks in the entire dataset.
|
| 89 |
+
|
| 90 |
+
#### meta/episodes.jsonl
|
| 91 |
+
|
| 92 |
+
Here is a sample of the `meta/episodes.jsonl` file that contains the episode information.
|
| 93 |
+
|
| 94 |
+
```
|
| 95 |
+
{"episode_index": 0, "tasks": [...], "length": 416}
|
| 96 |
+
{"episode_index": 1, "tasks": [...], "length": 470}
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
`episodes.jsonl` contains a list of all the episodes in the entire dataset. Each episode contains a list of tasks and the length of the episode.
|
| 100 |
+
|
| 101 |
+
## GR00T LeRobot Specific Requirements
|
| 102 |
+
|
| 103 |
+
### The `meta/modality.json` Configuration
|
| 104 |
+
|
| 105 |
+
We require an additional metadata file `meta/modality.json` that is not present in the standard LeRobot format. This file provides detailed metadata about state and action modalities, enabling:
|
| 106 |
+
|
| 107 |
+
- **Separate Data Storage and Interpretation:**
|
| 108 |
+
- **State and Action:** Stored as concatenated float32 arrays. The `modality.json` file supplies the metadata necessary to interpret these arrays as distinct, fine-grained fields.
|
| 109 |
+
- **Video:** Stored as separate files, with the configuration file allowing them to be renamed to a standardized format.
|
| 110 |
+
- **Annotations:** Keeps track of all annotation fields. If there are no annotations, do not include the `annotation` field in the configuration file.
|
| 111 |
+
- **Fine-Grained Splitting:** Divides the state and action arrays into more semantically meaningful fields.
|
| 112 |
+
- **Clear Mapping:** Explicit mapping of data dimensions.
|
| 113 |
+
- **Sophisticated Data Transformations:** Supports field-specific normalization and rotation transformations during training.
|
| 114 |
+
|
| 115 |
+
#### Schema
|
| 116 |
+
|
| 117 |
+
```json
|
| 118 |
+
{
|
| 119 |
+
"state": {
|
| 120 |
+
"<state_key>": {
|
| 121 |
+
"start": <int>, // Starting index in the state array
|
| 122 |
+
"end": <int> // Ending index in the state array
|
| 123 |
+
}
|
| 124 |
+
},
|
| 125 |
+
"action": {
|
| 126 |
+
"<action_key>": {
|
| 127 |
+
"start": <int>, // Starting index in the action array
|
| 128 |
+
"end": <int> // Ending index in the action array
|
| 129 |
+
}
|
| 130 |
+
},
|
| 131 |
+
"video": {
|
| 132 |
+
"<new_key>": {
|
| 133 |
+
"original_key": "<original_video_key>"
|
| 134 |
+
}
|
| 135 |
+
},
|
| 136 |
+
"annotation": {
|
| 137 |
+
"<annotation_key>": {} // Empty dictionary to maintain consistency with other modalities
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
#### Example
|
| 143 |
+
|
| 144 |
+
For a concrete example of `modality.json` and the full dataset structure, see the publicly available datasets on HuggingFace:
|
| 145 |
+
[nvidia/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim](https://huggingface.co/datasets/nvidia/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/tree/main).
|
| 146 |
+
|
| 147 |
+
You can also find a working example in the included demo data at [`demo_data/cube_to_bowl_5/meta/modality.json`](../demo_data/cube_to_bowl_5/meta/modality.json).
|
| 148 |
+
|
| 149 |
+
#### Notes
|
| 150 |
+
|
| 151 |
+
- All indices are zero-based and follow Python's array slicing convention (`[start:end]`).
|
| 152 |
+
|
| 153 |
+
## GR00T LeRobot Extensions to Standard LeRobot
|
| 154 |
+
GR00T LeRobot is a flavor of the standard LeRobot format with more opinionated requirements:
|
| 155 |
+
- We will compute `meta/stats.json` and `meta/relative_stats.json` for each dataset, and store them in the `meta` folder.
|
| 156 |
+
- Proprioceptive states must always be included in the "observation.state" keys.
|
| 157 |
+
- We support multi-channel annotation formats (e.g., coarsegrained, finetuned), allowing users to add as many annotation channels as needed via the `annotation.<annotation_source>.<annotation_type>` key.
|
| 158 |
+
- We require an additional metadata file `meta/modality.json` that is not present in the standard LeRobot format.
|
| 159 |
+
|
| 160 |
+
### Multiple Annotation Support
|
| 161 |
+
|
| 162 |
+
To support multiple annotations within a single parquet file, users may add extra columns to the parquet file. Users should treat these columns the same way as the `task_index` column in the original LeRobot v2 dataset:
|
| 163 |
+
|
| 164 |
+
In LeRobot v2, actual language descriptions are stored in a row of the `meta/tasks.jsonl` file, while the parquet file stores only the corresponding index in the `task_index` column. We follow the same convention and store the corresponding index for each annotation in the `annotation.<annotation_source>.<annotation_type>` column. Although the `task_index` column may still be used for the default annotation, a dedicated column `annotation.<annotation_source>.<annotation_type>` is required to ensure it is loadable by our custom data loader.
|
getting_started/finetune_new_embodiment.md
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fine-tune on Custom Embodiments ("NEW_EMBODIMENT")
|
| 2 |
+
|
| 3 |
+
This guide demonstrates how to finetune GR00T on your own robot data and configuration. We provide a complete example for the Huggingface [SO-100](https://github.com/TheRobotStudio/SO-ARM100) robot under `examples/SO100`, which uses `demo_data/cube_to_bowl_5` as the demo dataset.
|
| 4 |
+
|
| 5 |
+
## Step 1: Prepare Your Data
|
| 6 |
+
|
| 7 |
+
Prepare your data in **GR00T-flavored LeRobot v2 format** by following the [data preparation guide](data_preparation.md).
|
| 8 |
+
|
| 9 |
+
## Step 2: Prepare Your Modality Configuration
|
| 10 |
+
|
| 11 |
+
Define your own modality configuration by following the [modality config guide](data_config.md). Below is an example configuration that corresponds to the demo data:
|
| 12 |
+
```python
|
| 13 |
+
from gr00t.configs.data.embodiment_configs import register_modality_config
|
| 14 |
+
from gr00t.data.embodiment_tags import EmbodimentTag
|
| 15 |
+
from gr00t.data.types import (
|
| 16 |
+
ActionConfig,
|
| 17 |
+
ActionFormat,
|
| 18 |
+
ActionRepresentation,
|
| 19 |
+
ActionType,
|
| 20 |
+
ModalityConfig,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
so100_config = {
|
| 25 |
+
# Video: use current frame only ([0]); list camera view names matching modality.json
|
| 26 |
+
"video": ModalityConfig(
|
| 27 |
+
delta_indices=[0],
|
| 28 |
+
modality_keys=[
|
| 29 |
+
"front",
|
| 30 |
+
"wrist",
|
| 31 |
+
],
|
| 32 |
+
),
|
| 33 |
+
# State: current proprioceptive reading; keys must match modality.json "state" entries
|
| 34 |
+
"state": ModalityConfig(
|
| 35 |
+
delta_indices=[0],
|
| 36 |
+
modality_keys=[
|
| 37 |
+
"single_arm",
|
| 38 |
+
"gripper",
|
| 39 |
+
],
|
| 40 |
+
),
|
| 41 |
+
# Action: 16-step prediction horizon; each key needs an ActionConfig
|
| 42 |
+
"action": ModalityConfig(
|
| 43 |
+
delta_indices=list(range(0, 16)), # predict 16 future steps
|
| 44 |
+
modality_keys=[
|
| 45 |
+
"single_arm",
|
| 46 |
+
"gripper",
|
| 47 |
+
],
|
| 48 |
+
action_configs=[
|
| 49 |
+
# single_arm: RELATIVE = delta from current state (better generalization)
|
| 50 |
+
ActionConfig(
|
| 51 |
+
rep=ActionRepresentation.RELATIVE,
|
| 52 |
+
type=ActionType.NON_EEF, # joint-space, not end-effector
|
| 53 |
+
format=ActionFormat.DEFAULT,
|
| 54 |
+
),
|
| 55 |
+
# gripper: ABSOLUTE = target position (binary open/close works better absolute)
|
| 56 |
+
ActionConfig(
|
| 57 |
+
rep=ActionRepresentation.ABSOLUTE,
|
| 58 |
+
type=ActionType.NON_EEF,
|
| 59 |
+
format=ActionFormat.DEFAULT,
|
| 60 |
+
),
|
| 61 |
+
],
|
| 62 |
+
),
|
| 63 |
+
# Language: task instruction from annotation field in the dataset
|
| 64 |
+
"language": ModalityConfig(
|
| 65 |
+
delta_indices=[0],
|
| 66 |
+
modality_keys=["annotation.human.task_description"],
|
| 67 |
+
),
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
# Important: always register under EmbodimentTag.NEW_EMBODIMENT for custom embodiments
|
| 71 |
+
register_modality_config(so100_config, embodiment_tag=EmbodimentTag.NEW_EMBODIMENT)
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
## Step 3: Run Fine-tuning
|
| 75 |
+
|
| 76 |
+
We'll use `gr00t/experiment/launch_finetune.py` as the entry point. Ensure that the uv environment is enabled before launching. You can do this by running the command `uv run bash <example_script_name>`.
|
| 77 |
+
|
| 78 |
+
### View Available Arguments
|
| 79 |
+
```bash
|
| 80 |
+
# Display all available arguments
|
| 81 |
+
uv run python gr00t/experiment/launch_finetune.py --help
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
### Execute Fine-tuning
|
| 85 |
+
```bash
|
| 86 |
+
# Configure for single GPU
|
| 87 |
+
export NUM_GPUS=1
|
| 88 |
+
CUDA_VISIBLE_DEVICES=0 uv run python \
|
| 89 |
+
gr00t/experiment/launch_finetune.py \
|
| 90 |
+
--base-model-path nvidia/GR00T-N1.7-3B \
|
| 91 |
+
--dataset-path ./demo_data/cube_to_bowl_5 \
|
| 92 |
+
--embodiment-tag NEW_EMBODIMENT \
|
| 93 |
+
--modality-config-path examples/SO100/so100_config.py \
|
| 94 |
+
--num-gpus $NUM_GPUS \
|
| 95 |
+
--output-dir /tmp/so100 \
|
| 96 |
+
--save-total-limit 5 \
|
| 97 |
+
--save-steps 2000 \
|
| 98 |
+
--max-steps 2000 \
|
| 99 |
+
--use-wandb \
|
| 100 |
+
--global-batch-size 32 \
|
| 101 |
+
--color-jitter-params brightness 0.3 contrast 0.4 saturation 0.5 hue 0.08 \
|
| 102 |
+
--dataloader-num-workers 4
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
### Key Parameters
|
| 106 |
+
|
| 107 |
+
| Parameter | Description |
|
| 108 |
+
|-----------|-------------|
|
| 109 |
+
| `--base-model-path` | Path to the pre-trained base model checkpoint |
|
| 110 |
+
| `--dataset-path` | Path to your training dataset |
|
| 111 |
+
| `--embodiment-tag` | Tag to identify your robot embodiment |
|
| 112 |
+
| `--modality-config-path` | Path to user-specified modality config (required only for `NEW_EMBODIMENT` tag) |
|
| 113 |
+
| `--output-dir` | Directory where checkpoints will be saved |
|
| 114 |
+
| `--save-steps` | Save checkpoint every N steps |
|
| 115 |
+
| `--max-steps` | Total number of training steps |
|
| 116 |
+
| `--use-wandb` | Enable Weights & Biases logging for experiment tracking |
|
| 117 |
+
|
| 118 |
+
> **Note:** Validation during fine-tuning is disabled by default (`eval_strategy="no"` in the training config). To enable periodic validation, pass `--eval-strategy steps --eval-steps 500` (runs validation every 500 steps) or `--eval-strategy epoch` (runs validation every epoch). You can also adjust `--eval-batch-size` (default: 2).
|
| 119 |
+
|
| 120 |
+
## Step 4: Open Loop Evaluation
|
| 121 |
+
|
| 122 |
+
After finetuning, evaluate the model's performance using open loop evaluation:
|
| 123 |
+
```bash
|
| 124 |
+
uv run python gr00t/eval/open_loop_eval.py \
|
| 125 |
+
--dataset-path ./demo_data/cube_to_bowl_5 \
|
| 126 |
+
--embodiment-tag NEW_EMBODIMENT \
|
| 127 |
+
--model-path /tmp/so100/checkpoint-2000 \
|
| 128 |
+
--traj-ids 0 \
|
| 129 |
+
--action-horizon 16 \
|
| 130 |
+
--steps 400 \
|
| 131 |
+
--modality-keys single_arm gripper
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
### `open_loop_eval.py` Parameters
|
| 135 |
+
|
| 136 |
+
| Parameter | Default | Description |
|
| 137 |
+
|-----------|---------|-------------|
|
| 138 |
+
| `--dataset-path` | `demo_data/cube_to_bowl_5/` | Path to LeRobot-format dataset |
|
| 139 |
+
| `--embodiment-tag` | `new_embodiment` | Robot embodiment tag (case-insensitive) |
|
| 140 |
+
| `--model-path` | `None` | Path to checkpoint. If omitted, connects to a running server via `--host`/`--port` |
|
| 141 |
+
| `--traj-ids` | `[0]` | Episode indices to evaluate (space-separated, e.g., `0 1 2`) |
|
| 142 |
+
| `--action-horizon` | `16` | Action steps predicted per inference call |
|
| 143 |
+
| `--steps` | `200` | Max steps per trajectory (capped by actual trajectory length) |
|
| 144 |
+
| `--denoising-steps` | `4` | Diffusion denoising iterations |
|
| 145 |
+
| `--save-plot-path` | `None` | Directory to save GT-vs-predicted comparison plots |
|
| 146 |
+
| `--modality-keys` | `None` | Action keys to plot. If omitted, plots all action dimensions |
|
| 147 |
+
| `--host` / `--port` | `127.0.0.1` / `5555` | Server address when `--model-path` is omitted |
|
| 148 |
+
|
| 149 |
+
### Example Evaluation Result
|
| 150 |
+
|
| 151 |
+
The evaluation generates visualizations comparing predicted actions against ground truth trajectories:
|
| 152 |
+
|
| 153 |
+
<img src="../media/open_loop_eval_so100.jpg" width="800" alt="Open loop evaluation results showing predicted vs ground truth trajectories" />
|
getting_started/hardware_recommendation.md
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hardware Recommendations
|
| 2 |
+
|
| 3 |
+
GR00T N1.7 has two hardware profiles: **fine-tuning** (needs GPU VRAM and compute) and **inference/deployment** (needs low latency). This guide helps you choose the right hardware for each.
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## Inference Hardware
|
| 10 |
+
|
| 11 |
+
**Minimum:** 1 GPU with 16 GB+ VRAM, CUDA 12.6+.
|
| 12 |
+
|
| 13 |
+
The table below summarizes end-to-end inference frequency across tested platforms (GR00T N1.7, 4 denoising steps, 1 camera):
|
| 14 |
+
|
| 15 |
+
| Platform | VRAM | PyTorch Eager | With TensorRT | Use Case |
|
| 16 |
+
|----------|------|---------------|---------------|----------|
|
| 17 |
+
| H100 80GB HBM3 | 80 GB | 11.7 Hz | 35.9 Hz | High-frequency control, multi-env batch inference |
|
| 18 |
+
| H20 96GB HBM3 | 96 GB | 12.0 Hz | 29.4 Hz | Cost-effective datacenter inference |
|
| 19 |
+
| RTX Pro 6000 Blackwell | 96 GB | 12.8 Hz | 35.9 Hz | Workstation inference, development |
|
| 20 |
+
| RTX Pro 5000 72GB | 72 GB | 7.9 Hz | 24.7 Hz | Workstation inference |
|
| 21 |
+
| L40 | 48 GB | 7.8 Hz | 26.0 Hz | Cloud inference |
|
| 22 |
+
| L20 | 48 GB | 7.1 Hz | 23.3 Hz | Cloud inference |
|
| 23 |
+
| DGX Spark | 128 GB shared | 7.9 Hz | 10.1 Hz | Desktop edge, prototyping |
|
| 24 |
+
| AGX Thor | 128 GB shared | 6.9 Hz | 10.7 Hz | Robot-mounted edge deployment |
|
| 25 |
+
| Orin* | 64 GB shared | 2.9 Hz | 4.6 Hz | Legacy Jetson edge |
|
| 26 |
+
|
| 27 |
+
> *Orin uses DiT-only TensorRT (TRT 10.3 does not support the backbone engine). All other platforms use the full TensorRT pipeline.
|
| 28 |
+
|
| 29 |
+
### Key Insights
|
| 30 |
+
|
| 31 |
+
- **30+ Hz** (H100, RTX Pro 6000 with TensorRT): suitable for high-frequency closed-loop control where sub-30 ms latency matters.
|
| 32 |
+
- **10+ Hz** (Thor, Spark with TRT; most dGPUs with torch.compile): sufficient for typical manipulation tasks running at a 10 Hz control rate.
|
| 33 |
+
- **< 5 Hz** (Orin): only suitable for slow, non-reactive tasks. Orin's TRT 10.3 cannot accelerate the backbone — gains are limited to DiT-only mode.
|
| 34 |
+
- **TensorRT Full Pipeline** provides 1.5--3.3x speedup over PyTorch Eager depending on platform. Biggest gains are on datacenter GPUs where backbone acceleration is significant.
|
| 35 |
+
- **torch.compile** is a good zero-effort middle ground (no engine build step), achieving 1.1--1.9x speedup across all platforms.
|
| 36 |
+
|
| 37 |
+
> For full per-component latency breakdown, see the [Deployment Benchmark Results](../scripts/deployment/README.md#benchmark-results).
|
| 38 |
+
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
## Fine-Tuning Hardware
|
| 42 |
+
|
| 43 |
+
**Minimum:** 1 GPU with 40 GB+ VRAM. GR00T N1.7 is a ~3B parameter model (bfloat16).
|
| 44 |
+
|
| 45 |
+
| Setup | GPUs | VRAM per GPU | Global Batch Size | Notes |
|
| 46 |
+
|-------|------|-------------|-------------------|-------|
|
| 47 |
+
| Quick start / prototyping | 1x H100, L40, or A100 | 40--80 GB | 32 | Single GPU; sufficient for demo datasets |
|
| 48 |
+
| Recommended | 4--8x H100 or L40 | 40--80 GB each | 64--640 | Multi-GPU via torchrun; faster convergence |
|
| 49 |
+
| Full scale | 8x RTX Pro 6000 or DGX | 96 GB each | 640 | Large datasets, production fine-tuning |
|
| 50 |
+
|
| 51 |
+
### Key Details
|
| 52 |
+
|
| 53 |
+
- **Default fine-tuning** tunes the projector + diffusion action head (not the full LLM backbone), keeping peak VRAM under ~35 GB per GPU.
|
| 54 |
+
- **Enabling `--tune-llm` or `--tune-visual`** significantly increases VRAM — 80 GB+ per GPU recommended.
|
| 55 |
+
- **`--gradient-accumulation-steps`** can compensate for fewer GPUs. For example, 4 GPUs with 8 accumulation steps and per-GPU batch of 8 gives an effective global batch size of 256.
|
| 56 |
+
- **Reduce `--num-shards-per-epoch`** if host memory (not VRAM) is limited — this controls how much dataset is preloaded into RAM.
|
| 57 |
+
|
| 58 |
+
---
|
| 59 |
+
|
| 60 |
+
## Software Requirements
|
| 61 |
+
|
| 62 |
+
| Requirement | Version |
|
| 63 |
+
|-------------|---------|
|
| 64 |
+
| Python | 3.10 |
|
| 65 |
+
| CUDA | 12.6+ (dGPU, Orin) / 13.0 (Thor, Spark) |
|
| 66 |
+
| PyTorch | 2.7+ |
|
| 67 |
+
| OS | Ubuntu 22.04+ (dGPU), JetPack 6.2 (Orin), Ubuntu 24.04 (Thor, Spark) |
|
| 68 |
+
| Package manager | [uv](https://docs.astral.sh/uv/) (recommended) |
|
| 69 |
+
|
| 70 |
+
Platform-specific installation instructions: see the [Deployment Guide](../scripts/deployment/README.md).
|
| 71 |
+
|
| 72 |
+
---
|
| 73 |
+
|
| 74 |
+
## Recommended Configurations
|
| 75 |
+
|
| 76 |
+
### Starter Kit
|
| 77 |
+
|
| 78 |
+
For development, small-scale fine-tuning, and edge deployment:
|
| 79 |
+
|
| 80 |
+
| Component | Recommendation |
|
| 81 |
+
|-----------|---------------|
|
| 82 |
+
| Training | 1--4x L40 (48 GB) or RTX Pro 5000/6000 workstation |
|
| 83 |
+
| Edge Deployment | [Jetson AGX Thor](https://developer.nvidia.com/embedded/jetson) Developer Kit (128 GB shared memory, Blackwell GPU) |
|
| 84 |
+
| Storage | 500 GB+ SSD (datasets + checkpoints) |
|
| 85 |
+
|
| 86 |
+
### Center of Excellence
|
| 87 |
+
|
| 88 |
+
For production fine-tuning and high-throughput inference:
|
| 89 |
+
|
| 90 |
+
| Component | Recommendation |
|
| 91 |
+
|-----------|---------------|
|
| 92 |
+
| Training | DGX with 8x H100/B200, or RTX Pro Server with 8x RTX Pro 6000 Blackwell |
|
| 93 |
+
| Inference Server | H100 or H20 node with TensorRT Full Pipeline (35+ Hz per GPU) |
|
| 94 |
+
| Edge Deployment | [Jetson AGX Thor](https://developer.nvidia.com/embedded/jetson) or [DGX Spark](https://developer.nvidia.com/dgx-spark) |
|
| 95 |
+
| Storage | Scalable networked storage (NFS/S3) for large-scale datasets |
|
getting_started/policy.md
ADDED
|
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Understanding the GR00T Policy API
|
| 2 |
+
|
| 3 |
+
This guide explains how to use the `Gr00tPolicy` class to load and run inference with your trained model. After training, you'll use this API to integrate your model with evaluation environments.
|
| 4 |
+
|
| 5 |
+
## Loading the Policy
|
| 6 |
+
|
| 7 |
+
Initialize a policy by providing the embodiment tag, model checkpoint path, and device:
|
| 8 |
+
|
| 9 |
+
```python
|
| 10 |
+
from gr00t.policy import Gr00tPolicy
|
| 11 |
+
from gr00t.data.embodiment_tags import EmbodimentTag
|
| 12 |
+
|
| 13 |
+
# Load your trained model
|
| 14 |
+
policy = Gr00tPolicy(
|
| 15 |
+
model_path="/path/to/your/checkpoint",
|
| 16 |
+
embodiment_tag=EmbodimentTag.NEW_EMBODIMENT, # or other embodiment tags
|
| 17 |
+
device="cuda:0", # or "cpu", or device index like 0
|
| 18 |
+
strict=True # Enable input/output validation (recommended during development)
|
| 19 |
+
)
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
**Parameters:**
|
| 23 |
+
|
| 24 |
+
| Parameter | Type | Default | Description |
|
| 25 |
+
|-----------|------|---------|-------------|
|
| 26 |
+
| `embodiment_tag` | `EmbodimentTag \| str` | *(required)* | Robot type; accepts enum or case-insensitive string (e.g., `"NEW_EMBODIMENT"`) |
|
| 27 |
+
| `model_path` | `str` | *(required)* | Path to model checkpoint directory (local path or HuggingFace model ID) |
|
| 28 |
+
| `device` | `str \| int` | *(required)* | Inference device: `"cuda:0"`, `0`, or `"cpu"` |
|
| 29 |
+
| `strict` | `bool` | `True` | Validates observation shapes and dtypes at runtime. Recommended during development; disable in production for speed |
|
| 30 |
+
|
| 31 |
+
## Inference Parameter Guide
|
| 32 |
+
|
| 33 |
+
When running inference scripts (e.g., `standalone_inference_script.py`, `open_loop_eval.py`), the key parameters are:
|
| 34 |
+
|
| 35 |
+
### `--embodiment-tag`
|
| 36 |
+
|
| 37 |
+
Determines which modality config the model uses (state/action keys, normalization). **Must match the robot type of your dataset.**
|
| 38 |
+
|
| 39 |
+
The tag is **case-insensitive** and accepts either the enum name or the string value.
|
| 40 |
+
For example, `--embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` and `--embodiment-tag LIBERO_PANDA` all resolve correctly. An unknown tag will produce an error listing all known options.
|
| 41 |
+
|
| 42 |
+
- **Pretrain tags** (e.g., `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT`, `XDOF`, `REAL_G1`) — use for zero-shot inference on datasets that match the pretrained embodiment. The modality config is loaded from the base model checkpoint.
|
| 43 |
+
- **Posttrain tags** (`OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT`, `LIBERO_PANDA`, `SIMPLER_ENV_GOOGLE`, `SIMPLER_ENV_WIDOWX`) — require a finetuned checkpoint. Passing these to the base model will produce an error.
|
| 44 |
+
- **`NEW_EMBODIMENT`** — use for custom robots. Requires a `--modality-config-path` during finetuning. After finetuning, the config is saved in the checkpoint and loaded automatically during inference.
|
| 45 |
+
- Only one `NEW_EMBODIMENT` modality config may be registered per Python process. Examples like [`examples/SO100/so100_config.py`](../examples/SO100/so100_config.py) and [`examples/mask-guided-background-suppression/so101_config.py`](../examples/mask-guided-background-suppression/so101_config.py) each register under this tag; importing both in the same process will fail. In normal CLI use the selected `--modality-config-path` is the only one imported, so this is not an issue — just don't wire both configs into the same script.
|
| 46 |
+
|
| 47 |
+
#### Known Embodiment Tags
|
| 48 |
+
|
| 49 |
+
**Pretrain tags** — baked into the base model (`nvidia/GR00T-N1.7-3B`), ready for zero-shot inference:
|
| 50 |
+
|
| 51 |
+
| Tag | Robot / Data Source | Value |
|
| 52 |
+
|-----|---------------------|-------|
|
| 53 |
+
| `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` | DROID (relative EEF + joint) | `oxe_droid_relative_eef_relative_joint` |
|
| 54 |
+
| `XDOF` | Generic X-DOF (relative EEF + joint) | `xdof_relative_eef_relative_joint` |
|
| 55 |
+
| `XDOF_SUBTASK` | Generic X-DOF (subtask variant) | `xdof_relative_eef_relative_joint_subtask` |
|
| 56 |
+
| `REAL_G1` | Real-world Unitree G1 (relative EEF + joint) | `real_g1_relative_eef_relative_joints` |
|
| 57 |
+
| `REAL_R1_PRO_SHARPA` | Real-world R1 Pro Sharpa (relative EEF) | `real_r1_pro_sharpa_relative_eef` |
|
| 58 |
+
| `REAL_R1_PRO_SHARPA_HUMAN` | R1 Pro Sharpa — human teleop data | `real_r1_pro_sharpa_relative_eef_human` |
|
| 59 |
+
| `REAL_R1_PRO_SHARPA_MAXINSIGHTS` | R1 Pro Sharpa — MaxInsights (single-cam) | `real_r1_pro_sharpa_relative_eef_maxinsights` |
|
| 60 |
+
| `REAL_R1_PRO_SHARPA_MECKA` | R1 Pro Sharpa — Mecka (single-cam) | `real_r1_pro_sharpa_relative_eef_mecka` |
|
| 61 |
+
|
| 62 |
+
**Posttrain tags** — require a finetuned checkpoint (not usable with the base model directly):
|
| 63 |
+
|
| 64 |
+
| Tag | Robot | Value | Checkpoint |
|
| 65 |
+
|-----|-------|-------|------------|
|
| 66 |
+
| `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` | DROID (relative EEF + joint) | `oxe_droid_relative_eef_relative_joint` | `nvidia/GR00T-N1.7-DROID` |
|
| 67 |
+
| `LIBERO_PANDA` | LIBERO Panda | `libero_sim` | `nvidia/GR00T-N1.7-LIBERO` |
|
| 68 |
+
| `SIMPLER_ENV_GOOGLE` | SimplerEnv Google Robot | `simpler_env_google` | `nvidia/GR00T-N1.7-SimplerEnv-Fractal` |
|
| 69 |
+
| `SIMPLER_ENV_WIDOWX` | SimplerEnv WidowX | `simpler_env_widowx` | `nvidia/GR00T-N1.7-SimplerEnv-Bridge` |
|
| 70 |
+
|
| 71 |
+
**Generic tag** for any new robot: `NEW_EMBODIMENT` (requires `--modality-config-path`)
|
| 72 |
+
|
| 73 |
+
> **`OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` appears in both tables by design.** DROID is supported both zero-shot (via the base model) and via the finetuned `nvidia/GR00T-N1.7-DROID` checkpoint. Pass the tag with either `--model-path nvidia/GR00T-N1.7-3B` (zero-shot) or `--model-path nvidia/GR00T-N1.7-DROID` (finetuned); see `examples/DROID/README.md`.
|
| 74 |
+
|
| 75 |
+
> **Important:** Pretrain tags work with the base model for zero-shot inference. Posttrain tags require a finetuned checkpoint — using them with the base model will fail with an error listing the supported tags. You also cannot mix embodiment tags and datasets (e.g., `--embodiment-tag LIBERO_PANDA` expects LIBERO state keys and will fail on an SO100 dataset).
|
| 76 |
+
|
| 77 |
+
### `--traj-ids`
|
| 78 |
+
|
| 79 |
+
Which episode indices to evaluate. Check your dataset's `meta/episodes.jsonl` to see available episodes. For example, `--traj-ids 0 1 2` runs on the first 3 episodes.
|
| 80 |
+
|
| 81 |
+
### `--action-horizon`
|
| 82 |
+
|
| 83 |
+
Number of future action steps predicted per inference call. The model's maximum is 16 (from model config). Common values:
|
| 84 |
+
- `16` — full horizon, used for open-loop evaluation
|
| 85 |
+
- `8` — shorter horizon, common for real-time deployment where actions are re-planned frequently
|
| 86 |
+
|
| 87 |
+
This parameter is robot-agnostic — the same value works across different datasets and embodiments.
|
| 88 |
+
|
| 89 |
+
### `--inference-mode`
|
| 90 |
+
|
| 91 |
+
- `pytorch` — standard PyTorch inference (default, no setup required)
|
| 92 |
+
- `tensorrt` — accelerated inference using TensorRT engine (requires ONNX export + engine build first, see [Deployment Guide](../scripts/deployment/README.md))
|
| 93 |
+
|
| 94 |
+
### Expected Output (PyTorch mode)
|
| 95 |
+
|
| 96 |
+
The inference scripts produce:
|
| 97 |
+
- Per-trajectory **MSE** and **MAE** (unnormalized action prediction error vs ground truth)
|
| 98 |
+
- **Timing stats**: model load time, avg/min/max/P90 inference time per step
|
| 99 |
+
- **Summary**: average MSE/MAE across all trajectories
|
| 100 |
+
|
| 101 |
+
### Example: Matching Parameters to Dataset
|
| 102 |
+
|
| 103 |
+
| Dataset | Embodiment Tag | Notes |
|
| 104 |
+
|---------|---------------|-------|
|
| 105 |
+
| `demo_data/droid_sample` | `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` | DROID — works with base model (zero-shot) or finetuned `nvidia/GR00T-N1.7-DROID` |
|
| 106 |
+
| `demo_data/libero_demo` | `LIBERO_PANDA` | LIBERO Panda — uses finetuned checkpoint from `nvidia/GR00T-N1.7-LIBERO` (must be downloaded locally first, see [README](../README.md)) |
|
| 107 |
+
| `demo_data/cube_to_bowl_5` | `NEW_EMBODIMENT` | SO100 arm — only works with a finetuned checkpoint, not the base model |
|
| 108 |
+
|
| 109 |
+
## Understanding the Observation Format
|
| 110 |
+
|
| 111 |
+
The policy expects observations as a nested dictionary with three modalities:
|
| 112 |
+
|
| 113 |
+
```python
|
| 114 |
+
observation = {
|
| 115 |
+
"video": {
|
| 116 |
+
"camera_name": np.ndarray, # Shape: (B, T, H, W, 3), dtype: uint8
|
| 117 |
+
# ... one entry per camera
|
| 118 |
+
},
|
| 119 |
+
"state": {
|
| 120 |
+
"state_name": np.ndarray, # Shape: (B, T, D), dtype: float32
|
| 121 |
+
# ... one entry per state stream
|
| 122 |
+
},
|
| 123 |
+
"language": {
|
| 124 |
+
"task": [[str]], # Shape: (B, 1), list of lists of strings
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
### Dimensions
|
| 130 |
+
|
| 131 |
+
- **`B`**: Batch size (number of parallel environments)
|
| 132 |
+
- **`T`**: Temporal horizon (number of historical observations)
|
| 133 |
+
- **`H, W`**: Image height and width
|
| 134 |
+
- **`D`**: State dimension
|
| 135 |
+
- **`C`**: Number of channels (must be 3 for RGB)
|
| 136 |
+
|
| 137 |
+
### Data Type Requirements
|
| 138 |
+
|
| 139 |
+
- **Videos** must be `np.uint8` arrays with RGB pixel values in range [0, 255]
|
| 140 |
+
- **States** must be `np.float32` arrays
|
| 141 |
+
- **Language** instructions are lists of lists of strings
|
| 142 |
+
|
| 143 |
+
### Important Notes
|
| 144 |
+
|
| 145 |
+
- The temporal horizon `T` is determined by your model's training configuration
|
| 146 |
+
- Different modalities may have different temporal horizons (query via `get_modality_config()`)
|
| 147 |
+
- Language instructions are typically single timestep (`T=1`)
|
| 148 |
+
- All arrays in a batch must have the same batch size `B`
|
| 149 |
+
|
| 150 |
+
## Understanding the Action Format
|
| 151 |
+
|
| 152 |
+
The policy returns actions in a similar nested structure:
|
| 153 |
+
|
| 154 |
+
```python
|
| 155 |
+
action = {
|
| 156 |
+
"action_name": np.ndarray, # Shape: (B, T, D), dtype: float32
|
| 157 |
+
# ... one entry per action stream
|
| 158 |
+
}
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
### Dimensions
|
| 162 |
+
|
| 163 |
+
- **`B`**: Batch size (matches input batch size)
|
| 164 |
+
- **`T`**: Action horizon (number of future action steps to predict)
|
| 165 |
+
- **`D`**: Action dimension (e.g., 7 for arm joints, 1 for gripper)
|
| 166 |
+
|
| 167 |
+
### Important Notes
|
| 168 |
+
|
| 169 |
+
- Actions are returned in **physical units** (e.g., joint positions in radians, velocities in rad/s)
|
| 170 |
+
- Actions are **not normalized** - they're ready to send to your robot controller
|
| 171 |
+
- The action horizon `T` allows predicting multiple future steps (useful for action chunking)
|
| 172 |
+
|
| 173 |
+
## Running Inference
|
| 174 |
+
|
| 175 |
+
Use the `get_action()` method to compute actions from observations:
|
| 176 |
+
|
| 177 |
+
```python
|
| 178 |
+
# Get action from current observation
|
| 179 |
+
action, info = policy.get_action(observation)
|
| 180 |
+
|
| 181 |
+
# Access the action array
|
| 182 |
+
arm_action = action["action_name"] # Shape: (B, T, D)
|
| 183 |
+
|
| 184 |
+
# Extract the first action to execute
|
| 185 |
+
next_action = arm_action[:, 0, :] # Shape: (B, D)
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
The method returns a tuple of:
|
| 189 |
+
- `action`: Dictionary of action arrays
|
| 190 |
+
- `info`: Dictionary of additional information (currently empty, reserved for future use)
|
| 191 |
+
|
| 192 |
+
## Querying Modality Configurations
|
| 193 |
+
|
| 194 |
+
To understand what observations your policy expects and what actions it produces, query the modality configuration:
|
| 195 |
+
|
| 196 |
+
```python
|
| 197 |
+
# Get modality configs for your embodiment
|
| 198 |
+
modality_configs = policy.get_modality_config()
|
| 199 |
+
|
| 200 |
+
# Check what camera keys are expected
|
| 201 |
+
video_keys = modality_configs["video"].modality_keys
|
| 202 |
+
print(f"Expected cameras: {video_keys}")
|
| 203 |
+
|
| 204 |
+
# Check video temporal horizon
|
| 205 |
+
video_horizon = len(modality_configs["video"].delta_indices)
|
| 206 |
+
print(f"Video frames needed: {video_horizon}")
|
| 207 |
+
|
| 208 |
+
# Check state keys and horizon
|
| 209 |
+
state_keys = modality_configs["state"].modality_keys
|
| 210 |
+
state_horizon = len(modality_configs["state"].delta_indices)
|
| 211 |
+
print(f"Expected states: {state_keys}, horizon: {state_horizon}")
|
| 212 |
+
|
| 213 |
+
# Check action keys and horizon
|
| 214 |
+
action_keys = modality_configs["action"].modality_keys
|
| 215 |
+
action_horizon = len(modality_configs["action"].delta_indices)
|
| 216 |
+
print(f"Action outputs: {action_keys}, horizon: {action_horizon}")
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
This is especially useful when:
|
| 220 |
+
- You're unsure what observations your trained model expects
|
| 221 |
+
- You need to verify the temporal horizons for each modality
|
| 222 |
+
- You're debugging observation/action format mismatches
|
| 223 |
+
|
| 224 |
+
## Resetting the Policy
|
| 225 |
+
|
| 226 |
+
Reset the policy between episodes:
|
| 227 |
+
|
| 228 |
+
```python
|
| 229 |
+
# Reset policy state (if any) between episodes
|
| 230 |
+
info = policy.reset()
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
Currently, the policy is stateless, but calling `reset()` is good practice for future compatibility.
|
| 234 |
+
|
| 235 |
+
## Adapting the Policy to Your Environment
|
| 236 |
+
|
| 237 |
+
Most environments use different observation/action formats than the Policy API expects. You'll typically need to write a **policy wrapper** that:
|
| 238 |
+
|
| 239 |
+
1. **Transforms observations**: Convert your environment's observation format to the Policy API format
|
| 240 |
+
2. **Calls the policy**: Use `policy.get_action()` to compute actions
|
| 241 |
+
3. **Transforms actions**: Convert the policy's actions back to your environment's format
|
| 242 |
+
|
| 243 |
+
### Example Workflow
|
| 244 |
+
|
| 245 |
+
```python
|
| 246 |
+
# In your environment loop
|
| 247 |
+
env_obs = env.reset() # Environment-specific format
|
| 248 |
+
|
| 249 |
+
# Transform to Policy API format
|
| 250 |
+
policy_obs = transform_observation(env_obs)
|
| 251 |
+
|
| 252 |
+
# Get action from policy
|
| 253 |
+
policy_action, _ = policy.get_action(policy_obs)
|
| 254 |
+
|
| 255 |
+
# Transform back to environment format
|
| 256 |
+
env_action = transform_action(policy_action)
|
| 257 |
+
|
| 258 |
+
# Execute in environment
|
| 259 |
+
env_obs, reward, done, info = env.step(env_action)
|
| 260 |
+
```
|
| 261 |
+
|
| 262 |
+
### Using Server-Client Architecture for Remote Inference
|
| 263 |
+
|
| 264 |
+
For many use cases, especially when working with real robots or distributed systems, you may want to run the policy on a separate machine (e.g., a GPU server) and send observations/actions over the network. GR00T provides a built-in server-client architecture using ZeroMQ for this purpose.
|
| 265 |
+
|
| 266 |
+
#### Why Use Server-Client Architecture?
|
| 267 |
+
|
| 268 |
+
- **Separate compute resources**: Run policy inference on a GPU server while controlling the robot from a different machine
|
| 269 |
+
- **Dependency isolation**: Avoid dependency issues with the client policy
|
| 270 |
+
|
| 271 |
+
```mermaid
|
| 272 |
+
sequenceDiagram
|
| 273 |
+
participant Robot as Robot / Sim Client
|
| 274 |
+
participant Client as PolicyClient (ZMQ REQ)
|
| 275 |
+
participant Server as PolicyServer (ZMQ REP)
|
| 276 |
+
participant Policy as Gr00tPolicy (GPU)
|
| 277 |
+
|
| 278 |
+
Robot->>Client: observation dict
|
| 279 |
+
Client->>Server: msgpack(endpoint="get_action", data=obs)
|
| 280 |
+
Server->>Policy: policy.get_action(obs)
|
| 281 |
+
Policy-->>Server: (action_dict, info_dict)
|
| 282 |
+
Server-->>Client: msgpack(action, info)
|
| 283 |
+
Client-->>Robot: action dict
|
| 284 |
+
```
|
| 285 |
+
|
| 286 |
+
#### Starting the Policy Server
|
| 287 |
+
|
| 288 |
+
Launch the server using the `run_gr00t_server.py` script:
|
| 289 |
+
|
| 290 |
+
```bash
|
| 291 |
+
uv run python gr00t/eval/run_gr00t_server.py \
|
| 292 |
+
--embodiment-tag NEW_EMBODIMENT \
|
| 293 |
+
--model-path /path/to/your/checkpoint \
|
| 294 |
+
--device cuda:0 \
|
| 295 |
+
--host 0.0.0.0 \
|
| 296 |
+
--port 5555
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
**Parameters:**
|
| 300 |
+
- `--embodiment-tag`: The embodiment tag for your robot (e.g., `NEW_EMBODIMENT`)
|
| 301 |
+
- `--model-path`: Path to your trained model checkpoint directory
|
| 302 |
+
- `--device`: Device to run inference on (`cuda:0`, `cuda:1`, `cpu`, etc.)
|
| 303 |
+
- `--host`: Host address (`127.0.0.1` for local only, `0.0.0.0` to accept external connections)
|
| 304 |
+
- `--port`: Port number (default: 5555)
|
| 305 |
+
- `--strict` / `--no-strict`: Enable or disable input/output validation (default: True)
|
| 306 |
+
- `--use-sim-policy-wrapper`: Whether to use `Gr00tSimPolicyWrapper` for GR00T simulation environments (default: False)
|
| 307 |
+
|
| 308 |
+
Once started, the server will display:
|
| 309 |
+
```
|
| 310 |
+
Starting GR00T inference server...
|
| 311 |
+
Embodiment tag: NEW_EMBODIMENT
|
| 312 |
+
Model path: /path/to/your/checkpoint
|
| 313 |
+
Device: cuda:0
|
| 314 |
+
Host: 0.0.0.0
|
| 315 |
+
Port: 5555
|
| 316 |
+
Server is ready and listening on tcp://0.0.0.0:5555
|
| 317 |
+
```
|
| 318 |
+
|
| 319 |
+
#### Using the Policy Client
|
| 320 |
+
|
| 321 |
+
On the client side (your environment/robot control code), use `PolicyClient` to connect to the server:
|
| 322 |
+
|
| 323 |
+
```python
|
| 324 |
+
from gr00t.policy.server_client import PolicyClient
|
| 325 |
+
|
| 326 |
+
# Connect to the policy server
|
| 327 |
+
policy = PolicyClient(
|
| 328 |
+
host="localhost", # or IP address of your GPU server
|
| 329 |
+
port=5555,
|
| 330 |
+
timeout_ms=15000, # 15 second timeout for inference
|
| 331 |
+
strict=False, # leave the validation to the server
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Verify connection
|
| 335 |
+
if not policy.ping():
|
| 336 |
+
raise RuntimeError("Cannot connect to policy server!")
|
| 337 |
+
|
| 338 |
+
# Use just like a regular policy
|
| 339 |
+
observation = get_observation() # Your observation in Policy API format
|
| 340 |
+
action, info = policy.get_action(observation)
|
| 341 |
+
```
|
| 342 |
+
|
| 343 |
+
**Parameters:**
|
| 344 |
+
- `host`: Hostname or IP address of the policy server
|
| 345 |
+
- `port`: Port number (must match server port)
|
| 346 |
+
- `timeout_ms`: Timeout in milliseconds for network requests (default: 15000)
|
| 347 |
+
- `api_token`: Optional API token for authentication (default: None)
|
| 348 |
+
- `strict`: Enable client-side validation (usually False since server validates)
|
| 349 |
+
|
| 350 |
+
#### Client API
|
| 351 |
+
|
| 352 |
+
The `PolicyClient` implements the same `BasePolicy` interface, so it's a drop-in replacement:
|
| 353 |
+
|
| 354 |
+
```python
|
| 355 |
+
# Get modality configuration from the server
|
| 356 |
+
modality_configs = policy.get_modality_config()
|
| 357 |
+
|
| 358 |
+
# Get action — returns (action_dict, info_dict)
|
| 359 |
+
action, info = policy.get_action(observation, options=None)
|
| 360 |
+
|
| 361 |
+
# Reset policy state (e.g., switch episode in ReplayPolicy)
|
| 362 |
+
info = policy.reset(options=None)
|
| 363 |
+
|
| 364 |
+
# Check server health — returns True if server responds
|
| 365 |
+
is_alive = policy.ping()
|
| 366 |
+
|
| 367 |
+
# Shutdown the server remotely (optional)
|
| 368 |
+
policy.kill_server()
|
| 369 |
+
```
|
| 370 |
+
|
| 371 |
+
#### Server API Reference
|
| 372 |
+
|
| 373 |
+
| Parameter | Type | Default | Description |
|
| 374 |
+
|-----------|------|---------|-------------|
|
| 375 |
+
| `policy` | `BasePolicy` | *(required)* | The policy instance to serve (e.g., `Gr00tPolicy`, `ReplayPolicy`) |
|
| 376 |
+
| `host` | `str` | `"*"` | Bind address. `"*"` accepts connections on all interfaces |
|
| 377 |
+
| `port` | `int` | `5555` | TCP port for ZMQ REP socket |
|
| 378 |
+
| `api_token` | `str` | `None` | If set, clients must include a matching token in every request |
|
| 379 |
+
|
| 380 |
+
**Built-in endpoints:** `get_action`, `reset`, `get_modality_config`, `ping`, `kill`. Custom endpoints can be added via `server.register_endpoint(name, handler)`.
|
| 381 |
+
|
| 382 |
+
#### Error Handling
|
| 383 |
+
|
| 384 |
+
The server-client uses ZeroMQ REQ/REP sockets over TCP with msgpack serialization.
|
| 385 |
+
|
| 386 |
+
- **Timeout:** If the server does not respond within `timeout_ms`, the ZMQ socket will raise `zmq.error.Again`. The default 15 s timeout accommodates cold-start model loading on the first call.
|
| 387 |
+
- **Connection loss:** If `ping()` returns `False`, the client automatically recreates its ZMQ socket for the next attempt. Your control loop should retry or halt.
|
| 388 |
+
- **Server-side errors:** Exceptions in the policy are caught, serialized as `{"error": "..."}`, and re-raised as `RuntimeError` on the client side.
|
| 389 |
+
|
| 390 |
+
#### Debugging with ReplayPolicy
|
| 391 |
+
|
| 392 |
+
When developing a new environment integration or debugging your inference loop, running a full model inference can be cumbersome. `ReplayPolicy` allows you to **replay recorded actions from an existing dataset**, helping you verify that:
|
| 393 |
+
|
| 394 |
+
- Your environment setup works correctly
|
| 395 |
+
- Observations are formatted properly
|
| 396 |
+
- Action execution behaves as expected
|
| 397 |
+
- The server-client communication is functioning
|
| 398 |
+
|
| 399 |
+
This eliminates the need for a trained model during the development phase.
|
| 400 |
+
|
| 401 |
+
##### Starting the Server with ReplayPolicy
|
| 402 |
+
|
| 403 |
+
Instead of providing `--model-path`, use `--dataset-path` to start the server in replay mode:
|
| 404 |
+
|
| 405 |
+
```bash
|
| 406 |
+
uv run python gr00t/eval/run_gr00t_server.py \
|
| 407 |
+
--dataset-path /path/to/lerobot_dataset \
|
| 408 |
+
--embodiment-tag NEW_EMBODIMENT \
|
| 409 |
+
--host 0.0.0.0 \
|
| 410 |
+
--port 5555 \
|
| 411 |
+
--execution-horizon 8 # should match the executed action horizon in the environment
|
| 412 |
+
```
|
| 413 |
+
|
| 414 |
+
**Parameters:**
|
| 415 |
+
- `--dataset-path`: Path to a LeRobot-compatible dataset directory
|
| 416 |
+
- `--embodiment-tag`: The embodiment tag for modality configuration
|
| 417 |
+
- `--execution-horizon`: Number of steps to advance the dataset per `get_action()` call. Should match the number of executed action steps in the environment.
|
| 418 |
+
- `--modality-config-path`: (Optional) Path to custom modality config JSON file. If not provided, uses the config from `embodiment-tag`
|
| 419 |
+
- `--use-sim-policy-wrapper`: Apply `Gr00tSimPolicyWrapper` for GR00T simulation environments
|
| 420 |
+
|
| 421 |
+
##### Using ReplayPolicy from the Client
|
| 422 |
+
|
| 423 |
+
On the client side, use `PolicyClient` exactly as you would with a real model:
|
| 424 |
+
|
| 425 |
+
```python
|
| 426 |
+
from gr00t.policy.server_client import PolicyClient
|
| 427 |
+
|
| 428 |
+
# Connect to the replay policy server
|
| 429 |
+
policy = PolicyClient(host="localhost", port=5555)
|
| 430 |
+
|
| 431 |
+
# Use exactly like a regular policy
|
| 432 |
+
action, info = policy.get_action(observation)
|
| 433 |
+
|
| 434 |
+
# info contains replay metadata
|
| 435 |
+
print(f"Replaying step {info['current_step']} of episode {info['episode_index']}")
|
| 436 |
+
```
|
| 437 |
+
|
| 438 |
+
##### Switching Episodes
|
| 439 |
+
|
| 440 |
+
ReplayPolicy starts with episode 0 by default. To switch to a different episode:
|
| 441 |
+
|
| 442 |
+
```python
|
| 443 |
+
# Reset to a specific episode
|
| 444 |
+
policy.reset(options={"episode_index": 5})
|
| 445 |
+
|
| 446 |
+
# Optionally start from a specific step within the episode
|
| 447 |
+
policy.reset(options={"episode_index": 5, "step_index": 10})
|
| 448 |
+
```
|
| 449 |
+
|
| 450 |
+
The number of available episodes can be queried via the `info` dict returned from `reset()` or `get_action()`.
|
| 451 |
+
|
| 452 |
+
##### Example: Validating a LIBERO Environment
|
| 453 |
+
|
| 454 |
+
Here's a complete example of using ReplayPolicy to validate a simulation setup:
|
| 455 |
+
|
| 456 |
+
```bash
|
| 457 |
+
# Terminal 1: Start the replay server
|
| 458 |
+
uv run python gr00t/eval/run_gr00t_server.py \
|
| 459 |
+
--dataset-path <your_dataset_path> \
|
| 460 |
+
--embodiment-tag <YOUR_EMBODIMENT_TAG> \
|
| 461 |
+
--action-horizon 8 \
|
| 462 |
+
--use-sim-policy-wrapper
|
| 463 |
+
|
| 464 |
+
# Terminal 2: Run evaluation with the replay policy
|
| 465 |
+
uv run python gr00t/eval/rollout_policy.py \
|
| 466 |
+
--n-episodes 1 \
|
| 467 |
+
--policy-client-host 127.0.0.1 \
|
| 468 |
+
--policy-client-port 5555 \
|
| 469 |
+
--max-episode-steps 720 \
|
| 470 |
+
--env-name <env_prefix>/<task_name> \
|
| 471 |
+
--n-action-steps 8 \
|
| 472 |
+
--n-envs 1
|
| 473 |
+
```
|
| 474 |
+
|
| 475 |
+
If your environment is set up correctly, replaying ground-truth actions should achieve high (often 100%) success rates. Low success rates indicate issues with:
|
| 476 |
+
- Environment reset state not matching the dataset
|
| 477 |
+
- Observation preprocessing differences
|
| 478 |
+
- Action space mismatches
|
| 479 |
+
|
| 480 |
+
> **Tip:** ReplayPolicy is an excellent first step when integrating a new environment. Debug with replay first, then switch to model inference once the pipeline is validated.
|
| 481 |
+
|
| 482 |
+
#### Integrating the GR00T N1.7 Client Into Your Deployment Pipeline
|
| 483 |
+
|
| 484 |
+
GR00T's server–client architecture allows you to keep the **client side extremely lightweight**, making it easy to embed into any custom deployment pipeline without pulling in the full dependency stack.
|
| 485 |
+
|
| 486 |
+
For a minimal working example, see
|
| 487 |
+
[`eval_so100.py`](../gr00t/eval/real_robot/SO100/eval_so100.py).
|
| 488 |
+
|
| 489 |
+
In most cases, your deployment environment only needs to install the local GR00T client code:
|
| 490 |
+
|
| 491 |
+
```bash
|
| 492 |
+
uv pip install -e . --verbose --no-deps
|
| 493 |
+
```
|
| 494 |
+
|
| 495 |
+
The client relies solely on a small set of interfaces:
|
| 496 |
+
- `gr00t/policy/server_client.py`
|
| 497 |
+
- `gr00t/policy/policy.py`
|
| 498 |
+
- `gr00t/data/types.py`
|
| 499 |
+
- `gr00t/data/embodiment_tags.py`
|
| 500 |
+
|
| 501 |
+
## Common Patterns
|
| 502 |
+
|
| 503 |
+
### Batched Inference
|
| 504 |
+
|
| 505 |
+
The policy supports batched inference for efficiency:
|
| 506 |
+
|
| 507 |
+
```python
|
| 508 |
+
# Run 4 environments in parallel
|
| 509 |
+
batch_size = 4
|
| 510 |
+
observation = {
|
| 511 |
+
"video": {"wrist_cam": np.zeros((batch_size, T_video, H, W, 3), dtype=np.uint8)},
|
| 512 |
+
"state": {"joints": np.zeros((batch_size, T_state, D_state), dtype=np.float32)},
|
| 513 |
+
"language": {"task": [["pick up the cube"]] * batch_size},
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
action, _ = policy.get_action(observation)
|
| 517 |
+
# action["action_name"] has shape (batch_size, action_horizon, action_dim)
|
| 518 |
+
```
|
| 519 |
+
|
| 520 |
+
### Single Environment Inference
|
| 521 |
+
|
| 522 |
+
For single environments, use batch size of 1:
|
| 523 |
+
|
| 524 |
+
```python
|
| 525 |
+
# Add batch dimension (B=1)
|
| 526 |
+
observation = {
|
| 527 |
+
"video": {"wrist_cam": video[np.newaxis, ...]}, # (1, T, H, W, 3)
|
| 528 |
+
"state": {"joints": state[np.newaxis, ...]}, # (1, T, D)
|
| 529 |
+
"language": {"task": [["pick up the cube"]]}, # List of length 1
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
action, _ = policy.get_action(observation)
|
| 533 |
+
|
| 534 |
+
# Remove batch dimension
|
| 535 |
+
single_action = action["action_name"][0] # (action_horizon, action_dim)
|
| 536 |
+
```
|
| 537 |
+
|
| 538 |
+
### Action Chunking
|
| 539 |
+
|
| 540 |
+
When the action horizon `T > 1`, you can use action chunking:
|
| 541 |
+
|
| 542 |
+
```python
|
| 543 |
+
action, _ = policy.get_action(observation)
|
| 544 |
+
action_chunk = action["action_name"][:, :, :] # (B, T, D)
|
| 545 |
+
|
| 546 |
+
# Execute actions over multiple timesteps
|
| 547 |
+
for t in range(action_chunk.shape[1]):
|
| 548 |
+
env.step(action_chunk[:, t, :])
|
| 549 |
+
```
|
| 550 |
+
|
| 551 |
+
### Training Dataloading Optimization
|
| 552 |
+
|
| 553 |
+
When training a model, you can optimize the dataloading speed vs memory usage via various command line arguments.
|
| 554 |
+
|
| 555 |
+
examples:
|
| 556 |
+
```bash
|
| 557 |
+
uv run python gr00t/experiment/launch_finetune.py \
|
| 558 |
+
.... \
|
| 559 |
+
--num-shards-per-epoch 100 \
|
| 560 |
+
--dataloader-num-workers 2
|
| 561 |
+
--shard-size 512 \
|
| 562 |
+
```
|
| 563 |
+
|
| 564 |
+
If vram is limited, you can reduce the all the numbers above to reduce the memory usage.
|
| 565 |
+
|
| 566 |
+
To ensure more IID during sampling of shards, you can reduce the `episode_sampling_rate` to 0.05 or lower.
|
| 567 |
+
|
| 568 |
+
## Troubleshooting
|
| 569 |
+
|
| 570 |
+
1. **Enable strict mode** during development: `strict=True`
|
| 571 |
+
2. **Print modality configs** to understand expected formats
|
| 572 |
+
3. **Check shapes** of your observations before calling `get_action()`
|
| 573 |
+
4. **Use the reference wrapper** (`Gr00tSimPolicyWrapper`) as a template
|
| 574 |
+
5. **Validate incrementally**: Test with dummy observations first before connecting to real environments
|
getting_started/real_world_deployment.md
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GR00T Real-World Deployment Guide
|
| 2 |
+
|
| 3 |
+
This guide covers building an end-to-end real-world VLA pipeline—from data collection and training to deployment—with practical engineering recommendations.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
A typical GR00T real-world deployment workflow includes:
|
| 8 |
+
|
| 9 |
+
1. **[Hardware Preparation](#1-hardware-and-environment-preparation-device-requirements)**: Verify that the robot platform, sensors, and compute resources are ready.
|
| 10 |
+
2. **[Data Collection](#2-data-collection)**: Choose an appropriate teleoperation setup and collect at least 100 valid episodes.
|
| 11 |
+
3. **[Data Preprocessing](#3-data-preprocessing)**: Clean data, align timestamps, and convert to LeRobot format.
|
| 12 |
+
4. **[Model Training](#4-vla-model-training)**: Fine-tune GR00T N1.*.
|
| 13 |
+
5. **[Model Evaluation](#validation)**: Run open-loop evaluation to validate convergence and model quality.
|
| 14 |
+
6. **[Deployment Setup](#5-deployment-and-closed-loop-control)**: Build a ZMQ Server-Client architecture.
|
| 15 |
+
7. **[Closed-Loop Testing](#5-deployment-and-closed-loop-control)**: Run closed-loop control on real hardware and monitor jittering and stop-and-go behavior.
|
| 16 |
+
8. **[Optimization](#6-common-issues-jittering-and-stop-and-go)**: Tune RTC parameters and trajectory smoothing strategies based on real-world performance.
|
| 17 |
+
|
| 18 |
+
## 1. Hardware and Environment Preparation (Device Requirements)
|
| 19 |
+
|
| 20 |
+
Ensure your robot hardware, sensor pipeline, and control interfaces are stable and available.
|
| 21 |
+
|
| 22 |
+
### Robot Platform
|
| 23 |
+
|
| 24 |
+
- **Recommended platforms**: Robotic arms with SDK-level control support (e.g., Franka, UR, Piper, SO101).
|
| 25 |
+
- **Basic requirements**:
|
| 26 |
+
- Real-time joint state feedback.
|
| 27 |
+
- High-frequency action execution (30 FPS recommended).
|
| 28 |
+
- Stable control interface.
|
| 29 |
+
|
| 30 |
+
### Multimodal Sensors
|
| 31 |
+
|
| 32 |
+
| Sensor Type | Specification | Purpose |
|
| 33 |
+
|-------------|---------------|---------|
|
| 34 |
+
| **Wrist-mounted camera** | 30 FPS, RGB | Capture close-range manipulation visuals |
|
| 35 |
+
| **Third-person camera (3rd view)** | 30 FPS, RGB | Capture global scene context |
|
| 36 |
+
| **Robot proprioceptive state** | Real-time acquisition | Joint states and gripper state |
|
| 37 |
+
|
| 38 |
+
### Compute Resources
|
| 39 |
+
|
| 40 |
+
- **Training phase**: NVIDIA GPU servers (e.g., H100 or H20) are recommended for larger batch sizes.
|
| 41 |
+
- **Deployment phase**: Edge hardware such as Jetson AGX Thor supports on-device inference.
|
| 42 |
+
|
| 43 |
+
> For details, see the [hardware recommendation guide](hardware_recommendation.md).
|
| 44 |
+
|
| 45 |
+
### Teleoperation Devices
|
| 46 |
+
|
| 47 |
+
Teleoperation device selection is critical for data quality.
|
| 48 |
+
|
| 49 |
+
### Teleoperation Device Comparison
|
| 50 |
+
|
| 51 |
+
In the table below:
|
| 52 |
+
- **Embodiment dependency**: how similar the teleoperation device and target robot must be in joint topology, degrees of freedom (DoF), and workspace. Higher dependency implies harder cross-embodiment transfer.
|
| 53 |
+
- **Operational intuition**: how naturally operator inputs map to robot motion. Higher intuition means faster onboarding and lower demonstration error.
|
| 54 |
+
|
| 55 |
+
| Device Type | Cost Level (Reference) | Embodiment Dependency | Operational Intuition | Notes |
|
| 56 |
+
|-------------|------------------------|-----------------------|-----------------------|-------|
|
| 57 |
+
| **Keyboard/Gamepad/SpaceMouse/Joylo** | Low | Low: command mapping via keys/controls | Medium: requires adaptation to key-motion mapping | Low entry cost; a good starting point and useful in mobile scenarios |
|
| 58 |
+
| **Master-Slave arm systems** | Medium | High: master/slave arms usually require similar kinematics and workspace | High: near one-to-one human-robot mapping | Suitable for single-robot setups; commonly used by robot OEMs; can reduce the risk of reaching joint limits during demonstrations |
|
| 59 |
+
| **UMI / Fast-UMI / Pika Sense** | Medium | Low: hardware-agnostic action representation reusable across arms | High: after calibration, end-effector (EEF) following is intuitive | Suitable for training general VLA models; low-DoF arms may still hit joint limits |
|
| 60 |
+
| **VR-based teleoperation** | Medium (headset + rendering + network) | Low: mainly depends on software integration | Medium: depends on immersive visual feedback and tracking quality | A flexible solution, but with higher integration overhead |
|
| 61 |
+
| **Glove / Motion Capture** | High (commercial mocap suite + data gloves) | Low: retarget through kinematic mapping to different embodiments | High: intuitive full-hand/full-body control | Suitable for full-body control and dexterous-hand tasks |
|
| 62 |
+
| **Exoskeleton** | High | High: usually requires matched joint structure | High: natural action correspondence | Extendable to multi-joint humanoid control |
|
| 63 |
+
|
| 64 |
+
## 2. Data Collection
|
| 65 |
+
|
| 66 |
+
Key considerations for data collection:
|
| 67 |
+
|
| 68 |
+
### Timestamp Synchronization
|
| 69 |
+
|
| 70 |
+
- The FPS of both camera streams should be strictly matched, and capture triggers should be as synchronized as possible.
|
| 71 |
+
- Joint state sampling frequency should exceed camera FPS to enable accurate downsampling.
|
| 72 |
+
- Record full timestamps during collection for downstream temporal alignment.
|
| 73 |
+
|
| 74 |
+
### Action Representation
|
| 75 |
+
|
| 76 |
+
- If training and collection use the same embodiment (e.g., master-slave arms), log joint-space `Joint States` during collection. For task-space models, compute EEF pose via forward kinematics (FK) in post-processing.
|
| 77 |
+
- If embodiments differ (e.g., collect with UMI, deploy on Piper), directly record task-space EEF pose during collection.
|
| 78 |
+
|
| 79 |
+
### Data Distribution
|
| 80 |
+
|
| 81 |
+
- Current imitation-learning-based models perform more reliably in previously seen scenarios. In early-stage experiments, start with data collection and validation in a limited domain.
|
| 82 |
+
- After pipeline validation, gradually expand the domain by varying lighting, object placement, and initial robot poses to improve generalization.
|
| 83 |
+
|
| 84 |
+
### Scene Consistency
|
| 85 |
+
|
| 86 |
+
- Keep third-person camera extrinsics fixed and ensure a rigid wrist-camera mount.
|
| 87 |
+
- In early experiments, prioritize scene consistency; avoid varying lighting, object placement, or initial robot poses.
|
| 88 |
+
|
| 89 |
+
### Joint Limits
|
| 90 |
+
|
| 91 |
+
- If collecting joint-space data, avoid operating near joint limits to reduce the number of samples in those regions.
|
| 92 |
+
|
| 93 |
+
## 3. Data Preprocessing
|
| 94 |
+
|
| 95 |
+
Raw data must be cleaned, synchronized, and converted before training.
|
| 96 |
+
|
| 97 |
+
### Trajectory Filtering
|
| 98 |
+
|
| 99 |
+
Data filtering is recommended in two stages: script-based filtering and manual review.
|
| 100 |
+
|
| 101 |
+
#### Script Filtering
|
| 102 |
+
|
| 103 |
+
- Check image timestamps and remove samples with:
|
| 104 |
+
1. Excessive latency in a single camera stream.
|
| 105 |
+
2. Excessive timestamp difference between the two camera streams.
|
| 106 |
+
- Detect and remove abnormal jumps in robot state sequences.
|
| 107 |
+
|
| 108 |
+
#### Manual Filtering
|
| 109 |
+
|
| 110 |
+
Replay trajectories with synchronized visualization to catch issues missed by scripts:
|
| 111 |
+
|
| 112 |
+
- Remove samples with poor synchronization between image and action sequences.
|
| 113 |
+
- Remove blurry frames.
|
| 114 |
+
- Remove failed task executions.
|
| 115 |
+
- Remove low-quality trajectories (e.g., redundant paths, discontinuous actions).
|
| 116 |
+
|
| 117 |
+
### Trajectory Preprocessing
|
| 118 |
+
|
| 119 |
+
1. Timestamp alignment: align camera frames and robot joint states to a shared time base.
|
| 120 |
+
2. Head-tail trimming: remove idle segments at the start and end of trajectories.
|
| 121 |
+
3. Split long trajectories (several minutes) into multiple subtasks.
|
| 122 |
+
|
| 123 |
+
### Format Conversion
|
| 124 |
+
|
| 125 |
+
Convert all data to a standard format (e.g., LeRobot) for GR00T compatibility:
|
| 126 |
+
|
| 127 |
+
- See the [data preparation guide](data_preparation.md) for format requirements.
|
| 128 |
+
- Use the provided conversion scripts to convert data to GR00T LeRobot format.
|
| 129 |
+
|
| 130 |
+
## 4. VLA Model Training
|
| 131 |
+
|
| 132 |
+
### Training Parameter Configuration
|
| 133 |
+
|
| 134 |
+
**Dataset size recommendations**
|
| 135 |
+
|
| 136 |
+
For single-task `finetune`:
|
| 137 |
+
|
| 138 |
+
- **Minimum data size**: Prepare at least **100 valid episodes**. For very narrow task domains, ~30 episodes may suffice. A capture frequency of 20–50 Hz is recommended for manipulation tasks.
|
| 139 |
+
- **Episode length**: No hard limit, but each episode must contain a complete action cycle with idle frames removed. Split overly long episodes into subtasks.
|
| 140 |
+
- **Recommended data size**: 200+ episodes usually provide more stable performance.
|
| 141 |
+
|
| 142 |
+
**Core parameters**
|
| 143 |
+
|
| 144 |
+
- **Input/output mode**: Default to `State-relative Action Prediction`. Compared with `Absolute Action`, it converges more easily and improves inter-chunk consistency.
|
| 145 |
+
- **Training space**: Both joint space and task space are valid. For low-DoF arms, joint space is often preferred to reduce singularity-related risks.
|
| 146 |
+
- **Action Chunk Size**: Default is 16. If combined with RTC to mitigate stop-and-go, set it to at least 32.
|
| 147 |
+
- **Batch Size**: Increase the batch size as much as GPU memory allows.
|
| 148 |
+
|
| 149 |
+
> For additional training options, see the [fine-tuning guide](finetune_new_embodiment.md).
|
| 150 |
+
|
| 151 |
+
**Compute resources**
|
| 152 |
+
|
| 153 |
+
- Fine-tuning requires significantly less compute than pretraining.
|
| 154 |
+
- A single compute node (8 x H100 or 8 x H20) is usually sufficient.
|
| 155 |
+
|
| 156 |
+
### Validation
|
| 157 |
+
|
| 158 |
+
After training, run open-loop validation to confirm convergence, then proceed to closed-loop deployment validation.
|
| 159 |
+
|
| 160 |
+
> Open-loop validation is only a preliminary check. Final performance must be verified with closed-loop testing on real robots. For details, see the [fine-tuning guide](finetune_new_embodiment.md).
|
| 161 |
+
|
| 162 |
+
## 5. Deployment and Closed-Loop Control
|
| 163 |
+
|
| 164 |
+
### System Architecture
|
| 165 |
+
|
| 166 |
+
GR00T supports two inference modes:
|
| 167 |
+
|
| 168 |
+
1. **Direct `Gr00tPolicy` usage**: Suitable when model inference and robot control run on the same machine.
|
| 169 |
+
2. **ZMQ Server-Client architecture**: Suitable for real-world deployment and decouples local robot control (`Local Client`) from remote inference (`Model Server`).
|
| 170 |
+
|
| 171 |
+
For real-world deployment, **ZMQ inference service** is recommended:
|
| 172 |
+
|
| 173 |
+
- Move compute-intensive inference to GPU servers.
|
| 174 |
+
- Keep robot-side control code lightweight.
|
| 175 |
+
- Avoid installing the full inference dependency stack on the robot side.
|
| 176 |
+
|
| 177 |
+
### On-Device Deployment Logic
|
| 178 |
+
|
| 179 |
+
Deployment code has two phases: **initialization** and the **main control loop**.
|
| 180 |
+
The pseudo-code below uses a synchronous workflow, which may cause stop-and-go. See later sections for mitigation via asynchronous execution + RTC.
|
| 181 |
+
|
| 182 |
+
**Pseudo-code workflow:**
|
| 183 |
+
|
| 184 |
+
```python
|
| 185 |
+
# ========== Initialization ==========
|
| 186 |
+
# 1. Initialize and test cameras
|
| 187 |
+
hand_camera = initialize_hand_camera() # e.g., OrbbecSDK
|
| 188 |
+
env_camera = initialize_env_camera() # e.g., RealSense
|
| 189 |
+
test_cameras() # Show preview and verify normal operation
|
| 190 |
+
|
| 191 |
+
# 2. Connect and test robot
|
| 192 |
+
robot = connect_robot() # e.g., Piper SDK
|
| 193 |
+
robot.enable()
|
| 194 |
+
robot.reset_to_initial_position()
|
| 195 |
+
test_robot() # Send test command and verify robot response
|
| 196 |
+
|
| 197 |
+
# 3. Connect and test GR00T model server
|
| 198 |
+
gr00t_client = connect_to_gr00t_server(host, port)
|
| 199 |
+
if not gr00t_client.ping():
|
| 200 |
+
raise ConnectionError("Failed to connect to model server")
|
| 201 |
+
test_model() # Send test observation and verify inference
|
| 202 |
+
|
| 203 |
+
# ========== Main control loop ==========
|
| 204 |
+
while True:
|
| 205 |
+
# 1. Acquire sensor data
|
| 206 |
+
hand_image = hand_camera.get_frame()
|
| 207 |
+
env_image = env_camera.get_frame()
|
| 208 |
+
joint_states = robot.get_joint_states()
|
| 209 |
+
gripper_state = robot.get_gripper_state()
|
| 210 |
+
|
| 211 |
+
# 2. Format observation
|
| 212 |
+
observation = format_observation(
|
| 213 |
+
hand_image,
|
| 214 |
+
env_image,
|
| 215 |
+
joint_states,
|
| 216 |
+
gripper_state,
|
| 217 |
+
task_description,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# 3. Model inference (via ZMQ)
|
| 221 |
+
actions = gr00t_client.get_action(observation)
|
| 222 |
+
|
| 223 |
+
# 4. Trajectory post-processing
|
| 224 |
+
actions_arm = actions["joint_states"]
|
| 225 |
+
actions_arm = smooth_trajectory(actions_arm) # smoothing
|
| 226 |
+
actions_arm = check_safety_limits(actions_arm) # safety checks
|
| 227 |
+
|
| 228 |
+
# 5. Execute actions
|
| 229 |
+
for action_step in actions_arm:
|
| 230 |
+
robot.execute_action(action_step)
|
| 231 |
+
sleep(1.0 / 30.0) # 30 FPS
|
| 232 |
+
```
|
| 233 |
+
|
| 234 |
+
### Key Implementation Notes
|
| 235 |
+
|
| 236 |
+
**Important notes:**
|
| 237 |
+
|
| 238 |
+
- **Image format**: Use compressed formats such as JPG to reduce transmission bandwidth.
|
| 239 |
+
- **Safe operation**:
|
| 240 |
+
- **Soft Limits**: Add joint-angle and EEF pose range checks. If a predicted action exceeds workspace bounds, raise an alarm and stop immediately.
|
| 241 |
+
- **E-Stop logic**: Bind an emergency stop hotkey (e.g., Space) on the control PC, or use a physical E-Stop switch.
|
| 242 |
+
- **Action smoothing**: Apply interpolation and smoothing to predicted action sequences.
|
| 243 |
+
|
| 244 |
+
> For more deployment details, see the [policy API guide](policy.md).
|
| 245 |
+
|
| 246 |
+
## 6. Common Issues: Jittering and Stop-and-Go
|
| 247 |
+
|
| 248 |
+
The most common issues in real-world deployment are **jittering** and **stop-and-go**.
|
| 249 |
+
|
| 250 |
+
### Fixing Jittering
|
| 251 |
+
|
| 252 |
+
**Jittering** here refers to visible shaking or vibration of the end-effector or joints during task execution.
|
| 253 |
+
|
| 254 |
+
Jittering typically originates from **inconsistent model outputs** or **insufficient robot-side control quality**. Analyze these two components separately to localize the issue. The suggestions below are general guidelines and may not apply to every robot platform or control stack — always verify against your own hardware and environment.
|
| 255 |
+
|
| 256 |
+
```mermaid
|
| 257 |
+
flowchart TD
|
| 258 |
+
A[Jittering observed] --> B[Save & visualize action chunks in 3D]
|
| 259 |
+
B --> C{Where is the jitter?}
|
| 260 |
+
C -->|Inside each chunk| D[Case A: Model undertrained or poor data quality]
|
| 261 |
+
C -->|Between consecutive chunks| E[Case B: Inconsistent chunk predictions]
|
| 262 |
+
C -->|Chunks look smooth| F[Case C: Robot hardware / low-level control issue]
|
| 263 |
+
D --> D1[Add more data, train longer, check train/eval consistency]
|
| 264 |
+
E --> E1[Use state-relative actions + RTC chunking strategy]
|
| 265 |
+
F --> F1[Check drive control, interpolation, hardware status]
|
| 266 |
+
```
|
| 267 |
+
|
| 268 |
+
**Diagnosis and mitigation**
|
| 269 |
+
|
| 270 |
+
1. **Save and visualize Action Chunks**
|
| 271 |
+
- Save all predicted `Action Chunks`.
|
| 272 |
+
- Visualize continuous TCP (tool center point) trajectories in 3D.
|
| 273 |
+
- **Note**: Convert joint-space outputs to task space via FK before visualization.
|
| 274 |
+
|
| 275 |
+
2. **Analyze visualization results**
|
| 276 |
+
|
| 277 |
+
**Case A: Significant jitter inside each chunk**
|
| 278 |
+
- **Cause**: The model is undertrained, or data quality is insufficient.
|
| 279 |
+
- **Solution**: Improve data quality, add more training data, or train longer. Keep training and validation environments consistent.
|
| 280 |
+
|
| 281 |
+
**Case B: Significant jitter between chunks**
|
| 282 |
+
- **Cause**: Inconsistent adjacent `Action Chunk` predictions.
|
| 283 |
+
- **Solution**:
|
| 284 |
+
- Use `State-relative Action Prediction`. Predicting actions relative to the current state produces a more uniform output distribution, making the network easier to train.
|
| 285 |
+
- Use RTC (`Real-Time Chunking`) or similar strategies.
|
| 286 |
+
|
| 287 |
+
**Case C: Little jitter after visualization**
|
| 288 |
+
- **Cause**: Likely a robot hardware or low-level control issue.
|
| 289 |
+
- **Solution**: Check drive control, interpolation, and hardware status.
|
| 290 |
+
|
| 291 |
+
**Quantitative diagnostic metrics**
|
| 292 |
+
|
| 293 |
+
Trajectory jitter can also be quantified using these three metrics:
|
| 294 |
+
|
| 295 |
+
**Metric 1: Mean intra-chunk acceleration magnitude**
|
| 296 |
+
|
| 297 |
+
Measures intra-chunk smoothness. Only valid under fixed sampling frequency.
|
| 298 |
+
|
| 299 |
+
Formula: $a_t = pos_{t+1} - 2 \cdot pos_t + pos_{t-1}$
|
| 300 |
+
|
| 301 |
+
```python
|
| 302 |
+
def metric_intra_accel(chunks):
|
| 303 |
+
"""
|
| 304 |
+
Args:
|
| 305 |
+
chunks: numpy array with shape (N_chunks, Chunk_Length, Joint_Dim)
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
float: Mean acceleration magnitude
|
| 309 |
+
"""
|
| 310 |
+
velocity = np.diff(chunks, axis=1) # first-order difference
|
| 311 |
+
acceleration = np.diff(velocity, axis=1) # second-order difference
|
| 312 |
+
acc_magnitude = np.linalg.norm(acceleration, axis=-1) # L2 norm per step
|
| 313 |
+
return np.mean(acc_magnitude)
|
| 314 |
+
```
|
| 315 |
+
|
| 316 |
+
**Metric 2: Position jump at chunk boundary (L2 distance)**
|
| 317 |
+
|
| 318 |
+
Measures position continuity between chunks by comparing the last executed step of `Chunk[i]` with step 0 of `Chunk[i+1]`.
|
| 319 |
+
|
| 320 |
+
```python
|
| 321 |
+
def metric_boundary_jump(chunks, execute_steps=None):
|
| 322 |
+
"""
|
| 323 |
+
Args:
|
| 324 |
+
chunks: numpy array with shape (N_chunks, Chunk_Length, Joint_Dim)
|
| 325 |
+
execute_steps: number of executed steps per chunk; if None, use full chunk length
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
float: Mean position jump
|
| 329 |
+
"""
|
| 330 |
+
chunks = np.array(chunks)
|
| 331 |
+
exec_steps = chunks.shape[1] if execute_steps is None else execute_steps
|
| 332 |
+
|
| 333 |
+
last_frame_prev = chunks[:-1, exec_steps - 1, :] # last frame of previous chunk
|
| 334 |
+
first_frame_curr = chunks[1:, 0, :] # first frame of current chunk
|
| 335 |
+
jumps = np.linalg.norm(first_frame_curr - last_frame_prev, axis=-1) # Euclidean distance
|
| 336 |
+
return np.mean(jumps)
|
| 337 |
+
```
|
| 338 |
+
|
| 339 |
+
**Metric 3: Cosine similarity of velocity direction at chunk boundary**
|
| 340 |
+
|
| 341 |
+
Measures velocity-direction consistency between chunks. Values closer to 1 indicate better consistency.
|
| 342 |
+
|
| 343 |
+
```python
|
| 344 |
+
def metric_momentum_shift(chunks, execute_steps=None):
|
| 345 |
+
"""
|
| 346 |
+
Args:
|
| 347 |
+
chunks: numpy array with shape (N_chunks, Chunk_Length, Joint_Dim)
|
| 348 |
+
execute_steps: number of executed steps per chunk; if None, use full chunk length
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
float: Mean cosine similarity
|
| 352 |
+
"""
|
| 353 |
+
chunks = np.array(chunks)
|
| 354 |
+
exec_steps = chunks.shape[1] if execute_steps is None else execute_steps
|
| 355 |
+
|
| 356 |
+
# velocity at the end of previous chunk
|
| 357 |
+
idx = exec_steps - 1
|
| 358 |
+
if idx < 1:
|
| 359 |
+
raise ValueError("execute_steps must be >= 2 to compute end velocity")
|
| 360 |
+
v_end = chunks[:-1, idx, :] - chunks[:-1, idx - 1, :]
|
| 361 |
+
|
| 362 |
+
# velocity at the start of current chunk
|
| 363 |
+
v_start = chunks[1:, 1, :] - chunks[1:, 0, :]
|
| 364 |
+
|
| 365 |
+
# cosine similarity
|
| 366 |
+
dot_product = np.sum(v_end * v_start, axis=-1)
|
| 367 |
+
norm_prev = np.linalg.norm(v_end, axis=-1)
|
| 368 |
+
norm_curr = np.linalg.norm(v_start, axis=-1)
|
| 369 |
+
epsilon = 1e-8
|
| 370 |
+
cosine_sim = dot_product / (norm_prev * norm_curr + epsilon)
|
| 371 |
+
|
| 372 |
+
return np.mean(cosine_sim)
|
| 373 |
+
```
|
| 374 |
+
|
| 375 |
+
### Fixing Stop-and-Go
|
| 376 |
+
Stop-and-Go here refers to a behavior in which the robot intermittently pauses during motion, producing periodic stop-and-go behavior.
|
| 377 |
+
|
| 378 |
+
#### Root Cause
|
| 379 |
+
|
| 380 |
+
In **synchronous single-step closed-loop** control, stop-and-go occurs when the **end-to-end latency** (observation capture → VLA inference → action conversion) exceeds control-frequency requirements.
|
| 381 |
+
|
| 382 |
+
- **Control-frequency requirement**: At 30 FPS, latency must stay below ~33 ms.
|
| 383 |
+
- **Typical latency sources**: Data capture, network transfer, model inference, and post-processing often exceed 33 ms combined.
|
| 384 |
+
- **Consequence**: The next prediction is not ready when the current action finishes, causing pauses.
|
| 385 |
+
|
| 386 |
+
#### Solutions
|
| 387 |
+
|
| 388 |
+
**Option 1: Optimize the inference pipeline (direct but difficult)**
|
| 389 |
+
|
| 390 |
+
Reduce full workflow latency below 33 ms:
|
| 391 |
+
|
| 392 |
+
- Optimize network bandwidth (reduce transfer time).
|
| 393 |
+
- Use edge inference (reduce network latency).
|
| 394 |
+
- Quantize the VLA model (speed up inference).
|
| 395 |
+
- Use a smaller model (e.g., ACT).
|
| 396 |
+
|
| 397 |
+
**Limitation**: For VLA models, meeting strict real-time requirements through optimization alone is often impractical.
|
| 398 |
+
|
| 399 |
+
**Option 2: Use algorithmic scheduling strategies (recommended)**
|
| 400 |
+
|
| 401 |
+
When direct optimization is insufficient, use one or more of the following:
|
| 402 |
+
|
| 403 |
+
- **Asynchronous Inference**: A background thread runs inference while the main thread executes actions.
|
| 404 |
+
- **Receding Horizon**: Execute only the first few steps of each `Action Chunk` before triggering a new inference.
|
| 405 |
+
- **Temporal Ensemble**: Aggregate predictions across multiple timesteps.
|
| 406 |
+
- **Real-Time Chunking (RTC)**: Overlap the start of the current prediction with unexecuted steps from the previous one.
|
| 407 |
+
|
| 408 |
+
**Recommended strategy**: `Asynchronous Inference + RTC` is usually the most effective.
|
| 409 |
+
|
| 410 |
+
#### Real-Time Chunking (RTC) Details
|
| 411 |
+
|
| 412 |
+
**Principle**
|
| 413 |
+
|
| 414 |
+
RTC treats action prediction as an inpainting problem: overlapping the start of the new prediction with unexecuted steps from the previous one ensures smooth transitions.
|
| 415 |
+
|
| 416 |
+
**Applicability**
|
| 417 |
+
|
| 418 |
+
- Validated for **diffusion / flow-based** VLA policies.
|
| 419 |
+
- Requires `Action Chunk` length ≥ 32 steps.
|
| 420 |
+
- Should be combined with asynchronous inference.
|
| 421 |
+
|
| 422 |
+
**Implementation essentials**
|
| 423 |
+
|
| 424 |
+
1. **Predict longer Action Chunks**:
|
| 425 |
+
- Increase from the default 16 steps to at least 32.
|
| 426 |
+
- Provide a larger soft fusion window.
|
| 427 |
+
|
| 428 |
+
2. **Asynchronous inference architecture**:
|
| 429 |
+
- **Background thread**: Continuously infer, capture observations, and prepare action batches.
|
| 430 |
+
- **Main thread**: Execute the current action sequence.
|
| 431 |
+
- Buffer predictions in a queue to avoid blocking.
|
| 432 |
+
|
| 433 |
+
3. **Action fusion mechanism**:
|
| 434 |
+
- Use RTC for soft fusion in the overlap region.
|
| 435 |
+
- Ensure smooth transitions between adjacent chunks.
|
| 436 |
+
|
| 437 |
+
**Pseudocode: Async Inference + RTC**
|
| 438 |
+
|
| 439 |
+
In the RTC (Real-Time Chunking) framework, two key parameters control how adjacent action chunks overlap and transition:
|
| 440 |
+
|
| 441 |
+
- **`overlap`**: The number of action steps retained from the previous prediction to constrain the current one, ensuring temporal consistency between consecutive chunks.
|
| 442 |
+
- **`frozen`**: The number of steps that remain completely frozen (i.e., not updated by the new prediction), typically set to match the inference latency.
|
| 443 |
+
|
| 444 |
+
Below is a simplified async inference + RTC loop. Note that official RTC support for GR00T is coming soon; the current implementation may require manual adaptation.
|
| 445 |
+
|
| 446 |
+
```
|
| 447 |
+
actions = policy.infer(obs) # blocking first call
|
| 448 |
+
|
| 449 |
+
loop:
|
| 450 |
+
for i in range(action_horizon):
|
| 451 |
+
if i == action_horizon - overlap - 1:
|
| 452 |
+
future = async policy.infer(new_obs) # non-blocking
|
| 453 |
+
robot.execute(actions[i])
|
| 454 |
+
if i == action_horizon - frozen - 1:
|
| 455 |
+
actions = future.get() # swap in next chunk
|
| 456 |
+
break # discard frozen tail
|
| 457 |
+
```
|
| 458 |
+
|
| 459 |
+
|
gr00t/__init__.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _patch_hf_local_first() -> None:
|
| 20 |
+
"""Patch from_pretrained to prefer the local HF snapshot cache over network calls.
|
| 21 |
+
|
| 22 |
+
When a HF repo ID is passed we try snapshot_download(local_files_only=True)
|
| 23 |
+
first; if the model is not cached we fall through to the normal download path.
|
| 24 |
+
This avoids 429 rate-limit errors when many CI jobs run concurrently.
|
| 25 |
+
|
| 26 |
+
Covers: PreTrainedModel, PretrainedConfig, ProcessorMixin, AutoConfig,
|
| 27 |
+
AutoProcessor — every transformers from_pretrained entrypoint.
|
| 28 |
+
|
| 29 |
+
Triggered by GROOT_HF_LOCAL_FIRST (set by conftest.py, survives uv run) or
|
| 30 |
+
PYTEST_CURRENT_TEST (set automatically by pytest).
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def _resolve(name_or_path: str) -> str:
|
| 34 |
+
hf_home = os.environ.get("HF_HOME")
|
| 35 |
+
hf_hub = os.environ.get("HUGGINGFACE_HUB_CACHE")
|
| 36 |
+
hf_cache_info = f"HF_HOME={hf_home} HUGGINGFACE_HUB_CACHE={hf_hub}"
|
| 37 |
+
if os.path.isdir(name_or_path):
|
| 38 |
+
print(f"[groot/hf] local path: {name_or_path} | {hf_cache_info}", flush=True)
|
| 39 |
+
return name_or_path
|
| 40 |
+
try:
|
| 41 |
+
from huggingface_hub import snapshot_download
|
| 42 |
+
|
| 43 |
+
resolved = snapshot_download(name_or_path, local_files_only=True)
|
| 44 |
+
print(
|
| 45 |
+
f"[groot/hf] cache hit: {name_or_path} -> {resolved} | {hf_cache_info}", flush=True
|
| 46 |
+
)
|
| 47 |
+
return resolved
|
| 48 |
+
except Exception:
|
| 49 |
+
print(
|
| 50 |
+
f"[groot/hf] cache miss (will download): {name_or_path} | {hf_cache_info}",
|
| 51 |
+
flush=True,
|
| 52 |
+
)
|
| 53 |
+
return name_or_path
|
| 54 |
+
|
| 55 |
+
def _wrap(cls: type) -> None:
|
| 56 |
+
if "from_pretrained" not in cls.__dict__:
|
| 57 |
+
return
|
| 58 |
+
original = cls.from_pretrained
|
| 59 |
+
if getattr(original, "_groot_hf_local_patched", False):
|
| 60 |
+
return
|
| 61 |
+
|
| 62 |
+
def _make_patched(orig):
|
| 63 |
+
@classmethod # type: ignore[misc]
|
| 64 |
+
def patched(klass, pretrained_model_name_or_path, *args, **kwargs):
|
| 65 |
+
resolved = _resolve(str(pretrained_model_name_or_path))
|
| 66 |
+
|
| 67 |
+
return orig.__func__(klass, resolved, *args, **kwargs)
|
| 68 |
+
|
| 69 |
+
patched._groot_hf_local_patched = True # type: ignore[attr-defined]
|
| 70 |
+
return patched
|
| 71 |
+
|
| 72 |
+
cls.from_pretrained = _make_patched(original)
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
import transformers as _transformers
|
| 76 |
+
|
| 77 |
+
for _attr in (
|
| 78 |
+
"PreTrainedModel",
|
| 79 |
+
"PretrainedConfig",
|
| 80 |
+
"ProcessorMixin",
|
| 81 |
+
"AutoConfig",
|
| 82 |
+
"AutoProcessor",
|
| 83 |
+
):
|
| 84 |
+
_cls = getattr(_transformers, _attr, None)
|
| 85 |
+
if _cls is not None:
|
| 86 |
+
_wrap(_cls)
|
| 87 |
+
except Exception:
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _patch_mistral() -> None:
|
| 92 |
+
"""Suppress 429 / connection errors from the HuggingFace Hub in mistral regex patching.
|
| 93 |
+
|
| 94 |
+
transformers calls model_info() inside a nested is_base_mistral() function
|
| 95 |
+
unconditionally even when loading from a fully local checkpoint. Qwen3VL /
|
| 96 |
+
Cosmos is never Mistral, so returning the tokenizer unchanged on any network
|
| 97 |
+
failure is correct.
|
| 98 |
+
|
| 99 |
+
NOTE: is_base_mistral is a *nested* function inside _patch_mistral_regex, so
|
| 100 |
+
it is not accessible as a module-level attribute — we must wrap the classmethod.
|
| 101 |
+
|
| 102 |
+
Triggered by GROOT_PATCH_MISTRAL (set by conftest.py, survives uv run) or
|
| 103 |
+
PYTEST_CURRENT_TEST (set automatically by pytest, belt-and-suspenders).
|
| 104 |
+
"""
|
| 105 |
+
try:
|
| 106 |
+
import transformers.tokenization_utils_base as _tub
|
| 107 |
+
|
| 108 |
+
_cls = _tub.PreTrainedTokenizerBase
|
| 109 |
+
_orig = _cls._patch_mistral_regex.__func__
|
| 110 |
+
if getattr(_orig, "_groot_patched", False):
|
| 111 |
+
return
|
| 112 |
+
|
| 113 |
+
def _safe(cls, tokenizer, pretrained_model_name_or_path, **kwargs):
|
| 114 |
+
try:
|
| 115 |
+
return _orig(cls, tokenizer, pretrained_model_name_or_path, **kwargs)
|
| 116 |
+
except Exception:
|
| 117 |
+
return tokenizer
|
| 118 |
+
|
| 119 |
+
_safe._groot_patched = True # type: ignore[attr-defined]
|
| 120 |
+
_cls._patch_mistral_regex = classmethod(_safe)
|
| 121 |
+
except Exception:
|
| 122 |
+
pass
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if os.environ.get("PYTEST_CURRENT_TEST") or os.environ.get("GROOT_HF_LOCAL_FIRST"):
|
| 126 |
+
_patch_hf_local_first()
|
| 127 |
+
|
| 128 |
+
if os.environ.get("PYTEST_CURRENT_TEST") or os.environ.get("GROOT_PATCH_MISTRAL"):
|
| 129 |
+
_patch_mistral()
|
gr00t/configs/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
gr00t/configs/base_config.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
import json
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import List, Optional
|
| 20 |
+
|
| 21 |
+
import yaml
|
| 22 |
+
|
| 23 |
+
from gr00t.data.types import ActionConfig, ActionFormat, ActionRepresentation, ActionType
|
| 24 |
+
|
| 25 |
+
from .data.data_config import DataConfig, SingleDatasetConfig
|
| 26 |
+
from .model import create_model_union_type
|
| 27 |
+
from .model.gr00t_n1d7 import Gr00tN1d7Config
|
| 28 |
+
from .training.training_config import TrainingConfig
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
ModelUnionType = create_model_union_type()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class Config:
|
| 36 |
+
"""Complete configuration."""
|
| 37 |
+
|
| 38 |
+
load_config_path: Optional[str] = None
|
| 39 |
+
model: ModelUnionType = field(default_factory=lambda: Gr00tN1d7Config())
|
| 40 |
+
data: DataConfig = field(default_factory=DataConfig)
|
| 41 |
+
training: TrainingConfig = field(default_factory=TrainingConfig)
|
| 42 |
+
|
| 43 |
+
def save(self, path: Path):
|
| 44 |
+
"""Save configuration to YAML file."""
|
| 45 |
+
path = Path(path)
|
| 46 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 47 |
+
|
| 48 |
+
with open(path, "w") as f:
|
| 49 |
+
yaml.dump(self, f)
|
| 50 |
+
|
| 51 |
+
def load(self, path: Path):
|
| 52 |
+
"""Load configuration from YAML file."""
|
| 53 |
+
data = yaml.load(path.read_text(), Loader=yaml.Loader)
|
| 54 |
+
if isinstance(data, dict): # for training
|
| 55 |
+
self.load_dict(data)
|
| 56 |
+
elif isinstance(data, self.__class__):
|
| 57 |
+
self = data
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Invalid config file: {path}")
|
| 60 |
+
# config = cls(**config) # if yaml.dump(self.__dict__, ...) is used
|
| 61 |
+
return self
|
| 62 |
+
|
| 63 |
+
def load_dict(self, data: dict):
|
| 64 |
+
if "model" in data:
|
| 65 |
+
self.model = self.model.__class__(**data["model"])
|
| 66 |
+
if "data" in data:
|
| 67 |
+
self.data = DataConfig(**data["data"])
|
| 68 |
+
# Ensure nested datasets are converted to dataclass instances
|
| 69 |
+
converted: List[SingleDatasetConfig] = []
|
| 70 |
+
for ds in self.data.datasets:
|
| 71 |
+
if isinstance(ds, dict):
|
| 72 |
+
converted.append(SingleDatasetConfig(**ds))
|
| 73 |
+
else:
|
| 74 |
+
converted.append(ds)
|
| 75 |
+
self.data.datasets = converted
|
| 76 |
+
if "training" in data:
|
| 77 |
+
self.training = TrainingConfig(**data["training"])
|
| 78 |
+
return self
|
| 79 |
+
|
| 80 |
+
@classmethod
|
| 81 |
+
def from_pretrained(cls, path: Path) -> "Config":
|
| 82 |
+
"""Load configuration from YAML file."""
|
| 83 |
+
data = yaml.load(path.read_text(), Loader=yaml.Loader)
|
| 84 |
+
return data
|
| 85 |
+
|
| 86 |
+
def get_deepspeed_config(self) -> dict:
|
| 87 |
+
"""Generate DeepSpeed configuration."""
|
| 88 |
+
stage = self.training.deepspeed_stage
|
| 89 |
+
|
| 90 |
+
gr00t_dir = Path(__file__).parent.parent
|
| 91 |
+
if stage == 2:
|
| 92 |
+
config = json.load(open(gr00t_dir / "configs/deepspeed/zero2_config.json"))
|
| 93 |
+
elif stage == 3:
|
| 94 |
+
config = json.load(open(gr00t_dir / "configs/deepspeed/zero3_config.json"))
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError(f"Invalid DeepSpeed stage: {stage}")
|
| 97 |
+
|
| 98 |
+
return config
|
| 99 |
+
|
| 100 |
+
def validate(self):
|
| 101 |
+
"""Validate configuration."""
|
| 102 |
+
# Check dataset path(s)
|
| 103 |
+
embodiment_tags = set()
|
| 104 |
+
for d_cfg in self.data.datasets:
|
| 105 |
+
# (Disable missing data check because we now support caching PDX data sources.)
|
| 106 |
+
# if not Path(d_cfg.dataset_path).exists():
|
| 107 |
+
# raise ValueError(f"Dataset path does not exist: {d_cfg.dataset_path}")
|
| 108 |
+
if d_cfg.dataset_type == "physical_embodiment" and not d_cfg.embodiment_tag:
|
| 109 |
+
raise ValueError(f"Embodiment tag is empty for dataset {d_cfg.dataset_path}")
|
| 110 |
+
if d_cfg.embodiment_tag is not None:
|
| 111 |
+
embodiment_tags.add(d_cfg.embodiment_tag)
|
| 112 |
+
|
| 113 |
+
stripped_modality_configs = {}
|
| 114 |
+
for embodiment_tag in embodiment_tags:
|
| 115 |
+
modality_cfg = self.data.modality_configs.get(embodiment_tag)
|
| 116 |
+
if modality_cfg is None:
|
| 117 |
+
raise ValueError(
|
| 118 |
+
f"No modality config registered for embodiment tag '{embodiment_tag}'. "
|
| 119 |
+
f"Available tags: {sorted(self.data.modality_configs.keys())}. "
|
| 120 |
+
f"Provide --modality-config-path to register a custom modality config, "
|
| 121 |
+
f"or use one of the pre-registered tags."
|
| 122 |
+
)
|
| 123 |
+
stripped_modality_configs[embodiment_tag] = modality_cfg
|
| 124 |
+
self.data.modality_configs = stripped_modality_configs
|
| 125 |
+
|
| 126 |
+
# ensure mix ratios are valid
|
| 127 |
+
total_ratio = sum(d.mix_ratio for d in self.data.datasets)
|
| 128 |
+
if total_ratio <= 0:
|
| 129 |
+
raise ValueError("Sum of mix_ratio must be greater than zero")
|
| 130 |
+
|
| 131 |
+
# Fill in default values for action configs
|
| 132 |
+
for embodiment_tag in self.data.modality_configs:
|
| 133 |
+
# Fill in default values for action representation, type and format
|
| 134 |
+
if self.data.modality_configs[embodiment_tag]["action"].action_configs is None:
|
| 135 |
+
self.data.modality_configs[embodiment_tag]["action"].action_configs = [
|
| 136 |
+
ActionConfig(
|
| 137 |
+
rep=ActionRepresentation.ABSOLUTE,
|
| 138 |
+
type=ActionType.NON_EEF,
|
| 139 |
+
format=ActionFormat.DEFAULT,
|
| 140 |
+
)
|
| 141 |
+
] * len(self.data.modality_configs[embodiment_tag]["action"].modality_keys)
|
| 142 |
+
|
| 143 |
+
# Validate precision settings
|
| 144 |
+
if self.training.fp16 and self.training.bf16:
|
| 145 |
+
raise ValueError("Cannot use both fp16 and bf16")
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def get_default_config() -> Config:
|
| 149 |
+
"""Get default configuration."""
|
| 150 |
+
return Config()
|
gr00t/configs/data/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
gr00t/configs/data/data_config.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import Any, List, Optional
|
| 18 |
+
|
| 19 |
+
from gr00t.data.types import ModalityConfig
|
| 20 |
+
|
| 21 |
+
from .embodiment_configs import MODALITY_CONFIGS
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class SingleDatasetConfig:
|
| 26 |
+
"""Configuration for a single dataset in a mixed-training setup.
|
| 27 |
+
|
| 28 |
+
A list of these objects can be supplied in ``DataConfig.datasets`` to mix
|
| 29 |
+
multiple datasets at arbitrary ratios. For convenience the *legacy*
|
| 30 |
+
single-dataset fields still exist; if ``datasets`` is non-empty they take
|
| 31 |
+
precedence.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
# Path to the dataset root directory (can be strings or dicts for complex configs)
|
| 35 |
+
dataset_paths: List[Any]
|
| 36 |
+
|
| 37 |
+
# Robot embodiment identifier (e.g. "gr1", "franka")
|
| 38 |
+
embodiment_tag: Optional[str] = None
|
| 39 |
+
|
| 40 |
+
# Relative sampling probability (will be normalised across the list)
|
| 41 |
+
mix_ratio: float = 1.0
|
| 42 |
+
|
| 43 |
+
dataset_type: str = "physical_embodiment"
|
| 44 |
+
|
| 45 |
+
# Optional validation dataset path for open-loop evaluation
|
| 46 |
+
# If not provided, falls back to dataset_paths for evaluation
|
| 47 |
+
val_dataset_path: Optional[str] = None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class DataConfig:
|
| 52 |
+
"""Dataset configuration (supports single or multiple datasets)."""
|
| 53 |
+
|
| 54 |
+
# Leave empty by default for backwards-compatibility with the original
|
| 55 |
+
# single-dataset workflow. Users can supply one or more configs via CLI or
|
| 56 |
+
# YAML when they need mixing.
|
| 57 |
+
datasets: List[SingleDatasetConfig] = field(default_factory=list)
|
| 58 |
+
|
| 59 |
+
# Modality configs
|
| 60 |
+
# There are three sources of modality configs:
|
| 61 |
+
# 1. Default modality configs in code: gr00t/configs/data/embodiment_configs.py
|
| 62 |
+
# 2. Modality configs supplied through command line: --data.modality_configs (although rare and inconvenient)
|
| 63 |
+
# 1 and 2 are unified through `config.data.modality_configs`.
|
| 64 |
+
# 3. modality configs saved in the pretrained checkpoint.
|
| 65 |
+
modality_configs: dict[str, dict[str, ModalityConfig]] = field(
|
| 66 |
+
default_factory=lambda: MODALITY_CONFIGS
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Sharded dataset configuration
|
| 70 |
+
download_cache: bool = False
|
| 71 |
+
shard_size: int = 2**10
|
| 72 |
+
episode_sampling_rate: float = 0.1
|
| 73 |
+
num_shards_per_epoch: int = int(1e5)
|
| 74 |
+
|
| 75 |
+
# Override statistics from the pretrained checkpoint
|
| 76 |
+
override_pretraining_statistics: bool = True
|
| 77 |
+
|
| 78 |
+
# General task / mode config (shared across datasets)
|
| 79 |
+
mode: str = "single_turn"
|
| 80 |
+
random_chop: float = 0.0
|
| 81 |
+
mock_dataset_mode: bool = False # if True, cache the first datapoint of each dataset and always return one of them to simulate best-case dataloading
|
| 82 |
+
|
| 83 |
+
# Data loading
|
| 84 |
+
shuffle: bool = True
|
| 85 |
+
seed: int = 42
|
| 86 |
+
multiprocessing_context: str = "fork" # Options: "fork", "spawn", and "forkserver"
|
| 87 |
+
allow_padding: bool = False
|
| 88 |
+
|
| 89 |
+
# Subsample ratio for the dataset
|
| 90 |
+
subsample_ratio: float = 1.0
|
| 91 |
+
|
| 92 |
+
# DP Image Config
|
| 93 |
+
image_crop_size: List[int] = field(default_factory=lambda: [244, 244])
|
| 94 |
+
image_target_size: List[int] = field(default_factory=lambda: [224, 224])
|
| 95 |
+
video_backend: str = "torchcodec"
|
gr00t/configs/data/embodiment_configs.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from gr00t.data.embodiment_tags import EmbodimentTag
|
| 17 |
+
from gr00t.data.types import (
|
| 18 |
+
ActionConfig,
|
| 19 |
+
ActionFormat,
|
| 20 |
+
ActionRepresentation,
|
| 21 |
+
ActionType,
|
| 22 |
+
ModalityConfig,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
MODALITY_CONFIGS = {
|
| 27 |
+
##### Pre-registered pretrain configurations #####
|
| 28 |
+
"oxe_droid_relative_eef_relative_joint": {
|
| 29 |
+
"video": ModalityConfig(
|
| 30 |
+
delta_indices=[-15, 0],
|
| 31 |
+
modality_keys=["exterior_image_1_left", "wrist_image_left"],
|
| 32 |
+
),
|
| 33 |
+
"state": ModalityConfig(
|
| 34 |
+
delta_indices=[0],
|
| 35 |
+
modality_keys=["eef_9d", "gripper_position", "joint_position"],
|
| 36 |
+
),
|
| 37 |
+
"action": ModalityConfig(
|
| 38 |
+
delta_indices=list(range(40)),
|
| 39 |
+
modality_keys=["eef_9d", "gripper_position", "joint_position"],
|
| 40 |
+
action_configs=[
|
| 41 |
+
ActionConfig(
|
| 42 |
+
rep=ActionRepresentation.RELATIVE,
|
| 43 |
+
type=ActionType.EEF,
|
| 44 |
+
format=ActionFormat.XYZ_ROT6D,
|
| 45 |
+
state_key="eef_9d",
|
| 46 |
+
),
|
| 47 |
+
ActionConfig(
|
| 48 |
+
rep=ActionRepresentation.ABSOLUTE,
|
| 49 |
+
type=ActionType.NON_EEF,
|
| 50 |
+
format=ActionFormat.DEFAULT,
|
| 51 |
+
state_key="gripper_position",
|
| 52 |
+
),
|
| 53 |
+
ActionConfig(
|
| 54 |
+
rep=ActionRepresentation.RELATIVE,
|
| 55 |
+
type=ActionType.NON_EEF,
|
| 56 |
+
format=ActionFormat.DEFAULT,
|
| 57 |
+
state_key="joint_position",
|
| 58 |
+
),
|
| 59 |
+
],
|
| 60 |
+
),
|
| 61 |
+
"language": ModalityConfig(
|
| 62 |
+
delta_indices=[0],
|
| 63 |
+
modality_keys=["annotation.language.language_instruction"],
|
| 64 |
+
),
|
| 65 |
+
},
|
| 66 |
+
##### Pre-registered posttrain configurations #####
|
| 67 |
+
"unitree_g1_full_body_with_waist_height_nav_cmd": {
|
| 68 |
+
"video": ModalityConfig(
|
| 69 |
+
delta_indices=[0],
|
| 70 |
+
modality_keys=["ego_view"],
|
| 71 |
+
),
|
| 72 |
+
"state": ModalityConfig(
|
| 73 |
+
delta_indices=[0],
|
| 74 |
+
modality_keys=[
|
| 75 |
+
"left_leg",
|
| 76 |
+
"right_leg",
|
| 77 |
+
"waist",
|
| 78 |
+
"left_arm",
|
| 79 |
+
"right_arm",
|
| 80 |
+
"left_hand",
|
| 81 |
+
"right_hand",
|
| 82 |
+
],
|
| 83 |
+
),
|
| 84 |
+
"action": ModalityConfig(
|
| 85 |
+
delta_indices=list(range(50)),
|
| 86 |
+
modality_keys=[
|
| 87 |
+
"left_arm",
|
| 88 |
+
"right_arm",
|
| 89 |
+
"left_hand",
|
| 90 |
+
"right_hand",
|
| 91 |
+
"waist",
|
| 92 |
+
"base_height_command",
|
| 93 |
+
"navigate_command",
|
| 94 |
+
],
|
| 95 |
+
action_configs=[
|
| 96 |
+
# left_arm
|
| 97 |
+
ActionConfig(
|
| 98 |
+
rep=ActionRepresentation.RELATIVE,
|
| 99 |
+
type=ActionType.NON_EEF,
|
| 100 |
+
format=ActionFormat.DEFAULT,
|
| 101 |
+
),
|
| 102 |
+
# right_arm
|
| 103 |
+
ActionConfig(
|
| 104 |
+
rep=ActionRepresentation.RELATIVE,
|
| 105 |
+
type=ActionType.NON_EEF,
|
| 106 |
+
format=ActionFormat.DEFAULT,
|
| 107 |
+
),
|
| 108 |
+
# left_hand
|
| 109 |
+
ActionConfig(
|
| 110 |
+
rep=ActionRepresentation.ABSOLUTE, # G1 hand is controlled by binary signals like a gripper
|
| 111 |
+
type=ActionType.NON_EEF,
|
| 112 |
+
format=ActionFormat.DEFAULT,
|
| 113 |
+
),
|
| 114 |
+
# right_hand
|
| 115 |
+
ActionConfig(
|
| 116 |
+
rep=ActionRepresentation.ABSOLUTE, # G1 hand is controlled by binary signals like a gripper
|
| 117 |
+
type=ActionType.NON_EEF,
|
| 118 |
+
format=ActionFormat.DEFAULT,
|
| 119 |
+
),
|
| 120 |
+
# waist
|
| 121 |
+
ActionConfig(
|
| 122 |
+
rep=ActionRepresentation.ABSOLUTE,
|
| 123 |
+
type=ActionType.NON_EEF,
|
| 124 |
+
format=ActionFormat.DEFAULT,
|
| 125 |
+
),
|
| 126 |
+
# base_height_command
|
| 127 |
+
ActionConfig(
|
| 128 |
+
rep=ActionRepresentation.ABSOLUTE,
|
| 129 |
+
type=ActionType.NON_EEF,
|
| 130 |
+
format=ActionFormat.DEFAULT,
|
| 131 |
+
),
|
| 132 |
+
# navigate_command
|
| 133 |
+
ActionConfig(
|
| 134 |
+
rep=ActionRepresentation.ABSOLUTE,
|
| 135 |
+
type=ActionType.NON_EEF,
|
| 136 |
+
format=ActionFormat.DEFAULT,
|
| 137 |
+
),
|
| 138 |
+
],
|
| 139 |
+
),
|
| 140 |
+
"language": ModalityConfig(
|
| 141 |
+
delta_indices=[0],
|
| 142 |
+
modality_keys=["annotation.human.task_description"],
|
| 143 |
+
),
|
| 144 |
+
},
|
| 145 |
+
"libero_sim": {
|
| 146 |
+
"video": ModalityConfig(
|
| 147 |
+
delta_indices=[0],
|
| 148 |
+
modality_keys=["image", "wrist_image"],
|
| 149 |
+
),
|
| 150 |
+
"state": ModalityConfig(
|
| 151 |
+
delta_indices=[0],
|
| 152 |
+
modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
|
| 153 |
+
),
|
| 154 |
+
"action": ModalityConfig(
|
| 155 |
+
delta_indices=list(range(16)),
|
| 156 |
+
modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
|
| 157 |
+
),
|
| 158 |
+
"language": ModalityConfig(
|
| 159 |
+
delta_indices=[0],
|
| 160 |
+
modality_keys=["annotation.human.action.task_description"],
|
| 161 |
+
),
|
| 162 |
+
},
|
| 163 |
+
"simpler_env_widowx": {
|
| 164 |
+
"video": ModalityConfig(
|
| 165 |
+
delta_indices=[0],
|
| 166 |
+
modality_keys=["image_0"],
|
| 167 |
+
),
|
| 168 |
+
"state": ModalityConfig(
|
| 169 |
+
delta_indices=[0],
|
| 170 |
+
modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"],
|
| 171 |
+
),
|
| 172 |
+
"action": ModalityConfig(
|
| 173 |
+
delta_indices=list(range(8)),
|
| 174 |
+
modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
|
| 175 |
+
),
|
| 176 |
+
"language": ModalityConfig(
|
| 177 |
+
delta_indices=[0],
|
| 178 |
+
modality_keys=["annotation.human.action.task_description"],
|
| 179 |
+
),
|
| 180 |
+
},
|
| 181 |
+
"simpler_env_google": {
|
| 182 |
+
"video": ModalityConfig(
|
| 183 |
+
delta_indices=[0],
|
| 184 |
+
modality_keys=["image"],
|
| 185 |
+
),
|
| 186 |
+
"state": ModalityConfig(
|
| 187 |
+
delta_indices=[0],
|
| 188 |
+
modality_keys=["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"],
|
| 189 |
+
),
|
| 190 |
+
"action": ModalityConfig(
|
| 191 |
+
delta_indices=list(range(8)),
|
| 192 |
+
modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
|
| 193 |
+
),
|
| 194 |
+
"language": ModalityConfig(
|
| 195 |
+
delta_indices=[0],
|
| 196 |
+
modality_keys=["annotation.human.action.task_description"],
|
| 197 |
+
),
|
| 198 |
+
},
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def register_modality_config(
|
| 203 |
+
config: dict, embodiment_tag: EmbodimentTag = EmbodimentTag.NEW_EMBODIMENT
|
| 204 |
+
):
|
| 205 |
+
assert embodiment_tag.value not in MODALITY_CONFIGS, (
|
| 206 |
+
f"Embodiment tag {embodiment_tag} already registered"
|
| 207 |
+
)
|
| 208 |
+
MODALITY_CONFIGS[embodiment_tag.value] = config
|
gr00t/configs/deepspeed/zero2_config.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"checkpoint": {
|
| 3 |
+
"load_universal": false
|
| 4 |
+
},
|
| 5 |
+
"train_batch_size": "auto",
|
| 6 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 7 |
+
"gradient_accumulation_steps": "auto",
|
| 8 |
+
"gradient_clipping": "auto",
|
| 9 |
+
"zero_allow_untested_optimizer": true,
|
| 10 |
+
"fp16": {
|
| 11 |
+
"enabled": "auto",
|
| 12 |
+
"loss_scale": 0,
|
| 13 |
+
"loss_scale_window": 1000,
|
| 14 |
+
"initial_scale_power": 16,
|
| 15 |
+
"hysteresis": 2,
|
| 16 |
+
"min_loss_scale": 1
|
| 17 |
+
},
|
| 18 |
+
"bf16": {
|
| 19 |
+
"enabled": "auto"
|
| 20 |
+
},
|
| 21 |
+
"communication_data_type": "bf16",
|
| 22 |
+
"zero_optimization": {
|
| 23 |
+
"stage": 2,
|
| 24 |
+
"overlap_comm": true,
|
| 25 |
+
"reduce_scatter": true,
|
| 26 |
+
"allgather_partitions": true,
|
| 27 |
+
"contiguous_gradients": true,
|
| 28 |
+
"sub_group_size": 1e9,
|
| 29 |
+
"reduce_bucket_size": 1e8,
|
| 30 |
+
"allgather_bucket_size": 1e8
|
| 31 |
+
}
|
| 32 |
+
}
|
| 33 |
+
|
gr00t/configs/deepspeed/zero3_config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train_batch_size": "auto",
|
| 3 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 4 |
+
"gradient_accumulation_steps": "auto",
|
| 5 |
+
"gradient_clipping": "auto",
|
| 6 |
+
"zero_allow_untested_optimizer": true,
|
| 7 |
+
"fp16": {
|
| 8 |
+
"enabled": "auto",
|
| 9 |
+
"loss_scale": 0,
|
| 10 |
+
"loss_scale_window": 1000,
|
| 11 |
+
"initial_scale_power": 16,
|
| 12 |
+
"hysteresis": 2,
|
| 13 |
+
"min_loss_scale": 1
|
| 14 |
+
},
|
| 15 |
+
"bf16": {
|
| 16 |
+
"enabled": "auto"
|
| 17 |
+
},
|
| 18 |
+
"zero_optimization": {
|
| 19 |
+
"stage": 3,
|
| 20 |
+
"overlap_comm": true,
|
| 21 |
+
"contiguous_gradients": true,
|
| 22 |
+
"sub_group_size": 1e9,
|
| 23 |
+
"reduce_bucket_size": "auto",
|
| 24 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 25 |
+
"stage3_param_persistence_threshold": "auto",
|
| 26 |
+
"stage3_max_live_parameters": 1e9,
|
| 27 |
+
"stage3_max_reuse_distance": 1e9,
|
| 28 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 29 |
+
}
|
| 30 |
+
}
|
| 31 |
+
|
gr00t/configs/finetune_config.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
# Finetune config used for single node post-training.
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class FinetuneConfig:
|
| 22 |
+
"""
|
| 23 |
+
Configuration for fine-tuning a Vision-Language-Action (VLA) model.
|
| 24 |
+
|
| 25 |
+
This dataclass defines all parameters needed to launch a fine-tuning job
|
| 26 |
+
on a pretrained base model using a custom dataset and embodiment-specific
|
| 27 |
+
modality configuration. It controls model tuning options, data augmentation,
|
| 28 |
+
and training hyperparameters.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
# --- Data and Model Paths ---
|
| 32 |
+
base_model_path: str
|
| 33 |
+
"""Path to the pretrained base model checkpoint (e.g., Hugging Face model hub or local directory)."""
|
| 34 |
+
|
| 35 |
+
dataset_path: str
|
| 36 |
+
"""Path to the dataset root directory containing trajectory data for fine-tuning."""
|
| 37 |
+
|
| 38 |
+
embodiment_tag: str
|
| 39 |
+
"""Embodiment tag (name or value, case-insensitive). See EmbodimentTag for known tags."""
|
| 40 |
+
|
| 41 |
+
modality_config_path: str | None = None
|
| 42 |
+
"""
|
| 43 |
+
Path to a Python file defining the modality configuration for the given embodiment.
|
| 44 |
+
If None, use the pre-registered modality config in `gr00t/configs/data/embodiment_configs.py`.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
# --- Model Tuning Flags ---
|
| 48 |
+
tune_llm: bool = False
|
| 49 |
+
"""If True, fine-tune the language model (LLM) backbone during training."""
|
| 50 |
+
|
| 51 |
+
tune_visual: bool = False
|
| 52 |
+
"""If True, fine-tune the visual encoder (e.g., ViT or CNN backbone)."""
|
| 53 |
+
|
| 54 |
+
tune_projector: bool = True
|
| 55 |
+
"""If True, fine-tune the multimodal projector layers that map vision/language features to a shared space."""
|
| 56 |
+
|
| 57 |
+
tune_diffusion_model: bool = True
|
| 58 |
+
"""If True, fine-tune the diffusion-based action decoder (if present in the model)."""
|
| 59 |
+
|
| 60 |
+
state_dropout_prob: float = 0.2
|
| 61 |
+
"""
|
| 62 |
+
Dropout probability applied to state inputs for regularization during training.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
# --- Data Augmentation ---
|
| 66 |
+
random_rotation_angle: int | None = None
|
| 67 |
+
"""Maximum rotation angle (in degrees) for random rotation augmentation of input images."""
|
| 68 |
+
|
| 69 |
+
color_jitter_params: dict[str, float] | None = None
|
| 70 |
+
"""
|
| 71 |
+
Parameters for color jitter augmentation on images.
|
| 72 |
+
|
| 73 |
+
Expected keys include:
|
| 74 |
+
- "brightness": float
|
| 75 |
+
- "contrast": float
|
| 76 |
+
- "saturation": float
|
| 77 |
+
- "hue": float
|
| 78 |
+
Example: {"brightness": 0.4, "contrast": 0.4, "saturation": 0.4, "hue": 0.1}
|
| 79 |
+
|
| 80 |
+
If None, applying the default color jitter augmentation from the pretrained model.
|
| 81 |
+
"""
|
| 82 |
+
extra_augmentation_config: str | None = None
|
| 83 |
+
"""
|
| 84 |
+
JSON string for extra image augmentations (mask-based and others).
|
| 85 |
+
|
| 86 |
+
Expected keys include:
|
| 87 |
+
- "background_noise_transforms": list of dicts for noise on mask regions
|
| 88 |
+
- "target_mask_values": list of int (e.g., [0])
|
| 89 |
+
- "p": float (probability of applying)
|
| 90 |
+
- "masked_region_transforms": list of dicts for color tint on mask regions
|
| 91 |
+
- "target_mask_values": list of int (e.g., [4] or [5])
|
| 92 |
+
- "p": float (probability of applying)
|
| 93 |
+
- "alpha_range": [min, max] for random_tint intensity
|
| 94 |
+
|
| 95 |
+
Example: {"background_noise_transforms": [{"target_mask_values": [0], "p": 0.9}],
|
| 96 |
+
"masked_region_transforms": [{"target_mask_values": [4], "p": 1.0, "alpha_range": [0, 1]}]}
|
| 97 |
+
|
| 98 |
+
If None, no extra augmentations are applied.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
# --- Training Configuration ---
|
| 102 |
+
global_batch_size: int = 64
|
| 103 |
+
"""Total effective batch size across all GPUs and accumulation steps."""
|
| 104 |
+
|
| 105 |
+
dataloader_num_workers: int = 2
|
| 106 |
+
"""Number of parallel worker processes used for data loading."""
|
| 107 |
+
|
| 108 |
+
learning_rate: float = 1e-4
|
| 109 |
+
"""Initial learning rate for optimizer."""
|
| 110 |
+
|
| 111 |
+
gradient_accumulation_steps: int = 1
|
| 112 |
+
"""Number of forward passes to accumulate before performing a backward/update step."""
|
| 113 |
+
|
| 114 |
+
output_dir: str = "./outputs"
|
| 115 |
+
"""Directory where model checkpoints, logs, and outputs are saved."""
|
| 116 |
+
|
| 117 |
+
experiment_name: str | None = None
|
| 118 |
+
"""Optional experiment name used as the W&B run name. Defaults to the output directory basename."""
|
| 119 |
+
|
| 120 |
+
wandb_project: str = "finetune-gr00t-n1d7"
|
| 121 |
+
"""W&B project name to log runs to."""
|
| 122 |
+
|
| 123 |
+
save_steps: int = 1000
|
| 124 |
+
"""Frequency (in training steps) at which to save checkpoints."""
|
| 125 |
+
|
| 126 |
+
save_total_limit: int = 5
|
| 127 |
+
"""Maximum number of checkpoints to keep before older ones are deleted."""
|
| 128 |
+
|
| 129 |
+
num_gpus: int = 1
|
| 130 |
+
"""Number of GPUs available for distributed or single-node training."""
|
| 131 |
+
|
| 132 |
+
use_wandb: bool = False
|
| 133 |
+
"""
|
| 134 |
+
If True, log metrics and artifacts to Weights & Biases (wandb).
|
| 135 |
+
The project is `finetune-gr00t-n1d7`.
|
| 136 |
+
You need to login to wandb to view the logs.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
max_steps: int = 10000
|
| 140 |
+
"""Total number of training steps to run before stopping."""
|
| 141 |
+
|
| 142 |
+
weight_decay: float = 1e-5
|
| 143 |
+
"""Weight decay coefficient for optimizer (L2 regularization)."""
|
| 144 |
+
|
| 145 |
+
warmup_ratio: float = 0.05
|
| 146 |
+
"""Proportion of total training steps used for learning rate warm-up."""
|
| 147 |
+
|
| 148 |
+
shard_size: int = 2**10
|
| 149 |
+
"""Size of the shard to use for the dataset during preloading."""
|
| 150 |
+
|
| 151 |
+
episode_sampling_rate: float = 0.1
|
| 152 |
+
"""Sampling rate for the episodes."""
|
| 153 |
+
|
| 154 |
+
num_shards_per_epoch: int = int(1e5)
|
| 155 |
+
"""Number of shards to use for the dataset. reduce this number if vram is limited."""
|
| 156 |
+
|
| 157 |
+
save_only_model: bool = False
|
| 158 |
+
"""If True, save only model weights (skip optimizer/scheduler/RNG states). Cannot resume training from these checkpoints."""
|
| 159 |
+
|
| 160 |
+
skip_weight_loading: bool = False
|
| 161 |
+
"""If True, skip loading model weights from base_model_path (architecture only).
|
| 162 |
+
The processor (tokenizer/config) is still loaded from base_model_path.
|
| 163 |
+
Useful for CI/testing to skip the slow checkpoint shard loading."""
|
gr00t/configs/model/__init__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import importlib
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import typing
|
| 19 |
+
|
| 20 |
+
import tyro
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
MODEL_CONFIG_TYPES: dict[str, type] = {}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def register_model_config(shortname: str, configtype: type):
|
| 27 |
+
MODEL_CONFIG_TYPES[shortname] = configtype
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
for file in Path(__file__).parent.glob("*.py"):
|
| 31 |
+
if file.stem.startswith("_"):
|
| 32 |
+
continue
|
| 33 |
+
try:
|
| 34 |
+
importlib.import_module(f".{file.stem}", __name__)
|
| 35 |
+
except KeyboardInterrupt:
|
| 36 |
+
raise
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"Error importing module gr00t.configs.model.{file.stem}: {e}")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def create_model_union_type():
|
| 42 |
+
if not MODEL_CONFIG_TYPES:
|
| 43 |
+
# A Union of no types is invalid, so just return None
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
annotated_types = tuple(
|
| 47 |
+
typing.Annotated[model_type, tyro.conf.subcommand(model_shortname)]
|
| 48 |
+
for model_shortname, model_type in MODEL_CONFIG_TYPES.items()
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Create the Union dynamically
|
| 52 |
+
return typing.Union.__getitem__(annotated_types)
|
gr00t/configs/model/gr00t_n1d7.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from dataclasses import MISSING, asdict, dataclass, field, is_dataclass
|
| 17 |
+
from enum import Enum
|
| 18 |
+
import json
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from transformers import PretrainedConfig
|
| 23 |
+
|
| 24 |
+
from . import register_model_config
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class Gr00tN1d7Config(PretrainedConfig):
|
| 29 |
+
"""Unified configuration for Gr00tN1d7 model with backbone and action head.
|
| 30 |
+
|
| 31 |
+
Gr00tN1d7 uses the Cosmos-Reason2-2B (Qwen3-VL architecture) VLM backbone,
|
| 32 |
+
replacing the Eagle backbone used in Gr00tN1d6.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
# Model identification
|
| 36 |
+
model_type: str = "Gr00tN1d7"
|
| 37 |
+
model_dtype: str = "bfloat16" # Use bfloat16 for Flash Attention compatibility
|
| 38 |
+
|
| 39 |
+
# Backbone configuration
|
| 40 |
+
model_name: str = "nvidia/Cosmos-Reason2-2B"
|
| 41 |
+
backbone_model_type: str = "qwen"
|
| 42 |
+
model_revision: str | None = None
|
| 43 |
+
tune_top_llm_layers: int = 0 # Number of top LLM layers to tune
|
| 44 |
+
backbone_embedding_dim: int = 2048 # project_to_dim; must match Cosmos-Reason2-2B hidden size
|
| 45 |
+
tune_llm: bool = False
|
| 46 |
+
tune_visual: bool = False
|
| 47 |
+
select_layer: int = 12
|
| 48 |
+
reproject_vision: bool = False
|
| 49 |
+
use_flash_attention: bool = True
|
| 50 |
+
load_bf16: bool = False # Enable BF16 loading
|
| 51 |
+
backbone_trainable_params_fp32: bool = True
|
| 52 |
+
|
| 53 |
+
### Processing parameters
|
| 54 |
+
image_crop_size: tuple[int, int] | None = (230, 230)
|
| 55 |
+
image_target_size: tuple[int, int] | None = (256, 256)
|
| 56 |
+
|
| 57 |
+
shortest_image_edge: int | None = None
|
| 58 |
+
crop_fraction: float | None = None
|
| 59 |
+
|
| 60 |
+
random_rotation_angle: int | None = None
|
| 61 |
+
color_jitter_params: dict[str, float] | None = None
|
| 62 |
+
use_albumentations_transforms: bool = True
|
| 63 |
+
# Extra augmentation config (mask-based and others).
|
| 64 |
+
extra_augmentation_config: dict | None = None
|
| 65 |
+
formalize_language: bool = True
|
| 66 |
+
apply_sincos_state_encoding: bool = (
|
| 67 |
+
False # Global flag to enable per-embodiment sin/cos encoding
|
| 68 |
+
)
|
| 69 |
+
use_percentiles: bool = True
|
| 70 |
+
use_relative_action: bool = False
|
| 71 |
+
|
| 72 |
+
# Action head configuration parameters
|
| 73 |
+
max_state_dim: int = 132 # Default from state_shape
|
| 74 |
+
max_action_dim: int = 132 # Default from action_shape
|
| 75 |
+
action_horizon: int = 40
|
| 76 |
+
hidden_size: int = 1024
|
| 77 |
+
input_embedding_dim: int = 1536
|
| 78 |
+
|
| 79 |
+
# State history: number of consecutive state timesteps fed to the state encoder
|
| 80 |
+
state_history_length: int = 1
|
| 81 |
+
|
| 82 |
+
# Global parameters
|
| 83 |
+
add_pos_embed: bool = True
|
| 84 |
+
attn_dropout: float = 0.2
|
| 85 |
+
use_vlln: bool = True
|
| 86 |
+
max_seq_len: int = 1024
|
| 87 |
+
use_alternate_vl_dit: bool = True # True for AlternateVLDiT, False for DiT
|
| 88 |
+
attend_text_every_n_blocks: int = 2
|
| 89 |
+
|
| 90 |
+
diffusion_model_cfg: dict = field(
|
| 91 |
+
default_factory=lambda: {
|
| 92 |
+
"positional_embeddings": None,
|
| 93 |
+
"num_layers": 16,
|
| 94 |
+
"num_attention_heads": 32,
|
| 95 |
+
"attention_head_dim": 48,
|
| 96 |
+
"norm_type": "ada_norm",
|
| 97 |
+
"dropout": 0.2,
|
| 98 |
+
"final_dropout": True,
|
| 99 |
+
"output_dim": 1024,
|
| 100 |
+
"interleave_self_attention": True,
|
| 101 |
+
}
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Flow matching parameters
|
| 105 |
+
num_inference_timesteps: int = 4
|
| 106 |
+
noise_beta_alpha: float = 1.5
|
| 107 |
+
noise_beta_beta: float = 1.0
|
| 108 |
+
noise_s: float = 0.999
|
| 109 |
+
num_timestep_buckets: int = 1000
|
| 110 |
+
|
| 111 |
+
# Training parameters
|
| 112 |
+
tune_projector: bool = True
|
| 113 |
+
tune_diffusion_model: bool = True
|
| 114 |
+
tune_vlln: bool = True
|
| 115 |
+
|
| 116 |
+
# State augmentation parameters
|
| 117 |
+
state_dropout_prob: float = 0.8 # State dropout probability
|
| 118 |
+
exclude_state: bool = False # Zero out all state inputs (ablation)
|
| 119 |
+
use_mean_std: bool = False # Use mean/std normalization instead of min/max
|
| 120 |
+
|
| 121 |
+
# Multi-embodiment parameters
|
| 122 |
+
max_num_embodiments: int = 32
|
| 123 |
+
|
| 124 |
+
def __init__(self, **kwargs):
|
| 125 |
+
super().__init__(**kwargs)
|
| 126 |
+
for key, value in kwargs.items():
|
| 127 |
+
setattr(self, key, value)
|
| 128 |
+
|
| 129 |
+
# Ensures that all dataclass defaults (including those using default_factory)
|
| 130 |
+
# are explicitly assigned to the instance, even if dataclasses initialization or subclassing
|
| 131 |
+
# (PretrainedConfig) interferes with normal default injection.
|
| 132 |
+
for f in self.__dataclass_fields__.values():
|
| 133 |
+
if not hasattr(self, f.name):
|
| 134 |
+
if f.default is not MISSING:
|
| 135 |
+
setattr(self, f.name, f.default)
|
| 136 |
+
elif getattr(f, "default_factory", MISSING) is not MISSING:
|
| 137 |
+
setattr(self, f.name, f.default_factory())
|
| 138 |
+
|
| 139 |
+
def to_filtered_dict(self, exclude_augment: bool = True) -> dict:
|
| 140 |
+
"""Return a dictionary representation of this config, optionally excluding augmentation keys."""
|
| 141 |
+
if is_dataclass(self):
|
| 142 |
+
cfg = asdict(self)
|
| 143 |
+
else:
|
| 144 |
+
cfg = dict(self.__dict__)
|
| 145 |
+
|
| 146 |
+
if exclude_augment:
|
| 147 |
+
exclude_keys = {
|
| 148 |
+
"random_rotation_angle",
|
| 149 |
+
"color_jitter_params",
|
| 150 |
+
"use_albumentations_transforms",
|
| 151 |
+
"formalize_language",
|
| 152 |
+
"image_crop_size",
|
| 153 |
+
"image_target_size",
|
| 154 |
+
"shortest_image_edge",
|
| 155 |
+
"crop_fraction",
|
| 156 |
+
}
|
| 157 |
+
cfg = {k: v for k, v in cfg.items() if k not in exclude_keys}
|
| 158 |
+
|
| 159 |
+
return cfg
|
| 160 |
+
|
| 161 |
+
def to_filtered_json(self, exclude_augment: bool = True, **kwargs) -> str:
|
| 162 |
+
"""Return a JSON string of this config, optionally excluding augmentation keys."""
|
| 163 |
+
|
| 164 |
+
def default(o):
|
| 165 |
+
if isinstance(o, (Path, torch.dtype, torch.device)):
|
| 166 |
+
return str(o)
|
| 167 |
+
if isinstance(o, Enum):
|
| 168 |
+
return o.value
|
| 169 |
+
return str(o)
|
| 170 |
+
|
| 171 |
+
return json.dumps(
|
| 172 |
+
self.to_filtered_dict(exclude_augment),
|
| 173 |
+
indent=2,
|
| 174 |
+
default=default,
|
| 175 |
+
**kwargs,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
register_model_config("Gr00tN1d7", Gr00tN1d7Config)
|
gr00t/configs/training/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
gr00t/configs/training/training_config.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class TrainingConfig:
|
| 22 |
+
"""Training configuration."""
|
| 23 |
+
|
| 24 |
+
# Output
|
| 25 |
+
output_dir: str = "./outputs"
|
| 26 |
+
experiment_name: Optional[str] = None
|
| 27 |
+
|
| 28 |
+
# Basic training
|
| 29 |
+
max_steps: int = 30000 # this will override num_epochs
|
| 30 |
+
global_batch_size: int = 1024
|
| 31 |
+
batch_size: Optional[int] = None
|
| 32 |
+
gradient_accumulation_steps: int = 1
|
| 33 |
+
|
| 34 |
+
# Optimization
|
| 35 |
+
learning_rate: float = 1e-4
|
| 36 |
+
lr_scheduler_type: str = "cosine"
|
| 37 |
+
weight_decay: float = 1e-5
|
| 38 |
+
warmup_ratio: float = 0.05
|
| 39 |
+
warmup_steps: int = 0 # this will override warmup_ratio
|
| 40 |
+
max_grad_norm: float = 1.0
|
| 41 |
+
|
| 42 |
+
# Optimizer choice (huggingface TrainingArguments.optim)
|
| 43 |
+
# Options include: 'adamw_torch', 'adamw_torch_fused', 'paged_adamw_32bit',
|
| 44 |
+
# 'paged_adamw_8bit' (requires bitsandbytes), 'adafactor', etc.
|
| 45 |
+
optim: str = "adamw_torch_fused"
|
| 46 |
+
|
| 47 |
+
start_from_checkpoint: Optional[str] = None
|
| 48 |
+
skip_weight_loading: bool = False # skip loading checkpoint weights (architecture only)
|
| 49 |
+
|
| 50 |
+
# Mixed precision
|
| 51 |
+
tf32: bool = True
|
| 52 |
+
fp16: bool = False
|
| 53 |
+
bf16: bool = True
|
| 54 |
+
eval_bf16: bool = True
|
| 55 |
+
|
| 56 |
+
# Logging and saving
|
| 57 |
+
logging_steps: int = 10
|
| 58 |
+
save_steps: int = 1000
|
| 59 |
+
save_total_limit: int = 5
|
| 60 |
+
|
| 61 |
+
# Model saving
|
| 62 |
+
save_vl_model: bool = False # Control whether to save VL model and processor in callbacks
|
| 63 |
+
save_only_model: bool = False # Skip optimizer/scheduler/RNG states — cannot resume training
|
| 64 |
+
|
| 65 |
+
# Checkpoint uploading
|
| 66 |
+
upload_checkpoints: bool = False
|
| 67 |
+
upload_every: int = 1000
|
| 68 |
+
upload_last_n_checkpoints: int = 5
|
| 69 |
+
max_concurrent_uploads: int = 2
|
| 70 |
+
|
| 71 |
+
# Evaluation
|
| 72 |
+
eval_strategy: str = "no" # no, steps, epoch
|
| 73 |
+
eval_steps: int = 500
|
| 74 |
+
eval_set_split_ratio: float = 0.1
|
| 75 |
+
eval_batch_size: int = 2
|
| 76 |
+
save_best_eval_metric_name: str = ""
|
| 77 |
+
save_best_eval_metric_greater_is_better: bool = True
|
| 78 |
+
|
| 79 |
+
# DeepSpeed (default)
|
| 80 |
+
deepspeed_stage: int = 2 # ZeRO stage (1, 2, or 3)
|
| 81 |
+
gradient_checkpointing: bool = False
|
| 82 |
+
|
| 83 |
+
# Transformers loading parameters
|
| 84 |
+
transformers_trust_remote_code: bool = True
|
| 85 |
+
transformers_local_files_only: bool = False
|
| 86 |
+
transformers_cache_dir: str | None = None
|
| 87 |
+
transformers_access_token: str | None = None # Access token for HuggingFace Hub
|
| 88 |
+
|
| 89 |
+
# DDP
|
| 90 |
+
use_ddp: bool = False
|
| 91 |
+
ddp_bucket_cap_mb: int = 100
|
| 92 |
+
|
| 93 |
+
# Hardware
|
| 94 |
+
num_gpus: int = 1
|
| 95 |
+
dataloader_num_workers: int = 2
|
| 96 |
+
|
| 97 |
+
# Data handling
|
| 98 |
+
remove_unused_columns: bool = False
|
| 99 |
+
|
| 100 |
+
# Experiment tracking
|
| 101 |
+
use_wandb: bool = False
|
| 102 |
+
wandb_project: str = "finetune-gr00t-n1d7"
|
| 103 |
+
|
| 104 |
+
# Profiling
|
| 105 |
+
enable_profiling: bool = False
|
| 106 |
+
|
| 107 |
+
# Max number of retries in training for fault tolerance
|
| 108 |
+
max_retries: int = 3
|
| 109 |
+
|
| 110 |
+
# For testing.
|
| 111 |
+
assert_loss_less_than: float | None = None
|
| 112 |
+
|
| 113 |
+
# RL
|
| 114 |
+
add_rl_callback: bool = False
|
| 115 |
+
|
| 116 |
+
# Open-loop evaluation
|
| 117 |
+
enable_open_loop_eval: bool = False
|
| 118 |
+
"""Enable open-loop evaluation on saved checkpoints."""
|
| 119 |
+
|
| 120 |
+
open_loop_eval_traj_ids: list[int] = field(default_factory=lambda: [0])
|
| 121 |
+
"""List of trajectory IDs to evaluate."""
|
| 122 |
+
|
| 123 |
+
open_loop_eval_steps_per_traj: int = 100
|
| 124 |
+
"""Number of steps to evaluate per trajectory."""
|
| 125 |
+
|
| 126 |
+
open_loop_eval_plot_indices: Optional[list[int]] = None
|
| 127 |
+
"""List of action indices to plot. If None, plots all indices."""
|
gr00t/data/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
gr00t/data/collator/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from .collators import BasicDataCollator
|