yqi19 commited on
Commit
b60a439
·
verified ·
1 Parent(s): b184199

add: source files (batch 1)

Browse files
Files changed (50) hide show
  1. .github/ISSUE_TEMPLATE/bug_report.yml +46 -0
  2. .github/ISSUE_TEMPLATE/documentation.yml +21 -0
  3. .github/ISSUE_TEMPLATE/feature_request.yml +26 -0
  4. .github/actions/setup-venv/action.yml +55 -0
  5. .github/pull_request_template.md +17 -0
  6. .github/workflows/main.yml +68 -0
  7. AGENTS.md +80 -0
  8. ATTRIBUTIONS.md +0 -0
  9. CLAUDE.md +80 -0
  10. CONTRIBUTING.md +7 -0
  11. FAQ.md +81 -0
  12. README.md +555 -0
  13. examples/DROID/README.md +110 -0
  14. examples/DROID/main_gr00t.py +469 -0
  15. examples/DROID/server_client.py +365 -0
  16. examples/DROID/utils.py +81 -0
  17. examples/LIBERO/README.md +196 -0
  18. examples/LIBERO/modality.json +75 -0
  19. examples/SO100/README.md +87 -0
  20. examples/SO100/modality.json +35 -0
  21. examples/SO100/so100_config.py +70 -0
  22. examples/SimplerEnv/README.md +141 -0
  23. examples/SimplerEnv/bridge_modality.json +77 -0
  24. examples/SimplerEnv/convert_av1_to_h264.py +129 -0
  25. examples/SimplerEnv/fractal_modality.json +77 -0
  26. examples/finetune.sh +158 -0
  27. examples/mask-guided-background-suppression/README.md +203 -0
  28. examples/mask-guided-background-suppression/so101_config.py +62 -0
  29. examples/mask-guided-background-suppression/test_extra_augmentation.py +198 -0
  30. getting_started/data_config.md +331 -0
  31. getting_started/data_preparation.md +164 -0
  32. getting_started/finetune_new_embodiment.md +153 -0
  33. getting_started/hardware_recommendation.md +95 -0
  34. getting_started/policy.md +574 -0
  35. getting_started/real_world_deployment.md +459 -0
  36. gr00t/__init__.py +129 -0
  37. gr00t/configs/__init__.py +14 -0
  38. gr00t/configs/base_config.py +150 -0
  39. gr00t/configs/data/__init__.py +14 -0
  40. gr00t/configs/data/data_config.py +95 -0
  41. gr00t/configs/data/embodiment_configs.py +208 -0
  42. gr00t/configs/deepspeed/zero2_config.json +33 -0
  43. gr00t/configs/deepspeed/zero3_config.json +31 -0
  44. gr00t/configs/finetune_config.py +163 -0
  45. gr00t/configs/model/__init__.py +52 -0
  46. gr00t/configs/model/gr00t_n1d7.py +179 -0
  47. gr00t/configs/training/__init__.py +14 -0
  48. gr00t/configs/training/training_config.py +127 -0
  49. gr00t/data/__init__.py +14 -0
  50. gr00t/data/collator/__init__.py +16 -0
.github/ISSUE_TEMPLATE/bug_report.yml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 🐛 Bug Report
2
+ description: Create a report to help us reproduce and fix the bug
3
+ labels: 'bug'
4
+
5
+ body:
6
+ - type: markdown
7
+ attributes:
8
+ value: >
9
+ #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/Isaac-GR00T/issues?q=is%3Aissue+sort%3Acreated-desc+).
10
+ - type: textarea
11
+ attributes:
12
+ label: 🐛 Describe the bug
13
+ description: |
14
+ Please provide a clear and concise description of what the bug is.
15
+
16
+ If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example:
17
+
18
+ ```python
19
+ # All necessary imports at the beginning
20
+ import gr00t
21
+
22
+ # A succinct reproducing example trimmed down to the essential parts:
23
+ assert False is True, "Oh no!"
24
+ ```
25
+
26
+ If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
27
+
28
+ Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
29
+ placeholder: |
30
+ A clear and concise description of what the bug is.
31
+ validations:
32
+ required: true
33
+ - type: textarea
34
+ attributes:
35
+ label: Versions
36
+ description: |
37
+ Please run the following and paste the output below.
38
+ ```sh
39
+ python --version && pip freeze
40
+ ```
41
+ validations:
42
+ required: true
43
+ - type: markdown
44
+ attributes:
45
+ value: >
46
+ Thanks for contributing 🎉!
.github/ISSUE_TEMPLATE/documentation.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 📚 Documentation
2
+ description: Report an issue related to https://github.com/NVIDIA/Isaac-GR00T
3
+ labels: 'documentation'
4
+
5
+ body:
6
+ - type: textarea
7
+ attributes:
8
+ label: 📚 The doc issue
9
+ description: >
10
+ A clear and concise description of what content in https://github.com/NVIDIA/Isaac-GR00T is an issue.
11
+ validations:
12
+ required: true
13
+ - type: textarea
14
+ attributes:
15
+ label: Suggest a potential alternative/fix
16
+ description: >
17
+ Tell us how we could improve the documentation in this regard.
18
+ - type: markdown
19
+ attributes:
20
+ value: >
21
+ Thanks for contributing 🎉!
.github/ISSUE_TEMPLATE/feature_request.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 🚀 Feature request
2
+ description: Submit a proposal/request for a new feature
3
+ labels: 'feature request'
4
+
5
+ body:
6
+ - type: textarea
7
+ attributes:
8
+ label: 🚀 The feature, motivation and pitch
9
+ description: >
10
+ A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too.
11
+ validations:
12
+ required: true
13
+ - type: textarea
14
+ attributes:
15
+ label: Alternatives
16
+ description: >
17
+ A description of any alternative solutions or features you've considered, if any.
18
+ - type: textarea
19
+ attributes:
20
+ label: Additional context
21
+ description: >
22
+ Add any other context or screenshots about the feature request.
23
+ - type: markdown
24
+ attributes:
25
+ value: >
26
+ Thanks for contributing 🎉!
.github/actions/setup-venv/action.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Python virtualenv
2
+ description: Set up a Python virtual environment with caching
3
+ inputs:
4
+ python-version:
5
+ description: The Python version to use
6
+ required: true
7
+ cache-prefix:
8
+ description: Update this to invalidate the cache
9
+ required: true
10
+ default: v4
11
+ runs:
12
+ using: composite
13
+ steps:
14
+ - name: Setup Python
15
+ uses: actions/setup-python@v4
16
+ with:
17
+ python-version: ${{ inputs.python-version }}
18
+
19
+ - shell: bash
20
+ run: |
21
+ # Install prerequisites.
22
+ pip install --upgrade pip setuptools wheel virtualenv
23
+
24
+ - shell: bash
25
+ run: |
26
+ # Get the exact Python version to use in the cache key.
27
+ echo "PYTHON_VERSION=$(python --version)" >> $GITHUB_ENV
28
+
29
+ - uses: actions/cache@v4
30
+ id: virtualenv-cache
31
+ with:
32
+ path: .venv
33
+ key: ${{ inputs.cache-prefix }}-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('pyproject.toml') }}
34
+
35
+ - if: steps.virtualenv-cache.outputs.cache-hit != 'true'
36
+ shell: bash
37
+ run: |
38
+ # Set up virtual environment without cache hit.
39
+ test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv
40
+ . .venv/bin/activate
41
+ pip install ruff
42
+
43
+ - if: steps.virtualenv-cache.outputs.cache-hit == 'true'
44
+ shell: bash
45
+ run: |
46
+ # Set up virtual environment from cache hit.
47
+ . .venv/bin/activate
48
+
49
+ - shell: bash
50
+ run: |
51
+ # Show environment info.
52
+ . .venv/bin/activate
53
+ echo "✓ Installed $(python --version) virtual environment to $(which python)"
54
+ echo "Packages:"
55
+ pip freeze
.github/pull_request_template.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- To ensure we can review your pull request promptly please complete this template entirely. -->
2
+
3
+ <!-- Please reference the issue number here. You can replace "Fixes" with "Closes" if it makes more sense. -->
4
+ Fixes #
5
+
6
+ Changes proposed in this pull request:
7
+ <!-- Please list all changes/additions here. -->
8
+ -
9
+
10
+ ## Before submitting
11
+
12
+ <!-- Please complete this checklist BEFORE submitting your PR to speed along the review process. -->
13
+ - [ ] I've read and followed all steps in the [Making a pull request](https://github.com/NVIDIA/Isaac-GR00T/blob/main/CONTRIBUTING.md#making-a-pull-request)
14
+ section of the `CONTRIBUTING` docs.
15
+ - [ ] I've updated or added any relevant docstrings.
16
+ - [ ] If this PR fixes a bug, I've added a test that will fail without my fix.
17
+ - [ ] If this PR adds a new feature, I've added tests that sufficiently cover my new functionality.
.github/workflows/main.yml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Main
2
+
3
+ concurrency:
4
+ group: ${{ github.workflow }}-${{ github.ref }}
5
+ cancel-in-progress: true
6
+
7
+ on:
8
+ pull_request:
9
+ branches:
10
+ - main
11
+ push:
12
+ branches:
13
+ - main
14
+ tags:
15
+ - "v*.*.*"
16
+
17
+ env:
18
+ # Change this to invalidate existing cache.
19
+ CACHE_PREFIX: v4
20
+ PYTHONPATH: ./
21
+
22
+ jobs:
23
+ checks:
24
+ name: Python ${{ matrix.python }} - ${{ matrix.task.name }}
25
+ runs-on: [ubuntu-latest]
26
+ timeout-minutes: 15
27
+ strategy:
28
+ fail-fast: false
29
+ matrix:
30
+ include:
31
+ - python: "3.10"
32
+ task:
33
+ name: Lint
34
+ run: |
35
+ ruff check .
36
+ ruff format --check .
37
+
38
+ steps:
39
+ - uses: actions/checkout@v4
40
+ with:
41
+ lfs: true
42
+
43
+ - name: Pull LFS objects
44
+ run: git lfs pull
45
+
46
+ - name: Setup Python environment
47
+ uses: ./.github/actions/setup-venv
48
+ with:
49
+ python-version: ${{ matrix.python }}
50
+ cache-prefix: ${{ env.CACHE_PREFIX }}
51
+
52
+ - name: ${{ matrix.task.name }}
53
+ run: |
54
+ . .venv/bin/activate
55
+ ${{ matrix.task.run }}
56
+
57
+ - name: Upload package distribution files
58
+ if: matrix.task.name == 'Build'
59
+ uses: actions/upload-artifact@v4
60
+ with:
61
+ name: package
62
+ path: dist
63
+
64
+ - name: Clean up
65
+ if: always()
66
+ run: |
67
+ . .venv/bin/activate
68
+ pip uninstall -y gr00t
AGENTS.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md — Isaac GR00T N1.7
2
+
3
+ ## Project overview
4
+
5
+ Isaac GR00T N1.7 is an open vision-language-action (VLA) model for generalized humanoid robot skills.
6
+ The repo contains the model, training pipeline, evaluation harness, and deployment tooling.
7
+
8
+ - **Language:** Python 3.10 (dGPU, Orin); Python 3.12 (Thor, DGX Spark — see deployment dir)
9
+ - **Package manager:** [uv](https://docs.astral.sh/uv/)
10
+ - **Build system:** setuptools (see `pyproject.toml`)
11
+ - **CI:** internal GitLab CI (`.gitlab-ci.yml` + includes under `ci/`, not shipped to the public GitHub EA repo); public GitHub Actions (`.github/workflows/`)
12
+
13
+ ## Quick-start commands
14
+
15
+ ```bash
16
+ # Install (dev mode with all extras)
17
+ uv sync --all-extras
18
+
19
+ # Lint and format (uses ruff via pre-commit)
20
+ pre-commit run --all-files
21
+
22
+ # Run CPU tests
23
+ python -m pytest tests/ -m "not gpu" -v --timeout=300
24
+
25
+ # Run GPU tests
26
+ python -m pytest tests/ -m gpu -v --timeout=300
27
+
28
+ # Build package
29
+ uv build
30
+
31
+ # Validate lockfile
32
+ uv lock --locked
33
+ ```
34
+
35
+ ## Code style
36
+
37
+ - Formatter: `ruff format` (double quotes, spaces, line-length 100)
38
+ - Linter: `ruff check` with rules E, F, I (ignores E501)
39
+ - Config lives in `pyproject.toml` under `[tool.ruff]`
40
+ - Run `pre-commit run --all-files` before committing
41
+
42
+ ## Directory layout
43
+
44
+ ```
45
+ gr00t/ # Main package
46
+ configs/ # Training, data, and model configs
47
+ data/ # Data loading, embodiment tags, dataset processing
48
+ eval/ # Evaluation (run_gr00t_server.py)
49
+ experiment/ # Training pipeline (launch_finetune.py, trainer.py)
50
+ model/ # Model architecture (N1.7, base, modules)
51
+ policy/ # Policy inference (Gr00tPolicy, server/client)
52
+ examples/ # Per-embodiment example configs and READMEs
53
+ scripts/ # Deployment, conversion, and utility scripts
54
+ deployment/ # Platform install scripts (dgpu, orin, thor, spark)
55
+ tests/ # pytest suite (markers: gpu, not gpu)
56
+ getting_started/ # User-facing guides and notebooks
57
+ ```
58
+
59
+ ## Key entry points
60
+
61
+ - **Fine-tune:** `bash examples/finetune.sh --base-model-path <path> --dataset-path <path> --embodiment-tag <tag> --output-dir <dir>`
62
+ - **Inference server:** `python gr00t/eval/run_gr00t_server.py --model-path <path> --embodiment-tag <tag>`
63
+ - **ONNX export:** `python scripts/deployment/export_onnx_n1d7.py`
64
+ - **TensorRT build:** `python scripts/deployment/build_trt_pipeline.py`
65
+ - **Benchmark:** `python scripts/deployment/benchmark_inference.py`
66
+
67
+ ## Testing
68
+
69
+ - Test markers: `gpu` (requires GPU), default is CPU-safe
70
+ - Fixtures live in `tests/fixtures/` and `demo_data/`
71
+ - CI runs CPU and GPU tests in separate jobs with 300s timeout
72
+
73
+ ## Deployment platforms
74
+
75
+ - **dGPU (H100, A100, RTX):** CUDA 12.8 — install via `scripts/deployment/dgpu/install_deps.sh`, container via top-level `docker/Dockerfile` (supports x86_64 and aarch64)
76
+ - **Jetson Orin:** CUDA 12.6 — install via `scripts/deployment/orin/install_deps.sh`, container via `scripts/deployment/orin/Dockerfile`
77
+ - **Jetson Thor:** CUDA 13.0 — install via `scripts/deployment/thor/install_deps.sh`, container via `scripts/deployment/thor/Dockerfile`
78
+ - **DGX Spark:** CUDA 13.0 — install via `scripts/deployment/spark/install_deps.sh`, container via `scripts/deployment/spark/Dockerfile`
79
+
80
+ Each Jetson/Spark platform ships an `activate_*.sh` helper (`scripts/activate_orin.sh`, `scripts/activate_spark.sh`, `scripts/activate_thor.sh`) that exports platform-specific library paths. For dGPU, the standard `source .venv/bin/activate` is sufficient.
ATTRIBUTIONS.md ADDED
The diff for this file is too large to render. See raw diff
 
CLAUDE.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md — Isaac GR00T N1.7
2
+
3
+ ## Project overview
4
+
5
+ Isaac GR00T N1.7 is an open vision-language-action (VLA) model for generalized humanoid robot skills.
6
+ The repo contains the model, training pipeline, evaluation harness, and deployment tooling.
7
+
8
+ - **Language:** Python 3.10 (dGPU, Orin); Python 3.12 (Thor, DGX Spark — see deployment dir)
9
+ - **Package manager:** [uv](https://docs.astral.sh/uv/)
10
+ - **Build system:** setuptools (see `pyproject.toml`)
11
+ - **CI:** internal GitLab CI (`.gitlab-ci.yml` + includes under `ci/`, not shipped to the public GitHub EA repo); public GitHub Actions (`.github/workflows/`)
12
+
13
+ ## Quick-start commands
14
+
15
+ ```bash
16
+ # Install (dev mode with all extras)
17
+ uv sync --all-extras
18
+
19
+ # Lint and format (uses ruff via pre-commit)
20
+ pre-commit run --all-files
21
+
22
+ # Run CPU tests
23
+ python -m pytest tests/ -m "not gpu" -v --timeout=300
24
+
25
+ # Run GPU tests
26
+ python -m pytest tests/ -m gpu -v --timeout=300
27
+
28
+ # Build package
29
+ uv build
30
+
31
+ # Validate lockfile
32
+ uv lock --locked
33
+ ```
34
+
35
+ ## Code style
36
+
37
+ - Formatter: `ruff format` (double quotes, spaces, line-length 100)
38
+ - Linter: `ruff check` with rules E, F, I (ignores E501)
39
+ - Config lives in `pyproject.toml` under `[tool.ruff]`
40
+ - Run `pre-commit run --all-files` before committing
41
+
42
+ ## Directory layout
43
+
44
+ ```
45
+ gr00t/ # Main package
46
+ configs/ # Training, data, and model configs
47
+ data/ # Data loading, embodiment tags, dataset processing
48
+ eval/ # Evaluation (run_gr00t_server.py)
49
+ experiment/ # Training pipeline (launch_finetune.py, trainer.py)
50
+ model/ # Model architecture (N1.7, base, modules)
51
+ policy/ # Policy inference (Gr00tPolicy, server/client)
52
+ examples/ # Per-embodiment example configs and READMEs
53
+ scripts/ # Deployment, conversion, and utility scripts
54
+ deployment/ # Platform install scripts (dgpu, orin, thor, spark)
55
+ tests/ # pytest suite (markers: gpu, not gpu)
56
+ getting_started/ # User-facing guides and notebooks
57
+ ```
58
+
59
+ ## Key entry points
60
+
61
+ - **Fine-tune:** `bash examples/finetune.sh --base-model-path <path> --dataset-path <path> --embodiment-tag <tag> --output-dir <dir>`
62
+ - **Inference server:** `python gr00t/eval/run_gr00t_server.py --model-path <path> --embodiment-tag <tag>`
63
+ - **ONNX export:** `python scripts/deployment/export_onnx_n1d7.py`
64
+ - **TensorRT build:** `python scripts/deployment/build_trt_pipeline.py`
65
+ - **Benchmark:** `python scripts/deployment/benchmark_inference.py`
66
+
67
+ ## Testing
68
+
69
+ - Test markers: `gpu` (requires GPU), default is CPU-safe
70
+ - Fixtures live in `tests/fixtures/` and `demo_data/`
71
+ - CI runs CPU and GPU tests in separate jobs with 300s timeout
72
+
73
+ ## Deployment platforms
74
+
75
+ - **dGPU (H100, A100, RTX):** CUDA 12.8 — install via `scripts/deployment/dgpu/install_deps.sh`, container via top-level `docker/Dockerfile` (supports x86_64 and aarch64)
76
+ - **Jetson Orin:** CUDA 12.6 — install via `scripts/deployment/orin/install_deps.sh`, container via `scripts/deployment/orin/Dockerfile`
77
+ - **Jetson Thor:** CUDA 13.0 — install via `scripts/deployment/thor/install_deps.sh`, container via `scripts/deployment/thor/Dockerfile`
78
+ - **DGX Spark:** CUDA 13.0 — install via `scripts/deployment/spark/install_deps.sh`, container via `scripts/deployment/spark/Dockerfile`
79
+
80
+ Each Jetson/Spark platform ships an `activate_*.sh` helper (`scripts/activate_orin.sh`, `scripts/activate_spark.sh`, `scripts/activate_thor.sh`) that exports platform-specific library paths. For dGPU, the standard `source .venv/bin/activate` is sufficient.
CONTRIBUTING.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Contributions
2
+
3
+ During Early Access we are not accepting pull requests while the codebase stabilizes. If you encounter issues or have suggestions, please open an [Issue](https://github.com/NVIDIA/Isaac-GR00T/issues) in this repository.
4
+
5
+ ## Support
6
+
7
+ Support during Early Access is best-effort. We will continue iterating toward a more stable General Availability (GA) release.
FAQ.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GR00T N1.7 FAQ
2
+
3
+ ## Infrastructure & Hardware
4
+
5
+ ### Is the data loader GPU-accelerated?
6
+
7
+ No, the current data loader is CPU-based. However, it has been heavily optimized for multimodal data to ensure it does not become a training bottleneck. We validated this on various configurations, including GB200, H100, and local desktops with RTX 4090 GPUs. We are actively exploring GPU-accelerated approaches for future releases.
8
+
9
+ ### Is the same data loader used for both pre-training and post-training?
10
+
11
+ Yes, the data loading pipeline is unified across both training stages.
12
+
13
+ ### What is the role of the Policy Remote Server in the deployment diagram?
14
+
15
+ The Policy Remote Server decouples inference from the physical robot. This allows users to run the policy on a high-compute cluster (e.g., H100s) for faster inference while the robot operates in a separate environment. It separates dependencies and enables scaling beyond the robot's onboard compute.
16
+
17
+ ## Workflow & Architecture
18
+
19
+ ### Why retain only specific LLM layers (e.g., 16 layers) during fine-tuning?
20
+
21
+ This configuration was empirically tuned for the backbone (e.g., Eagle/Cosmos-Reason). Research suggests early layers capture grammatical structure, while middle-to-late layers are highly expressive. However, the very last layers are often over-optimized for next-token prediction; pruning or freezing them can sometimes yield better representations for vision-language-action alignment.
22
+
23
+ ### How do you verify if the language model is successfully aligned with the action space?
24
+
25
+ We evaluate this end-to-end via downstream task success. We design evaluation tasks that are ambiguous without language instructions (e.g., "pick the pear" from a bowl of mixed fruit). If the robot succeeds, it confirms the model is correctly grounding language commands into physical actions.
26
+
27
+ ## Data Strategy & Volume
28
+
29
+ ### How much data is required for post-training on a new embodiment or task?
30
+
31
+ Data requirements depend heavily on task complexity and scene variation. Typical guidelines include:
32
+
33
+ - **Simple, fixed-location tasks (Pick & Place):** ~100 trajectories.
34
+ - **Complex scenes or multi-step tasks:** ~500+ trajectories.
35
+ - **High-DoF humanoid tasks:** ~2,000+ trajectories (e.g., shelf-picking with G1).
36
+ - **Fine manipulation:** ~100–500 episodes, ideally with human motion pre-training.
37
+
38
+ ### What is the recommended strategy for improving success rates on hard tasks?
39
+
40
+ We recommend an iterative approach: start with ~100 teleoperated demonstrations, train a policy, and then use HG-DAgger (Human Gated Dataset Aggregation). Run the policy, intervene when it fails, and add the corrections from those trajectories to the dataset. This helps the model cover out-of-distribution states that pure behavior cloning (BC) might miss, and recover from partial failure states (e.g., a grip slipping or imprecise item placement).
41
+
42
+ ### Does including real-robot data from other embodiments help if I only care about one robot?
43
+
44
+ Yes. Even if cross-embodiment generalization is not your goal, including diverse real-robot data adds visual diversity and robustness to the VLA's backbone, improving performance on your specific target robot.
45
+
46
+ ### Does GR00T N1.7 support synthetic data generation via Cosmos?
47
+
48
+ While research models (like DreamGen) show promise, a robust, product-ready pipeline for generating synthetic training data via Cosmos is currently in development and not yet part of the standard release.
49
+
50
+ ## Model Capabilities
51
+
52
+ ### Can the model handle lighting changes or different object colors?
53
+
54
+ VLMs can struggle with drastic appearance changes (e.g., hard shadows or significant hue shifts). While we haven't released specific lighting ablations, we strongly recommend using color jitter augmentation during training and collecting diverse data (20–50 episodes) under different lighting conditions to prevent overfitting.
55
+
56
+ ### Can GR00T models perform reasoning or Visual Question Answering (VQA)?
57
+
58
+ The GR00T N1.x series is optimized specifically for action generation, not open-ended reasoning or VQA. Capabilities requiring complex semantic reasoning are targeted for future N2 releases.
59
+
60
+ ### Can the model learn "retry" behaviors?
61
+
62
+ The current architecture is stateless and does not inherently "know" if a previous attempt failed. While some retry behavior may emerge from high-quality data, explicit recovery strategies are best achieved through DAgger (collecting data on recovery from failure) or Reinforcement Learning (RL), rather than pure Imitation Learning.
63
+
64
+ ### Does the model distinguish between left and right arms in bimanual tasks?
65
+
66
+ Yes, provided the training data is distinct or annotated (e.g., instructions specifying "left arm" vs. "right arm"). If the dataset contains mixed, unannotated data where both arms perform identical tasks indiscriminately, the model may struggle to distinguish them.
67
+
68
+ ### Is there a zero-shot cross-embodiment VLA model?
69
+
70
+ No. While cross-embodiment data improves generalization, a true "zero-shot" model (one that works perfectly on a new robot without *any* fine-tuning) does not currently exist in the open VLA landscape.
71
+
72
+ ### Will differences in object shape between training and deployment cause the success rate to drop?
73
+
74
+ It depends on the degree of deviation. If the target object's shape differs drastically from the training data, performance will likely drop significantly. However, if the shape variation is minor and shares a similar grasping affordance (e.g., a slightly different bottle shape that is still grasped from the side), the model may still succeed, though with potentially lower reliability than on the original objects.
75
+
76
+ ### Has the impact of large viewpoint changes (e.g., head movement) on task difficulty been studied?
77
+
78
+ Yes. Large viewpoint changes effectively change the observation distribution, which can complicate simple tasks. For example, a "simple" handover becomes complex if the robot's head moves significantly, altering the camera's perspective of its own hands.
79
+
80
+ - **Current Status:** Most public GR00T demos feature a relatively fixed head position to stabilize observations.
81
+ - **Mitigation:** To handle natural head movement, we recommend training with aggressive camera pose augmentation or collecting data that explicitly includes head motion to ensure the policy becomes robust to viewpoint shifts.
README.md ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <img src="media/header_compress.png" width="800" alt="NVIDIA Isaac GR00T N1.7 Header">
4
+
5
+ <!-- --- -->
6
+
7
+ <p style="font-size: 1.2em;">
8
+ <a href="https://developer.nvidia.com/isaac/gr00t"><strong>Website</strong></a> |
9
+ <a href="https://huggingface.co/collections/nvidia/gr00t-n17"><strong>Model</strong></a> |
10
+ <a href="https://huggingface.co/collections/nvidia/physical-ai"><strong>Dataset</strong></a> |
11
+ <a href="https://arxiv.org/abs/2503.14734"><strong>Paper</strong></a> |
12
+ <a href="https://developer.nvidia.com/isaac"><strong>NVIDIA Isaac</strong></a> |
13
+ <a href="FAQ.md"><strong>FAQ</strong></a>
14
+ </p>
15
+ </div>
16
+
17
+ ## Table of Contents
18
+
19
+ - [NVIDIA Isaac GR00T](#nvidia-isaac-gr00t)
20
+ - [What's New in GR00T N1.7](#whats-new-in-gr00t-n17)
21
+ - [Installation](#installation)
22
+ - [Model Checkpoints & Embodiment Tags](#model-checkpoints--embodiment-tags)
23
+ - [Data Format](#data-format)
24
+ - [Inference](#inference)
25
+ - [Fine-tuning](#fine-tuning)
26
+ - [Evaluation](#evaluation)
27
+ - [Contributions](#contributions)
28
+ - [License](#license)
29
+ - [Citation](#citation)
30
+
31
+ ---
32
+
33
+ ## NVIDIA Isaac GR00T
34
+
35
+ <table style="width:100%; table-layout:fixed;">
36
+ <tr>
37
+ <td style="width:33.33%; text-align:center;">
38
+ <img src="media/unitree_g1.gif" style="max-width:100%; height:auto;">
39
+ </td>
40
+ <td style="width:33.33%; text-align:center;">
41
+ <img src="media/agibot_g1.gif" style="max-width:100%; height:auto;">
42
+ </td>
43
+ <td style="width:33.33%; text-align:center;">
44
+ <img src="media/yam.gif" style="max-width:100%; height:auto;">
45
+ </td>
46
+ </tr>
47
+ </table>
48
+
49
+ > We just released GR00T N1.7 Early Access, the latest version of GR00T N1 with a new VLM backbone (Cosmos-Reason2-2B / Qwen3-VL) and improved performance.
50
+
51
+ > **This is an Early Access (EA) release.** You are welcome to download the model, explore the codebase, and begin building on the stack, with the understanding that support and stability guarantees are limited until the GA release.
52
+ >
53
+ > **What's available:**
54
+ > - Pre-trained GR00T N1.7 model weights and reference code
55
+ > - Fine-tuning and inference with custom robot data or demonstrations
56
+ > - Experimentation, prototyping, and research use cases
57
+ >
58
+ > **Available at GA:**
59
+ > - Production deployment with commercial support
60
+ > - Complete benchmarks and a fully validated, stable feature set
61
+ > - Pull request contributions
62
+ >
63
+ > We welcome feedback - please feel free to raise issues in this repository.
64
+
65
+ > To use older versions: [N1.6](https://github.com/NVIDIA/Isaac-GR00T/releases/tag/n1.6-release) | [N1.5](https://github.com/NVIDIA/Isaac-GR00T/tree/n1.5-release)
66
+
67
+ NVIDIA Isaac GR00T N1.7 is an open vision-language-action (VLA) model for generalized humanoid robot skills. This cross-embodiment model takes multimodal input, including language and images, to perform manipulation tasks in diverse environments.
68
+
69
+ GR00T N1.7 is trained on a diverse mixture of robot data including bimanual, semi-humanoid and an expansive humanoid dataset. It is adaptable through post-training for specific embodiments, tasks and environments.
70
+
71
+ GR00T N1.7 is fully commercially licensable under Apache 2.0. It delivers comparable performance to N1.6, with improved generalization and language-following capabilities driven by the inclusion of 20K hours of EgoScale human video data in pretraining.
72
+
73
+ The neural network architecture of GR00T N1.7 is a combination of vision-language foundation model and diffusion transformer head that denoises continuous actions. Here is a schematic diagram of the architecture:
74
+
75
+ <div align="center">
76
+ <img src="media/model-architecture.png" width="800" alt="model-architecture">
77
+ </div>
78
+
79
+ ### Workflow Overview
80
+
81
+ 1. **Prepare data** — Collect robot demonstrations (video, state, action) and convert them to the [GR00T LeRobot format](#data-format). Demo datasets are included for quick testing.
82
+ 2. **Run inference** — Try zero-shot inference with the base model on [pretrain embodiments](#embodiment-tags), or use a [finetuned checkpoint](#checkpoints) for benchmark tasks.
83
+ 3. **Fine-tune** — Adapt the model to your robot using [`launch_finetune.py`](#fine-tuning) with your own data and modality config.
84
+ 4. **Evaluate** — Validate with [open-loop evaluation](#open-loop-evaluation), then test in [simulation benchmarks](#benchmark-examples) or on real hardware via the [Policy API](getting_started/policy.md).
85
+ 5. **Deploy** — Connect `Gr00tPolicy` to your robot controller, optionally accelerated with [TensorRT](scripts/deployment/README.md).
86
+
87
+ ## What's New in GR00T N1.7
88
+
89
+ GR00T N1.7 builds on N1.6 with a new VLM backbone and code-level improvements.
90
+
91
+ 1. **Relative EEF Action Space** — N1.7 adopts a relative end-effector action space shared across robot and human embodiments. Representing actions as deltas from the current pose (rather than absolute targets) improves generalization and is a key factor in the model's cross-embodiment performance. See [`getting_started/finetune_new_embodiment.md`](getting_started/finetune_new_embodiment.md) for guidance on configuring relative EEF for your own robot.
92
+
93
+ 2. **Human Video Pretraining** — N1.7 is pretrained on 20K hours of EgoScale human video data alongside diverse robot demonstrations. Because the relative EEF action representation is consistent across both human and robot data, the model can transfer manipulation priors learned from human video directly to robot control.
94
+
95
+ ### Key Changes from N1.6
96
+
97
+ - **New VLM backbone:** Cosmos-Reason2-2B (Qwen3-VL architecture), replacing the Eagle backbone used in N1.6. Supports flexible resolution and encodes images in their native aspect ratio without padding.
98
+ - Simplified data processing pipeline (`processing_gr00t_n1d7.py`).
99
+ - Added full pipeline export to ONNX and TensorRT with improved frequency.
100
+
101
+ ---
102
+
103
+ ## Installation
104
+
105
+ ### Hardware Requirements
106
+
107
+ **Inference:** 1 GPU with 16 GB+ VRAM (e.g., RTX 4090, L40, H100, Jetson AGX Thor/Orin, DGX Spark).
108
+
109
+ **Fine-tuning:** 1 or more GPUs with 40 GB+ VRAM recommended. We recommend H100 or L40 nodes for optimal performance. Other hardware (e.g., A6000) works but may require longer training time. See the [Hardware Recommendation Guide](getting_started/hardware_recommendation.md) for detailed specs.
110
+
111
+ **CUDA / Python per platform:** dGPU on CUDA 12.8 with Python 3.10; Jetson Orin on CUDA 12.6 with Python 3.10; Jetson Thor and DGX Spark on CUDA 13.0 with Python 3.12. The per-platform install scripts and Dockerfiles live under `scripts/deployment/`; see the [Deployment & Inference Guide](scripts/deployment/README.md) for the full matrix.
112
+
113
+ ### Clone the Repository
114
+
115
+ GR00T relies on submodules for certain dependencies. Include them when cloning:
116
+
117
+ **Note:** `git-lfs` is **required** to download parquet data files in `/demo_data`. Install it before cloning: `sudo apt install git-lfs && git lfs install`.
118
+ ```sh
119
+ git clone --recurse-submodules https://github.com/NVIDIA/Isaac-GR00T
120
+ cd Isaac-GR00T
121
+ ```
122
+
123
+ If you've already cloned without submodules, initialize them separately:
124
+
125
+ ```sh
126
+ git submodule update --init --recursive
127
+ ```
128
+
129
+ ### Set Up the Environment
130
+
131
+ GR00T uses [uv](https://github.com/astral-sh/uv) for fast, reproducible dependency management. Install uv first:
132
+
133
+ ```sh
134
+ curl -LsSf https://astral.sh/uv/install.sh | sh
135
+ ```
136
+
137
+ #### dGPU (x86_64) — Default
138
+
139
+ Install FFmpeg (required by `torchcodec`, the default video backend):
140
+ ```sh
141
+ sudo apt-get update && sudo apt-get install -y ffmpeg
142
+ ```
143
+
144
+ Create the environment and install GR00T:
145
+ ```sh
146
+ uv sync --python 3.10
147
+ ```
148
+ GPU dependencies (flash-attn, TensorRT, etc.) are included in the default install.
149
+
150
+ Verify the installation:
151
+ ```sh
152
+ uv run python -c "import gr00t; print('GR00T installed successfully')"
153
+ ```
154
+
155
+ > **`flash-attn` message on every `uv run`:** You may see `Installing flash-attn...` each time you run `uv run`. This is a known `uv` behavior with URL-pinned wheel sources — `uv` re-validates the cached wheel against the source URL on each invocation. It is **not** rebuilding from source; the wheel is already cached locally and the operation takes 2-3 seconds. This only affects x86_64 platforms.
156
+ > To suppress it, remove the `flash-attn` entries under `[tool.uv.sources]` in your local `pyproject.toml` after the initial install. But that will break `uv lock` and cause flash-attn to build from source on next lock regeneration.
157
+
158
+ <details>
159
+ <summary><strong>Alternative: pip install (without uv)</strong></summary>
160
+
161
+ If you prefer pip/conda over uv, create a Python 3.10 virtualenv and install:
162
+ ```sh
163
+ python3.10 -m venv .venv && source .venv/bin/activate
164
+ pip install -e .
165
+ ```
166
+ Note: GPU dependencies (flash-attn, TensorRT) may require manual installation with pip. The `uv` workflow handles these automatically.
167
+ </details>
168
+
169
+ > **If fine-tuning fails with `CUDA_HOME is unset`:** Run `bash scripts/deployment/dgpu/install_deps.sh` once to configure CUDA paths, or manually `export CUDA_HOME=/usr/local/cuda`.
170
+
171
+ > **CUDA 13.x Users (Thor, Spark, and other CUDA 13+ platforms):** PyTorch 2.7 pins Triton to 3.3.1, which does not recognize CUDA major version 13+. This causes a `RuntimeError` in Triton's `ptx_get_version()`. Run the patch script to fix:
172
+ > ```sh
173
+ > uv run bash scripts/patch_triton_cuda13.sh
174
+ > ```
175
+
176
+ > **GB300 (sm_103) Users:** Triton 3.3.1 (pinned by PyTorch 2.7) does not support the GB300 GPU architecture (sm_103). `torch.compile` will fail on GB300. Use PyTorch eager mode or TensorRT inference instead. Triton 3.5.1+ adds sm_103 support but is not yet compatible with the pinned PyTorch version.
177
+
178
+ > **aarch64 Video Backend:** On aarch64 platforms (Thor, Orin, Spark), `torchcodec` is the required video backend. `install_deps.sh` prefers the prebuilt aarch64 wheel under `scripts/deployment/dgpu/wheels/` (shared by Thor/Spark against FFmpeg 6; Orin uses a matching build against FFmpeg 4) and falls back to a source build only if the wheel is missing. If you encounter `NotImplementedError` from the video backend, ensure `torchcodec` was installed successfully during setup. Other backends (decord, pyav) are not supported on aarch64.
179
+
180
+ <details>
181
+ <summary><strong>DGX Spark</strong> (tested with DGX Spark GB10)</summary>
182
+
183
+ ```bash
184
+ bash scripts/deployment/spark/install_deps.sh
185
+ source .venv/bin/activate
186
+ source scripts/activate_spark.sh
187
+ ```
188
+
189
+ See the [Spark setup guide](scripts/deployment/README.md#dgx-spark-setup) for Docker and bare metal details.
190
+ </details>
191
+
192
+ <details>
193
+ <summary><strong>Jetson AGX Thor</strong> (tested with JetPack 7.1)</summary>
194
+
195
+ > **flash-attn on older systems (e.g., Ubuntu 20.04 with glibc < 2.35):** The pre-built `flash-attn` wheel may fail with `ImportError: glibc_compat.so: cannot open shared object file`. To fix this, build from source:
196
+ > ```sh
197
+ > uv pip install flash-attn==2.7.4.post1 --no-binary flash-attn --no-cache
198
+ > ```
199
+ > This compiles locally (~10-30 minutes) and avoids the glibc compatibility issue.
200
+
201
+ ```bash
202
+ bash scripts/deployment/thor/install_deps.sh
203
+ source .venv/bin/activate
204
+ source scripts/activate_thor.sh
205
+ ```
206
+
207
+ See the [Thor setup guide](scripts/deployment/README.md#jetson-thor-setup) for Docker and bare metal details.
208
+ </details>
209
+
210
+ <details>
211
+ <summary><strong>Jetson Orin</strong> (tested with JetPack 6.2)</summary>
212
+
213
+ ```bash
214
+ bash scripts/deployment/orin/install_deps.sh
215
+ source .venv/bin/activate
216
+ source scripts/activate_orin.sh
217
+ ```
218
+
219
+ See the [Orin setup guide](scripts/deployment/README.md#jetson-orin-setup) for Docker and bare metal details.
220
+ </details>
221
+
222
+ For a containerized setup that avoids system-level dependency conflicts, see our [Docker Setup Guide](docker/README.md).
223
+
224
+ ---
225
+
226
+ ## Model Checkpoints & Embodiment Tags
227
+
228
+ ### Checkpoints
229
+
230
+ | Checkpoint | Type | Embodiment Tag | Description |
231
+ |------------|------|---------------|-------------|
232
+ | [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B) | Base | See [pretrain tags](getting_started/policy.md#--embodiment-tag) | Base model (3B params) — zero-shot inference on pretrain embodiments, or finetune for new tasks |
233
+ | [`nvidia/GR00T-N1.7-LIBERO`](https://huggingface.co/nvidia/GR00T-N1.7-LIBERO) | Finetuned | `LIBERO_PANDA` | Finetuned on [LIBERO](https://libero-project.github.io/) benchmark (Franka Panda) |
234
+ | [`nvidia/GR00T-N1.7-DROID`](https://huggingface.co/nvidia/GR00T-N1.7-DROID) | Finetuned | `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` | Finetuned on [DROID](https://droid-dataset.github.io/) dataset |
235
+ | [`nvidia/GR00T-N1.7-SimplerEnv-Bridge`](https://huggingface.co/nvidia/GR00T-N1.7-SimplerEnv-Bridge) | Finetuned | `SIMPLER_ENV_WIDOWX` | Finetuned on SimplerEnv Bridge (WidowX) |
236
+ | [`nvidia/GR00T-N1.7-SimplerEnv-Fractal`](https://huggingface.co/nvidia/GR00T-N1.7-SimplerEnv-Fractal) | Finetuned | `SIMPLER_ENV_GOOGLE` | Finetuned on SimplerEnv Fractal (Google Robot) |
237
+
238
+ > Older versions: [N1.6 checkpoints](https://github.com/NVIDIA/Isaac-GR00T/tree/n1.6-release) | [N1.5 checkpoints](https://github.com/NVIDIA/Isaac-GR00T/tree/n1.5-release)
239
+
240
+ ### Embodiment Tags
241
+
242
+ Every inference or finetuning command requires an `--embodiment-tag`. The tag determines which modality config (state/action keys, normalization) the model uses. Tags are **case-insensitive**.
243
+
244
+ For the full list of pretrain and posttrain tags, see the [Policy API Guide — Embodiment Tags](getting_started/policy.md#--embodiment-tag).
245
+
246
+ ---
247
+
248
+ ## Data Format
249
+
250
+ GR00T uses a flavor of the [LeRobot v2 dataset format](https://github.com/huggingface/lerobot) with an additional `meta/modality.json` file that describes state/action/video structure. A dataset looks like:
251
+
252
+ ```
253
+ my_dataset/
254
+ meta/
255
+ info.json # dataset metadata
256
+ episodes.jsonl # episode index and lengths
257
+ tasks.jsonl # language task descriptions
258
+ modality.json # state/action/video key mapping (GR00T-specific)
259
+ data/chunk-000/ # parquet files (state, action per timestep)
260
+ videos/chunk-000/ # mp4 video files per episode
261
+ ```
262
+
263
+ The `modality.json` maps how the concatenated state/action arrays split into named fields (e.g., `x`, `y`, `z`, `gripper`) and which video keys are available. This is what the embodiment tag uses to interpret the data.
264
+
265
+ **Included demo datasets** (ready to use, no download needed):
266
+
267
+ | Dataset | Robot | Embodiment Tag | Use Case |
268
+ |---------|-------|---------------|----------|
269
+ | `demo_data/droid_sample` | DROID (3 episodes) | `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` | Zero-shot or finetuned inference (DROID) |
270
+ | `demo_data/libero_demo` | LIBERO Panda (5 episodes) | `LIBERO_PANDA` | Inference with finetuned checkpoint |
271
+ | `demo_data/simplerenv_bridge_sample` | WidowX (SimplerEnv Bridge) | `SIMPLER_ENV_WIDOWX` | Inference with finetuned SimplerEnv Bridge checkpoint |
272
+ | `demo_data/simplerenv_fractal_sample` | Google Robot (SimplerEnv Fractal) | `SIMPLER_ENV_GOOGLE` | Inference with finetuned SimplerEnv Fractal checkpoint |
273
+ | `demo_data/cube_to_bowl_5` | SO100 arm (5 episodes) | `NEW_EMBODIMENT` | Fine-tuning custom embodiment example |
274
+ | `demo_data/cube_to_bowl_5_with_mask` | SO100 arm + per-frame masks | `NEW_EMBODIMENT` | [Mask-guided background suppression](examples/mask-guided-background-suppression/README.md) example |
275
+
276
+ > To generate more DROID episodes: `python scripts/download_droid_sample.py --num-episodes 10`
277
+
278
+ **Using your own data:** Convert your demonstrations to the format above. If coming from LeRobot v3, use the conversion script: `python scripts/lerobot_conversion/convert_v3_to_v2.py`. See the full [Data Preparation Guide](getting_started/data_preparation.md) for schema details and examples.
279
+
280
+ ---
281
+
282
+ ## Inference
283
+
284
+ ### Zero-Shot Inference (Base Model)
285
+
286
+ The included `demo_data/droid_sample` dataset works with the base model out of the box — no finetuning or checkpoint download needed:
287
+
288
+ ```bash
289
+ uv run python scripts/deployment/standalone_inference_script.py \
290
+ --model-path nvidia/GR00T-N1.7-3B \
291
+ --dataset-path demo_data/droid_sample \
292
+ --embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT \
293
+ --traj-ids 1 2 \
294
+ --inference-mode pytorch \
295
+ --action-horizon 8
296
+ ```
297
+
298
+ This runs open-loop inference on 2 DROID episodes, comparing predicted actions against ground truth. The base model downloads automatically from HuggingFace on first run (~6 GB).
299
+
300
+ ### Finetuned Inference
301
+
302
+ For posttrain embodiments, use a finetuned checkpoint. Most finetuned checkpoints (e.g., DROID, SimplerEnv) have a flat file structure and can be passed directly as a HuggingFace model ID — no manual download needed:
303
+
304
+ ```bash
305
+ uv run python scripts/deployment/standalone_inference_script.py \
306
+ --model-path nvidia/GR00T-N1.7-DROID \
307
+ --dataset-path demo_data/droid_sample \
308
+ --embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT \
309
+ --traj-ids 1 2 \
310
+ --inference-mode pytorch \
311
+ --action-horizon 8
312
+ ```
313
+
314
+ Some checkpoints (e.g., LIBERO) use a nested folder structure with model files under a subfolder. HuggingFace does not support nested repo paths in `--model-path`, so you must download first:
315
+
316
+ ```bash
317
+ uv run hf download nvidia/GR00T-N1.7-LIBERO \
318
+ --include "libero_10/config.json" "libero_10/embodiment_id.json" \
319
+ "libero_10/model-*.safetensors" "libero_10/model.safetensors.index.json" \
320
+ "libero_10/processor_config.json" "libero_10/statistics.json" \
321
+ --local-dir checkpoints/GR00T-N1.7-LIBERO
322
+ ```
323
+
324
+ ```bash
325
+ uv run python scripts/deployment/standalone_inference_script.py \
326
+ --model-path checkpoints/GR00T-N1.7-LIBERO/libero_10 \
327
+ --dataset-path demo_data/libero_demo \
328
+ --embodiment-tag LIBERO_PANDA \
329
+ --traj-ids 0 1 2 \
330
+ --inference-mode pytorch \
331
+ --action-horizon 8
332
+ ```
333
+
334
+ ### Server-Client Inference (for Deployment)
335
+
336
+ For real-world deployment or simulation evaluation, use the server-client architecture. The policy runs on a GPU server; a lightweight client sends observations and receives actions over ZMQ.
337
+
338
+ **Terminal 1 — Start the policy server:**
339
+ ```bash
340
+ uv run python gr00t/eval/run_gr00t_server.py \
341
+ --model-path nvidia/GR00T-N1.7-3B \
342
+ --embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT \
343
+ --device cuda:0
344
+ ```
345
+
346
+ **Terminal 2 — Run open-loop evaluation as a client:**
347
+ ```bash
348
+ uv run python gr00t/eval/open_loop_eval.py \
349
+ --dataset-path demo_data/droid_sample \
350
+ --embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT \
351
+ --host 127.0.0.1 \
352
+ --port 5555 \
353
+ --traj-ids 1 2 \
354
+ --action-horizon 8
355
+ ```
356
+
357
+ > **Tip:** If you get `ZMQError: Address already in use`, the default port 5555 is occupied. Use `--port <other_port>`.
358
+
359
+ For connecting to a real robot (e.g., DROID hardware), see [examples/DROID/README.md](examples/DROID/README.md). For faster inference with TensorRT, see the [Deployment & Inference Guide](scripts/deployment/README.md).
360
+
361
+ See the complete [Policy API Guide](getting_started/policy.md) for documentation on observation/action formats, batched inference, and troubleshooting.
362
+
363
+ ---
364
+
365
+ ## Fine-tuning
366
+
367
+ ### Reproducing Benchmark Results
368
+
369
+ Each benchmark has a self-contained README with dataset download, finetune, and evaluation commands:
370
+
371
+ | Benchmark | Embodiment | Guide |
372
+ |-----------|-----------|-------|
373
+ | LIBERO | `LIBERO_PANDA` | [examples/LIBERO/README.md](examples/LIBERO/README.md) |
374
+ | SimplerEnv (Fractal) | `SIMPLER_ENV_GOOGLE` | [examples/SimplerEnv/README.md](examples/SimplerEnv/README.md) |
375
+ | SimplerEnv (Bridge) | `SIMPLER_ENV_WIDOWX` | [examples/SimplerEnv/README.md](examples/SimplerEnv/README.md) |
376
+ | SO100 | `NEW_EMBODIMENT` | [examples/SO100/README.md](examples/SO100/README.md) |
377
+
378
+ ### Fine-tune on Your Own Robot ("NEW_EMBODIMENT")
379
+
380
+ To finetune GR00T on your own robot data and configuration, follow the detailed tutorial at [`getting_started/finetune_new_embodiment.md`](getting_started/finetune_new_embodiment.md).
381
+
382
+ Ensure your input data follows the [GR00T LeRobot format](#data-format), and specify your modality configuration via `--modality-config-path`.
383
+
384
+ **Single GPU:**
385
+ ```bash
386
+ CUDA_VISIBLE_DEVICES=0 uv run python \
387
+ gr00t/experiment/launch_finetune.py \
388
+ --base-model-path nvidia/GR00T-N1.7-3B \
389
+ --dataset-path demo_data/cube_to_bowl_5 \
390
+ --embodiment-tag NEW_EMBODIMENT \
391
+ --modality-config-path examples/SO100/so100_config.py \
392
+ --num-gpus 1 \
393
+ --output-dir /tmp/test_finetune \
394
+ --max-steps 2000 \
395
+ --global-batch-size 32 \
396
+ --dataloader-num-workers 4
397
+ ```
398
+
399
+ **Multi-GPU (e.g., 8xH100):**
400
+ ```bash
401
+ uv run torchrun --nproc_per_node=8 --master_port=29500 \
402
+ gr00t/experiment/launch_finetune.py \
403
+ --base-model-path nvidia/GR00T-N1.7-3B \
404
+ --dataset-path demo_data/cube_to_bowl_5 \
405
+ --embodiment-tag NEW_EMBODIMENT \
406
+ --modality-config-path examples/SO100/so100_config.py \
407
+ --num-gpus 8 \
408
+ --output-dir /tmp/test_finetune_8gpu \
409
+ --max-steps 2000 \
410
+ --global-batch-size 32 \
411
+ --dataloader-num-workers 4
412
+ ```
413
+
414
+ Replace `demo_data/cube_to_bowl_5` and `examples/SO100/so100_config.py` with your own dataset and modality config. See [`examples/SO100`](examples/SO100/README.md) for a complete walkthrough.
415
+
416
+ > **Note:** Use `uv run torchrun` (not bare `torchrun`) to ensure the correct virtual environment is used. Add `--use-wandb` to enable Weights & Biases logging. For more extensive configuration, use `gr00t/experiment/launch_train.py`.
417
+
418
+ ### Training Tips
419
+
420
+ - Maximize batch size for your hardware and train for a few thousand steps.
421
+ - Users may observe 5-6% variance between runs due to non-deterministic image augmentations. Keep this in mind when comparing to reported benchmarks.
422
+ - **`--state_dropout_prob`** (model config default: 0.8; finetune CLI default: 0.2; see `gr00t/configs/finetune_config.py`): Randomly drops state inputs during training to improve generalization and reduce state-dependency. The shipped benchmark scripts override the CLI default per suite: LIBERO 10-Long uses 0.2 (the CLI default), SimplerEnv Bridge uses 0.8, SimplerEnv Fractal uses 0.5. If your task relies heavily on proprioceptive state, lower this value.
423
+
424
+ ---
425
+
426
+ ## Evaluation
427
+
428
+ ### Open-Loop Evaluation
429
+
430
+ Compare predicted actions against ground truth from your dataset:
431
+
432
+ ```bash
433
+ uv run python gr00t/eval/open_loop_eval.py \
434
+ --dataset-path <DATASET_PATH> \
435
+ --embodiment-tag NEW_EMBODIMENT \
436
+ --model-path <CHECKPOINT_PATH> \
437
+ --traj-ids 0 \
438
+ --action-horizon 16
439
+ ```
440
+
441
+ This generates a visualization at `/tmp/open_loop_eval/traj_{traj_id}.jpeg` with ground truth vs. predicted actions and MSE metrics. Use `--save-plot-path <dir>` to save plots to a custom location.
442
+
443
+ ### Closed-Loop Evaluation
444
+
445
+ Test your model in simulation or on real hardware using the server-client architecture:
446
+
447
+ ```bash
448
+ # Start the policy server
449
+ uv run python gr00t/eval/run_gr00t_server.py \
450
+ --embodiment-tag NEW_EMBODIMENT \
451
+ --model-path <CHECKPOINT_PATH> \
452
+ --device cuda:0 \
453
+ --host 0.0.0.0 --port 5555
454
+ ```
455
+
456
+ ```python
457
+ from gr00t.policy.server_client import PolicyClient
458
+
459
+ policy = PolicyClient(host="localhost", port=5555)
460
+ env = YourEnvironment()
461
+ obs, info = env.reset()
462
+ action, info = policy.get_action(obs)
463
+ obs, reward, done, truncated, info = env.step(action)
464
+ ```
465
+
466
+ **Debugging with ReplayPolicy:** To verify your environment setup without a trained model, start the server with `--dataset-path <DATASET_PATH>` (omit `--model-path`) to replay recorded actions from the dataset.
467
+
468
+ See the complete [Policy API Guide](getting_started/policy.md) for observation/action formats, batched inference, and troubleshooting.
469
+
470
+ ### Benchmark Examples
471
+
472
+ We support evaluation on public benchmarks using a server-client architecture. The policy server reuses the project root's uv environment; simulation clients have individual setup scripts.
473
+
474
+ You can use [the verification script](scripts/eval/check_sim_eval_ready.py) to verify that all dependencies are properly configured.
475
+
476
+ **Zero-shot** (evaluate with the base model, no finetuning):
477
+ - [DROID](examples/DROID/README.md) — real-world DROID robot (also available as the finetuned `nvidia/GR00T-N1.7-DROID` checkpoint; `examples/DROID/README.md` covers both paths)
478
+
479
+ **Finetuned** (evaluate with finetuned checkpoints):
480
+ - [DROID](examples/DROID/README.md) — real-world DROID robot via `nvidia/GR00T-N1.7-DROID`
481
+ - [LIBERO](examples/LIBERO/README.md) — LIBERO benchmark (Franka Panda)
482
+ - [SimplerEnv](examples/SimplerEnv/README.md) — Google Robot (Fractal) and WidowX (Bridge)
483
+ - [SO100](examples/SO100/README.md) — SO100 custom embodiment workflow
484
+
485
+ <details>
486
+ <summary><strong>Adding a New Sim Benchmark</strong></summary>
487
+
488
+ Each sim benchmark registers its environments under a gym env_name with the format `{prefix}/{task_name}` (e.g., `libero_sim/LIVING_ROOM_SCENE2_put_soup_in_basket`). The evaluation framework uses the prefix to look up the corresponding `EmbodimentTag` via a mapping in [`gr00t/eval/sim/env_utils.py`](gr00t/eval/sim/env_utils.py).
489
+
490
+ > **Important:** The env_name prefix and the `EmbodimentTag` value are often different. For example, `libero_sim` maps to `EmbodimentTag.LIBERO_PANDA` (`"libero_sim"`). Do not assume they match.
491
+
492
+ To add a new benchmark:
493
+
494
+ 1. Add an entry to `ENV_PREFIX_TO_EMBODIMENT_TAG` in `gr00t/eval/sim/env_utils.py`:
495
+ ```python
496
+ ENV_PREFIX_TO_EMBODIMENT_TAG = {
497
+ ...
498
+ "my_new_benchmark": EmbodimentTag.MY_ROBOT,
499
+ }
500
+ ```
501
+ 2. If the benchmark has multiple env_name prefixes (e.g., `my_benchmark_v1`, `my_benchmark_v2`), all related prefixes **must** map to the same `EmbodimentTag`.
502
+ 3. Add corresponding test cases in `tests/gr00t/eval/sim/test_env_utils.py` and update the `test_all_known_prefixes_present` test.
503
+ </details>
504
+
505
+
506
+
507
+ ---
508
+
509
+ # Contributions
510
+
511
+ During Early Access we are not accepting pull requests while the codebase stabilizes. If you encounter issues or have suggestions, please open an [Issue](https://github.com/NVIDIA/Isaac-GR00T/issues) in this repository.
512
+
513
+ # Support
514
+
515
+ Support during Early Access is best-effort. We will continue iterating toward a more stable General Availability (GA) release.
516
+
517
+
518
+ ## License
519
+
520
+ - **Code:** Apache 2.0 — see [LICENSE](LICENSE)
521
+ - **Model weights:** [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/)
522
+
523
+ ```
524
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
525
+ # SPDX-License-Identifier: Apache-2.0
526
+ #
527
+ # Licensed under the Apache License, Version 2.0 (the "License");
528
+ # you may not use this file except in compliance with the License.
529
+ # You may obtain a copy of the License at
530
+ #
531
+ # http://www.apache.org/licenses/LICENSE-2.0
532
+ #
533
+ # Unless required by applicable law or agreed to in writing, software
534
+ # distributed under the License is distributed on an "AS IS" BASIS,
535
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
536
+ # See the License for the specific language governing permissions and
537
+ # limitations under the License.
538
+ ```
539
+
540
+
541
+ ## Citation
542
+
543
+ [Paper Site](https://research.nvidia.com/labs/lpr/publication/gr00tn1_2025/)
544
+ ```bibtex
545
+ @inproceedings{gr00tn1_2025,
546
+ archivePrefix = {arxiv},
547
+ eprint = {2503.14734},
548
+ title = {{GR00T} {N1}: An Open Foundation Model for Generalist Humanoid Robots},
549
+ author = {NVIDIA and Johan Bjorck and Fernando Castañeda, Nikita Cherniadev and Xingye Da and Runyu Ding and Linxi "Jim" Fan and Yu Fang and Dieter Fox and Fengyuan Hu and Spencer Huang and Joel Jang and Zhenyu Jiang and Jan Kautz and Kaushil Kundalia and Lawrence Lao and Zhiqi Li and Zongyu Lin and Kevin Lin and Guilin Liu and Edith Llontop and Loic Magne and Ajay Mandlekar and Avnish Narayan and Soroush Nasiriany and Scott Reed and You Liang Tan and Guanzhi Wang and Zu Wang and Jing Wang and Qi Wang and Jiannan Xiang and Yuqi Xie and Yinzhen Xu and Zhenjia Xu and Seonghyeon Ye and Zhiding Yu and Ao Zhang and Hao Zhang and Yizhou Zhao and Ruijie Zheng and Yuke Zhu},
550
+ month = {March},
551
+ year = {2025},
552
+ booktitle = {ArXiv Preprint},
553
+ }
554
+ ```
555
+
examples/DROID/README.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GR00T DROID
2
+
3
+ The N1.7 base model supports DROID inference out of the box via the `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` pretrain tag. A finetuned checkpoint is also available at [`nvidia/GR00T-N1.7-DROID`](https://huggingface.co/nvidia/GR00T-N1.7-DROID).
4
+
5
+ > **Note:** The DROID dataset contains multiple language instruction paraphrases per episode (`language_instruction`, `language_instruction_2`, `language_instruction_3`). These are used for language augmentation during training. At inference time, only the first language key is used.
6
+
7
+ ## Data Format
8
+
9
+ The DROID embodiment expects the following modality structure:
10
+
11
+ | Modality | Keys | Dimensions |
12
+ |----------|------|------------|
13
+ | Video | `exterior_image_1_left`, `wrist_image_left` | 2 cameras |
14
+ | State | `eef_9d`, `gripper_position`, `joint_position` | 9D + 1D + 7D = 17D |
15
+ | Action | `eef_9d`, `gripper_position`, `joint_position` | 9D + 1D + 7D = 17D |
16
+ | Language | `annotation.language.language_instruction` | text |
17
+
18
+ Action representations:
19
+ - `eef_9d`: relative end-effector (XYZ + rotation 6D)
20
+ - `gripper_position`: absolute (1D)
21
+ - `joint_position`: relative joint positions (7D)
22
+
23
+ ### Preparing DROID Demo Data
24
+
25
+ The full DROID dataset ([lerobot/droid_1.0.1](https://huggingface.co/datasets/lerobot/droid_1.0.1)) is ~358 GB with 95k+ episodes in LeRobot v3.0 format. To create a small sample for testing:
26
+
27
+ ```bash
28
+ uv pip install jsonlines # one-time dependency
29
+ python scripts/download_droid_sample.py
30
+ ```
31
+
32
+ This downloads the first data/video chunk (~170 MB) and extracts 3 episodes into `demo_data/droid_sample/` in GR00T LeRobot v2.0 format.
33
+
34
+ **Key conversion notes:**
35
+ - Source is LeRobot v3.0 (consolidated parquet + concatenated videos) — the script converts to v2.0 (per-episode parquet + per-episode mp4).
36
+ - Video keys in the raw dataset (`exterior_1_left`, `wrist_left`) differ from the model config keys (`exterior_image_1_left`, `wrist_image_left`). The data loader auto-maps by position — no manual renaming needed.
37
+ - Language instructions are loaded via the `task_index` column mapped through `tasks.jsonl`.
38
+
39
+ ## 1. Standalone Inference (with demo data)
40
+
41
+ After preparing demo data, run inference directly (no server needed):
42
+
43
+ ```bash
44
+ uv run python scripts/deployment/standalone_inference_script.py \
45
+ --model-path nvidia/GR00T-N1.7-3B \
46
+ --dataset-path demo_data/droid_sample \
47
+ --embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT \
48
+ --traj-ids 0 1 \
49
+ --inference-mode pytorch \
50
+ --action-horizon 8
51
+ ```
52
+
53
+ > **Note:** Episode 0 may have an empty language instruction. If inference fails on episode 0, try `--traj-ids 1 2`.
54
+
55
+ Expected zero-shot performance on the base model (not finetuned):
56
+
57
+ | Metric | Value |
58
+ |--------|-------|
59
+ | Average MSE | ~0.0149 |
60
+ | Average MAE | ~0.0753 |
61
+ | Inference per step (base) | ~262 ms (H100) |
62
+ | Inference per step (finetuned) | ~253 ms (H100) |
63
+
64
+ ## 2. Inference Server (for real-world deployment)
65
+
66
+ ### Using the base model (zero-shot):
67
+
68
+ ```bash
69
+ uv run python gr00t/eval/run_gr00t_server.py \
70
+ --model-path nvidia/GR00T-N1.7-3B \
71
+ --embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT
72
+ ```
73
+
74
+ ### Using the finetuned model:
75
+
76
+ ```bash
77
+ uv run python gr00t/eval/run_gr00t_server.py \
78
+ --model-path nvidia/GR00T-N1.7-DROID \
79
+ --embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT
80
+ ```
81
+
82
+ ## 3. Fine-tuning
83
+
84
+ Fine-tune the base model on DROID data using the shared launcher:
85
+
86
+ ```bash
87
+ NUM_GPUS=8 MAX_STEPS=20000 GLOBAL_BATCH_SIZE=640 SAVE_STEPS=1000 uv run bash examples/finetune.sh \
88
+ --base-model-path nvidia/GR00T-N1.7-3B \
89
+ --dataset-path demo_data/droid_sample \
90
+ --embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT \
91
+ --output-dir /tmp/droid_finetune
92
+ ```
93
+
94
+ > **Note:** The above uses the small `demo_data/droid_sample` (3 episodes) for quick validation. For production training, replace `--dataset-path` with the full DROID dataset.
95
+
96
+ ## 4. Robot Control Script
97
+
98
+ 1. Install the DROID package on the robot control laptop/workstation — [instructions](https://droid-dataset.github.io/droid/software-setup/host-installation.html#configuring-the-laptopworkstation)
99
+
100
+ 2. Install dependencies for the GR00T control script in the environment from step 1:
101
+ ```bash
102
+ pip install tyro moviepy==1.0.3 pydantic numpy==1.26.4
103
+ ```
104
+
105
+ 3. Enter the camera IDs for your ZED cameras in `examples/DROID/main_gr00t.py`.
106
+
107
+ 4. Start the control script:
108
+ ```bash
109
+ python examples/DROID/main_gr00t.py --external-camera="left" # or "right"
110
+ ```
examples/DROID/main_gr00t.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # ruff: noqa
17
+ # NOTE: this requires installation of the droid repo.
18
+ # Adapted from https://github.com/Physical-Intelligence/openpi/blob/main/examples/droid/main.py
19
+
20
+ from __future__ import annotations
21
+
22
+ import contextlib
23
+ import dataclasses
24
+ import datetime
25
+ import faulthandler
26
+ import os
27
+ import signal
28
+ import time
29
+ from collections import deque
30
+
31
+ import cv2
32
+ import imageio
33
+ import numpy as np
34
+ import pandas as pd
35
+ import tqdm
36
+ import tyro
37
+ from moviepy.editor import ImageSequenceClip
38
+ from PIL import Image
39
+
40
+ from droid.robot_env import RobotEnv
41
+ from server_client import PolicyClient
42
+ from utils import resize_with_pad
43
+ from scipy.spatial.transform import Rotation
44
+
45
+ faulthandler.enable()
46
+
47
+ # DROID data collection frequency -- we slow down execution to match this frequency
48
+ DROID_CONTROL_FREQUENCY = 15
49
+ RESOLUTION = (180, 320) # resize images to this resolution before sending to the policy server
50
+
51
+ # Egocentric frame correction: R_euler is post-multiplied by this matrix
52
+ # to match the OXE DROID training pipeline (TFG convention).
53
+ DROID_EEF_ROTATION_CORRECT = np.array(
54
+ [[0, 0, -1], [-1, 0, 0], [0, 1, 0]],
55
+ dtype=np.float64,
56
+ )
57
+
58
+
59
+ def compute_eef_9d(cartesian_position: np.ndarray) -> np.ndarray:
60
+ """Convert cartesian_position (XYZ + euler 3D) to eef_9d (XYZ + rot6d).
61
+
62
+ Uses extrinsic XYZ Euler convention (scipy ``"XYZ"``, equivalent to
63
+ ``tfg.rotation_matrix_3d.from_euler``) and post-multiplies by
64
+ ``DROID_EEF_ROTATION_CORRECT`` to match the pretrained model.
65
+ """
66
+ c = np.asarray(cartesian_position, dtype=np.float64).reshape(6)
67
+ xyz = c[:3]
68
+ euler = c[3:6]
69
+ rot_robot = Rotation.from_euler("XYZ", euler).as_matrix()
70
+ rot_mat = rot_robot @ DROID_EEF_ROTATION_CORRECT
71
+ rot6d = rot_mat[:2, :].reshape(6)
72
+ return np.concatenate([xyz, rot6d]).astype(np.float32)
73
+
74
+
75
+ @dataclasses.dataclass
76
+ class Args:
77
+ # Hardware parameters
78
+
79
+ left_camera_id: str = "<SET THIS>" # e.g., "24259877"
80
+ right_camera_id: str = "<SET THIS>" # e.g., "24514023"
81
+ wrist_camera_id: str = "<SET THIS>" # e.g., "13062452"
82
+
83
+ # Policy parameters
84
+ policy_host: str = "localhost"
85
+ policy_port: int = 5555
86
+ policy_api_token: str = None
87
+
88
+ results_dir: str = None # if None, will use the current timestamp as the results directory
89
+
90
+ # Rollout parameters
91
+ max_timesteps: int = 600 # how many steps to run each rollout
92
+
93
+ # How many actions to execute from a predicted action chunk before querying policy server again
94
+ open_loop_horizon: int = 15
95
+ external_camera: str = (
96
+ "left" # which exterior camera to use for the policy server, choose from ["left", "right"]
97
+ )
98
+ render_camera: str = "left" # which camera to render saved video from
99
+ render_fps: int = 50
100
+
101
+ debug: bool = False
102
+ vis_cameras: bool = False
103
+
104
+ delay_seconds: int = 5
105
+
106
+
107
+ # We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
108
+ # waiting for a new action chunk, it will raise an exception and the server connection dies.
109
+ # This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
110
+ @contextlib.contextmanager
111
+ def prevent_keyboard_interrupt():
112
+ """Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
113
+ interrupted = False
114
+ original_handler = signal.getsignal(signal.SIGINT)
115
+
116
+ def handler(signum, frame):
117
+ nonlocal interrupted
118
+ interrupted = True
119
+
120
+ signal.signal(signal.SIGINT, handler)
121
+ try:
122
+ yield
123
+ finally:
124
+ signal.signal(signal.SIGINT, original_handler)
125
+ if interrupted:
126
+ raise KeyboardInterrupt
127
+
128
+
129
+ def main(args: Args):
130
+ assert args.external_camera in ["left", "right"], (
131
+ f"Invalid exterior camera: {args.exterior_camera}"
132
+ )
133
+
134
+ if args.results_dir is None:
135
+ results_dir = f"results_gr00t_{datetime.datetime.now().strftime('%Y_%m_%d')}"
136
+ else:
137
+ results_dir = args.results_dir
138
+
139
+ # Initialize the Panda environment.
140
+ env = RobotEnv(action_space="joint_position", gripper_action_space="position")
141
+ print("Created the droid env!")
142
+
143
+ os.makedirs(results_dir, exist_ok=True)
144
+
145
+ policy_client = PolicyClient(
146
+ host=args.policy_host, port=args.policy_port, api_token=args.policy_api_token
147
+ )
148
+
149
+ modality_config = policy_client.get_modality_config()
150
+ video_delta = modality_config["video"].delta_indices
151
+ video_T = len(video_delta)
152
+ video_history_len = max(-min(video_delta), 0) + 1 if video_delta else 1
153
+ video_keys = modality_config["video"].modality_keys
154
+ state_keys = modality_config["state"].modality_keys
155
+ state_T = len(modality_config["state"].delta_indices)
156
+ print(
157
+ f"Model config — video T={video_T} (delta={video_delta}), "
158
+ f"state T={state_T}, keys: video={video_keys}, state={state_keys}"
159
+ )
160
+
161
+ df = pd.DataFrame(columns=["success", "duration", "video_filename"])
162
+
163
+ if args.debug:
164
+ debug_dir = os.path.join(results_dir, "debug_data")
165
+ os.makedirs(debug_dir, exist_ok=True)
166
+ os.makedirs(os.path.join(debug_dir, "videos/wrist_image/"), exist_ok=True)
167
+ os.makedirs(os.path.join(debug_dir, "videos/exterior_image_1_left/"), exist_ok=True)
168
+
169
+ instruction = None
170
+ while True:
171
+ if instruction is None:
172
+ instruction = input("Enter instruction: ")
173
+ else:
174
+ if input("Change instruction? (enter y or n) ").lower() == "y":
175
+ instruction = input("Enter instruction: ")
176
+
177
+ time.sleep(args.delay_seconds)
178
+
179
+ # Rollout parameters
180
+ actions_from_chunk_completed = 0
181
+ pred_action_chunk = None
182
+
183
+ # Prepare to save video of rollout
184
+ timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
185
+ video = []
186
+ if args.debug:
187
+ model_wrist_image_writer = imageio.get_writer(
188
+ os.path.join(
189
+ debug_dir, "videos/wrist_image/", f"model_wrist_image_{timestamp}.mp4"
190
+ ),
191
+ fps=5,
192
+ )
193
+ model_exterior_image_1_left_writer = imageio.get_writer(
194
+ os.path.join(
195
+ debug_dir,
196
+ "videos/exterior_image_1_left/",
197
+ f"model_exterior_image_1_left_{timestamp}.mp4",
198
+ ),
199
+ fps=5,
200
+ )
201
+
202
+ bar = tqdm.tqdm(range(args.max_timesteps))
203
+ print("Running rollout... press Ctrl+C to stop early.")
204
+
205
+ # Profiling variables (reset for each rollout)
206
+ rollout_start_time = time.time()
207
+ obs_times = deque(maxlen=50) # Track observation collection times
208
+ server_times = deque(maxlen=50) # Track server response times
209
+ action_count = 0
210
+ frame_buffer = deque(maxlen=video_history_len)
211
+
212
+ for t_step in bar:
213
+ step_start_time = time.time()
214
+ try:
215
+ # Get the current observation
216
+ obs_start_time = time.time()
217
+ curr_obs = _extract_observation(
218
+ args,
219
+ env.get_observation(),
220
+ # Save the first observation to disk
221
+ save_to_disk=t_step == 0,
222
+ )
223
+ obs_time = time.time() - obs_start_time
224
+ obs_times.append(obs_time)
225
+
226
+ video.append(curr_obs[f"{args.render_camera}_image"])
227
+
228
+ # Resize every step so the rolling frame buffer stays current.
229
+ left_image = resize_with_pad(curr_obs["left_image"], RESOLUTION[0], RESOLUTION[1])
230
+ right_image = resize_with_pad(curr_obs["right_image"], RESOLUTION[0], RESOLUTION[1])
231
+ wrist_image = resize_with_pad(curr_obs["wrist_image"], RESOLUTION[0], RESOLUTION[1])
232
+
233
+ if args.external_camera == "left":
234
+ ext_image = left_image
235
+ elif args.external_camera == "right":
236
+ ext_image = right_image
237
+
238
+ frame_buffer.append({"ext": ext_image, "wrist": wrist_image})
239
+
240
+ # Send websocket request to policy server if it's time to predict a new chunk
241
+ if (
242
+ actions_from_chunk_completed == 0
243
+ or actions_from_chunk_completed >= args.open_loop_horizon
244
+ ):
245
+ actions_from_chunk_completed = 0
246
+
247
+ if args.debug:
248
+ model_wrist_image_writer.append_data(wrist_image)
249
+ model_exterior_image_1_left_writer.append_data(ext_image)
250
+
251
+ # Build video tensor with T frames derived from the model's
252
+ # delta_indices (e.g. [-15, 0] -> T=2, [0] -> T=1).
253
+ if video_T == 1:
254
+ video_dict = {
255
+ "exterior_image_1_left": ext_image[None, None, ...],
256
+ "wrist_image_left": wrist_image[None, None, ...],
257
+ } # (B=1, T=1, H, W, C)
258
+ else:
259
+ hist_frame = frame_buffer[0]
260
+ cur_frame = frame_buffer[-1]
261
+ video_dict = {
262
+ "exterior_image_1_left": np.stack(
263
+ [hist_frame["ext"], cur_frame["ext"]]
264
+ )[None, ...],
265
+ "wrist_image_left": np.stack([hist_frame["wrist"], cur_frame["wrist"]])[
266
+ None, ...
267
+ ],
268
+ } # (B=1, T=video_T, H, W, C)
269
+
270
+ # Build state dict from the model's reported state keys.
271
+ state_dict = {}
272
+ state_source = {
273
+ "eef_9d": curr_obs["eef_9d"],
274
+ "gripper_position": curr_obs["gripper_position"],
275
+ "joint_position": curr_obs["joint_position"],
276
+ }
277
+ for key in state_keys:
278
+ state_dict[key] = state_source[key][None, None, ...].astype(
279
+ np.float32
280
+ ) # (B=1, T=1, D)
281
+
282
+ lang_key = modality_config["language"].modality_keys[0]
283
+ request_data = {
284
+ "video": video_dict,
285
+ "state": state_dict,
286
+ "language": {lang_key: [[instruction]]},
287
+ }
288
+
289
+ if args.vis_cameras:
290
+ # viz the left image 1 and wrist image and use cv2 to display them side by side
291
+ left_image_display = cv2.resize(
292
+ left_image, (wrist_image.shape[1], wrist_image.shape[0])
293
+ )
294
+ combined_display = np.concatenate([left_image_display, wrist_image], axis=1)
295
+ # convert to bgr
296
+ combined_display = combined_display[..., ::-1]
297
+ cv2.imshow("Camera Views", combined_display)
298
+ cv2.waitKey(1)
299
+
300
+ # Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
301
+ # Ctrl+C will be handled after the server call is complete
302
+ server_start_time = time.time()
303
+ with prevent_keyboard_interrupt():
304
+ # this returns action chunk [N, 8] of joint position actions (7) + gripper position (1)
305
+ response = policy_client.get_action(request_data)
306
+ server_time = time.time() - server_start_time
307
+ server_times.append(server_time)
308
+
309
+ pred_action_chunk = np.concatenate(
310
+ (
311
+ response[0]["joint_position"][0],
312
+ response[0]["gripper_position"][0],
313
+ ),
314
+ axis=1,
315
+ )
316
+
317
+ # Select current action to execute from chunk
318
+ action = pred_action_chunk[actions_from_chunk_completed]
319
+ actions_from_chunk_completed += 1
320
+
321
+ # Binarize gripper action
322
+ if action[-1].item() > 0.5:
323
+ action = np.concatenate([action[:-1], np.ones((1,))])
324
+ else:
325
+ action = np.concatenate([action[:-1], np.zeros((1,))])
326
+
327
+ env.step(action)
328
+ action_count += 1
329
+
330
+ # Sleep to match DROID data collection frequency
331
+ elapsed_time = time.time() - step_start_time
332
+ if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
333
+ time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
334
+
335
+ # profiling stats
336
+ if obs_times:
337
+ avg_obs_time = np.mean(obs_times) * 1000
338
+ min_obs_time = np.min(obs_times) * 1000
339
+ max_obs_time = np.max(obs_times) * 1000
340
+ else:
341
+ avg_obs_time = min_obs_time = max_obs_time = 0
342
+
343
+ if server_times:
344
+ avg_server_time = np.mean(server_times) * 1000
345
+ min_server_time = np.min(server_times) * 1000
346
+ max_server_time = np.max(server_times) * 1000
347
+ else:
348
+ avg_server_time = min_server_time = max_server_time = 0
349
+
350
+ total_elapsed = time.time() - rollout_start_time
351
+ actions_per_sec = action_count / total_elapsed if total_elapsed > 0 else 0
352
+
353
+ bar.set_description(
354
+ f"Obs: {avg_obs_time:.1f}ms [{min_obs_time:.1f}-{max_obs_time:.1f}] | "
355
+ f"Server: {avg_server_time:.1f}ms [{min_server_time:.1f}-{max_server_time:.1f}] | "
356
+ f"Actions/sec: {actions_per_sec:.2f}"
357
+ )
358
+ except KeyboardInterrupt:
359
+ break
360
+
361
+ os.makedirs(os.path.join(results_dir, "videos"), exist_ok=True)
362
+ video = np.stack(video)
363
+ # replace whitespace with underscores in instruction
364
+ sanitized_instruction = instruction.replace(" ", "_")
365
+ save_filename = os.path.join(
366
+ results_dir, "videos", f"{sanitized_instruction}_video_" + timestamp
367
+ )
368
+ ImageSequenceClip(list(video), fps=args.render_fps).write_videofile(
369
+ save_filename + ".mp4", codec="libx264"
370
+ )
371
+
372
+ if args.debug:
373
+ model_wrist_image_writer.close()
374
+ model_exterior_image_1_left_writer.close()
375
+
376
+ success: str | float | None = None
377
+ while not isinstance(success, float):
378
+ success = input(
379
+ "Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
380
+ )
381
+ if success == "y":
382
+ success = 1.0
383
+ elif success == "n":
384
+ success = 0.0
385
+
386
+ success = float(success) / 100
387
+ if not (0 <= success <= 1):
388
+ print(f"Success must be a number in [0, 100] but got: {success * 100}")
389
+
390
+ new_row = {
391
+ "success": success,
392
+ "duration": t_step,
393
+ "video_filename": save_filename,
394
+ }
395
+ new_index = len(df)
396
+ df.loc[new_index] = new_row
397
+
398
+ if input("Do one more eval? (enter y or n) ").lower() != "y":
399
+ break
400
+ env.reset(randomize=False)
401
+
402
+ timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
403
+ csv_filename = os.path.join(results_dir, f"eval_{timestamp}.csv")
404
+ df.to_csv(csv_filename)
405
+ print(f"Results saved to {csv_filename}")
406
+
407
+
408
+ def _extract_observation(args: Args, obs_dict, *, stereo_camera="left", save_to_disk=False):
409
+ image_observations = obs_dict["image"]
410
+ key_left = f"{args.left_camera_id}_{stereo_camera}"
411
+ key_right = f"{args.right_camera_id}_{stereo_camera}"
412
+ key_wrist = f"{args.wrist_camera_id}_{stereo_camera}"
413
+
414
+ left_image = image_observations.get(key_left)
415
+ right_image = image_observations.get(key_right)
416
+ wrist_image = image_observations.get(key_wrist)
417
+
418
+ available = list(image_observations.keys())
419
+ assert left_image is not None, (
420
+ f"Left camera not found for key {key_left!r}. Available keys: {available}. "
421
+ "Set --left-camera-id to the ZED serial used in observation keys."
422
+ )
423
+ assert right_image is not None, (
424
+ f"Right camera not found for key {key_right!r}. Available keys: {available}. "
425
+ "Set --right-camera-id to the ZED serial used in observation keys."
426
+ )
427
+ assert wrist_image is not None, (
428
+ f"Wrist camera not found for key {key_wrist!r}. Available keys: {available}. "
429
+ "Set --wrist-camera-id to the ZED serial used in observation keys."
430
+ )
431
+
432
+ # Drop the alpha dimension
433
+ left_image = left_image[..., :3]
434
+ right_image = right_image[..., :3]
435
+ wrist_image = wrist_image[..., :3]
436
+
437
+ # Convert to RGB
438
+ left_image = left_image[..., ::-1]
439
+ right_image = right_image[..., ::-1]
440
+ wrist_image = wrist_image[..., ::-1]
441
+
442
+ # In addition to image observations, also capture the proprioceptive state
443
+ robot_state = obs_dict["robot_state"]
444
+ cartesian_position = np.array(robot_state["cartesian_position"])
445
+ joint_position = np.array(robot_state["joint_positions"])
446
+ gripper_position = np.array([robot_state["gripper_position"]])
447
+ eef_9d = compute_eef_9d(cartesian_position)
448
+
449
+ # Save the images to disk so that they can be viewed live while the robot is running
450
+ # Create one combined image to make live viewing easy
451
+ if save_to_disk:
452
+ combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
453
+ combined_image = Image.fromarray(combined_image)
454
+ combined_image.save("robot_camera_views.png")
455
+
456
+ return {
457
+ "left_image": left_image,
458
+ "right_image": right_image,
459
+ "wrist_image": wrist_image,
460
+ "cartesian_position": cartesian_position,
461
+ "eef_9d": eef_9d,
462
+ "joint_position": joint_position,
463
+ "gripper_position": gripper_position,
464
+ }
465
+
466
+
467
+ if __name__ == "__main__":
468
+ args: Args = tyro.cli(Args)
469
+ main(args)
examples/DROID/server_client.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ from abc import ABC, abstractmethod
19
+ from dataclasses import asdict, dataclass, is_dataclass
20
+ from enum import Enum
21
+ import io
22
+ from typing import Any
23
+
24
+ import msgpack
25
+ import numpy as np
26
+ import zmq
27
+
28
+
29
+ def to_json_serializable(obj: Any) -> Any:
30
+ """
31
+ Recursively convert dataclasses and numpy arrays to JSON-serializable format.
32
+
33
+ Args:
34
+ obj: Object to convert (can be dataclass, numpy array, dict, list, etc.)
35
+
36
+ Returns:
37
+ JSON-serializable representation of the object
38
+ """
39
+ if is_dataclass(obj) and not isinstance(obj, type):
40
+ # Convert dataclass to dict, then recursively process the dict
41
+ return to_json_serializable(asdict(obj))
42
+ elif isinstance(obj, np.ndarray):
43
+ # Convert numpy array to list
44
+ return obj.tolist()
45
+ elif isinstance(obj, np.integer):
46
+ # Convert numpy integers to Python int
47
+ return int(obj)
48
+ elif isinstance(obj, np.floating):
49
+ # Convert numpy floats to Python float
50
+ return float(obj)
51
+ elif isinstance(obj, np.bool_):
52
+ # Convert numpy bool to Python bool
53
+ return bool(obj)
54
+ elif isinstance(obj, dict):
55
+ # Recursively process dictionary values
56
+ return {key: to_json_serializable(value) for key, value in obj.items()}
57
+ elif isinstance(obj, (list, tuple)):
58
+ # Recursively process list/tuple elements
59
+ return [to_json_serializable(item) for item in obj]
60
+ elif isinstance(obj, set):
61
+ # Convert set to list
62
+ return [to_json_serializable(item) for item in obj]
63
+ elif isinstance(obj, (str, int, float, bool, type(None))):
64
+ # Already JSON-serializable
65
+ return obj
66
+ elif isinstance(obj, Enum):
67
+ return obj.name
68
+ else:
69
+ # For other types, try to convert to string as fallback
70
+ # You might want to handle specific types differently
71
+ return str(obj)
72
+
73
+
74
+ class MessageType(Enum):
75
+ START_OF_EPISODE = "start_of_episode"
76
+ END_OF_EPISODE = "end_of_episode"
77
+ EPISODE_STEP = "episode_step"
78
+ IMAGE = "image"
79
+ TEXT = "text"
80
+
81
+
82
+ class ActionRepresentation(Enum):
83
+ RELATIVE = "relative"
84
+ DELTA = "delta"
85
+ ABSOLUTE = "absolute"
86
+
87
+
88
+ class ActionType(Enum):
89
+ EEF = "eef"
90
+ NON_EEF = "non_eef"
91
+
92
+
93
+ class ActionFormat(Enum):
94
+ DEFAULT = "default"
95
+ XYZ_ROT6D = "xyz+rot6d"
96
+ XYZ_ROTVEC = "xyz+rotvec"
97
+
98
+
99
+ @dataclass
100
+ class ActionConfig:
101
+ rep: ActionRepresentation
102
+ type: ActionType
103
+ format: ActionFormat
104
+ state_key: str | None = None
105
+
106
+
107
+ @dataclass
108
+ class ModalityConfig:
109
+ """Configuration for a modality defining how data should be sampled and loaded.
110
+
111
+ This class specifies which indices to sample relative to a base index and which
112
+ keys to load for a particular modality (e.g., video, state, action).
113
+ """
114
+
115
+ delta_indices: list[int]
116
+ """Delta indices to sample relative to the current index. The returned data will correspond to the original data at a sampled base index + delta indices."""
117
+ modality_keys: list[str]
118
+ """The keys to load for the modality in the dataset."""
119
+ sin_cos_embedding_keys: list[str] | None = None
120
+ """Optional list of keys to apply sin/cos encoding. If None or empty, use min/max normalization for all keys."""
121
+ mean_std_embedding_keys: list[str] | None = None
122
+ """Optional list of keys to apply mean/std normalization. If None or empty, use min/max normalization for all keys."""
123
+ action_configs: list[ActionConfig] | None = None
124
+
125
+ def __post_init__(self):
126
+ """Set default values for action-related fields if not specified."""
127
+ if self.action_configs is not None:
128
+ assert len(self.action_configs) == len(self.modality_keys), (
129
+ f"Number of action configs ({len(self.action_configs)}) must match number of modality keys ({len(self.modality_keys)})"
130
+ )
131
+ parsed_action_configs = []
132
+ for action_config in self.action_configs:
133
+ if isinstance(action_config, dict):
134
+ action_config = ActionConfig(
135
+ rep=ActionRepresentation[action_config["rep"]],
136
+ type=ActionType[action_config["type"]],
137
+ format=ActionFormat[action_config["format"]],
138
+ state_key=action_config.get("state_key", None),
139
+ )
140
+ parsed_action_configs.append(action_config)
141
+ self.action_configs = parsed_action_configs
142
+
143
+
144
+ class MsgSerializer:
145
+ @staticmethod
146
+ def to_bytes(data: Any) -> bytes:
147
+ return msgpack.packb(data, default=MsgSerializer.encode_custom_classes)
148
+
149
+ @staticmethod
150
+ def from_bytes(data: bytes) -> Any:
151
+ return msgpack.unpackb(data, object_hook=MsgSerializer.decode_custom_classes)
152
+
153
+ @staticmethod
154
+ def decode_custom_classes(obj):
155
+ if not isinstance(obj, dict):
156
+ return obj
157
+ if "__ModalityConfig_class__" in obj:
158
+ return ModalityConfig(**obj["as_json"])
159
+ if "__ndarray_class__" in obj:
160
+ return np.load(io.BytesIO(obj["as_npy"]), allow_pickle=False)
161
+ return obj
162
+
163
+ @staticmethod
164
+ def encode_custom_classes(obj):
165
+ if isinstance(obj, ModalityConfig):
166
+ # Convert to dict and let msgpack recursively handle nested objects
167
+ return {"__ModalityConfig_class__": True, "as_json": to_json_serializable(obj)}
168
+ if isinstance(obj, np.ndarray):
169
+ output = io.BytesIO()
170
+ np.save(output, obj, allow_pickle=False)
171
+ return {"__ndarray_class__": True, "as_npy": output.getvalue()}
172
+ return obj
173
+
174
+
175
+ class BasePolicy(ABC):
176
+ """Abstract base class for robotic control policies.
177
+
178
+ This class defines the interface that all policies must implement, including
179
+ methods for action computation, input/output validation, and state management.
180
+
181
+ Subclasses must implement:
182
+ - check_observation(): Validate observation format
183
+ - check_action(): Validate action format
184
+ - _get_action(): Core action computation logic
185
+ - reset(): Reset policy to initial state
186
+ """
187
+
188
+ def __init__(self, *, strict: bool = True):
189
+ self.strict = strict
190
+
191
+ @abstractmethod
192
+ def check_observation(self, observation: dict[str, Any]) -> None:
193
+ """Check if the observation is valid.
194
+
195
+ Args:
196
+ observation: Dictionary containing the current state/observation of the environment
197
+
198
+ Raises:
199
+ AssertionError: If the observation is invalid.
200
+ """
201
+ pass
202
+
203
+ @abstractmethod
204
+ def check_action(self, action: dict[str, Any]) -> None:
205
+ """Check if the action is valid.
206
+
207
+ Args:
208
+ action: Dictionary containing the action to be executed
209
+
210
+ Raises:
211
+ AssertionError: If the action is invalid.
212
+ """
213
+ pass
214
+
215
+ @abstractmethod
216
+ def _get_action(
217
+ self, observation: dict[str, Any], options: dict[str, Any] | None = None
218
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
219
+ """Compute and return the next action based on current observation.
220
+
221
+ This method should be overridden by subclasses to implement policy-specific
222
+ action computation. Input validation is handled by the public get_action() method.
223
+
224
+ Args:
225
+ observation: Dictionary containing the current state/observation
226
+ options: Optional configuration dict for action computation
227
+
228
+ Returns:
229
+ Tuple of (action, info):
230
+ - action: Dictionary containing the action to be executed
231
+ - info: Dictionary containing additional metadata (e.g., confidence scores)
232
+ """
233
+ pass
234
+
235
+ def get_action(
236
+ self, observation: dict[str, Any], options: dict[str, Any] | None = None
237
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
238
+ """Compute and return the next action based on current observation with validation.
239
+
240
+ This is the main public interface. It validates the observation, calls
241
+ the internal _get_action(), and validates the resulting action.
242
+
243
+ Args:
244
+ observation: Dictionary containing the current state/observation
245
+ options: Optional configuration dict for action computation
246
+
247
+ Returns:
248
+ Tuple of (action, info):
249
+ - action: Dictionary containing the validated action
250
+ - info: Dictionary containing additional metadata
251
+
252
+ Raises:
253
+ AssertionError/ValueError: If observation or action validation fails
254
+ """
255
+ if self.strict:
256
+ self.check_observation(observation)
257
+ action, info = self._get_action(observation, options)
258
+ if self.strict:
259
+ self.check_action(action)
260
+ return action, info
261
+
262
+ @abstractmethod
263
+ def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]:
264
+ """Reset the policy to its initial state.
265
+
266
+ Args:
267
+ options: Dictionary containing the options for the reset
268
+
269
+ Returns:
270
+ Dictionary containing the info after resetting the policy
271
+ """
272
+ pass
273
+
274
+
275
+ class PolicyClient(BasePolicy):
276
+ def __init__(
277
+ self,
278
+ host: str = "localhost",
279
+ port: int = 5555,
280
+ timeout_ms: int = 15000,
281
+ api_token: str = None,
282
+ strict: bool = False,
283
+ ):
284
+ super().__init__(strict=strict)
285
+ self.context = zmq.Context()
286
+ self.host = host
287
+ self.port = port
288
+ self.timeout_ms = timeout_ms
289
+ self.api_token = api_token
290
+ self._init_socket()
291
+
292
+ def _init_socket(self):
293
+ """Initialize or reinitialize the socket with current settings"""
294
+ self.socket = self.context.socket(zmq.REQ)
295
+ self.socket.connect(f"tcp://{self.host}:{self.port}")
296
+
297
+ def ping(self) -> bool:
298
+ try:
299
+ self.call_endpoint("ping", requires_input=False)
300
+ return True
301
+ except zmq.error.ZMQError:
302
+ self._init_socket() # Recreate socket for next attempt
303
+ return False
304
+
305
+ def kill_server(self):
306
+ """
307
+ Kill the server.
308
+ """
309
+ self.call_endpoint("kill", requires_input=False)
310
+
311
+ def call_endpoint(
312
+ self, endpoint: str, data: dict | None = None, requires_input: bool = True
313
+ ) -> Any:
314
+ """
315
+ Call an endpoint on the server.
316
+
317
+ Args:
318
+ endpoint: The name of the endpoint.
319
+ data: The input data for the endpoint.
320
+ requires_input: Whether the endpoint requires input data.
321
+ """
322
+ request: dict = {"endpoint": endpoint}
323
+ if requires_input:
324
+ request["data"] = data
325
+ if self.api_token:
326
+ request["api_token"] = self.api_token
327
+
328
+ self.socket.send(MsgSerializer.to_bytes(request))
329
+ message = self.socket.recv()
330
+ if message == b"ERROR":
331
+ raise RuntimeError("Server error. Make sure we are running the correct policy server.")
332
+ response = MsgSerializer.from_bytes(message)
333
+
334
+ if isinstance(response, dict) and "error" in response:
335
+ raise RuntimeError(f"Server error: {response['error']}")
336
+ return response
337
+
338
+ def __del__(self):
339
+ """Cleanup resources on destruction"""
340
+ self.socket.close()
341
+ self.context.term()
342
+
343
+ def _get_action(
344
+ self, observation: dict[str, Any], options: dict[str, Any] | None = None
345
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
346
+ response = self.call_endpoint(
347
+ "get_action", {"observation": observation, "options": options}
348
+ )
349
+ return tuple(response) # Convert list (from msgpack) to tuple of (action, info)
350
+
351
+ def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]:
352
+ return self.call_endpoint("reset", {"options": options})
353
+
354
+ def get_modality_config(self) -> dict[str, ModalityConfig]:
355
+ return self.call_endpoint("get_modality_config", requires_input=False)
356
+
357
+ def check_observation(self, observation: dict[str, Any]) -> None:
358
+ raise NotImplementedError(
359
+ "check_observation is not implemented. Please use `strict=False` to disable strict mode or implement this method in the subclass."
360
+ )
361
+
362
+ def check_action(self, action: dict[str, Any]) -> None:
363
+ raise NotImplementedError(
364
+ "check_action is not implemented. Please use `strict=False` to disable strict mode or implement this method in the subclass."
365
+ )
examples/DROID/utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Taken from https://github.com/Physical-Intelligence/openpi/tree/main/packages/openpi-client/src/openpi_client
18
+ """
19
+
20
+ import numpy as np
21
+ from PIL import Image
22
+
23
+
24
+ def convert_to_uint8(img: np.ndarray) -> np.ndarray:
25
+ """Converts an image to uint8 if it is a float image.
26
+
27
+ This is important for reducing the size of the image when sending it over the network.
28
+ """
29
+ if np.issubdtype(img.dtype, np.floating):
30
+ img = (255 * img).astype(np.uint8)
31
+ return img
32
+
33
+
34
+ def resize_with_pad(
35
+ images: np.ndarray, height: int, width: int, method=Image.BILINEAR
36
+ ) -> np.ndarray:
37
+ """Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
38
+
39
+ Args:
40
+ images: A batch of images in [..., height, width, channel] format.
41
+ height: The target height of the image.
42
+ width: The target width of the image.
43
+ method: The interpolation method to use. Default is bilinear.
44
+
45
+ Returns:
46
+ The resized images in [..., height, width, channel].
47
+ """
48
+ # If the images are already the correct size, return them as is.
49
+ if images.shape[-3:-1] == (height, width):
50
+ return images
51
+
52
+ original_shape = images.shape
53
+
54
+ images = images.reshape(-1, *original_shape[-3:])
55
+ resized = np.stack(
56
+ [_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images]
57
+ )
58
+ return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
59
+
60
+
61
+ def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
62
+ """Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
63
+ width without distortion by padding with zeros.
64
+
65
+ Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
66
+ """
67
+ cur_width, cur_height = image.size
68
+ if cur_width == width and cur_height == height:
69
+ return image # No need to resize if the image is already the correct size.
70
+
71
+ ratio = max(cur_width / width, cur_height / height)
72
+ resized_height = int(cur_height / ratio)
73
+ resized_width = int(cur_width / ratio)
74
+ resized_image = image.resize((resized_width, resized_height), resample=method)
75
+
76
+ zero_image = Image.new(resized_image.mode, (width, height), 0)
77
+ pad_height = max(0, int((height - resized_height) / 2))
78
+ pad_width = max(0, int((width - resized_width) / 2))
79
+ zero_image.paste(resized_image, (pad_width, pad_height))
80
+ assert zero_image.size == (width, height)
81
+ return zero_image
examples/LIBERO/README.md ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LIBERO
2
+
3
+ Benchmark for studying knowledge transfer in lifelong robot learning. Includes multiple suites: **Spatial** (spatial reasoning), **Object** (object generalization), **Goal** (goal-conditioned learning), and **10 Long** (long-horizon multi-step tasks). Provides RGB images, proprioception data, and language task specifications.
4
+
5
+ For more information, see the [official website](https://libero-project.github.io/main.html).
6
+
7
+ ---
8
+
9
+ # LIBERO evaluation benchmark result
10
+
11
+ > **Note:** The full task list is attached at the end of this document.
12
+
13
+ All four suites were finetuned with the same hyper-parameters, including
14
+ `--state-dropout-prob 0.2` (the finetune CLI default from
15
+ `gr00t/configs/finetune_config.py`).
16
+
17
+ | Task | Success rate | max_steps | grad_accum_steps | batch_size |
18
+ |-----------|--------------------|-----------|------------------|------------|
19
+ | Spatial | 195/200 (97.65%) | 20K | 1 | 640 |
20
+ | Goal | 195/200 (97.5%) | 20K | 1 | 640 |
21
+ | Object | 197/200 (98.45%) | 20K | 1 | 640 |
22
+ | 10 (Long) | 189/200 (94.35%) | 20K | 1 | 640 |
23
+
24
+ # Fine-tune LIBERO 10 (long)
25
+
26
+ To reproduce our finetune results, use the following commands to setup dataset and launch finetune experiments. Please remember to set `WANDB_API_KEY` since `--use-wandb` is turned on by default. If you don't have a WANDB account, please remove this argument:
27
+
28
+ ```bash
29
+ uv run hf download \
30
+ --repo-type dataset IPEC-COMMUNITY/libero_10_no_noops_1.0.0_lerobot \
31
+ --local-dir examples/LIBERO/libero_10_no_noops_1.0.0_lerobot/
32
+
33
+ # Copy the patches and run the finetune script
34
+ cp -r examples/LIBERO/modality.json examples/LIBERO/libero_10_no_noops_1.0.0_lerobot/meta/
35
+ ```
36
+
37
+ Run the shared finetune launcher:
38
+ ```bash
39
+ NUM_GPUS=8 MAX_STEPS=20000 GLOBAL_BATCH_SIZE=640 SAVE_STEPS=1000 uv run bash examples/finetune.sh \
40
+ --base-model-path nvidia/GR00T-N1.7-3B \
41
+ --dataset-path examples/LIBERO/libero_10_no_noops_1.0.0_lerobot/ \
42
+ --embodiment-tag LIBERO_PANDA \
43
+ --output-dir /tmp/libero_10 \
44
+ --state-dropout-prob 0.2
45
+ ```
46
+
47
+ # Fine-tune LIBERO goal
48
+
49
+ ```bash
50
+ uv run hf download \
51
+ --repo-type dataset IPEC-COMMUNITY/libero_goal_no_noops_1.0.0_lerobot \
52
+ --local-dir examples/LIBERO/libero_goal_no_noops_1.0.0_lerobot/
53
+
54
+ # Copy the patches and run the finetune script
55
+ cp -r examples/LIBERO/modality.json examples/LIBERO/libero_goal_no_noops_1.0.0_lerobot/meta/
56
+ ## This is a patch for one of the episode where the image seems to be corrupted.
57
+ cp examples/LIBERO/patches/episode_000082.mp4 examples/LIBERO/libero_goal_no_noops_1.0.0_lerobot/videos/chunk-000/observation.images.wrist_image/
58
+ ```
59
+
60
+ Run the shared finetune launcher:
61
+ ```bash
62
+ NUM_GPUS=8 MAX_STEPS=20000 GLOBAL_BATCH_SIZE=640 SAVE_STEPS=1000 uv run bash examples/finetune.sh \
63
+ --base-model-path nvidia/GR00T-N1.7-3B \
64
+ --dataset-path examples/LIBERO/libero_goal_no_noops_1.0.0_lerobot/ \
65
+ --embodiment-tag LIBERO_PANDA \
66
+ --output-dir /tmp/libero_goal
67
+ ```
68
+
69
+ # Fine-tune LIBERO object
70
+
71
+ ```bash
72
+ uv run hf download \
73
+ --repo-type dataset IPEC-COMMUNITY/libero_object_no_noops_1.0.0_lerobot \
74
+ --local-dir examples/LIBERO/libero_object_no_noops_1.0.0_lerobot/
75
+
76
+ # Copy the patches and run the finetune script
77
+ cp -r examples/LIBERO/modality.json examples/LIBERO/libero_object_no_noops_1.0.0_lerobot/meta/
78
+ ```
79
+
80
+ Run the shared finetune launcher:
81
+ ```bash
82
+ NUM_GPUS=8 MAX_STEPS=20000 GLOBAL_BATCH_SIZE=640 SAVE_STEPS=1000 uv run bash examples/finetune.sh \
83
+ --base-model-path nvidia/GR00T-N1.7-3B \
84
+ --dataset-path examples/LIBERO/libero_object_no_noops_1.0.0_lerobot/ \
85
+ --embodiment-tag LIBERO_PANDA \
86
+ --output-dir /tmp/libero_object
87
+ ```
88
+
89
+ # Fine-tune LIBERO spatial
90
+
91
+ ```bash
92
+ uv run hf download \
93
+ --repo-type dataset IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot \
94
+ --local-dir examples/LIBERO/libero_spatial_no_noops_1.0.0_lerobot/
95
+
96
+ # Copy the patches and run the finetune script
97
+ cp -r examples/LIBERO/modality.json examples/LIBERO/libero_spatial_no_noops_1.0.0_lerobot/meta/
98
+ ```
99
+
100
+ Run the shared finetune launcher:
101
+ ```bash
102
+ NUM_GPUS=8 MAX_STEPS=20000 GLOBAL_BATCH_SIZE=640 SAVE_STEPS=1000 uv run bash examples/finetune.sh \
103
+ --base-model-path nvidia/GR00T-N1.7-3B \
104
+ --dataset-path examples/LIBERO/libero_spatial_no_noops_1.0.0_lerobot/ \
105
+ --embodiment-tag LIBERO_PANDA \
106
+ --output-dir /tmp/libero_spatial
107
+ ```
108
+
109
+ # Evaluate checkpoint
110
+
111
+ First, setup the evaluation simulation environment. This only needs to run once for each simulation benchmark. After it's done, we only need to launch server and client.
112
+
113
+ ```bash
114
+ sudo apt update
115
+ sudo apt install libegl1-mesa-dev libglu1-mesa
116
+ bash gr00t/eval/sim/LIBERO/setup_libero.sh
117
+ ```
118
+
119
+ Then, download the finetuned model to a local directory (HuggingFace does not support nested repo paths directly):
120
+ ```bash
121
+ uv run hf download nvidia/GR00T-N1.7-LIBERO --include "libero_10/config.json" "libero_10/embodiment_id.json" "libero_10/model-*.safetensors" "libero_10/model.safetensors.index.json" "libero_10/processor_config.json" "libero_10/statistics.json" --local-dir checkpoints/GR00T-N1.7-LIBERO
122
+ ```
123
+
124
+ Run client server evaluation under the project root directory in separate terminals:
125
+
126
+ **Terminal 1 - Server:**
127
+ ```bash
128
+ uv run python gr00t/eval/run_gr00t_server.py \
129
+ --model-path checkpoints/GR00T-N1.7-LIBERO/libero_10 \
130
+ --embodiment-tag LIBERO_PANDA \
131
+ --use-sim-policy-wrapper
132
+ ```
133
+
134
+ > **Note:** Replace `checkpoints/GR00T-N1.7-LIBERO/libero_10` with your own checkpoint path (e.g., `/tmp/libero_10/checkpoint-20000/`) if evaluating a locally finetuned model.
135
+
136
+ **Terminal 2 - Client:**
137
+ ```bash
138
+ gr00t/eval/sim/LIBERO/libero_uv/.venv/bin/python gr00t/eval/rollout_policy.py \
139
+ --n-episodes 10 \
140
+ --policy-client-host 127.0.0.1 \
141
+ --policy-client-port 5555 \
142
+ --max-episode-steps 720 \
143
+ --env-name libero_sim/KITCHEN_SCENE3_turn_on_the_stove_and_put_the_moka_pot_on_it \
144
+ --n-action-steps 8 \
145
+ --n-envs 5
146
+ ```
147
+
148
+ # Full task list
149
+
150
+ ## Libero 10 (Long)
151
+ - `libero_sim/LIVING_ROOM_SCENE2_put_both_the_alphabet_soup_and_the_tomato_sauce_in_the_basket`
152
+ - `libero_sim/LIVING_ROOM_SCENE2_put_both_the_cream_cheese_box_and_the_butter_in_the_basket`
153
+ - `libero_sim/KITCHEN_SCENE3_turn_on_the_stove_and_put_the_moka_pot_on_it`
154
+ - `libero_sim/KITCHEN_SCENE4_put_the_black_bowl_in_the_bottom_drawer_of_the_cabinet_and_close_it`
155
+ - `libero_sim/LIVING_ROOM_SCENE5_put_the_white_mug_on_the_left_plate_and_put_the_yellow_and_white_mug_on_the_right_plate`
156
+ - `libero_sim/STUDY_SCENE1_pick_up_the_book_and_place_it_in_the_back_compartment_of_the_caddy`
157
+ - `libero_sim/LIVING_ROOM_SCENE6_put_the_white_mug_on_the_plate_and_put_the_chocolate_pudding_to_the_right_of_the_plate`
158
+ - `libero_sim/LIVING_ROOM_SCENE1_put_both_the_alphabet_soup_and_the_cream_cheese_box_in_the_basket`
159
+ - `libero_sim/KITCHEN_SCENE8_put_both_moka_pots_on_the_stove`
160
+ - `libero_sim/KITCHEN_SCENE6_put_the_yellow_and_white_mug_in_the_microwave_and_close_it`
161
+
162
+ ## Libero Goal
163
+ - `libero_sim/open_the_middle_drawer_of_the_cabinet`
164
+ - `libero_sim/put_the_bowl_on_the_stove`
165
+ - `libero_sim/put_the_wine_bottle_on_top_of_the_cabinet`
166
+ - `libero_sim/open_the_top_drawer_and_put_the_bowl_inside`
167
+ - `libero_sim/put_the_bowl_on_top_of_the_cabinet`
168
+ - `libero_sim/push_the_plate_to_the_front_of_the_stove`
169
+ - `libero_sim/put_the_cream_cheese_in_the_bowl`
170
+ - `libero_sim/turn_on_the_stove`
171
+ - `libero_sim/put_the_bowl_on_the_plate`
172
+ - `libero_sim/put_the_wine_bottle_on_the_rack`
173
+
174
+ ## Libero Object
175
+ - `libero_sim/pick_up_the_alphabet_soup_and_place_it_in_the_basket`
176
+ - `libero_sim/pick_up_the_cream_cheese_and_place_it_in_the_basket`
177
+ - `libero_sim/pick_up_the_salad_dressing_and_place_it_in_the_basket`
178
+ - `libero_sim/pick_up_the_bbq_sauce_and_place_it_in_the_basket`
179
+ - `libero_sim/pick_up_the_ketchup_and_place_it_in_the_basket`
180
+ - `libero_sim/pick_up_the_tomato_sauce_and_place_it_in_the_basket`
181
+ - `libero_sim/pick_up_the_butter_and_place_it_in_the_basket`
182
+ - `libero_sim/pick_up_the_milk_and_place_it_in_the_basket`
183
+ - `libero_sim/pick_up_the_chocolate_pudding_and_place_it_in_the_basket`
184
+ - `libero_sim/pick_up_the_orange_juice_and_place_it_in_the_basket`
185
+
186
+ ## Libero Spatial
187
+ - `libero_sim/pick_up_the_black_bowl_between_the_plate_and_the_ramekin_and_place_it_on_the_plate`
188
+ - `libero_sim/pick_up_the_black_bowl_next_to_the_ramekin_and_place_it_on_the_plate`
189
+ - `libero_sim/pick_up_the_black_bowl_from_table_center_and_place_it_on_the_plate`
190
+ - `libero_sim/pick_up_the_black_bowl_on_the_cookie_box_and_place_it_on_the_plate`
191
+ - `libero_sim/pick_up_the_black_bowl_in_the_top_drawer_of_the_wooden_cabinet_and_place_it_on_the_plate`
192
+ - `libero_sim/pick_up_the_black_bowl_on_the_ramekin_and_place_it_on_the_plate`
193
+ - `libero_sim/pick_up_the_black_bowl_next_to_the_cookie_box_and_place_it_on_the_plate`
194
+ - `libero_sim/pick_up_the_black_bowl_on_the_stove_and_place_it_on_the_plate`
195
+ - `libero_sim/pick_up_the_black_bowl_next_to_the_plate_and_place_it_on_the_plate`
196
+ - `libero_sim/pick_up_the_black_bowl_on_the_wooden_cabinet_and_place_it_on_the_plate`
examples/LIBERO/modality.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "state": {
3
+ "x": {
4
+ "start": 0,
5
+ "end": 1
6
+ },
7
+ "y": {
8
+ "start": 1,
9
+ "end": 2
10
+ },
11
+ "z": {
12
+ "start": 2,
13
+ "end": 3
14
+ },
15
+ "roll": {
16
+ "start": 3,
17
+ "end": 4
18
+ },
19
+ "pitch": {
20
+ "start": 4,
21
+ "end": 5
22
+ },
23
+ "yaw": {
24
+ "start": 5,
25
+ "end": 6
26
+ },
27
+ "gripper": {
28
+ "start": 6,
29
+ "end": 8
30
+ }
31
+ },
32
+ "action": {
33
+ "x": {
34
+ "start": 0,
35
+ "end": 1
36
+ },
37
+ "y": {
38
+ "start": 1,
39
+ "end": 2
40
+ },
41
+ "z": {
42
+ "start": 2,
43
+ "end": 3
44
+ },
45
+ "roll": {
46
+ "start": 3,
47
+ "end": 4
48
+ },
49
+ "pitch": {
50
+ "start": 4,
51
+ "end": 5
52
+ },
53
+ "yaw": {
54
+ "start": 5,
55
+ "end": 6
56
+ },
57
+ "gripper": {
58
+ "start": 6,
59
+ "end": 7
60
+ }
61
+ },
62
+ "video": {
63
+ "image": {
64
+ "original_key": "observation.images.image"
65
+ },
66
+ "wrist_image": {
67
+ "original_key": "observation.images.wrist_image"
68
+ }
69
+ },
70
+ "annotation": {
71
+ "human.action.task_description": {
72
+ "original_key": "task_index"
73
+ }
74
+ }
75
+ }
examples/SO100/README.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Finetuning SO100 Model
2
+
3
+ This guide shows how to finetune dataset collected from [SO100](https://huggingface.co/docs/lerobot/en/so101) robot, and evaluate the model on the real robot.
4
+
5
+
6
+ ## Dataset
7
+
8
+ To collect the dataset via teleoperation, please refer to the official documentation in lerobot: https://huggingface.co/docs/lerobot/il_robots?teleoperate_so101=Command
9
+
10
+ **Dataset Path:** [izuluaga/finish_sandwich](https://huggingface.co/datasets/izuluaga/finish_sandwich)
11
+
12
+ Visualize it with this [link](https://huggingface.co/spaces/lerobot/visualize_dataset?path=%2Fizuluaga%2Ffinish_sandwich%2Fepisode_0)
13
+
14
+ ## Handling the dataset
15
+
16
+ ```bash
17
+ uv run --project scripts/lerobot_conversion \
18
+ python scripts/lerobot_conversion/convert_v3_to_v2.py \
19
+ --repo-id izuluaga/finish_sandwich \
20
+ --root examples/SO100/finish_sandwich_lerobot
21
+ ```
22
+
23
+ Then move the `modality.json` file to the root of the dataset.
24
+ ```bash
25
+ cp examples/SO100/modality.json examples/SO100/finish_sandwich_lerobot/izuluaga/finish_sandwich/meta/modality.json
26
+ ```
27
+
28
+ ## Finetuning
29
+
30
+ Run the shared finetune launcher directly, using absolute joint positions (feel free to experiment with relative positions):
31
+ ```bash
32
+ CUDA_VISIBLE_DEVICES=0 NUM_GPUS=1 uv run bash examples/finetune.sh \
33
+ --base-model-path nvidia/GR00T-N1.7-3B \
34
+ --dataset-path examples/SO100/finish_sandwich_lerobot/izuluaga/finish_sandwich \
35
+ --modality-config-path examples/SO100/so100_config.py \
36
+ --embodiment-tag NEW_EMBODIMENT \
37
+ --output-dir /tmp/so100_finetune
38
+ ```
39
+
40
+ ## Open-Loop Evaluation
41
+
42
+ Evaluate the finetuned model with the following command:
43
+ ```bash
44
+ uv run python gr00t/eval/open_loop_eval.py \
45
+ --dataset-path examples/SO100/finish_sandwich_lerobot/izuluaga/finish_sandwich/ \
46
+ --embodiment-tag NEW_EMBODIMENT \
47
+ --model-path /tmp/so100_finetune/checkpoint-10000 \
48
+ --traj-ids 0 \
49
+ --action-horizon 16 \
50
+ --steps 400
51
+ ```
52
+
53
+ ### Evaluation Results
54
+
55
+ The evaluation produces visualizations comparing predicted actions against ground truth trajectories:
56
+
57
+ <img src="../../media/open_loop_eval_so100.jpg" width="800" alt="Open-loop evaluation results showing predicted vs ground truth trajectories" />
58
+
59
+ ## Closed-Loop Evaluation
60
+
61
+ Please refer to [eval_so100.py](../../gr00t/eval/real_robot/SO100/eval_so100.py) for how to write SO100 deployment code using Policy API.
62
+
63
+ 1. set up client side deps
64
+
65
+ ```bash
66
+ cd gr00t/eval/real_robot/SO100
67
+ uv venv
68
+ source .venv/bin/activate
69
+ uv pip install -e . --verbose
70
+ uv pip install --no-deps -e ../../../../
71
+ ```
72
+
73
+ 2. Start policy server
74
+ ```bash
75
+ uv run python gr00t/eval/run_gr00t_server.py \
76
+ --model-path /tmp/so100_finetune/checkpoint-10000 \
77
+ --embodiment-tag NEW_EMBODIMENT
78
+ ```
79
+
80
+ 3. Run the eval script, as client.
81
+ ```bash
82
+ uv run python gr00t/eval/real_robot/SO100/eval_so100.py \
83
+ --robot.type=so101_follower --robot.port=/dev/ttyACM2 \
84
+ --robot.id=orange_follower \
85
+ --robot.cameras="{ wrist: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 6, width: 640, height: 480, fps: 30}}" \
86
+ --policy-host=localhost --policy-port=5555 --lang-instruction="finish the ham cheese olives sandwich"
87
+ ```
examples/SO100/modality.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "state": {
3
+ "single_arm": {
4
+ "start": 0,
5
+ "end": 5
6
+ },
7
+ "gripper": {
8
+ "start": 5,
9
+ "end": 6
10
+ }
11
+ },
12
+ "action": {
13
+ "single_arm": {
14
+ "start": 0,
15
+ "end": 5
16
+ },
17
+ "gripper": {
18
+ "start": 5,
19
+ "end": 6
20
+ }
21
+ },
22
+ "video": {
23
+ "front": {
24
+ "original_key": "observation.images.front"
25
+ },
26
+ "wrist": {
27
+ "original_key": "observation.images.wrist"
28
+ }
29
+ },
30
+ "annotation": {
31
+ "human.task_description": {
32
+ "original_key": "task_index"
33
+ }
34
+ }
35
+ }
examples/SO100/so100_config.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from gr00t.configs.data.embodiment_configs import register_modality_config
17
+ from gr00t.data.embodiment_tags import EmbodimentTag
18
+ from gr00t.data.types import (
19
+ ActionConfig,
20
+ ActionFormat,
21
+ ActionRepresentation,
22
+ ActionType,
23
+ ModalityConfig,
24
+ )
25
+
26
+
27
+ so100_config = {
28
+ # Video: current frame only; keys must match "video" entries in meta/modality.json
29
+ "video": ModalityConfig(
30
+ delta_indices=[0],
31
+ modality_keys=["front", "wrist"], # front third-person view + wrist egocentric
32
+ ),
33
+ # State: current proprioceptive reading; keys must match "state" entries in meta/modality.json
34
+ "state": ModalityConfig(
35
+ delta_indices=[0],
36
+ modality_keys=[
37
+ "single_arm", # joint positions
38
+ "gripper", # gripper state
39
+ ],
40
+ ),
41
+ # Action: 16-step prediction horizon; one ActionConfig per modality key
42
+ "action": ModalityConfig(
43
+ delta_indices=list(range(0, 16)), # predict 16 future steps
44
+ modality_keys=[
45
+ "single_arm",
46
+ "gripper",
47
+ ],
48
+ action_configs=[
49
+ # single_arm: RELATIVE = delta from current state (better generalization)
50
+ ActionConfig(
51
+ rep=ActionRepresentation.RELATIVE,
52
+ type=ActionType.NON_EEF, # joint-space, not end-effector
53
+ format=ActionFormat.DEFAULT,
54
+ ),
55
+ # gripper: ABSOLUTE = target position (binary open/close works better absolute)
56
+ ActionConfig(
57
+ rep=ActionRepresentation.ABSOLUTE,
58
+ type=ActionType.NON_EEF,
59
+ format=ActionFormat.DEFAULT,
60
+ ),
61
+ ],
62
+ ),
63
+ # Language: task instruction from annotation field in the dataset
64
+ "language": ModalityConfig(
65
+ delta_indices=[0],
66
+ modality_keys=["annotation.human.task_description"],
67
+ ),
68
+ }
69
+
70
+ register_modality_config(so100_config, embodiment_tag=EmbodimentTag.NEW_EMBODIMENT)
examples/SimplerEnv/README.md ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SimplerEnv
2
+
3
+ Framework for evaluating real-world robot manipulation policies (RT-1, RT-1-X, Octo) in simulation. Replicates common setups like Google Robot and WidowX+Bridge, with GPU-accelerated simulations (10-15x speedup). Offers visual matching and variant aggregation evaluation methods for robust policy assessment.
4
+
5
+ For more information, see the [official repository](https://github.com/simpler-env/SimplerEnv).
6
+
7
+ ---
8
+
9
+ # Fine-tune Simpler Env bridge dataset (WidowX robot)
10
+
11
+ To reproduce our finetune results, use the following commands to setup dataset and launch finetune experiments. Please remember to set `WANDB_API_KEY` since `--use-wandb` is turned on by default. If you don't have a WANDB account, please remove this argument:
12
+
13
+ ```bash
14
+ uv run hf download \
15
+ --repo-type dataset IPEC-COMMUNITY/bridge_orig_lerobot \
16
+ --local-dir examples/SimplerEnv/bridge_orig_lerobot/
17
+
18
+ # Copy the patches and run the finetune script
19
+ cp examples/SimplerEnv/bridge_modality.json examples/SimplerEnv/bridge_orig_lerobot/meta/modality.json
20
+ ```
21
+
22
+ ```bash
23
+ NUM_GPUS=8 MAX_STEPS=20000 GLOBAL_BATCH_SIZE=1024 SAVE_STEPS=1000 uv run bash examples/finetune.sh \
24
+ --base-model-path nvidia/GR00T-N1.7-3B \
25
+ --dataset-path examples/SimplerEnv/bridge_orig_lerobot/ \
26
+ --embodiment-tag SIMPLER_ENV_WIDOWX \
27
+ --output-dir /tmp/bridge_finetune \
28
+ --state-dropout-prob 0.8
29
+ ```
30
+
31
+ # Fine-tune Simpler Env fractal dataset (Google robot)
32
+
33
+ ```bash
34
+ uv run hf download \
35
+ --repo-type dataset IPEC-COMMUNITY/fractal20220817_data_lerobot \
36
+ --local-dir examples/SimplerEnv/fractal20220817_data_lerobot/
37
+
38
+ # Copy the patches and run the finetune script
39
+ cp -r examples/SimplerEnv/fractal_modality.json examples/SimplerEnv/fractal20220817_data_lerobot/meta/modality.json
40
+ uv run python examples/SimplerEnv/convert_av1_to_h264.py examples/SimplerEnv/fractal20220817_data_lerobot --jobs 16 # (Optional) if AV1 doesn't work on your machine
41
+ ```
42
+
43
+ ```bash
44
+ NUM_GPUS=8 MAX_STEPS=20000 GLOBAL_BATCH_SIZE=1024 SAVE_STEPS=1000 uv run bash examples/finetune.sh \
45
+ --base-model-path nvidia/GR00T-N1.7-3B \
46
+ --dataset-path examples/SimplerEnv/fractal20220817_data_lerobot/ \
47
+ --embodiment-tag SIMPLER_ENV_GOOGLE \
48
+ --output-dir /tmp/fractal_finetune \
49
+ --state-dropout-prob 0.5
50
+ ```
51
+
52
+ # Evaluate checkpoint
53
+
54
+ First, setup the evaluation simulation environment. This only needs to run once for each simulation benchmark. After it's done, we only need to launch server and client.
55
+
56
+ ```bash
57
+ sudo apt update
58
+ sudo apt install libegl1-mesa-dev libglu1-mesa
59
+ bash gr00t/eval/sim/SimplerEnv/setup_SimplerEnv.sh
60
+ ```
61
+
62
+ Then, run client server evaluation under the project root directory in separate terminals:
63
+
64
+ ## Fractal (Google Robot) Evaluation
65
+
66
+ **Terminal 1 - Server:**
67
+
68
+ You can use either a local finetuned checkpoint path or the remote finetuned checkpoint (provided by us):
69
+
70
+ **Option 1: Local finetuned checkpoint**
71
+ ```bash
72
+ uv run python gr00t/eval/run_gr00t_server.py \
73
+ --model-path /tmp/fractal_finetune/checkpoint-30000 \
74
+ --embodiment-tag SIMPLER_ENV_GOOGLE \
75
+ --use-sim-policy-wrapper
76
+ ```
77
+
78
+ **Option 2: Remote finetuned checkpoint (directly runnable)**
79
+ ```bash
80
+ uv run python gr00t/eval/run_gr00t_server.py \
81
+ --model-path nvidia/GR00T-N1.7-SimplerEnv-Fractal \
82
+ --embodiment-tag SIMPLER_ENV_GOOGLE \
83
+ --use-sim-policy-wrapper
84
+ ```
85
+
86
+ **Terminal 2 - Client:**
87
+ ```bash
88
+ gr00t/eval/sim/SimplerEnv/simpler_uv/.venv/bin/python gr00t/eval/rollout_policy.py \
89
+ --n-episodes 10 \
90
+ --policy-client-host 127.0.0.1 \
91
+ --policy-client-port 5555 \
92
+ --max-episode-steps 300 \
93
+ --env-name simpler_env_google/google_robot_pick_coke_can \
94
+ --n-action-steps 1 \
95
+ --n-envs 5
96
+ ```
97
+
98
+ ## Bridge (WidowX) Evaluation
99
+
100
+ **Terminal 1 - Server:**
101
+
102
+ **Option 1: Local finetuned checkpoint**
103
+ ```bash
104
+ uv run python gr00t/eval/run_gr00t_server.py \
105
+ --model-path /tmp/bridge_finetune/checkpoint-30000 \
106
+ --embodiment-tag SIMPLER_ENV_WIDOWX \
107
+ --use-sim-policy-wrapper
108
+ ```
109
+
110
+ **Option 2: Remote finetuned checkpoint (directly runnable)**
111
+ ```bash
112
+ uv run python gr00t/eval/run_gr00t_server.py \
113
+ --model-path nvidia/GR00T-N1.7-SimplerEnv-Bridge \
114
+ --embodiment-tag SIMPLER_ENV_WIDOWX \
115
+ --use-sim-policy-wrapper
116
+ ```
117
+
118
+ **Terminal 2 - Client:**
119
+ ```bash
120
+ gr00t/eval/sim/SimplerEnv/simpler_uv/.venv/bin/python gr00t/eval/rollout_policy.py \
121
+ --n-episodes 10 \
122
+ --policy-client-host 127.0.0.1 \
123
+ --policy-client-port 5555 \
124
+ --max-episode-steps 300 \
125
+ --env-name simpler_env_widowx/widowx_spoon_on_towel \
126
+ --n-action-steps 4 \
127
+ --n-envs 5
128
+ ```
129
+
130
+ Other supported tasks are:
131
+ ```
132
+ simpler_env_google/google_robot_pick_object
133
+ simpler_env_google/google_robot_move_near
134
+ simpler_env_google/google_robot_open_drawer
135
+ ...
136
+ simpler_env_widowx/widowx_spoon_on_towel
137
+ simpler_env_widowx/widowx_carrot_on_plate
138
+ simpler_env_widowx/widowx_stack_cube
139
+ ```
140
+
141
+ you can replace the env_name with the corresponding tasks listed in the SimplerEnv fork this repo pins at `external_dependencies/SimplerEnv` (see `.gitmodules`).
examples/SimplerEnv/bridge_modality.json ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "state": {
3
+ "x": {
4
+ "start": 0,
5
+ "end": 1
6
+ },
7
+ "y": {
8
+ "start": 1,
9
+ "end": 2
10
+ },
11
+ "z": {
12
+ "start": 2,
13
+ "end": 3
14
+ },
15
+ "roll": {
16
+ "start": 3,
17
+ "end": 4
18
+ },
19
+ "pitch": {
20
+ "start": 4,
21
+ "end": 5
22
+ },
23
+ "yaw": {
24
+ "start": 5,
25
+ "end": 6
26
+ },
27
+ "pad": {
28
+ "start": 6,
29
+ "end": 7
30
+ },
31
+ "gripper": {
32
+ "start": 7,
33
+ "end": 8
34
+ }
35
+ },
36
+ "action": {
37
+ "x": {
38
+ "start": 0,
39
+ "end": 1
40
+ },
41
+ "y": {
42
+ "start": 1,
43
+ "end": 2
44
+ },
45
+ "z": {
46
+ "start": 2,
47
+ "end": 3
48
+ },
49
+ "roll": {
50
+ "start": 3,
51
+ "end": 4
52
+ },
53
+ "pitch": {
54
+ "start": 4,
55
+ "end": 5
56
+ },
57
+ "yaw": {
58
+ "start": 5,
59
+ "end": 6
60
+ },
61
+ "gripper": {
62
+ "start": 6,
63
+ "end": 7
64
+ }
65
+ },
66
+ "video": {
67
+ "image_0": {
68
+ "original_key": "observation.images.image_0"
69
+ }
70
+ },
71
+ "annotation": {
72
+ "human.action.task_description": {
73
+ "original_key": "task_index"
74
+ },
75
+ "human.validity": {}
76
+ }
77
+ }
examples/SimplerEnv/convert_av1_to_h264.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import argparse
19
+ from concurrent.futures import ThreadPoolExecutor
20
+ import os
21
+ from pathlib import Path
22
+ import subprocess
23
+
24
+
25
+ VIDEO_EXTS = {".mp4", ".mov", ".mkv"}
26
+
27
+
28
+ def run(cmd):
29
+ return subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
30
+
31
+
32
+ def is_av1(path: Path) -> bool:
33
+ cmd = [
34
+ "ffprobe",
35
+ "-v",
36
+ "error",
37
+ "-select_streams",
38
+ "v:0",
39
+ "-show_entries",
40
+ "stream=codec_name",
41
+ "-of",
42
+ "default=nw=1:nk=1",
43
+ str(path),
44
+ ]
45
+ proc = run(cmd)
46
+ if proc.returncode != 0:
47
+ print(f"[ffprobe FAIL] {path}: {proc.stderr.strip()}")
48
+ return False
49
+ codec = proc.stdout.strip()
50
+ return codec in ("av01", "av1")
51
+
52
+
53
+ def convert_file(path: Path):
54
+ if not is_av1(path):
55
+ print(f"[SKIP] {path}")
56
+ return
57
+
58
+ tmp = path.with_suffix(path.suffix + ".mp4")
59
+ print(f"[CONVERT] {path} -> {tmp}")
60
+
61
+ cmd = [
62
+ "ffmpeg",
63
+ "-y",
64
+ "-i",
65
+ str(path),
66
+ "-c:v",
67
+ "libx264",
68
+ "-qp",
69
+ "0",
70
+ "-pix_fmt",
71
+ "yuv420p",
72
+ "-c:a",
73
+ "copy",
74
+ "-vsync",
75
+ "passthrough",
76
+ "-copyts",
77
+ "-muxdelay",
78
+ "0",
79
+ "-muxpreload",
80
+ "0",
81
+ str(tmp),
82
+ ]
83
+ proc = run(cmd)
84
+ if proc.returncode != 0:
85
+ print(f"[ffmpeg FAIL] {path}: {proc.stderr.strip()}")
86
+ if tmp.exists():
87
+ tmp.unlink()
88
+ return
89
+
90
+ tmp.replace(path)
91
+ print(f"[DONE] {path}")
92
+
93
+
94
+ def find_videos(root: Path):
95
+ for dirpath, _, filenames in os.walk(root):
96
+ for name in filenames:
97
+ p = Path(dirpath) / name
98
+ if p.suffix.lower() in VIDEO_EXTS:
99
+ yield p
100
+
101
+
102
+ def main():
103
+ ap = argparse.ArgumentParser(
104
+ description="Recursively convert AV1 videos to H.264 (lossless-ish) in place."
105
+ )
106
+ ap.add_argument("root", nargs="?", default=".", help="Root directory (default: .)")
107
+ ap.add_argument(
108
+ "-j",
109
+ "--jobs",
110
+ type=int,
111
+ default=os.cpu_count() or 4,
112
+ help="Number of parallel workers (default: CPU count)",
113
+ )
114
+ args = ap.parse_args()
115
+
116
+ root = Path(args.root).resolve()
117
+ files = list(find_videos(root))
118
+ print(f"Scanning {root}, found {len(files)} candidate video files")
119
+
120
+ if not files:
121
+ return
122
+
123
+ with ThreadPoolExecutor(max_workers=args.jobs) as ex:
124
+ for p in files:
125
+ ex.submit(convert_file, p)
126
+
127
+
128
+ if __name__ == "__main__":
129
+ main()
examples/SimplerEnv/fractal_modality.json ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "state": {
3
+ "x": {
4
+ "start": 0,
5
+ "end": 1
6
+ },
7
+ "y": {
8
+ "start": 1,
9
+ "end": 2
10
+ },
11
+ "z": {
12
+ "start": 2,
13
+ "end": 3
14
+ },
15
+ "rx": {
16
+ "start": 3,
17
+ "end": 4
18
+ },
19
+ "ry": {
20
+ "start": 4,
21
+ "end": 5
22
+ },
23
+ "rz": {
24
+ "start": 5,
25
+ "end": 6
26
+ },
27
+ "rw": {
28
+ "start": 6,
29
+ "end": 7
30
+ },
31
+ "gripper": {
32
+ "start": 7,
33
+ "end": 8
34
+ }
35
+ },
36
+ "action": {
37
+ "x": {
38
+ "start": 0,
39
+ "end": 1
40
+ },
41
+ "y": {
42
+ "start": 1,
43
+ "end": 2
44
+ },
45
+ "z": {
46
+ "start": 2,
47
+ "end": 3
48
+ },
49
+ "roll": {
50
+ "start": 3,
51
+ "end": 4
52
+ },
53
+ "pitch": {
54
+ "start": 4,
55
+ "end": 5
56
+ },
57
+ "yaw": {
58
+ "start": 5,
59
+ "end": 6
60
+ },
61
+ "gripper": {
62
+ "start": 6,
63
+ "end": 7
64
+ }
65
+ },
66
+ "video": {
67
+ "image": {
68
+ "original_key": "observation.images.image"
69
+ }
70
+ },
71
+ "annotation": {
72
+ "human.action.task_description": {
73
+ "original_key": "task_index"
74
+ },
75
+ "human.validity": {}
76
+ }
77
+ }
examples/finetune.sh ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -x -euo pipefail
4
+
5
+ NUM_GPUS="${NUM_GPUS:-1}"
6
+ MASTER_PORT="${MASTER_PORT:-29500}"
7
+ SAVE_STEPS="${SAVE_STEPS:-1000}"
8
+ MAX_STEPS="${MAX_STEPS:-10000}"
9
+ USE_WANDB="${USE_WANDB:-1}"
10
+ DATALOADER_NUM_WORKERS="${DATALOADER_NUM_WORKERS:-4}"
11
+ GLOBAL_BATCH_SIZE="${GLOBAL_BATCH_SIZE:-32}"
12
+ SHARD_SIZE="${SHARD_SIZE:-1024}"
13
+ NUM_SHARDS_PER_EPOCH="${NUM_SHARDS_PER_EPOCH:-100000}"
14
+ EPISODE_SAMPLING_RATE="${EPISODE_SAMPLING_RATE:-0.1}"
15
+
16
+ BASE_MODEL_PATH=""
17
+ DATASET_PATH=""
18
+ MODALITY_CONFIG_PATH=""
19
+ EMBODIMENT_TAG=""
20
+ OUTPUT_DIR=""
21
+ EXPERIMENT_NAME=""
22
+ WANDB_PROJECT=""
23
+ STATE_DROPOUT_PROB=""
24
+ EXTRA_ARGS=()
25
+
26
+ usage() {
27
+ cat <<'EOF'
28
+ Usage: bash examples/finetune.sh \
29
+ --base-model-path <path> \
30
+ --dataset-path <path> \
31
+ --embodiment-tag <tag> \
32
+ --output-dir <path> \
33
+ [--modality-config-path <path>] \
34
+ [--state-dropout-prob <value>] \
35
+ [--save-only-model] \
36
+ [-- <extra launch_finetune.py args>...]
37
+ EOF
38
+ }
39
+
40
+ while [ "$#" -gt 0 ]; do
41
+ case "$1" in
42
+ --base-model-path)
43
+ BASE_MODEL_PATH="$2"
44
+ shift 2
45
+ ;;
46
+ --dataset-path)
47
+ DATASET_PATH="$2"
48
+ shift 2
49
+ ;;
50
+ --modality-config-path)
51
+ MODALITY_CONFIG_PATH="$2"
52
+ shift 2
53
+ ;;
54
+ --embodiment-tag)
55
+ EMBODIMENT_TAG="$2"
56
+ shift 2
57
+ ;;
58
+ --output-dir)
59
+ OUTPUT_DIR="$2"
60
+ shift 2
61
+ ;;
62
+ --experiment-name)
63
+ EXPERIMENT_NAME="$2"
64
+ shift 2
65
+ ;;
66
+ --wandb-project)
67
+ WANDB_PROJECT="$2"
68
+ shift 2
69
+ ;;
70
+ --state-dropout-prob)
71
+ STATE_DROPOUT_PROB="$2"
72
+ shift 2
73
+ ;;
74
+ --save-only-model)
75
+ SAVE_ONLY_MODEL=1
76
+ shift
77
+ ;;
78
+ --help|-h)
79
+ usage
80
+ exit 0
81
+ ;;
82
+ --)
83
+ shift
84
+ EXTRA_ARGS=("$@")
85
+ break
86
+ ;;
87
+ *)
88
+ echo "Unknown argument: $1" >&2
89
+ usage >&2
90
+ exit 1
91
+ ;;
92
+ esac
93
+ done
94
+
95
+ for required_var in BASE_MODEL_PATH DATASET_PATH EMBODIMENT_TAG OUTPUT_DIR; do
96
+ if [ -z "${!required_var}" ]; then
97
+ echo "Missing required argument: ${required_var}" >&2
98
+ usage >&2
99
+ exit 1
100
+ fi
101
+ done
102
+
103
+ WANDB_FLAG=()
104
+ if [ "$USE_WANDB" = "1" ]; then
105
+ WANDB_FLAG+=(--use_wandb)
106
+ fi
107
+
108
+ LAUNCH_CMD=(
109
+ gr00t/experiment/launch_finetune.py
110
+ --base_model_path "$BASE_MODEL_PATH"
111
+ --dataset_path "$DATASET_PATH"
112
+ --embodiment_tag "$EMBODIMENT_TAG"
113
+ --num_gpus "$NUM_GPUS"
114
+ --output_dir "$OUTPUT_DIR"
115
+ --save_steps "$SAVE_STEPS"
116
+ --save_total_limit 5
117
+ --max_steps "$MAX_STEPS"
118
+ --warmup_ratio 0.05
119
+ --weight_decay 1e-5
120
+ --learning_rate 1e-4
121
+ "${WANDB_FLAG[@]}"
122
+ --global_batch_size "$GLOBAL_BATCH_SIZE"
123
+ --color_jitter_params brightness 0.3 contrast 0.4 saturation 0.5 hue 0.08
124
+ --dataloader_num_workers "$DATALOADER_NUM_WORKERS"
125
+ --shard_size "$SHARD_SIZE"
126
+ --num_shards_per_epoch "$NUM_SHARDS_PER_EPOCH"
127
+ --episode_sampling_rate "$EPISODE_SAMPLING_RATE"
128
+ )
129
+
130
+ if [ -n "$MODALITY_CONFIG_PATH" ]; then
131
+ LAUNCH_CMD+=(--modality_config_path "$MODALITY_CONFIG_PATH")
132
+ fi
133
+ if [ -n "$EXPERIMENT_NAME" ]; then
134
+ LAUNCH_CMD+=(--experiment_name "$EXPERIMENT_NAME")
135
+ fi
136
+ if [ -n "$WANDB_PROJECT" ]; then
137
+ LAUNCH_CMD+=(--wandb_project "$WANDB_PROJECT")
138
+ fi
139
+
140
+ if [ -n "$STATE_DROPOUT_PROB" ]; then
141
+ LAUNCH_CMD+=(--state_dropout_prob "$STATE_DROPOUT_PROB")
142
+ fi
143
+ if [ -n "${SAVE_ONLY_MODEL:-}" ]; then
144
+ LAUNCH_CMD+=(--save_only_model)
145
+ fi
146
+
147
+ if [ "${#EXTRA_ARGS[@]}" -gt 0 ]; then
148
+ LAUNCH_CMD+=("${EXTRA_ARGS[@]}")
149
+ fi
150
+
151
+ if [ "$NUM_GPUS" = "1" ]; then
152
+ # Restrict to a single GPU so HF Trainer doesn't wrap the model in DataParallel,
153
+ # which crashes with a StopIteration error in the model's device property.
154
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
155
+ exec python "${LAUNCH_CMD[@]}"
156
+ fi
157
+
158
+ exec torchrun --nproc_per_node="$NUM_GPUS" --master_port="$MASTER_PORT" "${LAUNCH_CMD[@]}"
examples/mask-guided-background-suppression/README.md ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mask-Guided Background Suppression
2
+
3
+ Mask-guided augmentations leverage per-frame segmentation masks to apply targeted image transformations during training. This enables **domain randomization** on specific regions (e.g., replacing backgrounds with noise, tinting foreground objects) without affecting the rest of the image.
4
+
5
+ This feature is controlled via the `--extra_augmentation_config` argument, which accepts a JSON string specifying which mask regions to augment and how.
6
+
7
+ ---
8
+
9
+ ## Prerequisites
10
+
11
+ 1. **Segmentation masks** must be pre-generated and stored alongside your dataset. The dataset's `info.json` must include a `mask_path` template, and `modality.json` must define a `"mask"` section mapping camera views.
12
+
13
+ 2. **Albumentations transforms** are enabled by default in N1.7 (`use_albumentations_transforms=True` in model config). No extra flag is needed.
14
+
15
+ ---
16
+
17
+ ## Supported Augmentation Types
18
+
19
+ ### 1. Background Noise Transform
20
+
21
+ Replaces pixels in specified mask regions with **random RGB noise**. Useful for sim-to-real transfer or preventing the model from overfitting to static backgrounds.
22
+
23
+ | Parameter | Type | Description |
24
+ |-----------|------|-------------|
25
+ | `target_mask_values` | `list[int]` | Mask label values to replace with noise (e.g., `[0]` for background) |
26
+ | `p` | `float` | Probability of applying the transform per frame (0.0 to 1.0) |
27
+
28
+ ### 2. Masked Region Color Transform
29
+
30
+ Applies a **random color tint** to pixels in specified mask regions. Useful for augmenting the appearance of specific objects (e.g., tables, tools) to improve color generalization.
31
+
32
+ | Parameter | Type | Description |
33
+ |-----------|------|-------------|
34
+ | `target_mask_values` | `list[int]` | Mask label values to apply the tint to (e.g., `[4]`, `[5]`) |
35
+ | `p` | `float` | Probability of applying the transform per frame (0.0 to 1.0) |
36
+ | `alpha_range` | `[min, max]` | Range for blending intensity between original and tint color (default: `[0.3, 1.0]`) |
37
+
38
+ ---
39
+
40
+ ## Configuration Format
41
+
42
+ The `--extra_augmentation_config` argument takes a JSON string with two optional keys:
43
+
44
+ ```json
45
+ {
46
+ "background_noise_transforms": [
47
+ {"target_mask_values": [0], "p": 0.9}
48
+ ],
49
+ "masked_region_transforms": [
50
+ {"target_mask_values": [4], "p": 1.0, "alpha_range": [0.0, 1.0]}
51
+ ]
52
+ }
53
+ ```
54
+
55
+ Multiple transforms of each type can be specified (e.g., different mask values with different probabilities).
56
+
57
+ ---
58
+
59
+ ## Quick Start with Demo Data
60
+
61
+ The included demo dataset `demo_data/cube_to_bowl_5_with_mask` contains a single episode with front and wrist camera views, along with pre-generated segmentation masks. The masks were generated using [SAM 3](https://github.com/facebookresearch/sam3) with the text prompt `"background"`, then converted so that background pixels = `0` and foreground pixels = `1` (see [Generating mask files](#generating-mask-files) below).
62
+
63
+ ### 1. Background noise only
64
+
65
+ Replace background (mask=0) with random noise:
66
+
67
+ ```bash
68
+ uv run python test_extra_augmentation.py \
69
+ --dataset_path ../../demo_data/cube_to_bowl_5_with_mask \
70
+ --embodiment_tag NEW_EMBODIMENT \
71
+ --modality_config_path so101_config.py \
72
+ --extra_augmentation_config '{"background_noise_transforms": [{"target_mask_values": [0], "p": 1.0}]}'
73
+ ```
74
+
75
+ ### 2. Background noise + foreground color tint
76
+
77
+ Apply both transforms together:
78
+
79
+ ```bash
80
+ uv run python test_extra_augmentation.py \
81
+ --dataset_path ../../demo_data/cube_to_bowl_5_with_mask \
82
+ --embodiment_tag NEW_EMBODIMENT \
83
+ --modality_config_path so101_config.py \
84
+ --extra_augmentation_config '{"background_noise_transforms": [{"target_mask_values": [0], "p": 1.0}], "masked_region_transforms": [{"target_mask_values": [1], "p": 1.0, "alpha_range": [0.3, 1.0]}]}' \
85
+ --output_dir /tmp/augmentation_vis --num_frames 5
86
+ ```
87
+
88
+ Both commands save side-by-side comparison images (**Original | Augmented | Mask**) under `output_dir/<view_name>/`, with frames sampled evenly across the episode.
89
+
90
+ ### 3. Fine-tune with mask-guided augmentation
91
+
92
+ ```bash
93
+ export NUM_GPUS=8
94
+
95
+ torchrun --nproc_per_node=$NUM_GPUS --master_port=29500 \
96
+ gr00t/experiment/launch_finetune.py \
97
+ --base-model-path nvidia/GR00T-N1.7-3B \
98
+ --dataset-path <YOUR_DATASET_WITH_MASKS> \
99
+ --embodiment-tag <YOUR_EMBODIMENT_TAG> \
100
+ --num-gpus $NUM_GPUS \
101
+ --output-dir /tmp/mask_augmentation_run \
102
+ --save-steps 1000 \
103
+ --save-total-limit 5 \
104
+ --max-steps 20000 \
105
+ --warmup-ratio 0.05 \
106
+ --weight-decay 1e-5 \
107
+ --learning-rate 1e-4 \
108
+ --use-wandb \
109
+ --global-batch-size 640 \
110
+ --dataloader-num-workers 4 \
111
+ --extra-augmentation-config '{"background_noise_transforms": [{"target_mask_values": [0], "p": 0.9}], "masked_region_transforms": [{"target_mask_values": [4], "p": 1.0, "alpha_range": [0, 1]}]}'
112
+ ```
113
+
114
+ ---
115
+
116
+ ## Dataset Setup
117
+
118
+ To use mask-guided augmentation with your own dataset, ensure:
119
+
120
+ 1. **Mask files** are stored as `.npz` files under a `masks/` directory, following the same chunk/episode structure as videos. Each `.npz` contains a single `uint8` array of shape `(num_frames, H, W)` where each pixel holds an integer semantic label (e.g., `0` = background, `1` = object A, `2` = object B).
121
+
122
+ ```
123
+ masks/
124
+ └── chunk-000/
125
+ └── observation.images.front/
126
+ └── episode_000000_masks.npz
127
+ ```
128
+
129
+ See [Generating mask files](#generating-mask-files) below for how to produce these files.
130
+
131
+ 2. **`info.json`** includes a `mask_path` template:
132
+
133
+ ```json
134
+ {
135
+ "mask_path": "masks/chunk-{episode_chunk:03d}/{mask_key}/episode_{episode_index:06d}_masks.npz"
136
+ }
137
+ ```
138
+
139
+ 3. **`modality.json`** includes a `"mask"` section mapping view names to their original keys. The keys should match the actual camera view names in your dataset:
140
+
141
+ ```json
142
+ {
143
+ "mask": {
144
+ "<view_name>": {
145
+ "original_key": "<observation.images.xxx>"
146
+ }
147
+ }
148
+ }
149
+ ```
150
+
151
+ For example, if your dataset has `front` and `wrist` cameras:
152
+
153
+ ```json
154
+ {
155
+ "mask": {
156
+ "front": {
157
+ "original_key": "observation.images.front"
158
+ },
159
+ "wrist": {
160
+ "original_key": "observation.images.wrist"
161
+ }
162
+ }
163
+ }
164
+ ```
165
+
166
+ ---
167
+
168
+ ## Generating Mask Files
169
+
170
+ You can generate mask files using any video segmentation model that produces per-pixel labels. The demo masks in this example were created with [SAM 3](https://github.com/facebookresearch/sam3) (see the [video predictor example](https://github.com/facebookresearch/sam3/blob/main/examples/sam3_video_predictor_example.ipynb) for SAM 3 usage). The workflow was:
171
+
172
+ 1. Run SAM 3 on each episode video with a text prompt such as `"background"`. SAM 3 returns per-frame binary masks via `propagate_in_video`.
173
+ 2. Convert the binary masks into the label format expected by this pipeline (`0` = background, non-zero = foreground categories) and save as `.npz`:
174
+
175
+ ```python
176
+ import numpy as np
177
+
178
+ # sam3_binary_masks: (num_frames, H, W) bool array from SAM 3 (True where prompt matched)
179
+ # For a "background" prompt, invert so that background=0 and foreground=1:
180
+ label_masks = (~sam3_binary_masks).astype(np.uint8)
181
+
182
+ # For multiple prompts, merge into one label array instead:
183
+ # label_masks = np.zeros((num_frames, H, W), dtype=np.uint8)
184
+ # label_masks[prompt_0_masks] = 1
185
+ # label_masks[prompt_1_masks] = 2
186
+
187
+ np.savez_compressed("episode_000000_masks.npz", label_masks)
188
+ ```
189
+
190
+ The pipeline loads the array from the `.npz` file (it expects the key `arr_0`, which is the default for `np.savez_compressed`). A single `.npy` file containing the `(num_frames, H, W)` array also works.
191
+
192
+ ---
193
+
194
+ ## How It Works
195
+
196
+ The augmentation pipeline applies mask-based transforms **per-frame** before the standard augmentations (crop, resize, color jitter, etc.):
197
+
198
+ 1. For each frame, the corresponding segmentation mask is loaded.
199
+ 2. `BackgroundNoiseTransform` replaces all pixels where `mask == target_value` with random RGB noise.
200
+ 3. `MaskedColorTransform` blends a random color into all pixels where `mask == target_value`, controlled by `alpha_range`.
201
+ 4. Standard augmentations (shared across views via replay) are then applied on top.
202
+
203
+ This ordering ensures that mask-guided augmentations are applied independently per frame, while standard augmentations remain consistent across camera views within the same timestep.
examples/mask-guided-background-suppression/so101_config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from gr00t.configs.data.embodiment_configs import register_modality_config
17
+ from gr00t.data.embodiment_tags import EmbodimentTag
18
+ from gr00t.data.types import (
19
+ ActionConfig,
20
+ ActionFormat,
21
+ ActionRepresentation,
22
+ ActionType,
23
+ ModalityConfig,
24
+ )
25
+
26
+
27
+ so101_config = {
28
+ "video": ModalityConfig(
29
+ delta_indices=[0],
30
+ modality_keys=["front", "wrist"],
31
+ ),
32
+ "mask": ModalityConfig(
33
+ delta_indices=[0],
34
+ modality_keys=["front", "wrist"],
35
+ ),
36
+ "state": ModalityConfig(
37
+ delta_indices=[0],
38
+ modality_keys=["single_arm", "gripper"],
39
+ ),
40
+ "action": ModalityConfig(
41
+ delta_indices=list(range(16)),
42
+ modality_keys=["single_arm", "gripper"],
43
+ action_configs=[
44
+ ActionConfig(
45
+ rep=ActionRepresentation.RELATIVE,
46
+ type=ActionType.NON_EEF,
47
+ format=ActionFormat.DEFAULT,
48
+ ),
49
+ ActionConfig(
50
+ rep=ActionRepresentation.ABSOLUTE,
51
+ type=ActionType.NON_EEF,
52
+ format=ActionFormat.DEFAULT,
53
+ ),
54
+ ],
55
+ ),
56
+ "language": ModalityConfig(
57
+ delta_indices=[0],
58
+ modality_keys=["annotation.human.task_description"],
59
+ ),
60
+ }
61
+
62
+ register_modality_config(so101_config, embodiment_tag=EmbodimentTag.NEW_EMBODIMENT)
examples/mask-guided-background-suppression/test_extra_augmentation.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Smoke test: apply extra_augmentation_config to raw frames and save comparison images."""
19
+
20
+ from __future__ import annotations
21
+
22
+ import argparse
23
+ import importlib
24
+ import json
25
+ import os
26
+ from pathlib import Path
27
+ import sys
28
+
29
+ from gr00t.configs.data.embodiment_configs import MODALITY_CONFIGS
30
+ from gr00t.data.dataset.lerobot_episode_loader import LeRobotEpisodeLoader
31
+ from gr00t.data.embodiment_tags import EmbodimentTag
32
+ from gr00t.model.gr00t_n1d7.image_augmentations import (
33
+ apply_with_replay,
34
+ build_image_transformations_albumentations,
35
+ )
36
+ import numpy as np
37
+ from PIL import Image
38
+
39
+
40
+ def save_comparison(original, augmented, mask, output_path):
41
+ orig_arr = np.array(original)
42
+ aug_arr = augmented.transpose(1, 2, 0) if augmented.shape[0] == 3 else augmented
43
+
44
+ panels = [orig_arr, aug_arr]
45
+ if mask is not None:
46
+ mask_vis = np.where(mask[..., None] > 0, 255, 0).astype(np.uint8)
47
+ mask_vis = np.broadcast_to(mask_vis, (*mask.shape[:2], 3)).copy()
48
+ if mask_vis.shape[:2] != orig_arr.shape[:2]:
49
+ mask_vis = np.array(
50
+ Image.fromarray(mask_vis).resize(
51
+ (orig_arr.shape[1], orig_arr.shape[0]), Image.NEAREST
52
+ )
53
+ )
54
+ panels.append(mask_vis)
55
+
56
+ h = panels[0].shape[0]
57
+ resized = []
58
+ for p in panels:
59
+ if p.shape[0] != h:
60
+ new_w = int(p.shape[1] * h / p.shape[0])
61
+ p = np.array(Image.fromarray(p).resize((new_w, h), Image.BILINEAR))
62
+ resized.append(p)
63
+
64
+ Image.fromarray(np.concatenate(resized, axis=1)).save(output_path)
65
+
66
+
67
+ def main():
68
+ parser = argparse.ArgumentParser()
69
+ parser.add_argument("--dataset_path", required=True)
70
+ parser.add_argument("--embodiment_tag", required=True)
71
+ parser.add_argument("--modality_config_path", default=None)
72
+ parser.add_argument("--extra_augmentation_config", type=str, required=True)
73
+ parser.add_argument("--output_dir", type=str, default="/tmp/augmentation_vis")
74
+ parser.add_argument("--num_frames", type=int, default=5)
75
+ parser.add_argument("--video_backend", type=str, default="torchcodec")
76
+ args = parser.parse_args()
77
+
78
+ if args.modality_config_path:
79
+ path = Path(args.modality_config_path)
80
+ sys.path.append(str(path.parent))
81
+ importlib.import_module(path.stem)
82
+
83
+ embodiment_tag = EmbodimentTag[args.embodiment_tag].value
84
+ modality_configs = MODALITY_CONFIGS[embodiment_tag]
85
+ extra_aug_config = json.loads(args.extra_augmentation_config)
86
+
87
+ train_transform, _ = build_image_transformations_albumentations(
88
+ image_target_size=[224, 224],
89
+ image_crop_size=[224, 224],
90
+ random_rotation_angle=0,
91
+ color_jitter_params=None,
92
+ shortest_image_edge=512,
93
+ crop_fraction=0.95,
94
+ extra_augmentation_config=extra_aug_config,
95
+ )
96
+
97
+ loader = LeRobotEpisodeLoader(
98
+ dataset_path=args.dataset_path,
99
+ modality_configs=modality_configs,
100
+ video_backend=args.video_backend,
101
+ )
102
+ episode_df = loader[0]
103
+
104
+ video_cols = [c for c in episode_df.columns if c.startswith("video.")]
105
+ mask_cols = [c for c in episode_df.columns if c.startswith("mask.")]
106
+ print(f"Video columns: {video_cols}")
107
+ print(f"Mask columns: {mask_cols}")
108
+
109
+ num_frames = min(args.num_frames, len(episode_df))
110
+ frame_indices = np.linspace(0, len(episode_df) - 1, num_frames, dtype=int)
111
+
112
+ os.makedirs(args.output_dir, exist_ok=True)
113
+
114
+ for vcol in video_cols:
115
+ view_name = vcol.replace("video.", "")
116
+ mcol = f"mask.{view_name}"
117
+ has_mask = mcol in mask_cols
118
+
119
+ view_dir = os.path.join(args.output_dir, view_name.replace(".", "_"))
120
+ os.makedirs(view_dir, exist_ok=True)
121
+
122
+ for fidx in frame_indices:
123
+ orig_img = episode_df[vcol].iloc[fidx]
124
+ mask_arr = np.array(episode_df[mcol].iloc[fidx]) if has_mask else None
125
+ masks_list = [mask_arr] if mask_arr is not None else None
126
+
127
+ transformed, _ = apply_with_replay(train_transform, [orig_img], masks_list)
128
+ aug_arr = transformed[0].numpy()
129
+
130
+ out_path = os.path.join(view_dir, f"frame_{fidx:04d}.png")
131
+ save_comparison(orig_img, aug_arr, mask_arr, out_path)
132
+ print(f" Saved: {out_path}")
133
+
134
+ print(f"\nDone! {num_frames} frames x {len(video_cols)} views saved to {args.output_dir}")
135
+
136
+ print("\n" + "=" * 60)
137
+ print("Testing full training pipeline (processor + dataloader) ...")
138
+ print("=" * 60)
139
+
140
+ from gr00t.configs.base_config import get_default_config
141
+ from gr00t.data.dataset.factory import DatasetFactory
142
+ from gr00t.model.gr00t_n1d7.processing_gr00t_n1d7 import Gr00tN1d7Processor
143
+
144
+ config = get_default_config()
145
+ config = config.load_dict(
146
+ {
147
+ "data": {
148
+ "download_cache": False,
149
+ "video_backend": args.video_backend,
150
+ "datasets": [
151
+ {
152
+ "dataset_paths": [args.dataset_path],
153
+ "mix_ratio": 1.0,
154
+ "embodiment_tag": embodiment_tag,
155
+ }
156
+ ],
157
+ }
158
+ }
159
+ )
160
+ config.model.extra_augmentation_config = extra_aug_config
161
+ config.model.use_albumentations_transforms = True
162
+
163
+ processor = Gr00tN1d7Processor(
164
+ modality_configs=config.data.modality_configs,
165
+ statistics=None,
166
+ image_crop_size=config.model.image_crop_size,
167
+ image_target_size=config.model.image_target_size,
168
+ random_rotation_angle=config.model.random_rotation_angle,
169
+ color_jitter_params=config.model.color_jitter_params,
170
+ model_name=config.model.model_name,
171
+ model_type=config.model.backbone_model_type,
172
+ formalize_language=config.model.formalize_language,
173
+ max_state_dim=config.model.max_state_dim,
174
+ max_action_dim=config.model.max_action_dim,
175
+ apply_sincos_state_encoding=config.model.apply_sincos_state_encoding,
176
+ max_action_horizon=config.model.action_horizon,
177
+ use_albumentations=config.model.use_albumentations_transforms,
178
+ extra_augmentation_config=config.model.extra_augmentation_config,
179
+ shortest_image_edge=config.model.shortest_image_edge,
180
+ crop_fraction=config.model.crop_fraction,
181
+ use_relative_action=config.model.use_relative_action,
182
+ )
183
+ processor.train()
184
+
185
+ dataset_factory = DatasetFactory(config=config)
186
+ train_dataset, _ = dataset_factory.build(processor=processor)
187
+ sample = next(iter(train_dataset))
188
+
189
+ print(f"Sample keys: {list(sample.keys())}")
190
+ print(f"VLM keys: {list(sample['vlm_content'].keys())}")
191
+ for k, v in sample.items():
192
+ if hasattr(v, "shape"):
193
+ print(f" {k}: shape={v.shape}, dtype={v.dtype}")
194
+ print("\nPipeline test PASSED!")
195
+
196
+
197
+ if __name__ == "__main__":
198
+ main()
getting_started/data_config.md ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to prepare your modality configuration
2
+
3
+ ## Overview
4
+
5
+ The modality configuration defines how your robot's data should be loaded, processed, and interpreted by the model. This configuration bridges your dataset's physical structure (defined in `meta/modality.json`) and the model's data processing pipeline.
6
+
7
+ Each embodiment requires a Python configuration file that specifies:
8
+ - Which observations to use (video cameras, proprioceptive states)
9
+ - How to sample data temporally (current frame, historical frames, future action horizons)
10
+ - How actions should be interpreted and transformed
11
+ - Which language annotations to use
12
+
13
+ ## Configuration Structure
14
+
15
+ A modality configuration is a Python dictionary containing four top-level keys: `"video"`, `"state"`, `"action"`, and `"language"`. Each key maps to a `ModalityConfig` object.
16
+
17
+ Here's the [SO-100 example](../examples/SO100/so100_config.py):
18
+
19
+ ```python
20
+ from gr00t.configs.data.embodiment_configs import register_modality_config
21
+ from gr00t.data.types import ModalityConfig, ActionConfig, ActionRepresentation, ActionType, ActionFormat
22
+
23
+ so100_config = {
24
+ "video": ModalityConfig(...),
25
+ "state": ModalityConfig(...),
26
+ "action": ModalityConfig(...),
27
+ "language": ModalityConfig(...),
28
+ }
29
+
30
+ register_modality_config(so100_config, embodiment_tag=EmbodimentTag.NEW_EMBODIMENT)
31
+ ```
32
+
33
+ ## Understanding `ModalityConfig`
34
+
35
+ Each `ModalityConfig` specifies two required fields and several optional ones:
36
+
37
+ ### Required Fields
38
+
39
+ **1. `delta_indices` (list[int])**
40
+
41
+ Defines which temporal offsets to sample relative to the current timestep:
42
+ - Current observation: Use [0] for the current timestep (recommended for video and state)
43
+ - Future actions: Use positive indices (e.g., list(range(0, 16))) for action prediction horizons
44
+
45
+ > **Note:** Negative indices (e.g., [-2, -1, 0]) are supported by the data loader for historical context, but no current N1.7 embodiment config uses them. Stick with [0] for video and state unless you have a specific reason to stack frames.
46
+
47
+ Examples:
48
+ ```python
49
+ # Single current frame for video
50
+ delta_indices=[0]
51
+
52
+
53
+ # 16-step action prediction horizon
54
+ delta_indices=list(range(0, 16))
55
+ ```
56
+
57
+ > **Note:** If you modify `delta_indices` for the action modality (e.g., changing the action horizon from 16 to 8), you **must** regenerate the dataset statistics by re-running `python gr00t/data/stats.py --dataset-path <dataset_path> --embodiment-tag <embodiment_tag>`. The normalization statistics (especially `meta/relative_stats.json`) are computed based on the original `delta_indices` length, and a mismatch will cause errors during training.
58
+
59
+ <details>
60
+ <summary>Example: What happens if you change <code>delta_indices</code> without regenerating stats?</summary>
61
+
62
+ Suppose your action config originally uses a 16-step horizon:
63
+
64
+ ```python
65
+ "action": ModalityConfig(
66
+ delta_indices=list(range(0, 16)), # 16 steps
67
+ ...
68
+ )
69
+ ```
70
+
71
+ Running `python gr00t/data/stats.py` generates `meta/relative_stats.json` with per-step statistics of shape `(16, D)`, where `D` is the action dimension.
72
+
73
+ If you later change the horizon to 8 steps:
74
+
75
+ ```python
76
+ "action": ModalityConfig(
77
+ delta_indices=list(range(0, 8)), # 8 steps
78
+ ...
79
+ )
80
+ ```
81
+
82
+ The training data will now have shape `(8, D)`, but the normalization parameters from `relative_stats.json` still have shape `(16, D)`. This dimension mismatch causes an `IndexError` during normalization:
83
+
84
+ ```
85
+ IndexError: boolean index did not match indexed array along dimension 0;
86
+ dimension is 8 but corresponding boolean dimension is 16
87
+ ```
88
+
89
+ **Fix:** Re-run `python gr00t/data/stats.py --dataset-path <dataset_path> --embodiment-tag <embodiment_tag>` after changing `delta_indices` to regenerate matching statistics.
90
+
91
+ </details>
92
+
93
+
94
+ **2. `modality_keys` (list[str])**
95
+
96
+ Specifies which keys to load from your dataset. These keys **must match** the keys defined in your `meta/modality.json` file.
97
+
98
+ For the SO-100 example:
99
+ - **Video keys**: Must match keys in `meta/modality.json` under `"video"` (e.g., `"front"`, `"wrist"`)
100
+ - **State keys**: Must match keys in `meta/modality.json` under `"state"` (e.g., `"single_arm"`, `"gripper"`)
101
+ - **Action keys**: Must match keys in `meta/modality.json` under `"action"` (e.g., `"single_arm"`, `"gripper"`)
102
+ - **Language keys**: Must match keys in `meta/modality.json` under `"annotation"` (e.g., `"annotation.human.task_description"` for SO-100)
103
+
104
+ ### Optional Fields
105
+
106
+ **3. `sin_cos_embedding_keys` (list[str] | None)**
107
+
108
+ Specifies which state keys should use sine/cosine encoding. Best for dimensions that are in radians (e.g., joint angles). If not specified, min-max normalization is used. Note that this will duplicate the number of dimensions by 2, and is only recommended for proprioceptive states.
109
+
110
+ ```python
111
+ "state": ModalityConfig(
112
+ delta_indices=[0],
113
+ modality_keys=["single_arm", "gripper"],
114
+ sin_cos_embedding_keys=["single_arm"], # Apply sin/cos to joint angles
115
+ )
116
+ ```
117
+
118
+ **4. `mean_std_embedding_keys` (list[str] | None)**
119
+
120
+ Specifies which keys should use mean/standard deviation normalization instead of min-max normalization.
121
+
122
+ **5. `action_configs` (list[ActionConfig] | None)**
123
+
124
+ Required for the `"action"` modality. Defines how each action modality should be interpreted and transformed. The list must have the **same length and same order** as `modality_keys` — `action_configs[0]` applies to `modality_keys[0]`, `action_configs[1]` to `modality_keys[1]`, etc. A mismatch in ordering will silently apply the wrong representation (e.g., RELATIVE to a gripper that should be ABSOLUTE). See more details in the [Action Modality](#understanding-actionconfig) section.
125
+
126
+ ## Configuring Each Modality
127
+
128
+ ### Video Modality
129
+
130
+ Defines which camera views to use:
131
+
132
+ ```python
133
+ "video": ModalityConfig(
134
+ delta_indices=[0], # Current frame only
135
+ modality_keys=[
136
+ "front", # Must match a key in meta/modality.json under "video"
137
+ ],
138
+ )
139
+ ```
140
+
141
+ For multiple cameras:
142
+ ```python
143
+ "video": ModalityConfig(
144
+ delta_indices=[0],
145
+ modality_keys=["front", "wrist"],
146
+ )
147
+ ```
148
+
149
+ ### State Modality
150
+
151
+ Defines proprioceptive observations (joint positions, gripper states, etc.):
152
+
153
+ ```python
154
+ "state": ModalityConfig(
155
+ delta_indices=[0], # Current state
156
+ modality_keys=[
157
+ "single_arm", # Must match keys in meta/modality.json under "state"
158
+ "gripper",
159
+ ],
160
+ )
161
+ ```
162
+
163
+ ### Action Modality
164
+
165
+ Defines the action space and prediction horizon:
166
+
167
+ ```python
168
+ "action": ModalityConfig(
169
+ delta_indices=list(range(0, 16)), # Predict 16 steps into the future
170
+ modality_keys=[
171
+ "single_arm", # Must match keys in meta/modality.json under "action"
172
+ "gripper",
173
+ ],
174
+ action_configs=[
175
+ # One ActionConfig per modality_key
176
+ # single_arm
177
+ ActionConfig(
178
+ rep=ActionRepresentation.RELATIVE, # relative control of the single arm
179
+ type=ActionType.NON_EEF,
180
+ format=ActionFormat.DEFAULT,
181
+ ),
182
+ # gripper
183
+ ActionConfig(
184
+ rep=ActionRepresentation.ABSOLUTE, # absolute control of the gripper
185
+ type=ActionType.NON_EEF,
186
+ format=ActionFormat.DEFAULT,
187
+ ),
188
+ ],
189
+ )
190
+ ```
191
+
192
+ #### Understanding `ActionConfig`
193
+
194
+ Each `ActionConfig` has three required fields and one optional field:
195
+
196
+ **1. `rep` (ActionRepresentation)**
197
+
198
+ Defines how actions should be interpreted:
199
+ - `RELATIVE`: Actions are deltas from the current state (introduced in the UMI paper)
200
+ - `ABSOLUTE`: Actions are target positions
201
+
202
+ Using relative actions will lead to smoother actions, but might suffer from drifting. If you want to use relative actions, please make sure the state and action stored in the dataset are absolute, and the absolute to relative will be handled in the processor.
203
+
204
+ **2. `type` (ActionType)**
205
+
206
+ Specifies the control space:
207
+ - `EEF`: End-effector/Cartesian space control (Expecting a 9-dimensional vector: x, y, z positions + rotation 6D)
208
+ - `NON_EEF`: Joint space control and other non-EEF control spaces (joint angles, positions, gripper positions, etc.)
209
+
210
+ **3. `format` (ActionFormat)**
211
+
212
+ Defines the action representation format:
213
+ - `DEFAULT`: Standard format (e.g., joint angles, gripper positions)
214
+ - `XYZ_ROT6D`: 3D position + 6D rotation representation for end-effector control
215
+ - `XYZ_ROTVEC`: 3D position + rotation vector for end-effector control
216
+
217
+ **4. `state_key` (str | None)**
218
+
219
+ Optional. Specifies the corresponding reference state key for computing relative actions when `rep=RELATIVE`. If not provided, the system will use the action key as the reference state key.
220
+
221
+ Example with `state_key`:
222
+ ```python
223
+ "joint_pos_action_left": ActionConfig(
224
+ rep=ActionRepresentation.RELATIVE,
225
+ type=ActionType.NON_EEF,
226
+ format=ActionFormat.DEFAULT,
227
+ state_key="joint_pos_obs_left", # Use this state to compute relative action
228
+ )
229
+ ```
230
+
231
+ ### Language Modality
232
+
233
+ Defines which language annotations to use:
234
+
235
+ ```python
236
+ "language": ModalityConfig(
237
+ delta_indices=[0],
238
+ modality_keys=["annotation.human.task_description"], # Must match annotation keys in meta/modality.json
239
+ )
240
+ ```
241
+
242
+ ## Complete Example: SO-100
243
+
244
+ Here's the complete SO-100 configuration with explanations:
245
+
246
+ ```python
247
+ so100_config = {
248
+ "video": ModalityConfig(
249
+ delta_indices=[0],
250
+ modality_keys=["front", "wrist"],
251
+ ),
252
+ "state": ModalityConfig(
253
+ delta_indices=[0],
254
+ modality_keys=[
255
+ "single_arm",
256
+ "gripper",
257
+ ],
258
+ ),
259
+ "action": ModalityConfig(
260
+ delta_indices=list(range(0, 16)),
261
+ modality_keys=[
262
+ "single_arm",
263
+ "gripper",
264
+ ],
265
+ action_configs=[
266
+ ActionConfig(
267
+ rep=ActionRepresentation.RELATIVE,
268
+ type=ActionType.NON_EEF,
269
+ format=ActionFormat.DEFAULT,
270
+ ),
271
+ ActionConfig(
272
+ rep=ActionRepresentation.ABSOLUTE,
273
+ type=ActionType.NON_EEF,
274
+ format=ActionFormat.DEFAULT,
275
+ ),
276
+ ],
277
+ ),
278
+ "language": ModalityConfig(
279
+ delta_indices=[0],
280
+ modality_keys=["annotation.human.task_description"],
281
+ ),
282
+ }
283
+ ```
284
+
285
+ ## Key Relationships with `meta/modality.json`
286
+
287
+ The modality configuration's `modality_keys` must reference keys that exist in your dataset's `meta/modality.json`:
288
+
289
+ **Example `meta/modality.json`:**
290
+ ```json
291
+ {
292
+ "state": {
293
+ "single_arm": {"start": 0, "end": 5},
294
+ "gripper": {"start": 5, "end": 6},
295
+ },
296
+ "action": {
297
+ "single_arm": {"start": 0, "end": 5},
298
+ "gripper": {"start": 5, "end": 6},
299
+ },
300
+ "video": {
301
+ "front": {"original_key": "observation.images.front"},
302
+ "wrist": {"original_key": "observation.images.wrist"},
303
+ },
304
+ "annotation": {
305
+ "human.task_description": {
306
+ "original_key": "task_index"
307
+ }
308
+ }
309
+ }
310
+ ```
311
+
312
+ The system will:
313
+ 1. Use `modality_keys` to look up the corresponding entries in `meta/modality.json`
314
+ 2. Extract the correct slices from the concatenated state/action arrays
315
+ 3. Apply the specified transformations (normalization, action representation conversion)
316
+
317
+ ## Registering Your Configuration
318
+
319
+ After defining your configuration, register it so it's available to the training and inference pipelines:
320
+
321
+ ```python
322
+ from gr00t.configs.data.embodiment_configs import register_modality_config
323
+
324
+ your_modality_config = {
325
+ ...
326
+ }
327
+
328
+ register_modality_config(your_modality_config, embodiment_tag=EmbodimentTag.NEW_EMBODIMENT)
329
+ ```
330
+
331
+ Save your configuration to a Python file and pass the path to the `modality_config_path` argument when running the finetuning script.
getting_started/data_preparation.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Robot Data Preparation Guide
2
+
3
+ ## Overview
4
+
5
+ This guide shows how to convert your robot data to work with our flavor of the [LeRobot dataset V2 format](https://github.com/huggingface/lerobot?tab=readme-ov-file#the-lerobotdataset-format) ([LeRobot docs](https://huggingface.co/docs/lerobot)) -- `GR00T LeRobot`. While we have added additional structure, our schema maintains full compatibility with the upstream LeRobot v2. The additional metadata and structure allow for more detailed specification and language annotations for your robot data.
6
+
7
+ > The TLDR: Add a `meta/modality.json` file to your LeRobot v2 dataset and follow the schema below.
8
+
9
+ ## LeRobot v2 Requirements
10
+
11
+ If you already have a dataset in the LeRobot v2 format, you can skip this section.
12
+
13
+ If you have a dataset in the LeRobot v3.0 format, please use [this script](../scripts/lerobot_conversion/convert_v3_to_v2.py) to convert it to the LeRobot v2 format.
14
+
15
+ > **Why LeRobot v2?** GR00T currently uses the LeRobot v2 data format because many upstream datasets (DROID, LIBERO, Bridge, etc.) are published in v2. We plan to support both v2 and v3 formats natively in a future release. For now, please convert v3 datasets to v2 using the script above.
16
+
17
+ If you have a dataset in another format, please convert it to the LeRobot v2 format satisfying the following requirements.
18
+
19
+ ### Structure Requirements
20
+
21
+ The folder should follow a similar structure as below and contain these core folders and files:
22
+
23
+ ```
24
+ .
25
+ ├─meta
26
+ │ ├─episodes.jsonl
27
+ │ ├─modality.json # -> GR00T LeRobot specific
28
+ │ ├─info.json
29
+ │ └─tasks.jsonl
30
+ ├─videos
31
+ │ └─chunk-000
32
+ │ └─observation.images.ego_view
33
+ │ └─episode_000001.mp4
34
+ │ └─episode_000000.mp4
35
+ └─data
36
+ └─chunk-000
37
+ ├─episode_000001.parquet
38
+ └─episode_000000.parquet
39
+ ```
40
+
41
+ ### Video Observations (video/chunk-*)
42
+ The videos folder will contain the mp4 files associated with each episode following episode_00000X.mp4 naming format where X indicates the episode number.
43
+ **Requirements**:
44
+ - Must be stored as MP4 files.
45
+ - Should be named using the format: `observation.images.<video_name>`
46
+
47
+
48
+ ### Data (data/chunk-*)
49
+ The data folder will contain all of the parquet files associated with each episode following episode_00000X.parquet naming format where X indicates the episode number.
50
+ Each parquet file will contain:
51
+ - State information: stored as observation.state which is a 1D concatenated array of all state modalities.
52
+ - Action: stored as action which is a 1D concatenated array of all action modalities.
53
+ - Timestamp: stored as timestamp which is a float point number of the starting time.
54
+ - Annotations: stored as annotation.<annotation_source>.<annotation_type>(.<annotation_name>) (see the annotation field in the example configuration for example naming.). No other columns should have the annotation prefix, see the (multiple-annotation-support) if interested in adding multiple annotations.
55
+
56
+ #### Example Parquet File
57
+ Here is a sample of the `cube_to_bowl` dataset that is present in the [demo_data](../demo_data/cube_to_bowl_5/) directory.
58
+ ```
59
+ {
60
+ "observation.state":[-0.01,...,0], // 1D array: all state modalities concatenated per modality.json order
61
+ "action":[-0.010,...,0], // 1D array: all action modalities concatenated per modality.json order
62
+ "timestamp":0.049, // float: wall-clock time of this observation (seconds)
63
+ "annotation.human.action.task_description":0, // int: index into meta/tasks.jsonl for the language instruction
64
+ "task_index":0, // int: task identifier (same as annotation index for single-task)
65
+ "annotation.human.validity":1, // int: index into meta/tasks.jsonl for validity label
66
+ "episode_index":0, // int: which episode this frame belongs to
67
+ "index":0, // int: global frame index across all episodes in the dataset
68
+ "next.reward":0, // float: reward at the next timestep (0 if unused)
69
+ "next.done":false // bool: true if this is the last frame of the episode
70
+ }
71
+ ```
72
+
73
+ ### Meta
74
+
75
+ - `episodes.jsonl` contains a list of all the episodes in the entire dataset. Each episode contains a list of tasks and the length of the episode.
76
+ - `tasks.jsonl` contains a list of all the tasks in the entire dataset.
77
+ - `info.json` contains the dataset information.
78
+
79
+ #### meta/tasks.jsonl
80
+ Here is a sample of the `meta/tasks.jsonl` file that contains the task descriptions.
81
+ ```
82
+ {"task_index": 0, "task": "pick the squash from the counter and place it in the plate"}
83
+ {"task_index": 1, "task": "valid"}
84
+ ```
85
+
86
+ You can refer the task index in the parquet file to get the task description. So in this case, the `annotation.human.action.task_description` for the first observation is "pick the squash from the counter and place it in the plate" and `annotation.human.validity` is "valid".
87
+
88
+ `tasks.jsonl` contains a list of all the tasks in the entire dataset.
89
+
90
+ #### meta/episodes.jsonl
91
+
92
+ Here is a sample of the `meta/episodes.jsonl` file that contains the episode information.
93
+
94
+ ```
95
+ {"episode_index": 0, "tasks": [...], "length": 416}
96
+ {"episode_index": 1, "tasks": [...], "length": 470}
97
+ ```
98
+
99
+ `episodes.jsonl` contains a list of all the episodes in the entire dataset. Each episode contains a list of tasks and the length of the episode.
100
+
101
+ ## GR00T LeRobot Specific Requirements
102
+
103
+ ### The `meta/modality.json` Configuration
104
+
105
+ We require an additional metadata file `meta/modality.json` that is not present in the standard LeRobot format. This file provides detailed metadata about state and action modalities, enabling:
106
+
107
+ - **Separate Data Storage and Interpretation:**
108
+ - **State and Action:** Stored as concatenated float32 arrays. The `modality.json` file supplies the metadata necessary to interpret these arrays as distinct, fine-grained fields.
109
+ - **Video:** Stored as separate files, with the configuration file allowing them to be renamed to a standardized format.
110
+ - **Annotations:** Keeps track of all annotation fields. If there are no annotations, do not include the `annotation` field in the configuration file.
111
+ - **Fine-Grained Splitting:** Divides the state and action arrays into more semantically meaningful fields.
112
+ - **Clear Mapping:** Explicit mapping of data dimensions.
113
+ - **Sophisticated Data Transformations:** Supports field-specific normalization and rotation transformations during training.
114
+
115
+ #### Schema
116
+
117
+ ```json
118
+ {
119
+ "state": {
120
+ "<state_key>": {
121
+ "start": <int>, // Starting index in the state array
122
+ "end": <int> // Ending index in the state array
123
+ }
124
+ },
125
+ "action": {
126
+ "<action_key>": {
127
+ "start": <int>, // Starting index in the action array
128
+ "end": <int> // Ending index in the action array
129
+ }
130
+ },
131
+ "video": {
132
+ "<new_key>": {
133
+ "original_key": "<original_video_key>"
134
+ }
135
+ },
136
+ "annotation": {
137
+ "<annotation_key>": {} // Empty dictionary to maintain consistency with other modalities
138
+ }
139
+ }
140
+ ```
141
+
142
+ #### Example
143
+
144
+ For a concrete example of `modality.json` and the full dataset structure, see the publicly available datasets on HuggingFace:
145
+ [nvidia/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim](https://huggingface.co/datasets/nvidia/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/tree/main).
146
+
147
+ You can also find a working example in the included demo data at [`demo_data/cube_to_bowl_5/meta/modality.json`](../demo_data/cube_to_bowl_5/meta/modality.json).
148
+
149
+ #### Notes
150
+
151
+ - All indices are zero-based and follow Python's array slicing convention (`[start:end]`).
152
+
153
+ ## GR00T LeRobot Extensions to Standard LeRobot
154
+ GR00T LeRobot is a flavor of the standard LeRobot format with more opinionated requirements:
155
+ - We will compute `meta/stats.json` and `meta/relative_stats.json` for each dataset, and store them in the `meta` folder.
156
+ - Proprioceptive states must always be included in the "observation.state" keys.
157
+ - We support multi-channel annotation formats (e.g., coarsegrained, finetuned), allowing users to add as many annotation channels as needed via the `annotation.<annotation_source>.<annotation_type>` key.
158
+ - We require an additional metadata file `meta/modality.json` that is not present in the standard LeRobot format.
159
+
160
+ ### Multiple Annotation Support
161
+
162
+ To support multiple annotations within a single parquet file, users may add extra columns to the parquet file. Users should treat these columns the same way as the `task_index` column in the original LeRobot v2 dataset:
163
+
164
+ In LeRobot v2, actual language descriptions are stored in a row of the `meta/tasks.jsonl` file, while the parquet file stores only the corresponding index in the `task_index` column. We follow the same convention and store the corresponding index for each annotation in the `annotation.<annotation_source>.<annotation_type>` column. Although the `task_index` column may still be used for the default annotation, a dedicated column `annotation.<annotation_source>.<annotation_type>` is required to ensure it is loadable by our custom data loader.
getting_started/finetune_new_embodiment.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tune on Custom Embodiments ("NEW_EMBODIMENT")
2
+
3
+ This guide demonstrates how to finetune GR00T on your own robot data and configuration. We provide a complete example for the Huggingface [SO-100](https://github.com/TheRobotStudio/SO-ARM100) robot under `examples/SO100`, which uses `demo_data/cube_to_bowl_5` as the demo dataset.
4
+
5
+ ## Step 1: Prepare Your Data
6
+
7
+ Prepare your data in **GR00T-flavored LeRobot v2 format** by following the [data preparation guide](data_preparation.md).
8
+
9
+ ## Step 2: Prepare Your Modality Configuration
10
+
11
+ Define your own modality configuration by following the [modality config guide](data_config.md). Below is an example configuration that corresponds to the demo data:
12
+ ```python
13
+ from gr00t.configs.data.embodiment_configs import register_modality_config
14
+ from gr00t.data.embodiment_tags import EmbodimentTag
15
+ from gr00t.data.types import (
16
+ ActionConfig,
17
+ ActionFormat,
18
+ ActionRepresentation,
19
+ ActionType,
20
+ ModalityConfig,
21
+ )
22
+
23
+
24
+ so100_config = {
25
+ # Video: use current frame only ([0]); list camera view names matching modality.json
26
+ "video": ModalityConfig(
27
+ delta_indices=[0],
28
+ modality_keys=[
29
+ "front",
30
+ "wrist",
31
+ ],
32
+ ),
33
+ # State: current proprioceptive reading; keys must match modality.json "state" entries
34
+ "state": ModalityConfig(
35
+ delta_indices=[0],
36
+ modality_keys=[
37
+ "single_arm",
38
+ "gripper",
39
+ ],
40
+ ),
41
+ # Action: 16-step prediction horizon; each key needs an ActionConfig
42
+ "action": ModalityConfig(
43
+ delta_indices=list(range(0, 16)), # predict 16 future steps
44
+ modality_keys=[
45
+ "single_arm",
46
+ "gripper",
47
+ ],
48
+ action_configs=[
49
+ # single_arm: RELATIVE = delta from current state (better generalization)
50
+ ActionConfig(
51
+ rep=ActionRepresentation.RELATIVE,
52
+ type=ActionType.NON_EEF, # joint-space, not end-effector
53
+ format=ActionFormat.DEFAULT,
54
+ ),
55
+ # gripper: ABSOLUTE = target position (binary open/close works better absolute)
56
+ ActionConfig(
57
+ rep=ActionRepresentation.ABSOLUTE,
58
+ type=ActionType.NON_EEF,
59
+ format=ActionFormat.DEFAULT,
60
+ ),
61
+ ],
62
+ ),
63
+ # Language: task instruction from annotation field in the dataset
64
+ "language": ModalityConfig(
65
+ delta_indices=[0],
66
+ modality_keys=["annotation.human.task_description"],
67
+ ),
68
+ }
69
+
70
+ # Important: always register under EmbodimentTag.NEW_EMBODIMENT for custom embodiments
71
+ register_modality_config(so100_config, embodiment_tag=EmbodimentTag.NEW_EMBODIMENT)
72
+ ```
73
+
74
+ ## Step 3: Run Fine-tuning
75
+
76
+ We'll use `gr00t/experiment/launch_finetune.py` as the entry point. Ensure that the uv environment is enabled before launching. You can do this by running the command `uv run bash <example_script_name>`.
77
+
78
+ ### View Available Arguments
79
+ ```bash
80
+ # Display all available arguments
81
+ uv run python gr00t/experiment/launch_finetune.py --help
82
+ ```
83
+
84
+ ### Execute Fine-tuning
85
+ ```bash
86
+ # Configure for single GPU
87
+ export NUM_GPUS=1
88
+ CUDA_VISIBLE_DEVICES=0 uv run python \
89
+ gr00t/experiment/launch_finetune.py \
90
+ --base-model-path nvidia/GR00T-N1.7-3B \
91
+ --dataset-path ./demo_data/cube_to_bowl_5 \
92
+ --embodiment-tag NEW_EMBODIMENT \
93
+ --modality-config-path examples/SO100/so100_config.py \
94
+ --num-gpus $NUM_GPUS \
95
+ --output-dir /tmp/so100 \
96
+ --save-total-limit 5 \
97
+ --save-steps 2000 \
98
+ --max-steps 2000 \
99
+ --use-wandb \
100
+ --global-batch-size 32 \
101
+ --color-jitter-params brightness 0.3 contrast 0.4 saturation 0.5 hue 0.08 \
102
+ --dataloader-num-workers 4
103
+ ```
104
+
105
+ ### Key Parameters
106
+
107
+ | Parameter | Description |
108
+ |-----------|-------------|
109
+ | `--base-model-path` | Path to the pre-trained base model checkpoint |
110
+ | `--dataset-path` | Path to your training dataset |
111
+ | `--embodiment-tag` | Tag to identify your robot embodiment |
112
+ | `--modality-config-path` | Path to user-specified modality config (required only for `NEW_EMBODIMENT` tag) |
113
+ | `--output-dir` | Directory where checkpoints will be saved |
114
+ | `--save-steps` | Save checkpoint every N steps |
115
+ | `--max-steps` | Total number of training steps |
116
+ | `--use-wandb` | Enable Weights & Biases logging for experiment tracking |
117
+
118
+ > **Note:** Validation during fine-tuning is disabled by default (`eval_strategy="no"` in the training config). To enable periodic validation, pass `--eval-strategy steps --eval-steps 500` (runs validation every 500 steps) or `--eval-strategy epoch` (runs validation every epoch). You can also adjust `--eval-batch-size` (default: 2).
119
+
120
+ ## Step 4: Open Loop Evaluation
121
+
122
+ After finetuning, evaluate the model's performance using open loop evaluation:
123
+ ```bash
124
+ uv run python gr00t/eval/open_loop_eval.py \
125
+ --dataset-path ./demo_data/cube_to_bowl_5 \
126
+ --embodiment-tag NEW_EMBODIMENT \
127
+ --model-path /tmp/so100/checkpoint-2000 \
128
+ --traj-ids 0 \
129
+ --action-horizon 16 \
130
+ --steps 400 \
131
+ --modality-keys single_arm gripper
132
+ ```
133
+
134
+ ### `open_loop_eval.py` Parameters
135
+
136
+ | Parameter | Default | Description |
137
+ |-----------|---------|-------------|
138
+ | `--dataset-path` | `demo_data/cube_to_bowl_5/` | Path to LeRobot-format dataset |
139
+ | `--embodiment-tag` | `new_embodiment` | Robot embodiment tag (case-insensitive) |
140
+ | `--model-path` | `None` | Path to checkpoint. If omitted, connects to a running server via `--host`/`--port` |
141
+ | `--traj-ids` | `[0]` | Episode indices to evaluate (space-separated, e.g., `0 1 2`) |
142
+ | `--action-horizon` | `16` | Action steps predicted per inference call |
143
+ | `--steps` | `200` | Max steps per trajectory (capped by actual trajectory length) |
144
+ | `--denoising-steps` | `4` | Diffusion denoising iterations |
145
+ | `--save-plot-path` | `None` | Directory to save GT-vs-predicted comparison plots |
146
+ | `--modality-keys` | `None` | Action keys to plot. If omitted, plots all action dimensions |
147
+ | `--host` / `--port` | `127.0.0.1` / `5555` | Server address when `--model-path` is omitted |
148
+
149
+ ### Example Evaluation Result
150
+
151
+ The evaluation generates visualizations comparing predicted actions against ground truth trajectories:
152
+
153
+ <img src="../media/open_loop_eval_so100.jpg" width="800" alt="Open loop evaluation results showing predicted vs ground truth trajectories" />
getting_started/hardware_recommendation.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hardware Recommendations
2
+
3
+ GR00T N1.7 has two hardware profiles: **fine-tuning** (needs GPU VRAM and compute) and **inference/deployment** (needs low latency). This guide helps you choose the right hardware for each.
4
+
5
+ ![Workflow Diagram](../media/GR00T-reference-arch-diagram.png "Post-training and deployment workflow")
6
+
7
+ ---
8
+
9
+ ## Inference Hardware
10
+
11
+ **Minimum:** 1 GPU with 16 GB+ VRAM, CUDA 12.6+.
12
+
13
+ The table below summarizes end-to-end inference frequency across tested platforms (GR00T N1.7, 4 denoising steps, 1 camera):
14
+
15
+ | Platform | VRAM | PyTorch Eager | With TensorRT | Use Case |
16
+ |----------|------|---------------|---------------|----------|
17
+ | H100 80GB HBM3 | 80 GB | 11.7 Hz | 35.9 Hz | High-frequency control, multi-env batch inference |
18
+ | H20 96GB HBM3 | 96 GB | 12.0 Hz | 29.4 Hz | Cost-effective datacenter inference |
19
+ | RTX Pro 6000 Blackwell | 96 GB | 12.8 Hz | 35.9 Hz | Workstation inference, development |
20
+ | RTX Pro 5000 72GB | 72 GB | 7.9 Hz | 24.7 Hz | Workstation inference |
21
+ | L40 | 48 GB | 7.8 Hz | 26.0 Hz | Cloud inference |
22
+ | L20 | 48 GB | 7.1 Hz | 23.3 Hz | Cloud inference |
23
+ | DGX Spark | 128 GB shared | 7.9 Hz | 10.1 Hz | Desktop edge, prototyping |
24
+ | AGX Thor | 128 GB shared | 6.9 Hz | 10.7 Hz | Robot-mounted edge deployment |
25
+ | Orin* | 64 GB shared | 2.9 Hz | 4.6 Hz | Legacy Jetson edge |
26
+
27
+ > *Orin uses DiT-only TensorRT (TRT 10.3 does not support the backbone engine). All other platforms use the full TensorRT pipeline.
28
+
29
+ ### Key Insights
30
+
31
+ - **30+ Hz** (H100, RTX Pro 6000 with TensorRT): suitable for high-frequency closed-loop control where sub-30 ms latency matters.
32
+ - **10+ Hz** (Thor, Spark with TRT; most dGPUs with torch.compile): sufficient for typical manipulation tasks running at a 10 Hz control rate.
33
+ - **< 5 Hz** (Orin): only suitable for slow, non-reactive tasks. Orin's TRT 10.3 cannot accelerate the backbone — gains are limited to DiT-only mode.
34
+ - **TensorRT Full Pipeline** provides 1.5--3.3x speedup over PyTorch Eager depending on platform. Biggest gains are on datacenter GPUs where backbone acceleration is significant.
35
+ - **torch.compile** is a good zero-effort middle ground (no engine build step), achieving 1.1--1.9x speedup across all platforms.
36
+
37
+ > For full per-component latency breakdown, see the [Deployment Benchmark Results](../scripts/deployment/README.md#benchmark-results).
38
+
39
+ ---
40
+
41
+ ## Fine-Tuning Hardware
42
+
43
+ **Minimum:** 1 GPU with 40 GB+ VRAM. GR00T N1.7 is a ~3B parameter model (bfloat16).
44
+
45
+ | Setup | GPUs | VRAM per GPU | Global Batch Size | Notes |
46
+ |-------|------|-------------|-------------------|-------|
47
+ | Quick start / prototyping | 1x H100, L40, or A100 | 40--80 GB | 32 | Single GPU; sufficient for demo datasets |
48
+ | Recommended | 4--8x H100 or L40 | 40--80 GB each | 64--640 | Multi-GPU via torchrun; faster convergence |
49
+ | Full scale | 8x RTX Pro 6000 or DGX | 96 GB each | 640 | Large datasets, production fine-tuning |
50
+
51
+ ### Key Details
52
+
53
+ - **Default fine-tuning** tunes the projector + diffusion action head (not the full LLM backbone), keeping peak VRAM under ~35 GB per GPU.
54
+ - **Enabling `--tune-llm` or `--tune-visual`** significantly increases VRAM — 80 GB+ per GPU recommended.
55
+ - **`--gradient-accumulation-steps`** can compensate for fewer GPUs. For example, 4 GPUs with 8 accumulation steps and per-GPU batch of 8 gives an effective global batch size of 256.
56
+ - **Reduce `--num-shards-per-epoch`** if host memory (not VRAM) is limited — this controls how much dataset is preloaded into RAM.
57
+
58
+ ---
59
+
60
+ ## Software Requirements
61
+
62
+ | Requirement | Version |
63
+ |-------------|---------|
64
+ | Python | 3.10 |
65
+ | CUDA | 12.6+ (dGPU, Orin) / 13.0 (Thor, Spark) |
66
+ | PyTorch | 2.7+ |
67
+ | OS | Ubuntu 22.04+ (dGPU), JetPack 6.2 (Orin), Ubuntu 24.04 (Thor, Spark) |
68
+ | Package manager | [uv](https://docs.astral.sh/uv/) (recommended) |
69
+
70
+ Platform-specific installation instructions: see the [Deployment Guide](../scripts/deployment/README.md).
71
+
72
+ ---
73
+
74
+ ## Recommended Configurations
75
+
76
+ ### Starter Kit
77
+
78
+ For development, small-scale fine-tuning, and edge deployment:
79
+
80
+ | Component | Recommendation |
81
+ |-----------|---------------|
82
+ | Training | 1--4x L40 (48 GB) or RTX Pro 5000/6000 workstation |
83
+ | Edge Deployment | [Jetson AGX Thor](https://developer.nvidia.com/embedded/jetson) Developer Kit (128 GB shared memory, Blackwell GPU) |
84
+ | Storage | 500 GB+ SSD (datasets + checkpoints) |
85
+
86
+ ### Center of Excellence
87
+
88
+ For production fine-tuning and high-throughput inference:
89
+
90
+ | Component | Recommendation |
91
+ |-----------|---------------|
92
+ | Training | DGX with 8x H100/B200, or RTX Pro Server with 8x RTX Pro 6000 Blackwell |
93
+ | Inference Server | H100 or H20 node with TensorRT Full Pipeline (35+ Hz per GPU) |
94
+ | Edge Deployment | [Jetson AGX Thor](https://developer.nvidia.com/embedded/jetson) or [DGX Spark](https://developer.nvidia.com/dgx-spark) |
95
+ | Storage | Scalable networked storage (NFS/S3) for large-scale datasets |
getting_started/policy.md ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Understanding the GR00T Policy API
2
+
3
+ This guide explains how to use the `Gr00tPolicy` class to load and run inference with your trained model. After training, you'll use this API to integrate your model with evaluation environments.
4
+
5
+ ## Loading the Policy
6
+
7
+ Initialize a policy by providing the embodiment tag, model checkpoint path, and device:
8
+
9
+ ```python
10
+ from gr00t.policy import Gr00tPolicy
11
+ from gr00t.data.embodiment_tags import EmbodimentTag
12
+
13
+ # Load your trained model
14
+ policy = Gr00tPolicy(
15
+ model_path="/path/to/your/checkpoint",
16
+ embodiment_tag=EmbodimentTag.NEW_EMBODIMENT, # or other embodiment tags
17
+ device="cuda:0", # or "cpu", or device index like 0
18
+ strict=True # Enable input/output validation (recommended during development)
19
+ )
20
+ ```
21
+
22
+ **Parameters:**
23
+
24
+ | Parameter | Type | Default | Description |
25
+ |-----------|------|---------|-------------|
26
+ | `embodiment_tag` | `EmbodimentTag \| str` | *(required)* | Robot type; accepts enum or case-insensitive string (e.g., `"NEW_EMBODIMENT"`) |
27
+ | `model_path` | `str` | *(required)* | Path to model checkpoint directory (local path or HuggingFace model ID) |
28
+ | `device` | `str \| int` | *(required)* | Inference device: `"cuda:0"`, `0`, or `"cpu"` |
29
+ | `strict` | `bool` | `True` | Validates observation shapes and dtypes at runtime. Recommended during development; disable in production for speed |
30
+
31
+ ## Inference Parameter Guide
32
+
33
+ When running inference scripts (e.g., `standalone_inference_script.py`, `open_loop_eval.py`), the key parameters are:
34
+
35
+ ### `--embodiment-tag`
36
+
37
+ Determines which modality config the model uses (state/action keys, normalization). **Must match the robot type of your dataset.**
38
+
39
+ The tag is **case-insensitive** and accepts either the enum name or the string value.
40
+ For example, `--embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` and `--embodiment-tag LIBERO_PANDA` all resolve correctly. An unknown tag will produce an error listing all known options.
41
+
42
+ - **Pretrain tags** (e.g., `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT`, `XDOF`, `REAL_G1`) — use for zero-shot inference on datasets that match the pretrained embodiment. The modality config is loaded from the base model checkpoint.
43
+ - **Posttrain tags** (`OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT`, `LIBERO_PANDA`, `SIMPLER_ENV_GOOGLE`, `SIMPLER_ENV_WIDOWX`) — require a finetuned checkpoint. Passing these to the base model will produce an error.
44
+ - **`NEW_EMBODIMENT`** — use for custom robots. Requires a `--modality-config-path` during finetuning. After finetuning, the config is saved in the checkpoint and loaded automatically during inference.
45
+ - Only one `NEW_EMBODIMENT` modality config may be registered per Python process. Examples like [`examples/SO100/so100_config.py`](../examples/SO100/so100_config.py) and [`examples/mask-guided-background-suppression/so101_config.py`](../examples/mask-guided-background-suppression/so101_config.py) each register under this tag; importing both in the same process will fail. In normal CLI use the selected `--modality-config-path` is the only one imported, so this is not an issue — just don't wire both configs into the same script.
46
+
47
+ #### Known Embodiment Tags
48
+
49
+ **Pretrain tags** — baked into the base model (`nvidia/GR00T-N1.7-3B`), ready for zero-shot inference:
50
+
51
+ | Tag | Robot / Data Source | Value |
52
+ |-----|---------------------|-------|
53
+ | `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` | DROID (relative EEF + joint) | `oxe_droid_relative_eef_relative_joint` |
54
+ | `XDOF` | Generic X-DOF (relative EEF + joint) | `xdof_relative_eef_relative_joint` |
55
+ | `XDOF_SUBTASK` | Generic X-DOF (subtask variant) | `xdof_relative_eef_relative_joint_subtask` |
56
+ | `REAL_G1` | Real-world Unitree G1 (relative EEF + joint) | `real_g1_relative_eef_relative_joints` |
57
+ | `REAL_R1_PRO_SHARPA` | Real-world R1 Pro Sharpa (relative EEF) | `real_r1_pro_sharpa_relative_eef` |
58
+ | `REAL_R1_PRO_SHARPA_HUMAN` | R1 Pro Sharpa — human teleop data | `real_r1_pro_sharpa_relative_eef_human` |
59
+ | `REAL_R1_PRO_SHARPA_MAXINSIGHTS` | R1 Pro Sharpa — MaxInsights (single-cam) | `real_r1_pro_sharpa_relative_eef_maxinsights` |
60
+ | `REAL_R1_PRO_SHARPA_MECKA` | R1 Pro Sharpa — Mecka (single-cam) | `real_r1_pro_sharpa_relative_eef_mecka` |
61
+
62
+ **Posttrain tags** — require a finetuned checkpoint (not usable with the base model directly):
63
+
64
+ | Tag | Robot | Value | Checkpoint |
65
+ |-----|-------|-------|------------|
66
+ | `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` | DROID (relative EEF + joint) | `oxe_droid_relative_eef_relative_joint` | `nvidia/GR00T-N1.7-DROID` |
67
+ | `LIBERO_PANDA` | LIBERO Panda | `libero_sim` | `nvidia/GR00T-N1.7-LIBERO` |
68
+ | `SIMPLER_ENV_GOOGLE` | SimplerEnv Google Robot | `simpler_env_google` | `nvidia/GR00T-N1.7-SimplerEnv-Fractal` |
69
+ | `SIMPLER_ENV_WIDOWX` | SimplerEnv WidowX | `simpler_env_widowx` | `nvidia/GR00T-N1.7-SimplerEnv-Bridge` |
70
+
71
+ **Generic tag** for any new robot: `NEW_EMBODIMENT` (requires `--modality-config-path`)
72
+
73
+ > **`OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` appears in both tables by design.** DROID is supported both zero-shot (via the base model) and via the finetuned `nvidia/GR00T-N1.7-DROID` checkpoint. Pass the tag with either `--model-path nvidia/GR00T-N1.7-3B` (zero-shot) or `--model-path nvidia/GR00T-N1.7-DROID` (finetuned); see `examples/DROID/README.md`.
74
+
75
+ > **Important:** Pretrain tags work with the base model for zero-shot inference. Posttrain tags require a finetuned checkpoint — using them with the base model will fail with an error listing the supported tags. You also cannot mix embodiment tags and datasets (e.g., `--embodiment-tag LIBERO_PANDA` expects LIBERO state keys and will fail on an SO100 dataset).
76
+
77
+ ### `--traj-ids`
78
+
79
+ Which episode indices to evaluate. Check your dataset's `meta/episodes.jsonl` to see available episodes. For example, `--traj-ids 0 1 2` runs on the first 3 episodes.
80
+
81
+ ### `--action-horizon`
82
+
83
+ Number of future action steps predicted per inference call. The model's maximum is 16 (from model config). Common values:
84
+ - `16` — full horizon, used for open-loop evaluation
85
+ - `8` — shorter horizon, common for real-time deployment where actions are re-planned frequently
86
+
87
+ This parameter is robot-agnostic — the same value works across different datasets and embodiments.
88
+
89
+ ### `--inference-mode`
90
+
91
+ - `pytorch` — standard PyTorch inference (default, no setup required)
92
+ - `tensorrt` — accelerated inference using TensorRT engine (requires ONNX export + engine build first, see [Deployment Guide](../scripts/deployment/README.md))
93
+
94
+ ### Expected Output (PyTorch mode)
95
+
96
+ The inference scripts produce:
97
+ - Per-trajectory **MSE** and **MAE** (unnormalized action prediction error vs ground truth)
98
+ - **Timing stats**: model load time, avg/min/max/P90 inference time per step
99
+ - **Summary**: average MSE/MAE across all trajectories
100
+
101
+ ### Example: Matching Parameters to Dataset
102
+
103
+ | Dataset | Embodiment Tag | Notes |
104
+ |---------|---------------|-------|
105
+ | `demo_data/droid_sample` | `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` | DROID — works with base model (zero-shot) or finetuned `nvidia/GR00T-N1.7-DROID` |
106
+ | `demo_data/libero_demo` | `LIBERO_PANDA` | LIBERO Panda — uses finetuned checkpoint from `nvidia/GR00T-N1.7-LIBERO` (must be downloaded locally first, see [README](../README.md)) |
107
+ | `demo_data/cube_to_bowl_5` | `NEW_EMBODIMENT` | SO100 arm — only works with a finetuned checkpoint, not the base model |
108
+
109
+ ## Understanding the Observation Format
110
+
111
+ The policy expects observations as a nested dictionary with three modalities:
112
+
113
+ ```python
114
+ observation = {
115
+ "video": {
116
+ "camera_name": np.ndarray, # Shape: (B, T, H, W, 3), dtype: uint8
117
+ # ... one entry per camera
118
+ },
119
+ "state": {
120
+ "state_name": np.ndarray, # Shape: (B, T, D), dtype: float32
121
+ # ... one entry per state stream
122
+ },
123
+ "language": {
124
+ "task": [[str]], # Shape: (B, 1), list of lists of strings
125
+ }
126
+ }
127
+ ```
128
+
129
+ ### Dimensions
130
+
131
+ - **`B`**: Batch size (number of parallel environments)
132
+ - **`T`**: Temporal horizon (number of historical observations)
133
+ - **`H, W`**: Image height and width
134
+ - **`D`**: State dimension
135
+ - **`C`**: Number of channels (must be 3 for RGB)
136
+
137
+ ### Data Type Requirements
138
+
139
+ - **Videos** must be `np.uint8` arrays with RGB pixel values in range [0, 255]
140
+ - **States** must be `np.float32` arrays
141
+ - **Language** instructions are lists of lists of strings
142
+
143
+ ### Important Notes
144
+
145
+ - The temporal horizon `T` is determined by your model's training configuration
146
+ - Different modalities may have different temporal horizons (query via `get_modality_config()`)
147
+ - Language instructions are typically single timestep (`T=1`)
148
+ - All arrays in a batch must have the same batch size `B`
149
+
150
+ ## Understanding the Action Format
151
+
152
+ The policy returns actions in a similar nested structure:
153
+
154
+ ```python
155
+ action = {
156
+ "action_name": np.ndarray, # Shape: (B, T, D), dtype: float32
157
+ # ... one entry per action stream
158
+ }
159
+ ```
160
+
161
+ ### Dimensions
162
+
163
+ - **`B`**: Batch size (matches input batch size)
164
+ - **`T`**: Action horizon (number of future action steps to predict)
165
+ - **`D`**: Action dimension (e.g., 7 for arm joints, 1 for gripper)
166
+
167
+ ### Important Notes
168
+
169
+ - Actions are returned in **physical units** (e.g., joint positions in radians, velocities in rad/s)
170
+ - Actions are **not normalized** - they're ready to send to your robot controller
171
+ - The action horizon `T` allows predicting multiple future steps (useful for action chunking)
172
+
173
+ ## Running Inference
174
+
175
+ Use the `get_action()` method to compute actions from observations:
176
+
177
+ ```python
178
+ # Get action from current observation
179
+ action, info = policy.get_action(observation)
180
+
181
+ # Access the action array
182
+ arm_action = action["action_name"] # Shape: (B, T, D)
183
+
184
+ # Extract the first action to execute
185
+ next_action = arm_action[:, 0, :] # Shape: (B, D)
186
+ ```
187
+
188
+ The method returns a tuple of:
189
+ - `action`: Dictionary of action arrays
190
+ - `info`: Dictionary of additional information (currently empty, reserved for future use)
191
+
192
+ ## Querying Modality Configurations
193
+
194
+ To understand what observations your policy expects and what actions it produces, query the modality configuration:
195
+
196
+ ```python
197
+ # Get modality configs for your embodiment
198
+ modality_configs = policy.get_modality_config()
199
+
200
+ # Check what camera keys are expected
201
+ video_keys = modality_configs["video"].modality_keys
202
+ print(f"Expected cameras: {video_keys}")
203
+
204
+ # Check video temporal horizon
205
+ video_horizon = len(modality_configs["video"].delta_indices)
206
+ print(f"Video frames needed: {video_horizon}")
207
+
208
+ # Check state keys and horizon
209
+ state_keys = modality_configs["state"].modality_keys
210
+ state_horizon = len(modality_configs["state"].delta_indices)
211
+ print(f"Expected states: {state_keys}, horizon: {state_horizon}")
212
+
213
+ # Check action keys and horizon
214
+ action_keys = modality_configs["action"].modality_keys
215
+ action_horizon = len(modality_configs["action"].delta_indices)
216
+ print(f"Action outputs: {action_keys}, horizon: {action_horizon}")
217
+ ```
218
+
219
+ This is especially useful when:
220
+ - You're unsure what observations your trained model expects
221
+ - You need to verify the temporal horizons for each modality
222
+ - You're debugging observation/action format mismatches
223
+
224
+ ## Resetting the Policy
225
+
226
+ Reset the policy between episodes:
227
+
228
+ ```python
229
+ # Reset policy state (if any) between episodes
230
+ info = policy.reset()
231
+ ```
232
+
233
+ Currently, the policy is stateless, but calling `reset()` is good practice for future compatibility.
234
+
235
+ ## Adapting the Policy to Your Environment
236
+
237
+ Most environments use different observation/action formats than the Policy API expects. You'll typically need to write a **policy wrapper** that:
238
+
239
+ 1. **Transforms observations**: Convert your environment's observation format to the Policy API format
240
+ 2. **Calls the policy**: Use `policy.get_action()` to compute actions
241
+ 3. **Transforms actions**: Convert the policy's actions back to your environment's format
242
+
243
+ ### Example Workflow
244
+
245
+ ```python
246
+ # In your environment loop
247
+ env_obs = env.reset() # Environment-specific format
248
+
249
+ # Transform to Policy API format
250
+ policy_obs = transform_observation(env_obs)
251
+
252
+ # Get action from policy
253
+ policy_action, _ = policy.get_action(policy_obs)
254
+
255
+ # Transform back to environment format
256
+ env_action = transform_action(policy_action)
257
+
258
+ # Execute in environment
259
+ env_obs, reward, done, info = env.step(env_action)
260
+ ```
261
+
262
+ ### Using Server-Client Architecture for Remote Inference
263
+
264
+ For many use cases, especially when working with real robots or distributed systems, you may want to run the policy on a separate machine (e.g., a GPU server) and send observations/actions over the network. GR00T provides a built-in server-client architecture using ZeroMQ for this purpose.
265
+
266
+ #### Why Use Server-Client Architecture?
267
+
268
+ - **Separate compute resources**: Run policy inference on a GPU server while controlling the robot from a different machine
269
+ - **Dependency isolation**: Avoid dependency issues with the client policy
270
+
271
+ ```mermaid
272
+ sequenceDiagram
273
+ participant Robot as Robot / Sim Client
274
+ participant Client as PolicyClient (ZMQ REQ)
275
+ participant Server as PolicyServer (ZMQ REP)
276
+ participant Policy as Gr00tPolicy (GPU)
277
+
278
+ Robot->>Client: observation dict
279
+ Client->>Server: msgpack(endpoint="get_action", data=obs)
280
+ Server->>Policy: policy.get_action(obs)
281
+ Policy-->>Server: (action_dict, info_dict)
282
+ Server-->>Client: msgpack(action, info)
283
+ Client-->>Robot: action dict
284
+ ```
285
+
286
+ #### Starting the Policy Server
287
+
288
+ Launch the server using the `run_gr00t_server.py` script:
289
+
290
+ ```bash
291
+ uv run python gr00t/eval/run_gr00t_server.py \
292
+ --embodiment-tag NEW_EMBODIMENT \
293
+ --model-path /path/to/your/checkpoint \
294
+ --device cuda:0 \
295
+ --host 0.0.0.0 \
296
+ --port 5555
297
+ ```
298
+
299
+ **Parameters:**
300
+ - `--embodiment-tag`: The embodiment tag for your robot (e.g., `NEW_EMBODIMENT`)
301
+ - `--model-path`: Path to your trained model checkpoint directory
302
+ - `--device`: Device to run inference on (`cuda:0`, `cuda:1`, `cpu`, etc.)
303
+ - `--host`: Host address (`127.0.0.1` for local only, `0.0.0.0` to accept external connections)
304
+ - `--port`: Port number (default: 5555)
305
+ - `--strict` / `--no-strict`: Enable or disable input/output validation (default: True)
306
+ - `--use-sim-policy-wrapper`: Whether to use `Gr00tSimPolicyWrapper` for GR00T simulation environments (default: False)
307
+
308
+ Once started, the server will display:
309
+ ```
310
+ Starting GR00T inference server...
311
+ Embodiment tag: NEW_EMBODIMENT
312
+ Model path: /path/to/your/checkpoint
313
+ Device: cuda:0
314
+ Host: 0.0.0.0
315
+ Port: 5555
316
+ Server is ready and listening on tcp://0.0.0.0:5555
317
+ ```
318
+
319
+ #### Using the Policy Client
320
+
321
+ On the client side (your environment/robot control code), use `PolicyClient` to connect to the server:
322
+
323
+ ```python
324
+ from gr00t.policy.server_client import PolicyClient
325
+
326
+ # Connect to the policy server
327
+ policy = PolicyClient(
328
+ host="localhost", # or IP address of your GPU server
329
+ port=5555,
330
+ timeout_ms=15000, # 15 second timeout for inference
331
+ strict=False, # leave the validation to the server
332
+ )
333
+
334
+ # Verify connection
335
+ if not policy.ping():
336
+ raise RuntimeError("Cannot connect to policy server!")
337
+
338
+ # Use just like a regular policy
339
+ observation = get_observation() # Your observation in Policy API format
340
+ action, info = policy.get_action(observation)
341
+ ```
342
+
343
+ **Parameters:**
344
+ - `host`: Hostname or IP address of the policy server
345
+ - `port`: Port number (must match server port)
346
+ - `timeout_ms`: Timeout in milliseconds for network requests (default: 15000)
347
+ - `api_token`: Optional API token for authentication (default: None)
348
+ - `strict`: Enable client-side validation (usually False since server validates)
349
+
350
+ #### Client API
351
+
352
+ The `PolicyClient` implements the same `BasePolicy` interface, so it's a drop-in replacement:
353
+
354
+ ```python
355
+ # Get modality configuration from the server
356
+ modality_configs = policy.get_modality_config()
357
+
358
+ # Get action — returns (action_dict, info_dict)
359
+ action, info = policy.get_action(observation, options=None)
360
+
361
+ # Reset policy state (e.g., switch episode in ReplayPolicy)
362
+ info = policy.reset(options=None)
363
+
364
+ # Check server health — returns True if server responds
365
+ is_alive = policy.ping()
366
+
367
+ # Shutdown the server remotely (optional)
368
+ policy.kill_server()
369
+ ```
370
+
371
+ #### Server API Reference
372
+
373
+ | Parameter | Type | Default | Description |
374
+ |-----------|------|---------|-------------|
375
+ | `policy` | `BasePolicy` | *(required)* | The policy instance to serve (e.g., `Gr00tPolicy`, `ReplayPolicy`) |
376
+ | `host` | `str` | `"*"` | Bind address. `"*"` accepts connections on all interfaces |
377
+ | `port` | `int` | `5555` | TCP port for ZMQ REP socket |
378
+ | `api_token` | `str` | `None` | If set, clients must include a matching token in every request |
379
+
380
+ **Built-in endpoints:** `get_action`, `reset`, `get_modality_config`, `ping`, `kill`. Custom endpoints can be added via `server.register_endpoint(name, handler)`.
381
+
382
+ #### Error Handling
383
+
384
+ The server-client uses ZeroMQ REQ/REP sockets over TCP with msgpack serialization.
385
+
386
+ - **Timeout:** If the server does not respond within `timeout_ms`, the ZMQ socket will raise `zmq.error.Again`. The default 15 s timeout accommodates cold-start model loading on the first call.
387
+ - **Connection loss:** If `ping()` returns `False`, the client automatically recreates its ZMQ socket for the next attempt. Your control loop should retry or halt.
388
+ - **Server-side errors:** Exceptions in the policy are caught, serialized as `{"error": "..."}`, and re-raised as `RuntimeError` on the client side.
389
+
390
+ #### Debugging with ReplayPolicy
391
+
392
+ When developing a new environment integration or debugging your inference loop, running a full model inference can be cumbersome. `ReplayPolicy` allows you to **replay recorded actions from an existing dataset**, helping you verify that:
393
+
394
+ - Your environment setup works correctly
395
+ - Observations are formatted properly
396
+ - Action execution behaves as expected
397
+ - The server-client communication is functioning
398
+
399
+ This eliminates the need for a trained model during the development phase.
400
+
401
+ ##### Starting the Server with ReplayPolicy
402
+
403
+ Instead of providing `--model-path`, use `--dataset-path` to start the server in replay mode:
404
+
405
+ ```bash
406
+ uv run python gr00t/eval/run_gr00t_server.py \
407
+ --dataset-path /path/to/lerobot_dataset \
408
+ --embodiment-tag NEW_EMBODIMENT \
409
+ --host 0.0.0.0 \
410
+ --port 5555 \
411
+ --execution-horizon 8 # should match the executed action horizon in the environment
412
+ ```
413
+
414
+ **Parameters:**
415
+ - `--dataset-path`: Path to a LeRobot-compatible dataset directory
416
+ - `--embodiment-tag`: The embodiment tag for modality configuration
417
+ - `--execution-horizon`: Number of steps to advance the dataset per `get_action()` call. Should match the number of executed action steps in the environment.
418
+ - `--modality-config-path`: (Optional) Path to custom modality config JSON file. If not provided, uses the config from `embodiment-tag`
419
+ - `--use-sim-policy-wrapper`: Apply `Gr00tSimPolicyWrapper` for GR00T simulation environments
420
+
421
+ ##### Using ReplayPolicy from the Client
422
+
423
+ On the client side, use `PolicyClient` exactly as you would with a real model:
424
+
425
+ ```python
426
+ from gr00t.policy.server_client import PolicyClient
427
+
428
+ # Connect to the replay policy server
429
+ policy = PolicyClient(host="localhost", port=5555)
430
+
431
+ # Use exactly like a regular policy
432
+ action, info = policy.get_action(observation)
433
+
434
+ # info contains replay metadata
435
+ print(f"Replaying step {info['current_step']} of episode {info['episode_index']}")
436
+ ```
437
+
438
+ ##### Switching Episodes
439
+
440
+ ReplayPolicy starts with episode 0 by default. To switch to a different episode:
441
+
442
+ ```python
443
+ # Reset to a specific episode
444
+ policy.reset(options={"episode_index": 5})
445
+
446
+ # Optionally start from a specific step within the episode
447
+ policy.reset(options={"episode_index": 5, "step_index": 10})
448
+ ```
449
+
450
+ The number of available episodes can be queried via the `info` dict returned from `reset()` or `get_action()`.
451
+
452
+ ##### Example: Validating a LIBERO Environment
453
+
454
+ Here's a complete example of using ReplayPolicy to validate a simulation setup:
455
+
456
+ ```bash
457
+ # Terminal 1: Start the replay server
458
+ uv run python gr00t/eval/run_gr00t_server.py \
459
+ --dataset-path <your_dataset_path> \
460
+ --embodiment-tag <YOUR_EMBODIMENT_TAG> \
461
+ --action-horizon 8 \
462
+ --use-sim-policy-wrapper
463
+
464
+ # Terminal 2: Run evaluation with the replay policy
465
+ uv run python gr00t/eval/rollout_policy.py \
466
+ --n-episodes 1 \
467
+ --policy-client-host 127.0.0.1 \
468
+ --policy-client-port 5555 \
469
+ --max-episode-steps 720 \
470
+ --env-name <env_prefix>/<task_name> \
471
+ --n-action-steps 8 \
472
+ --n-envs 1
473
+ ```
474
+
475
+ If your environment is set up correctly, replaying ground-truth actions should achieve high (often 100%) success rates. Low success rates indicate issues with:
476
+ - Environment reset state not matching the dataset
477
+ - Observation preprocessing differences
478
+ - Action space mismatches
479
+
480
+ > **Tip:** ReplayPolicy is an excellent first step when integrating a new environment. Debug with replay first, then switch to model inference once the pipeline is validated.
481
+
482
+ #### Integrating the GR00T N1.7 Client Into Your Deployment Pipeline
483
+
484
+ GR00T's server–client architecture allows you to keep the **client side extremely lightweight**, making it easy to embed into any custom deployment pipeline without pulling in the full dependency stack.
485
+
486
+ For a minimal working example, see
487
+ [`eval_so100.py`](../gr00t/eval/real_robot/SO100/eval_so100.py).
488
+
489
+ In most cases, your deployment environment only needs to install the local GR00T client code:
490
+
491
+ ```bash
492
+ uv pip install -e . --verbose --no-deps
493
+ ```
494
+
495
+ The client relies solely on a small set of interfaces:
496
+ - `gr00t/policy/server_client.py`
497
+ - `gr00t/policy/policy.py`
498
+ - `gr00t/data/types.py`
499
+ - `gr00t/data/embodiment_tags.py`
500
+
501
+ ## Common Patterns
502
+
503
+ ### Batched Inference
504
+
505
+ The policy supports batched inference for efficiency:
506
+
507
+ ```python
508
+ # Run 4 environments in parallel
509
+ batch_size = 4
510
+ observation = {
511
+ "video": {"wrist_cam": np.zeros((batch_size, T_video, H, W, 3), dtype=np.uint8)},
512
+ "state": {"joints": np.zeros((batch_size, T_state, D_state), dtype=np.float32)},
513
+ "language": {"task": [["pick up the cube"]] * batch_size},
514
+ }
515
+
516
+ action, _ = policy.get_action(observation)
517
+ # action["action_name"] has shape (batch_size, action_horizon, action_dim)
518
+ ```
519
+
520
+ ### Single Environment Inference
521
+
522
+ For single environments, use batch size of 1:
523
+
524
+ ```python
525
+ # Add batch dimension (B=1)
526
+ observation = {
527
+ "video": {"wrist_cam": video[np.newaxis, ...]}, # (1, T, H, W, 3)
528
+ "state": {"joints": state[np.newaxis, ...]}, # (1, T, D)
529
+ "language": {"task": [["pick up the cube"]]}, # List of length 1
530
+ }
531
+
532
+ action, _ = policy.get_action(observation)
533
+
534
+ # Remove batch dimension
535
+ single_action = action["action_name"][0] # (action_horizon, action_dim)
536
+ ```
537
+
538
+ ### Action Chunking
539
+
540
+ When the action horizon `T > 1`, you can use action chunking:
541
+
542
+ ```python
543
+ action, _ = policy.get_action(observation)
544
+ action_chunk = action["action_name"][:, :, :] # (B, T, D)
545
+
546
+ # Execute actions over multiple timesteps
547
+ for t in range(action_chunk.shape[1]):
548
+ env.step(action_chunk[:, t, :])
549
+ ```
550
+
551
+ ### Training Dataloading Optimization
552
+
553
+ When training a model, you can optimize the dataloading speed vs memory usage via various command line arguments.
554
+
555
+ examples:
556
+ ```bash
557
+ uv run python gr00t/experiment/launch_finetune.py \
558
+ .... \
559
+ --num-shards-per-epoch 100 \
560
+ --dataloader-num-workers 2
561
+ --shard-size 512 \
562
+ ```
563
+
564
+ If vram is limited, you can reduce the all the numbers above to reduce the memory usage.
565
+
566
+ To ensure more IID during sampling of shards, you can reduce the `episode_sampling_rate` to 0.05 or lower.
567
+
568
+ ## Troubleshooting
569
+
570
+ 1. **Enable strict mode** during development: `strict=True`
571
+ 2. **Print modality configs** to understand expected formats
572
+ 3. **Check shapes** of your observations before calling `get_action()`
573
+ 4. **Use the reference wrapper** (`Gr00tSimPolicyWrapper`) as a template
574
+ 5. **Validate incrementally**: Test with dummy observations first before connecting to real environments
getting_started/real_world_deployment.md ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GR00T Real-World Deployment Guide
2
+
3
+ This guide covers building an end-to-end real-world VLA pipeline—from data collection and training to deployment—with practical engineering recommendations.
4
+
5
+ ## Overview
6
+
7
+ A typical GR00T real-world deployment workflow includes:
8
+
9
+ 1. **[Hardware Preparation](#1-hardware-and-environment-preparation-device-requirements)**: Verify that the robot platform, sensors, and compute resources are ready.
10
+ 2. **[Data Collection](#2-data-collection)**: Choose an appropriate teleoperation setup and collect at least 100 valid episodes.
11
+ 3. **[Data Preprocessing](#3-data-preprocessing)**: Clean data, align timestamps, and convert to LeRobot format.
12
+ 4. **[Model Training](#4-vla-model-training)**: Fine-tune GR00T N1.*.
13
+ 5. **[Model Evaluation](#validation)**: Run open-loop evaluation to validate convergence and model quality.
14
+ 6. **[Deployment Setup](#5-deployment-and-closed-loop-control)**: Build a ZMQ Server-Client architecture.
15
+ 7. **[Closed-Loop Testing](#5-deployment-and-closed-loop-control)**: Run closed-loop control on real hardware and monitor jittering and stop-and-go behavior.
16
+ 8. **[Optimization](#6-common-issues-jittering-and-stop-and-go)**: Tune RTC parameters and trajectory smoothing strategies based on real-world performance.
17
+
18
+ ## 1. Hardware and Environment Preparation (Device Requirements)
19
+
20
+ Ensure your robot hardware, sensor pipeline, and control interfaces are stable and available.
21
+
22
+ ### Robot Platform
23
+
24
+ - **Recommended platforms**: Robotic arms with SDK-level control support (e.g., Franka, UR, Piper, SO101).
25
+ - **Basic requirements**:
26
+ - Real-time joint state feedback.
27
+ - High-frequency action execution (30 FPS recommended).
28
+ - Stable control interface.
29
+
30
+ ### Multimodal Sensors
31
+
32
+ | Sensor Type | Specification | Purpose |
33
+ |-------------|---------------|---------|
34
+ | **Wrist-mounted camera** | 30 FPS, RGB | Capture close-range manipulation visuals |
35
+ | **Third-person camera (3rd view)** | 30 FPS, RGB | Capture global scene context |
36
+ | **Robot proprioceptive state** | Real-time acquisition | Joint states and gripper state |
37
+
38
+ ### Compute Resources
39
+
40
+ - **Training phase**: NVIDIA GPU servers (e.g., H100 or H20) are recommended for larger batch sizes.
41
+ - **Deployment phase**: Edge hardware such as Jetson AGX Thor supports on-device inference.
42
+
43
+ > For details, see the [hardware recommendation guide](hardware_recommendation.md).
44
+
45
+ ### Teleoperation Devices
46
+
47
+ Teleoperation device selection is critical for data quality.
48
+
49
+ ### Teleoperation Device Comparison
50
+
51
+ In the table below:
52
+ - **Embodiment dependency**: how similar the teleoperation device and target robot must be in joint topology, degrees of freedom (DoF), and workspace. Higher dependency implies harder cross-embodiment transfer.
53
+ - **Operational intuition**: how naturally operator inputs map to robot motion. Higher intuition means faster onboarding and lower demonstration error.
54
+
55
+ | Device Type | Cost Level (Reference) | Embodiment Dependency | Operational Intuition | Notes |
56
+ |-------------|------------------------|-----------------------|-----------------------|-------|
57
+ | **Keyboard/Gamepad/SpaceMouse/Joylo** | Low | Low: command mapping via keys/controls | Medium: requires adaptation to key-motion mapping | Low entry cost; a good starting point and useful in mobile scenarios |
58
+ | **Master-Slave arm systems** | Medium | High: master/slave arms usually require similar kinematics and workspace | High: near one-to-one human-robot mapping | Suitable for single-robot setups; commonly used by robot OEMs; can reduce the risk of reaching joint limits during demonstrations |
59
+ | **UMI / Fast-UMI / Pika Sense** | Medium | Low: hardware-agnostic action representation reusable across arms | High: after calibration, end-effector (EEF) following is intuitive | Suitable for training general VLA models; low-DoF arms may still hit joint limits |
60
+ | **VR-based teleoperation** | Medium (headset + rendering + network) | Low: mainly depends on software integration | Medium: depends on immersive visual feedback and tracking quality | A flexible solution, but with higher integration overhead |
61
+ | **Glove / Motion Capture** | High (commercial mocap suite + data gloves) | Low: retarget through kinematic mapping to different embodiments | High: intuitive full-hand/full-body control | Suitable for full-body control and dexterous-hand tasks |
62
+ | **Exoskeleton** | High | High: usually requires matched joint structure | High: natural action correspondence | Extendable to multi-joint humanoid control |
63
+
64
+ ## 2. Data Collection
65
+
66
+ Key considerations for data collection:
67
+
68
+ ### Timestamp Synchronization
69
+
70
+ - The FPS of both camera streams should be strictly matched, and capture triggers should be as synchronized as possible.
71
+ - Joint state sampling frequency should exceed camera FPS to enable accurate downsampling.
72
+ - Record full timestamps during collection for downstream temporal alignment.
73
+
74
+ ### Action Representation
75
+
76
+ - If training and collection use the same embodiment (e.g., master-slave arms), log joint-space `Joint States` during collection. For task-space models, compute EEF pose via forward kinematics (FK) in post-processing.
77
+ - If embodiments differ (e.g., collect with UMI, deploy on Piper), directly record task-space EEF pose during collection.
78
+
79
+ ### Data Distribution
80
+
81
+ - Current imitation-learning-based models perform more reliably in previously seen scenarios. In early-stage experiments, start with data collection and validation in a limited domain.
82
+ - After pipeline validation, gradually expand the domain by varying lighting, object placement, and initial robot poses to improve generalization.
83
+
84
+ ### Scene Consistency
85
+
86
+ - Keep third-person camera extrinsics fixed and ensure a rigid wrist-camera mount.
87
+ - In early experiments, prioritize scene consistency; avoid varying lighting, object placement, or initial robot poses.
88
+
89
+ ### Joint Limits
90
+
91
+ - If collecting joint-space data, avoid operating near joint limits to reduce the number of samples in those regions.
92
+
93
+ ## 3. Data Preprocessing
94
+
95
+ Raw data must be cleaned, synchronized, and converted before training.
96
+
97
+ ### Trajectory Filtering
98
+
99
+ Data filtering is recommended in two stages: script-based filtering and manual review.
100
+
101
+ #### Script Filtering
102
+
103
+ - Check image timestamps and remove samples with:
104
+ 1. Excessive latency in a single camera stream.
105
+ 2. Excessive timestamp difference between the two camera streams.
106
+ - Detect and remove abnormal jumps in robot state sequences.
107
+
108
+ #### Manual Filtering
109
+
110
+ Replay trajectories with synchronized visualization to catch issues missed by scripts:
111
+
112
+ - Remove samples with poor synchronization between image and action sequences.
113
+ - Remove blurry frames.
114
+ - Remove failed task executions.
115
+ - Remove low-quality trajectories (e.g., redundant paths, discontinuous actions).
116
+
117
+ ### Trajectory Preprocessing
118
+
119
+ 1. Timestamp alignment: align camera frames and robot joint states to a shared time base.
120
+ 2. Head-tail trimming: remove idle segments at the start and end of trajectories.
121
+ 3. Split long trajectories (several minutes) into multiple subtasks.
122
+
123
+ ### Format Conversion
124
+
125
+ Convert all data to a standard format (e.g., LeRobot) for GR00T compatibility:
126
+
127
+ - See the [data preparation guide](data_preparation.md) for format requirements.
128
+ - Use the provided conversion scripts to convert data to GR00T LeRobot format.
129
+
130
+ ## 4. VLA Model Training
131
+
132
+ ### Training Parameter Configuration
133
+
134
+ **Dataset size recommendations**
135
+
136
+ For single-task `finetune`:
137
+
138
+ - **Minimum data size**: Prepare at least **100 valid episodes**. For very narrow task domains, ~30 episodes may suffice. A capture frequency of 20–50 Hz is recommended for manipulation tasks.
139
+ - **Episode length**: No hard limit, but each episode must contain a complete action cycle with idle frames removed. Split overly long episodes into subtasks.
140
+ - **Recommended data size**: 200+ episodes usually provide more stable performance.
141
+
142
+ **Core parameters**
143
+
144
+ - **Input/output mode**: Default to `State-relative Action Prediction`. Compared with `Absolute Action`, it converges more easily and improves inter-chunk consistency.
145
+ - **Training space**: Both joint space and task space are valid. For low-DoF arms, joint space is often preferred to reduce singularity-related risks.
146
+ - **Action Chunk Size**: Default is 16. If combined with RTC to mitigate stop-and-go, set it to at least 32.
147
+ - **Batch Size**: Increase the batch size as much as GPU memory allows.
148
+
149
+ > For additional training options, see the [fine-tuning guide](finetune_new_embodiment.md).
150
+
151
+ **Compute resources**
152
+
153
+ - Fine-tuning requires significantly less compute than pretraining.
154
+ - A single compute node (8 x H100 or 8 x H20) is usually sufficient.
155
+
156
+ ### Validation
157
+
158
+ After training, run open-loop validation to confirm convergence, then proceed to closed-loop deployment validation.
159
+
160
+ > Open-loop validation is only a preliminary check. Final performance must be verified with closed-loop testing on real robots. For details, see the [fine-tuning guide](finetune_new_embodiment.md).
161
+
162
+ ## 5. Deployment and Closed-Loop Control
163
+
164
+ ### System Architecture
165
+
166
+ GR00T supports two inference modes:
167
+
168
+ 1. **Direct `Gr00tPolicy` usage**: Suitable when model inference and robot control run on the same machine.
169
+ 2. **ZMQ Server-Client architecture**: Suitable for real-world deployment and decouples local robot control (`Local Client`) from remote inference (`Model Server`).
170
+
171
+ For real-world deployment, **ZMQ inference service** is recommended:
172
+
173
+ - Move compute-intensive inference to GPU servers.
174
+ - Keep robot-side control code lightweight.
175
+ - Avoid installing the full inference dependency stack on the robot side.
176
+
177
+ ### On-Device Deployment Logic
178
+
179
+ Deployment code has two phases: **initialization** and the **main control loop**.
180
+ The pseudo-code below uses a synchronous workflow, which may cause stop-and-go. See later sections for mitigation via asynchronous execution + RTC.
181
+
182
+ **Pseudo-code workflow:**
183
+
184
+ ```python
185
+ # ========== Initialization ==========
186
+ # 1. Initialize and test cameras
187
+ hand_camera = initialize_hand_camera() # e.g., OrbbecSDK
188
+ env_camera = initialize_env_camera() # e.g., RealSense
189
+ test_cameras() # Show preview and verify normal operation
190
+
191
+ # 2. Connect and test robot
192
+ robot = connect_robot() # e.g., Piper SDK
193
+ robot.enable()
194
+ robot.reset_to_initial_position()
195
+ test_robot() # Send test command and verify robot response
196
+
197
+ # 3. Connect and test GR00T model server
198
+ gr00t_client = connect_to_gr00t_server(host, port)
199
+ if not gr00t_client.ping():
200
+ raise ConnectionError("Failed to connect to model server")
201
+ test_model() # Send test observation and verify inference
202
+
203
+ # ========== Main control loop ==========
204
+ while True:
205
+ # 1. Acquire sensor data
206
+ hand_image = hand_camera.get_frame()
207
+ env_image = env_camera.get_frame()
208
+ joint_states = robot.get_joint_states()
209
+ gripper_state = robot.get_gripper_state()
210
+
211
+ # 2. Format observation
212
+ observation = format_observation(
213
+ hand_image,
214
+ env_image,
215
+ joint_states,
216
+ gripper_state,
217
+ task_description,
218
+ )
219
+
220
+ # 3. Model inference (via ZMQ)
221
+ actions = gr00t_client.get_action(observation)
222
+
223
+ # 4. Trajectory post-processing
224
+ actions_arm = actions["joint_states"]
225
+ actions_arm = smooth_trajectory(actions_arm) # smoothing
226
+ actions_arm = check_safety_limits(actions_arm) # safety checks
227
+
228
+ # 5. Execute actions
229
+ for action_step in actions_arm:
230
+ robot.execute_action(action_step)
231
+ sleep(1.0 / 30.0) # 30 FPS
232
+ ```
233
+
234
+ ### Key Implementation Notes
235
+
236
+ **Important notes:**
237
+
238
+ - **Image format**: Use compressed formats such as JPG to reduce transmission bandwidth.
239
+ - **Safe operation**:
240
+ - **Soft Limits**: Add joint-angle and EEF pose range checks. If a predicted action exceeds workspace bounds, raise an alarm and stop immediately.
241
+ - **E-Stop logic**: Bind an emergency stop hotkey (e.g., Space) on the control PC, or use a physical E-Stop switch.
242
+ - **Action smoothing**: Apply interpolation and smoothing to predicted action sequences.
243
+
244
+ > For more deployment details, see the [policy API guide](policy.md).
245
+
246
+ ## 6. Common Issues: Jittering and Stop-and-Go
247
+
248
+ The most common issues in real-world deployment are **jittering** and **stop-and-go**.
249
+
250
+ ### Fixing Jittering
251
+
252
+ **Jittering** here refers to visible shaking or vibration of the end-effector or joints during task execution.
253
+
254
+ Jittering typically originates from **inconsistent model outputs** or **insufficient robot-side control quality**. Analyze these two components separately to localize the issue. The suggestions below are general guidelines and may not apply to every robot platform or control stack — always verify against your own hardware and environment.
255
+
256
+ ```mermaid
257
+ flowchart TD
258
+ A[Jittering observed] --> B[Save & visualize action chunks in 3D]
259
+ B --> C{Where is the jitter?}
260
+ C -->|Inside each chunk| D[Case A: Model undertrained or poor data quality]
261
+ C -->|Between consecutive chunks| E[Case B: Inconsistent chunk predictions]
262
+ C -->|Chunks look smooth| F[Case C: Robot hardware / low-level control issue]
263
+ D --> D1[Add more data, train longer, check train/eval consistency]
264
+ E --> E1[Use state-relative actions + RTC chunking strategy]
265
+ F --> F1[Check drive control, interpolation, hardware status]
266
+ ```
267
+
268
+ **Diagnosis and mitigation**
269
+
270
+ 1. **Save and visualize Action Chunks**
271
+ - Save all predicted `Action Chunks`.
272
+ - Visualize continuous TCP (tool center point) trajectories in 3D.
273
+ - **Note**: Convert joint-space outputs to task space via FK before visualization.
274
+
275
+ 2. **Analyze visualization results**
276
+
277
+ **Case A: Significant jitter inside each chunk**
278
+ - **Cause**: The model is undertrained, or data quality is insufficient.
279
+ - **Solution**: Improve data quality, add more training data, or train longer. Keep training and validation environments consistent.
280
+
281
+ **Case B: Significant jitter between chunks**
282
+ - **Cause**: Inconsistent adjacent `Action Chunk` predictions.
283
+ - **Solution**:
284
+ - Use `State-relative Action Prediction`. Predicting actions relative to the current state produces a more uniform output distribution, making the network easier to train.
285
+ - Use RTC (`Real-Time Chunking`) or similar strategies.
286
+
287
+ **Case C: Little jitter after visualization**
288
+ - **Cause**: Likely a robot hardware or low-level control issue.
289
+ - **Solution**: Check drive control, interpolation, and hardware status.
290
+
291
+ **Quantitative diagnostic metrics**
292
+
293
+ Trajectory jitter can also be quantified using these three metrics:
294
+
295
+ **Metric 1: Mean intra-chunk acceleration magnitude**
296
+
297
+ Measures intra-chunk smoothness. Only valid under fixed sampling frequency.
298
+
299
+ Formula: $a_t = pos_{t+1} - 2 \cdot pos_t + pos_{t-1}$
300
+
301
+ ```python
302
+ def metric_intra_accel(chunks):
303
+ """
304
+ Args:
305
+ chunks: numpy array with shape (N_chunks, Chunk_Length, Joint_Dim)
306
+
307
+ Returns:
308
+ float: Mean acceleration magnitude
309
+ """
310
+ velocity = np.diff(chunks, axis=1) # first-order difference
311
+ acceleration = np.diff(velocity, axis=1) # second-order difference
312
+ acc_magnitude = np.linalg.norm(acceleration, axis=-1) # L2 norm per step
313
+ return np.mean(acc_magnitude)
314
+ ```
315
+
316
+ **Metric 2: Position jump at chunk boundary (L2 distance)**
317
+
318
+ Measures position continuity between chunks by comparing the last executed step of `Chunk[i]` with step 0 of `Chunk[i+1]`.
319
+
320
+ ```python
321
+ def metric_boundary_jump(chunks, execute_steps=None):
322
+ """
323
+ Args:
324
+ chunks: numpy array with shape (N_chunks, Chunk_Length, Joint_Dim)
325
+ execute_steps: number of executed steps per chunk; if None, use full chunk length
326
+
327
+ Returns:
328
+ float: Mean position jump
329
+ """
330
+ chunks = np.array(chunks)
331
+ exec_steps = chunks.shape[1] if execute_steps is None else execute_steps
332
+
333
+ last_frame_prev = chunks[:-1, exec_steps - 1, :] # last frame of previous chunk
334
+ first_frame_curr = chunks[1:, 0, :] # first frame of current chunk
335
+ jumps = np.linalg.norm(first_frame_curr - last_frame_prev, axis=-1) # Euclidean distance
336
+ return np.mean(jumps)
337
+ ```
338
+
339
+ **Metric 3: Cosine similarity of velocity direction at chunk boundary**
340
+
341
+ Measures velocity-direction consistency between chunks. Values closer to 1 indicate better consistency.
342
+
343
+ ```python
344
+ def metric_momentum_shift(chunks, execute_steps=None):
345
+ """
346
+ Args:
347
+ chunks: numpy array with shape (N_chunks, Chunk_Length, Joint_Dim)
348
+ execute_steps: number of executed steps per chunk; if None, use full chunk length
349
+
350
+ Returns:
351
+ float: Mean cosine similarity
352
+ """
353
+ chunks = np.array(chunks)
354
+ exec_steps = chunks.shape[1] if execute_steps is None else execute_steps
355
+
356
+ # velocity at the end of previous chunk
357
+ idx = exec_steps - 1
358
+ if idx < 1:
359
+ raise ValueError("execute_steps must be >= 2 to compute end velocity")
360
+ v_end = chunks[:-1, idx, :] - chunks[:-1, idx - 1, :]
361
+
362
+ # velocity at the start of current chunk
363
+ v_start = chunks[1:, 1, :] - chunks[1:, 0, :]
364
+
365
+ # cosine similarity
366
+ dot_product = np.sum(v_end * v_start, axis=-1)
367
+ norm_prev = np.linalg.norm(v_end, axis=-1)
368
+ norm_curr = np.linalg.norm(v_start, axis=-1)
369
+ epsilon = 1e-8
370
+ cosine_sim = dot_product / (norm_prev * norm_curr + epsilon)
371
+
372
+ return np.mean(cosine_sim)
373
+ ```
374
+
375
+ ### Fixing Stop-and-Go
376
+ Stop-and-Go here refers to a behavior in which the robot intermittently pauses during motion, producing periodic stop-and-go behavior.
377
+
378
+ #### Root Cause
379
+
380
+ In **synchronous single-step closed-loop** control, stop-and-go occurs when the **end-to-end latency** (observation capture → VLA inference → action conversion) exceeds control-frequency requirements.
381
+
382
+ - **Control-frequency requirement**: At 30 FPS, latency must stay below ~33 ms.
383
+ - **Typical latency sources**: Data capture, network transfer, model inference, and post-processing often exceed 33 ms combined.
384
+ - **Consequence**: The next prediction is not ready when the current action finishes, causing pauses.
385
+
386
+ #### Solutions
387
+
388
+ **Option 1: Optimize the inference pipeline (direct but difficult)**
389
+
390
+ Reduce full workflow latency below 33 ms:
391
+
392
+ - Optimize network bandwidth (reduce transfer time).
393
+ - Use edge inference (reduce network latency).
394
+ - Quantize the VLA model (speed up inference).
395
+ - Use a smaller model (e.g., ACT).
396
+
397
+ **Limitation**: For VLA models, meeting strict real-time requirements through optimization alone is often impractical.
398
+
399
+ **Option 2: Use algorithmic scheduling strategies (recommended)**
400
+
401
+ When direct optimization is insufficient, use one or more of the following:
402
+
403
+ - **Asynchronous Inference**: A background thread runs inference while the main thread executes actions.
404
+ - **Receding Horizon**: Execute only the first few steps of each `Action Chunk` before triggering a new inference.
405
+ - **Temporal Ensemble**: Aggregate predictions across multiple timesteps.
406
+ - **Real-Time Chunking (RTC)**: Overlap the start of the current prediction with unexecuted steps from the previous one.
407
+
408
+ **Recommended strategy**: `Asynchronous Inference + RTC` is usually the most effective.
409
+
410
+ #### Real-Time Chunking (RTC) Details
411
+
412
+ **Principle**
413
+
414
+ RTC treats action prediction as an inpainting problem: overlapping the start of the new prediction with unexecuted steps from the previous one ensures smooth transitions.
415
+
416
+ **Applicability**
417
+
418
+ - Validated for **diffusion / flow-based** VLA policies.
419
+ - Requires `Action Chunk` length ≥ 32 steps.
420
+ - Should be combined with asynchronous inference.
421
+
422
+ **Implementation essentials**
423
+
424
+ 1. **Predict longer Action Chunks**:
425
+ - Increase from the default 16 steps to at least 32.
426
+ - Provide a larger soft fusion window.
427
+
428
+ 2. **Asynchronous inference architecture**:
429
+ - **Background thread**: Continuously infer, capture observations, and prepare action batches.
430
+ - **Main thread**: Execute the current action sequence.
431
+ - Buffer predictions in a queue to avoid blocking.
432
+
433
+ 3. **Action fusion mechanism**:
434
+ - Use RTC for soft fusion in the overlap region.
435
+ - Ensure smooth transitions between adjacent chunks.
436
+
437
+ **Pseudocode: Async Inference + RTC**
438
+
439
+ In the RTC (Real-Time Chunking) framework, two key parameters control how adjacent action chunks overlap and transition:
440
+
441
+ - **`overlap`**: The number of action steps retained from the previous prediction to constrain the current one, ensuring temporal consistency between consecutive chunks.
442
+ - **`frozen`**: The number of steps that remain completely frozen (i.e., not updated by the new prediction), typically set to match the inference latency.
443
+
444
+ Below is a simplified async inference + RTC loop. Note that official RTC support for GR00T is coming soon; the current implementation may require manual adaptation.
445
+
446
+ ```
447
+ actions = policy.infer(obs) # blocking first call
448
+
449
+ loop:
450
+ for i in range(action_horizon):
451
+ if i == action_horizon - overlap - 1:
452
+ future = async policy.infer(new_obs) # non-blocking
453
+ robot.execute(actions[i])
454
+ if i == action_horizon - frozen - 1:
455
+ actions = future.get() # swap in next chunk
456
+ break # discard frozen tail
457
+ ```
458
+
459
+
gr00t/__init__.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+
19
+ def _patch_hf_local_first() -> None:
20
+ """Patch from_pretrained to prefer the local HF snapshot cache over network calls.
21
+
22
+ When a HF repo ID is passed we try snapshot_download(local_files_only=True)
23
+ first; if the model is not cached we fall through to the normal download path.
24
+ This avoids 429 rate-limit errors when many CI jobs run concurrently.
25
+
26
+ Covers: PreTrainedModel, PretrainedConfig, ProcessorMixin, AutoConfig,
27
+ AutoProcessor — every transformers from_pretrained entrypoint.
28
+
29
+ Triggered by GROOT_HF_LOCAL_FIRST (set by conftest.py, survives uv run) or
30
+ PYTEST_CURRENT_TEST (set automatically by pytest).
31
+ """
32
+
33
+ def _resolve(name_or_path: str) -> str:
34
+ hf_home = os.environ.get("HF_HOME")
35
+ hf_hub = os.environ.get("HUGGINGFACE_HUB_CACHE")
36
+ hf_cache_info = f"HF_HOME={hf_home} HUGGINGFACE_HUB_CACHE={hf_hub}"
37
+ if os.path.isdir(name_or_path):
38
+ print(f"[groot/hf] local path: {name_or_path} | {hf_cache_info}", flush=True)
39
+ return name_or_path
40
+ try:
41
+ from huggingface_hub import snapshot_download
42
+
43
+ resolved = snapshot_download(name_or_path, local_files_only=True)
44
+ print(
45
+ f"[groot/hf] cache hit: {name_or_path} -> {resolved} | {hf_cache_info}", flush=True
46
+ )
47
+ return resolved
48
+ except Exception:
49
+ print(
50
+ f"[groot/hf] cache miss (will download): {name_or_path} | {hf_cache_info}",
51
+ flush=True,
52
+ )
53
+ return name_or_path
54
+
55
+ def _wrap(cls: type) -> None:
56
+ if "from_pretrained" not in cls.__dict__:
57
+ return
58
+ original = cls.from_pretrained
59
+ if getattr(original, "_groot_hf_local_patched", False):
60
+ return
61
+
62
+ def _make_patched(orig):
63
+ @classmethod # type: ignore[misc]
64
+ def patched(klass, pretrained_model_name_or_path, *args, **kwargs):
65
+ resolved = _resolve(str(pretrained_model_name_or_path))
66
+
67
+ return orig.__func__(klass, resolved, *args, **kwargs)
68
+
69
+ patched._groot_hf_local_patched = True # type: ignore[attr-defined]
70
+ return patched
71
+
72
+ cls.from_pretrained = _make_patched(original)
73
+
74
+ try:
75
+ import transformers as _transformers
76
+
77
+ for _attr in (
78
+ "PreTrainedModel",
79
+ "PretrainedConfig",
80
+ "ProcessorMixin",
81
+ "AutoConfig",
82
+ "AutoProcessor",
83
+ ):
84
+ _cls = getattr(_transformers, _attr, None)
85
+ if _cls is not None:
86
+ _wrap(_cls)
87
+ except Exception:
88
+ pass
89
+
90
+
91
+ def _patch_mistral() -> None:
92
+ """Suppress 429 / connection errors from the HuggingFace Hub in mistral regex patching.
93
+
94
+ transformers calls model_info() inside a nested is_base_mistral() function
95
+ unconditionally even when loading from a fully local checkpoint. Qwen3VL /
96
+ Cosmos is never Mistral, so returning the tokenizer unchanged on any network
97
+ failure is correct.
98
+
99
+ NOTE: is_base_mistral is a *nested* function inside _patch_mistral_regex, so
100
+ it is not accessible as a module-level attribute — we must wrap the classmethod.
101
+
102
+ Triggered by GROOT_PATCH_MISTRAL (set by conftest.py, survives uv run) or
103
+ PYTEST_CURRENT_TEST (set automatically by pytest, belt-and-suspenders).
104
+ """
105
+ try:
106
+ import transformers.tokenization_utils_base as _tub
107
+
108
+ _cls = _tub.PreTrainedTokenizerBase
109
+ _orig = _cls._patch_mistral_regex.__func__
110
+ if getattr(_orig, "_groot_patched", False):
111
+ return
112
+
113
+ def _safe(cls, tokenizer, pretrained_model_name_or_path, **kwargs):
114
+ try:
115
+ return _orig(cls, tokenizer, pretrained_model_name_or_path, **kwargs)
116
+ except Exception:
117
+ return tokenizer
118
+
119
+ _safe._groot_patched = True # type: ignore[attr-defined]
120
+ _cls._patch_mistral_regex = classmethod(_safe)
121
+ except Exception:
122
+ pass
123
+
124
+
125
+ if os.environ.get("PYTEST_CURRENT_TEST") or os.environ.get("GROOT_HF_LOCAL_FIRST"):
126
+ _patch_hf_local_first()
127
+
128
+ if os.environ.get("PYTEST_CURRENT_TEST") or os.environ.get("GROOT_PATCH_MISTRAL"):
129
+ _patch_mistral()
gr00t/configs/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
gr00t/configs/base_config.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass, field
17
+ import json
18
+ from pathlib import Path
19
+ from typing import List, Optional
20
+
21
+ import yaml
22
+
23
+ from gr00t.data.types import ActionConfig, ActionFormat, ActionRepresentation, ActionType
24
+
25
+ from .data.data_config import DataConfig, SingleDatasetConfig
26
+ from .model import create_model_union_type
27
+ from .model.gr00t_n1d7 import Gr00tN1d7Config
28
+ from .training.training_config import TrainingConfig
29
+
30
+
31
+ ModelUnionType = create_model_union_type()
32
+
33
+
34
+ @dataclass
35
+ class Config:
36
+ """Complete configuration."""
37
+
38
+ load_config_path: Optional[str] = None
39
+ model: ModelUnionType = field(default_factory=lambda: Gr00tN1d7Config())
40
+ data: DataConfig = field(default_factory=DataConfig)
41
+ training: TrainingConfig = field(default_factory=TrainingConfig)
42
+
43
+ def save(self, path: Path):
44
+ """Save configuration to YAML file."""
45
+ path = Path(path)
46
+ path.parent.mkdir(parents=True, exist_ok=True)
47
+
48
+ with open(path, "w") as f:
49
+ yaml.dump(self, f)
50
+
51
+ def load(self, path: Path):
52
+ """Load configuration from YAML file."""
53
+ data = yaml.load(path.read_text(), Loader=yaml.Loader)
54
+ if isinstance(data, dict): # for training
55
+ self.load_dict(data)
56
+ elif isinstance(data, self.__class__):
57
+ self = data
58
+ else:
59
+ raise ValueError(f"Invalid config file: {path}")
60
+ # config = cls(**config) # if yaml.dump(self.__dict__, ...) is used
61
+ return self
62
+
63
+ def load_dict(self, data: dict):
64
+ if "model" in data:
65
+ self.model = self.model.__class__(**data["model"])
66
+ if "data" in data:
67
+ self.data = DataConfig(**data["data"])
68
+ # Ensure nested datasets are converted to dataclass instances
69
+ converted: List[SingleDatasetConfig] = []
70
+ for ds in self.data.datasets:
71
+ if isinstance(ds, dict):
72
+ converted.append(SingleDatasetConfig(**ds))
73
+ else:
74
+ converted.append(ds)
75
+ self.data.datasets = converted
76
+ if "training" in data:
77
+ self.training = TrainingConfig(**data["training"])
78
+ return self
79
+
80
+ @classmethod
81
+ def from_pretrained(cls, path: Path) -> "Config":
82
+ """Load configuration from YAML file."""
83
+ data = yaml.load(path.read_text(), Loader=yaml.Loader)
84
+ return data
85
+
86
+ def get_deepspeed_config(self) -> dict:
87
+ """Generate DeepSpeed configuration."""
88
+ stage = self.training.deepspeed_stage
89
+
90
+ gr00t_dir = Path(__file__).parent.parent
91
+ if stage == 2:
92
+ config = json.load(open(gr00t_dir / "configs/deepspeed/zero2_config.json"))
93
+ elif stage == 3:
94
+ config = json.load(open(gr00t_dir / "configs/deepspeed/zero3_config.json"))
95
+ else:
96
+ raise ValueError(f"Invalid DeepSpeed stage: {stage}")
97
+
98
+ return config
99
+
100
+ def validate(self):
101
+ """Validate configuration."""
102
+ # Check dataset path(s)
103
+ embodiment_tags = set()
104
+ for d_cfg in self.data.datasets:
105
+ # (Disable missing data check because we now support caching PDX data sources.)
106
+ # if not Path(d_cfg.dataset_path).exists():
107
+ # raise ValueError(f"Dataset path does not exist: {d_cfg.dataset_path}")
108
+ if d_cfg.dataset_type == "physical_embodiment" and not d_cfg.embodiment_tag:
109
+ raise ValueError(f"Embodiment tag is empty for dataset {d_cfg.dataset_path}")
110
+ if d_cfg.embodiment_tag is not None:
111
+ embodiment_tags.add(d_cfg.embodiment_tag)
112
+
113
+ stripped_modality_configs = {}
114
+ for embodiment_tag in embodiment_tags:
115
+ modality_cfg = self.data.modality_configs.get(embodiment_tag)
116
+ if modality_cfg is None:
117
+ raise ValueError(
118
+ f"No modality config registered for embodiment tag '{embodiment_tag}'. "
119
+ f"Available tags: {sorted(self.data.modality_configs.keys())}. "
120
+ f"Provide --modality-config-path to register a custom modality config, "
121
+ f"or use one of the pre-registered tags."
122
+ )
123
+ stripped_modality_configs[embodiment_tag] = modality_cfg
124
+ self.data.modality_configs = stripped_modality_configs
125
+
126
+ # ensure mix ratios are valid
127
+ total_ratio = sum(d.mix_ratio for d in self.data.datasets)
128
+ if total_ratio <= 0:
129
+ raise ValueError("Sum of mix_ratio must be greater than zero")
130
+
131
+ # Fill in default values for action configs
132
+ for embodiment_tag in self.data.modality_configs:
133
+ # Fill in default values for action representation, type and format
134
+ if self.data.modality_configs[embodiment_tag]["action"].action_configs is None:
135
+ self.data.modality_configs[embodiment_tag]["action"].action_configs = [
136
+ ActionConfig(
137
+ rep=ActionRepresentation.ABSOLUTE,
138
+ type=ActionType.NON_EEF,
139
+ format=ActionFormat.DEFAULT,
140
+ )
141
+ ] * len(self.data.modality_configs[embodiment_tag]["action"].modality_keys)
142
+
143
+ # Validate precision settings
144
+ if self.training.fp16 and self.training.bf16:
145
+ raise ValueError("Cannot use both fp16 and bf16")
146
+
147
+
148
+ def get_default_config() -> Config:
149
+ """Get default configuration."""
150
+ return Config()
gr00t/configs/data/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
gr00t/configs/data/data_config.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass, field
17
+ from typing import Any, List, Optional
18
+
19
+ from gr00t.data.types import ModalityConfig
20
+
21
+ from .embodiment_configs import MODALITY_CONFIGS
22
+
23
+
24
+ @dataclass
25
+ class SingleDatasetConfig:
26
+ """Configuration for a single dataset in a mixed-training setup.
27
+
28
+ A list of these objects can be supplied in ``DataConfig.datasets`` to mix
29
+ multiple datasets at arbitrary ratios. For convenience the *legacy*
30
+ single-dataset fields still exist; if ``datasets`` is non-empty they take
31
+ precedence.
32
+ """
33
+
34
+ # Path to the dataset root directory (can be strings or dicts for complex configs)
35
+ dataset_paths: List[Any]
36
+
37
+ # Robot embodiment identifier (e.g. "gr1", "franka")
38
+ embodiment_tag: Optional[str] = None
39
+
40
+ # Relative sampling probability (will be normalised across the list)
41
+ mix_ratio: float = 1.0
42
+
43
+ dataset_type: str = "physical_embodiment"
44
+
45
+ # Optional validation dataset path for open-loop evaluation
46
+ # If not provided, falls back to dataset_paths for evaluation
47
+ val_dataset_path: Optional[str] = None
48
+
49
+
50
+ @dataclass
51
+ class DataConfig:
52
+ """Dataset configuration (supports single or multiple datasets)."""
53
+
54
+ # Leave empty by default for backwards-compatibility with the original
55
+ # single-dataset workflow. Users can supply one or more configs via CLI or
56
+ # YAML when they need mixing.
57
+ datasets: List[SingleDatasetConfig] = field(default_factory=list)
58
+
59
+ # Modality configs
60
+ # There are three sources of modality configs:
61
+ # 1. Default modality configs in code: gr00t/configs/data/embodiment_configs.py
62
+ # 2. Modality configs supplied through command line: --data.modality_configs (although rare and inconvenient)
63
+ # 1 and 2 are unified through `config.data.modality_configs`.
64
+ # 3. modality configs saved in the pretrained checkpoint.
65
+ modality_configs: dict[str, dict[str, ModalityConfig]] = field(
66
+ default_factory=lambda: MODALITY_CONFIGS
67
+ )
68
+
69
+ # Sharded dataset configuration
70
+ download_cache: bool = False
71
+ shard_size: int = 2**10
72
+ episode_sampling_rate: float = 0.1
73
+ num_shards_per_epoch: int = int(1e5)
74
+
75
+ # Override statistics from the pretrained checkpoint
76
+ override_pretraining_statistics: bool = True
77
+
78
+ # General task / mode config (shared across datasets)
79
+ mode: str = "single_turn"
80
+ random_chop: float = 0.0
81
+ mock_dataset_mode: bool = False # if True, cache the first datapoint of each dataset and always return one of them to simulate best-case dataloading
82
+
83
+ # Data loading
84
+ shuffle: bool = True
85
+ seed: int = 42
86
+ multiprocessing_context: str = "fork" # Options: "fork", "spawn", and "forkserver"
87
+ allow_padding: bool = False
88
+
89
+ # Subsample ratio for the dataset
90
+ subsample_ratio: float = 1.0
91
+
92
+ # DP Image Config
93
+ image_crop_size: List[int] = field(default_factory=lambda: [244, 244])
94
+ image_target_size: List[int] = field(default_factory=lambda: [224, 224])
95
+ video_backend: str = "torchcodec"
gr00t/configs/data/embodiment_configs.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from gr00t.data.embodiment_tags import EmbodimentTag
17
+ from gr00t.data.types import (
18
+ ActionConfig,
19
+ ActionFormat,
20
+ ActionRepresentation,
21
+ ActionType,
22
+ ModalityConfig,
23
+ )
24
+
25
+
26
+ MODALITY_CONFIGS = {
27
+ ##### Pre-registered pretrain configurations #####
28
+ "oxe_droid_relative_eef_relative_joint": {
29
+ "video": ModalityConfig(
30
+ delta_indices=[-15, 0],
31
+ modality_keys=["exterior_image_1_left", "wrist_image_left"],
32
+ ),
33
+ "state": ModalityConfig(
34
+ delta_indices=[0],
35
+ modality_keys=["eef_9d", "gripper_position", "joint_position"],
36
+ ),
37
+ "action": ModalityConfig(
38
+ delta_indices=list(range(40)),
39
+ modality_keys=["eef_9d", "gripper_position", "joint_position"],
40
+ action_configs=[
41
+ ActionConfig(
42
+ rep=ActionRepresentation.RELATIVE,
43
+ type=ActionType.EEF,
44
+ format=ActionFormat.XYZ_ROT6D,
45
+ state_key="eef_9d",
46
+ ),
47
+ ActionConfig(
48
+ rep=ActionRepresentation.ABSOLUTE,
49
+ type=ActionType.NON_EEF,
50
+ format=ActionFormat.DEFAULT,
51
+ state_key="gripper_position",
52
+ ),
53
+ ActionConfig(
54
+ rep=ActionRepresentation.RELATIVE,
55
+ type=ActionType.NON_EEF,
56
+ format=ActionFormat.DEFAULT,
57
+ state_key="joint_position",
58
+ ),
59
+ ],
60
+ ),
61
+ "language": ModalityConfig(
62
+ delta_indices=[0],
63
+ modality_keys=["annotation.language.language_instruction"],
64
+ ),
65
+ },
66
+ ##### Pre-registered posttrain configurations #####
67
+ "unitree_g1_full_body_with_waist_height_nav_cmd": {
68
+ "video": ModalityConfig(
69
+ delta_indices=[0],
70
+ modality_keys=["ego_view"],
71
+ ),
72
+ "state": ModalityConfig(
73
+ delta_indices=[0],
74
+ modality_keys=[
75
+ "left_leg",
76
+ "right_leg",
77
+ "waist",
78
+ "left_arm",
79
+ "right_arm",
80
+ "left_hand",
81
+ "right_hand",
82
+ ],
83
+ ),
84
+ "action": ModalityConfig(
85
+ delta_indices=list(range(50)),
86
+ modality_keys=[
87
+ "left_arm",
88
+ "right_arm",
89
+ "left_hand",
90
+ "right_hand",
91
+ "waist",
92
+ "base_height_command",
93
+ "navigate_command",
94
+ ],
95
+ action_configs=[
96
+ # left_arm
97
+ ActionConfig(
98
+ rep=ActionRepresentation.RELATIVE,
99
+ type=ActionType.NON_EEF,
100
+ format=ActionFormat.DEFAULT,
101
+ ),
102
+ # right_arm
103
+ ActionConfig(
104
+ rep=ActionRepresentation.RELATIVE,
105
+ type=ActionType.NON_EEF,
106
+ format=ActionFormat.DEFAULT,
107
+ ),
108
+ # left_hand
109
+ ActionConfig(
110
+ rep=ActionRepresentation.ABSOLUTE, # G1 hand is controlled by binary signals like a gripper
111
+ type=ActionType.NON_EEF,
112
+ format=ActionFormat.DEFAULT,
113
+ ),
114
+ # right_hand
115
+ ActionConfig(
116
+ rep=ActionRepresentation.ABSOLUTE, # G1 hand is controlled by binary signals like a gripper
117
+ type=ActionType.NON_EEF,
118
+ format=ActionFormat.DEFAULT,
119
+ ),
120
+ # waist
121
+ ActionConfig(
122
+ rep=ActionRepresentation.ABSOLUTE,
123
+ type=ActionType.NON_EEF,
124
+ format=ActionFormat.DEFAULT,
125
+ ),
126
+ # base_height_command
127
+ ActionConfig(
128
+ rep=ActionRepresentation.ABSOLUTE,
129
+ type=ActionType.NON_EEF,
130
+ format=ActionFormat.DEFAULT,
131
+ ),
132
+ # navigate_command
133
+ ActionConfig(
134
+ rep=ActionRepresentation.ABSOLUTE,
135
+ type=ActionType.NON_EEF,
136
+ format=ActionFormat.DEFAULT,
137
+ ),
138
+ ],
139
+ ),
140
+ "language": ModalityConfig(
141
+ delta_indices=[0],
142
+ modality_keys=["annotation.human.task_description"],
143
+ ),
144
+ },
145
+ "libero_sim": {
146
+ "video": ModalityConfig(
147
+ delta_indices=[0],
148
+ modality_keys=["image", "wrist_image"],
149
+ ),
150
+ "state": ModalityConfig(
151
+ delta_indices=[0],
152
+ modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
153
+ ),
154
+ "action": ModalityConfig(
155
+ delta_indices=list(range(16)),
156
+ modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
157
+ ),
158
+ "language": ModalityConfig(
159
+ delta_indices=[0],
160
+ modality_keys=["annotation.human.action.task_description"],
161
+ ),
162
+ },
163
+ "simpler_env_widowx": {
164
+ "video": ModalityConfig(
165
+ delta_indices=[0],
166
+ modality_keys=["image_0"],
167
+ ),
168
+ "state": ModalityConfig(
169
+ delta_indices=[0],
170
+ modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"],
171
+ ),
172
+ "action": ModalityConfig(
173
+ delta_indices=list(range(8)),
174
+ modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
175
+ ),
176
+ "language": ModalityConfig(
177
+ delta_indices=[0],
178
+ modality_keys=["annotation.human.action.task_description"],
179
+ ),
180
+ },
181
+ "simpler_env_google": {
182
+ "video": ModalityConfig(
183
+ delta_indices=[0],
184
+ modality_keys=["image"],
185
+ ),
186
+ "state": ModalityConfig(
187
+ delta_indices=[0],
188
+ modality_keys=["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"],
189
+ ),
190
+ "action": ModalityConfig(
191
+ delta_indices=list(range(8)),
192
+ modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
193
+ ),
194
+ "language": ModalityConfig(
195
+ delta_indices=[0],
196
+ modality_keys=["annotation.human.action.task_description"],
197
+ ),
198
+ },
199
+ }
200
+
201
+
202
+ def register_modality_config(
203
+ config: dict, embodiment_tag: EmbodimentTag = EmbodimentTag.NEW_EMBODIMENT
204
+ ):
205
+ assert embodiment_tag.value not in MODALITY_CONFIGS, (
206
+ f"Embodiment tag {embodiment_tag} already registered"
207
+ )
208
+ MODALITY_CONFIGS[embodiment_tag.value] = config
gr00t/configs/deepspeed/zero2_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "checkpoint": {
3
+ "load_universal": false
4
+ },
5
+ "train_batch_size": "auto",
6
+ "train_micro_batch_size_per_gpu": "auto",
7
+ "gradient_accumulation_steps": "auto",
8
+ "gradient_clipping": "auto",
9
+ "zero_allow_untested_optimizer": true,
10
+ "fp16": {
11
+ "enabled": "auto",
12
+ "loss_scale": 0,
13
+ "loss_scale_window": 1000,
14
+ "initial_scale_power": 16,
15
+ "hysteresis": 2,
16
+ "min_loss_scale": 1
17
+ },
18
+ "bf16": {
19
+ "enabled": "auto"
20
+ },
21
+ "communication_data_type": "bf16",
22
+ "zero_optimization": {
23
+ "stage": 2,
24
+ "overlap_comm": true,
25
+ "reduce_scatter": true,
26
+ "allgather_partitions": true,
27
+ "contiguous_gradients": true,
28
+ "sub_group_size": 1e9,
29
+ "reduce_bucket_size": 1e8,
30
+ "allgather_bucket_size": 1e8
31
+ }
32
+ }
33
+
gr00t/configs/deepspeed/zero3_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_batch_size": "auto",
3
+ "train_micro_batch_size_per_gpu": "auto",
4
+ "gradient_accumulation_steps": "auto",
5
+ "gradient_clipping": "auto",
6
+ "zero_allow_untested_optimizer": true,
7
+ "fp16": {
8
+ "enabled": "auto",
9
+ "loss_scale": 0,
10
+ "loss_scale_window": 1000,
11
+ "initial_scale_power": 16,
12
+ "hysteresis": 2,
13
+ "min_loss_scale": 1
14
+ },
15
+ "bf16": {
16
+ "enabled": "auto"
17
+ },
18
+ "zero_optimization": {
19
+ "stage": 3,
20
+ "overlap_comm": true,
21
+ "contiguous_gradients": true,
22
+ "sub_group_size": 1e9,
23
+ "reduce_bucket_size": "auto",
24
+ "stage3_prefetch_bucket_size": "auto",
25
+ "stage3_param_persistence_threshold": "auto",
26
+ "stage3_max_live_parameters": 1e9,
27
+ "stage3_max_reuse_distance": 1e9,
28
+ "stage3_gather_16bit_weights_on_model_save": true
29
+ }
30
+ }
31
+
gr00t/configs/finetune_config.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Finetune config used for single node post-training.
17
+ from dataclasses import dataclass
18
+
19
+
20
+ @dataclass
21
+ class FinetuneConfig:
22
+ """
23
+ Configuration for fine-tuning a Vision-Language-Action (VLA) model.
24
+
25
+ This dataclass defines all parameters needed to launch a fine-tuning job
26
+ on a pretrained base model using a custom dataset and embodiment-specific
27
+ modality configuration. It controls model tuning options, data augmentation,
28
+ and training hyperparameters.
29
+ """
30
+
31
+ # --- Data and Model Paths ---
32
+ base_model_path: str
33
+ """Path to the pretrained base model checkpoint (e.g., Hugging Face model hub or local directory)."""
34
+
35
+ dataset_path: str
36
+ """Path to the dataset root directory containing trajectory data for fine-tuning."""
37
+
38
+ embodiment_tag: str
39
+ """Embodiment tag (name or value, case-insensitive). See EmbodimentTag for known tags."""
40
+
41
+ modality_config_path: str | None = None
42
+ """
43
+ Path to a Python file defining the modality configuration for the given embodiment.
44
+ If None, use the pre-registered modality config in `gr00t/configs/data/embodiment_configs.py`.
45
+ """
46
+
47
+ # --- Model Tuning Flags ---
48
+ tune_llm: bool = False
49
+ """If True, fine-tune the language model (LLM) backbone during training."""
50
+
51
+ tune_visual: bool = False
52
+ """If True, fine-tune the visual encoder (e.g., ViT or CNN backbone)."""
53
+
54
+ tune_projector: bool = True
55
+ """If True, fine-tune the multimodal projector layers that map vision/language features to a shared space."""
56
+
57
+ tune_diffusion_model: bool = True
58
+ """If True, fine-tune the diffusion-based action decoder (if present in the model)."""
59
+
60
+ state_dropout_prob: float = 0.2
61
+ """
62
+ Dropout probability applied to state inputs for regularization during training.
63
+ """
64
+
65
+ # --- Data Augmentation ---
66
+ random_rotation_angle: int | None = None
67
+ """Maximum rotation angle (in degrees) for random rotation augmentation of input images."""
68
+
69
+ color_jitter_params: dict[str, float] | None = None
70
+ """
71
+ Parameters for color jitter augmentation on images.
72
+
73
+ Expected keys include:
74
+ - "brightness": float
75
+ - "contrast": float
76
+ - "saturation": float
77
+ - "hue": float
78
+ Example: {"brightness": 0.4, "contrast": 0.4, "saturation": 0.4, "hue": 0.1}
79
+
80
+ If None, applying the default color jitter augmentation from the pretrained model.
81
+ """
82
+ extra_augmentation_config: str | None = None
83
+ """
84
+ JSON string for extra image augmentations (mask-based and others).
85
+
86
+ Expected keys include:
87
+ - "background_noise_transforms": list of dicts for noise on mask regions
88
+ - "target_mask_values": list of int (e.g., [0])
89
+ - "p": float (probability of applying)
90
+ - "masked_region_transforms": list of dicts for color tint on mask regions
91
+ - "target_mask_values": list of int (e.g., [4] or [5])
92
+ - "p": float (probability of applying)
93
+ - "alpha_range": [min, max] for random_tint intensity
94
+
95
+ Example: {"background_noise_transforms": [{"target_mask_values": [0], "p": 0.9}],
96
+ "masked_region_transforms": [{"target_mask_values": [4], "p": 1.0, "alpha_range": [0, 1]}]}
97
+
98
+ If None, no extra augmentations are applied.
99
+ """
100
+
101
+ # --- Training Configuration ---
102
+ global_batch_size: int = 64
103
+ """Total effective batch size across all GPUs and accumulation steps."""
104
+
105
+ dataloader_num_workers: int = 2
106
+ """Number of parallel worker processes used for data loading."""
107
+
108
+ learning_rate: float = 1e-4
109
+ """Initial learning rate for optimizer."""
110
+
111
+ gradient_accumulation_steps: int = 1
112
+ """Number of forward passes to accumulate before performing a backward/update step."""
113
+
114
+ output_dir: str = "./outputs"
115
+ """Directory where model checkpoints, logs, and outputs are saved."""
116
+
117
+ experiment_name: str | None = None
118
+ """Optional experiment name used as the W&B run name. Defaults to the output directory basename."""
119
+
120
+ wandb_project: str = "finetune-gr00t-n1d7"
121
+ """W&B project name to log runs to."""
122
+
123
+ save_steps: int = 1000
124
+ """Frequency (in training steps) at which to save checkpoints."""
125
+
126
+ save_total_limit: int = 5
127
+ """Maximum number of checkpoints to keep before older ones are deleted."""
128
+
129
+ num_gpus: int = 1
130
+ """Number of GPUs available for distributed or single-node training."""
131
+
132
+ use_wandb: bool = False
133
+ """
134
+ If True, log metrics and artifacts to Weights & Biases (wandb).
135
+ The project is `finetune-gr00t-n1d7`.
136
+ You need to login to wandb to view the logs.
137
+ """
138
+
139
+ max_steps: int = 10000
140
+ """Total number of training steps to run before stopping."""
141
+
142
+ weight_decay: float = 1e-5
143
+ """Weight decay coefficient for optimizer (L2 regularization)."""
144
+
145
+ warmup_ratio: float = 0.05
146
+ """Proportion of total training steps used for learning rate warm-up."""
147
+
148
+ shard_size: int = 2**10
149
+ """Size of the shard to use for the dataset during preloading."""
150
+
151
+ episode_sampling_rate: float = 0.1
152
+ """Sampling rate for the episodes."""
153
+
154
+ num_shards_per_epoch: int = int(1e5)
155
+ """Number of shards to use for the dataset. reduce this number if vram is limited."""
156
+
157
+ save_only_model: bool = False
158
+ """If True, save only model weights (skip optimizer/scheduler/RNG states). Cannot resume training from these checkpoints."""
159
+
160
+ skip_weight_loading: bool = False
161
+ """If True, skip loading model weights from base_model_path (architecture only).
162
+ The processor (tokenizer/config) is still loaded from base_model_path.
163
+ Useful for CI/testing to skip the slow checkpoint shard loading."""
gr00t/configs/model/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import importlib
17
+ from pathlib import Path
18
+ import typing
19
+
20
+ import tyro
21
+
22
+
23
+ MODEL_CONFIG_TYPES: dict[str, type] = {}
24
+
25
+
26
+ def register_model_config(shortname: str, configtype: type):
27
+ MODEL_CONFIG_TYPES[shortname] = configtype
28
+
29
+
30
+ for file in Path(__file__).parent.glob("*.py"):
31
+ if file.stem.startswith("_"):
32
+ continue
33
+ try:
34
+ importlib.import_module(f".{file.stem}", __name__)
35
+ except KeyboardInterrupt:
36
+ raise
37
+ except Exception as e:
38
+ print(f"Error importing module gr00t.configs.model.{file.stem}: {e}")
39
+
40
+
41
+ def create_model_union_type():
42
+ if not MODEL_CONFIG_TYPES:
43
+ # A Union of no types is invalid, so just return None
44
+ return None
45
+
46
+ annotated_types = tuple(
47
+ typing.Annotated[model_type, tyro.conf.subcommand(model_shortname)]
48
+ for model_shortname, model_type in MODEL_CONFIG_TYPES.items()
49
+ )
50
+
51
+ # Create the Union dynamically
52
+ return typing.Union.__getitem__(annotated_types)
gr00t/configs/model/gr00t_n1d7.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import MISSING, asdict, dataclass, field, is_dataclass
17
+ from enum import Enum
18
+ import json
19
+ from pathlib import Path
20
+
21
+ import torch
22
+ from transformers import PretrainedConfig
23
+
24
+ from . import register_model_config
25
+
26
+
27
+ @dataclass
28
+ class Gr00tN1d7Config(PretrainedConfig):
29
+ """Unified configuration for Gr00tN1d7 model with backbone and action head.
30
+
31
+ Gr00tN1d7 uses the Cosmos-Reason2-2B (Qwen3-VL architecture) VLM backbone,
32
+ replacing the Eagle backbone used in Gr00tN1d6.
33
+ """
34
+
35
+ # Model identification
36
+ model_type: str = "Gr00tN1d7"
37
+ model_dtype: str = "bfloat16" # Use bfloat16 for Flash Attention compatibility
38
+
39
+ # Backbone configuration
40
+ model_name: str = "nvidia/Cosmos-Reason2-2B"
41
+ backbone_model_type: str = "qwen"
42
+ model_revision: str | None = None
43
+ tune_top_llm_layers: int = 0 # Number of top LLM layers to tune
44
+ backbone_embedding_dim: int = 2048 # project_to_dim; must match Cosmos-Reason2-2B hidden size
45
+ tune_llm: bool = False
46
+ tune_visual: bool = False
47
+ select_layer: int = 12
48
+ reproject_vision: bool = False
49
+ use_flash_attention: bool = True
50
+ load_bf16: bool = False # Enable BF16 loading
51
+ backbone_trainable_params_fp32: bool = True
52
+
53
+ ### Processing parameters
54
+ image_crop_size: tuple[int, int] | None = (230, 230)
55
+ image_target_size: tuple[int, int] | None = (256, 256)
56
+
57
+ shortest_image_edge: int | None = None
58
+ crop_fraction: float | None = None
59
+
60
+ random_rotation_angle: int | None = None
61
+ color_jitter_params: dict[str, float] | None = None
62
+ use_albumentations_transforms: bool = True
63
+ # Extra augmentation config (mask-based and others).
64
+ extra_augmentation_config: dict | None = None
65
+ formalize_language: bool = True
66
+ apply_sincos_state_encoding: bool = (
67
+ False # Global flag to enable per-embodiment sin/cos encoding
68
+ )
69
+ use_percentiles: bool = True
70
+ use_relative_action: bool = False
71
+
72
+ # Action head configuration parameters
73
+ max_state_dim: int = 132 # Default from state_shape
74
+ max_action_dim: int = 132 # Default from action_shape
75
+ action_horizon: int = 40
76
+ hidden_size: int = 1024
77
+ input_embedding_dim: int = 1536
78
+
79
+ # State history: number of consecutive state timesteps fed to the state encoder
80
+ state_history_length: int = 1
81
+
82
+ # Global parameters
83
+ add_pos_embed: bool = True
84
+ attn_dropout: float = 0.2
85
+ use_vlln: bool = True
86
+ max_seq_len: int = 1024
87
+ use_alternate_vl_dit: bool = True # True for AlternateVLDiT, False for DiT
88
+ attend_text_every_n_blocks: int = 2
89
+
90
+ diffusion_model_cfg: dict = field(
91
+ default_factory=lambda: {
92
+ "positional_embeddings": None,
93
+ "num_layers": 16,
94
+ "num_attention_heads": 32,
95
+ "attention_head_dim": 48,
96
+ "norm_type": "ada_norm",
97
+ "dropout": 0.2,
98
+ "final_dropout": True,
99
+ "output_dim": 1024,
100
+ "interleave_self_attention": True,
101
+ }
102
+ )
103
+
104
+ # Flow matching parameters
105
+ num_inference_timesteps: int = 4
106
+ noise_beta_alpha: float = 1.5
107
+ noise_beta_beta: float = 1.0
108
+ noise_s: float = 0.999
109
+ num_timestep_buckets: int = 1000
110
+
111
+ # Training parameters
112
+ tune_projector: bool = True
113
+ tune_diffusion_model: bool = True
114
+ tune_vlln: bool = True
115
+
116
+ # State augmentation parameters
117
+ state_dropout_prob: float = 0.8 # State dropout probability
118
+ exclude_state: bool = False # Zero out all state inputs (ablation)
119
+ use_mean_std: bool = False # Use mean/std normalization instead of min/max
120
+
121
+ # Multi-embodiment parameters
122
+ max_num_embodiments: int = 32
123
+
124
+ def __init__(self, **kwargs):
125
+ super().__init__(**kwargs)
126
+ for key, value in kwargs.items():
127
+ setattr(self, key, value)
128
+
129
+ # Ensures that all dataclass defaults (including those using default_factory)
130
+ # are explicitly assigned to the instance, even if dataclasses initialization or subclassing
131
+ # (PretrainedConfig) interferes with normal default injection.
132
+ for f in self.__dataclass_fields__.values():
133
+ if not hasattr(self, f.name):
134
+ if f.default is not MISSING:
135
+ setattr(self, f.name, f.default)
136
+ elif getattr(f, "default_factory", MISSING) is not MISSING:
137
+ setattr(self, f.name, f.default_factory())
138
+
139
+ def to_filtered_dict(self, exclude_augment: bool = True) -> dict:
140
+ """Return a dictionary representation of this config, optionally excluding augmentation keys."""
141
+ if is_dataclass(self):
142
+ cfg = asdict(self)
143
+ else:
144
+ cfg = dict(self.__dict__)
145
+
146
+ if exclude_augment:
147
+ exclude_keys = {
148
+ "random_rotation_angle",
149
+ "color_jitter_params",
150
+ "use_albumentations_transforms",
151
+ "formalize_language",
152
+ "image_crop_size",
153
+ "image_target_size",
154
+ "shortest_image_edge",
155
+ "crop_fraction",
156
+ }
157
+ cfg = {k: v for k, v in cfg.items() if k not in exclude_keys}
158
+
159
+ return cfg
160
+
161
+ def to_filtered_json(self, exclude_augment: bool = True, **kwargs) -> str:
162
+ """Return a JSON string of this config, optionally excluding augmentation keys."""
163
+
164
+ def default(o):
165
+ if isinstance(o, (Path, torch.dtype, torch.device)):
166
+ return str(o)
167
+ if isinstance(o, Enum):
168
+ return o.value
169
+ return str(o)
170
+
171
+ return json.dumps(
172
+ self.to_filtered_dict(exclude_augment),
173
+ indent=2,
174
+ default=default,
175
+ **kwargs,
176
+ )
177
+
178
+
179
+ register_model_config("Gr00tN1d7", Gr00tN1d7Config)
gr00t/configs/training/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
gr00t/configs/training/training_config.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass, field
17
+ from typing import Optional
18
+
19
+
20
+ @dataclass
21
+ class TrainingConfig:
22
+ """Training configuration."""
23
+
24
+ # Output
25
+ output_dir: str = "./outputs"
26
+ experiment_name: Optional[str] = None
27
+
28
+ # Basic training
29
+ max_steps: int = 30000 # this will override num_epochs
30
+ global_batch_size: int = 1024
31
+ batch_size: Optional[int] = None
32
+ gradient_accumulation_steps: int = 1
33
+
34
+ # Optimization
35
+ learning_rate: float = 1e-4
36
+ lr_scheduler_type: str = "cosine"
37
+ weight_decay: float = 1e-5
38
+ warmup_ratio: float = 0.05
39
+ warmup_steps: int = 0 # this will override warmup_ratio
40
+ max_grad_norm: float = 1.0
41
+
42
+ # Optimizer choice (huggingface TrainingArguments.optim)
43
+ # Options include: 'adamw_torch', 'adamw_torch_fused', 'paged_adamw_32bit',
44
+ # 'paged_adamw_8bit' (requires bitsandbytes), 'adafactor', etc.
45
+ optim: str = "adamw_torch_fused"
46
+
47
+ start_from_checkpoint: Optional[str] = None
48
+ skip_weight_loading: bool = False # skip loading checkpoint weights (architecture only)
49
+
50
+ # Mixed precision
51
+ tf32: bool = True
52
+ fp16: bool = False
53
+ bf16: bool = True
54
+ eval_bf16: bool = True
55
+
56
+ # Logging and saving
57
+ logging_steps: int = 10
58
+ save_steps: int = 1000
59
+ save_total_limit: int = 5
60
+
61
+ # Model saving
62
+ save_vl_model: bool = False # Control whether to save VL model and processor in callbacks
63
+ save_only_model: bool = False # Skip optimizer/scheduler/RNG states — cannot resume training
64
+
65
+ # Checkpoint uploading
66
+ upload_checkpoints: bool = False
67
+ upload_every: int = 1000
68
+ upload_last_n_checkpoints: int = 5
69
+ max_concurrent_uploads: int = 2
70
+
71
+ # Evaluation
72
+ eval_strategy: str = "no" # no, steps, epoch
73
+ eval_steps: int = 500
74
+ eval_set_split_ratio: float = 0.1
75
+ eval_batch_size: int = 2
76
+ save_best_eval_metric_name: str = ""
77
+ save_best_eval_metric_greater_is_better: bool = True
78
+
79
+ # DeepSpeed (default)
80
+ deepspeed_stage: int = 2 # ZeRO stage (1, 2, or 3)
81
+ gradient_checkpointing: bool = False
82
+
83
+ # Transformers loading parameters
84
+ transformers_trust_remote_code: bool = True
85
+ transformers_local_files_only: bool = False
86
+ transformers_cache_dir: str | None = None
87
+ transformers_access_token: str | None = None # Access token for HuggingFace Hub
88
+
89
+ # DDP
90
+ use_ddp: bool = False
91
+ ddp_bucket_cap_mb: int = 100
92
+
93
+ # Hardware
94
+ num_gpus: int = 1
95
+ dataloader_num_workers: int = 2
96
+
97
+ # Data handling
98
+ remove_unused_columns: bool = False
99
+
100
+ # Experiment tracking
101
+ use_wandb: bool = False
102
+ wandb_project: str = "finetune-gr00t-n1d7"
103
+
104
+ # Profiling
105
+ enable_profiling: bool = False
106
+
107
+ # Max number of retries in training for fault tolerance
108
+ max_retries: int = 3
109
+
110
+ # For testing.
111
+ assert_loss_less_than: float | None = None
112
+
113
+ # RL
114
+ add_rl_callback: bool = False
115
+
116
+ # Open-loop evaluation
117
+ enable_open_loop_eval: bool = False
118
+ """Enable open-loop evaluation on saved checkpoints."""
119
+
120
+ open_loop_eval_traj_ids: list[int] = field(default_factory=lambda: [0])
121
+ """List of trajectory IDs to evaluate."""
122
+
123
+ open_loop_eval_steps_per_traj: int = 100
124
+ """Number of steps to evaluate per trajectory."""
125
+
126
+ open_loop_eval_plot_indices: Optional[list[int]] = None
127
+ """List of action indices to plot. If None, plots all indices."""
gr00t/data/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
gr00t/data/collator/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .collators import BasicDataCollator