Adrian Li commited on
Commit ·
070759f
1
Parent(s): 6915737
Update references to GCP assets
Browse files- README.md +9 -9
- docs/norm_stats.md +3 -3
- docs/remote_inference.md +1 -1
- examples/aloha_real/README.md +40 -40
- examples/droid/README.md +1 -1
- examples/inference.ipynb +135 -135
- examples/ur5/README.md +2 -2
- scripts/serve_policy.py +4 -4
- src/openpi/models/model_test.py +1 -1
- src/openpi/policies/policy_test.py +2 -2
- src/openpi/shared/download.py +3 -135
- src/openpi/shared/download_test.py +4 -4
- src/openpi/training/config.py +8 -8
- src/openpi/training/weight_loaders.py +1 -1
README.md
CHANGED
|
@@ -59,22 +59,22 @@ We provide multiple base VLA model checkpoints. These checkpoints have been pre-
|
|
| 59 |
|
| 60 |
| Model | Use Case | Description | Checkpoint Path |
|
| 61 |
| ------------ | ----------- | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------- |
|
| 62 |
-
| $\pi_0$ | Fine-Tuning | Base diffusion [π₀ model](https://www.physicalintelligence.company/blog/pi0) for fine-tuning | `
|
| 63 |
-
| $\pi_0$-FAST | Fine-Tuning | Base autoregressive [π₀-FAST model](https://www.physicalintelligence.company/research/fast) for fine-tuning | `
|
| 64 |
|
| 65 |
### Fine-Tuned Models
|
| 66 |
We also provide "expert" checkpoints for various robot platforms and tasks. These models are fine-tuned from the base models above and intended to run directly on the target robot. These may or may not work on your particular robot. Since these checkpoints were fine-tuned on relatively small datasets collected with more widely available robots, such as ALOHA and the DROID Franka setup, they might not generalize to your particular setup, though we found some of these, especially the DROID checkpoint, to generalize quite broadly in practice.
|
| 67 |
|
| 68 |
| Model | Use Case | Description | Checkpoint Path |
|
| 69 |
| ------------------------ | ----------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------- |
|
| 70 |
-
| $\pi_0$-FAST-DROID | Inference | $\pi_0$-FAST model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/), can perform a wide range of simple table-top manipulation tasks 0-shot in new scenes on the DROID robot platform | `
|
| 71 |
-
| $\pi_0$-DROID | Fine-Tuning | $\pi_0$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/), faster inference than $\pi_0$-FAST-DROID, but may not follow language commands as well | `
|
| 72 |
-
| $\pi_0$-ALOHA-towel | Inference | $\pi_0$ model fine-tuned on internal ALOHA data, can fold diverse towels 0-shot on [ALOHA](https://tonyzhaozh.github.io/aloha/) robot platforms | `
|
| 73 |
-
| $\pi_0$-ALOHA-tupperware | Inference | $\pi_0$ model fine-tuned on internal ALOHA data, can unpack food from a tupperware container | `
|
| 74 |
-
| $\pi_0$-ALOHA-pen-uncap | Inference | $\pi_0$ model fine-tuned on [public ALOHA data](https://dit-policy.github.io/), can uncap a pen | `
|
| 75 |
|
| 76 |
|
| 77 |
-
By default, checkpoints are automatically downloaded from `
|
| 78 |
|
| 79 |
|
| 80 |
|
|
@@ -88,7 +88,7 @@ from openpi.policies import policy_config
|
|
| 88 |
from openpi.shared import download
|
| 89 |
|
| 90 |
config = config.get_config("pi0_fast_droid")
|
| 91 |
-
checkpoint_dir = download.maybe_download("
|
| 92 |
|
| 93 |
# Create a trained policy.
|
| 94 |
policy = policy_config.create_trained_policy(config, checkpoint_dir)
|
|
|
|
| 59 |
|
| 60 |
| Model | Use Case | Description | Checkpoint Path |
|
| 61 |
| ------------ | ----------- | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------- |
|
| 62 |
+
| $\pi_0$ | Fine-Tuning | Base diffusion [π₀ model](https://www.physicalintelligence.company/blog/pi0) for fine-tuning | `gs://openpi-assets/checkpoints/pi0_base` |
|
| 63 |
+
| $\pi_0$-FAST | Fine-Tuning | Base autoregressive [π₀-FAST model](https://www.physicalintelligence.company/research/fast) for fine-tuning | `gs://openpi-assets/checkpoints/pi0_fast_base` |
|
| 64 |
|
| 65 |
### Fine-Tuned Models
|
| 66 |
We also provide "expert" checkpoints for various robot platforms and tasks. These models are fine-tuned from the base models above and intended to run directly on the target robot. These may or may not work on your particular robot. Since these checkpoints were fine-tuned on relatively small datasets collected with more widely available robots, such as ALOHA and the DROID Franka setup, they might not generalize to your particular setup, though we found some of these, especially the DROID checkpoint, to generalize quite broadly in practice.
|
| 67 |
|
| 68 |
| Model | Use Case | Description | Checkpoint Path |
|
| 69 |
| ------------------------ | ----------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------- |
|
| 70 |
+
| $\pi_0$-FAST-DROID | Inference | $\pi_0$-FAST model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/), can perform a wide range of simple table-top manipulation tasks 0-shot in new scenes on the DROID robot platform | `gs://openpi-assets/checkpoints/pi0_fast_droid` |
|
| 71 |
+
| $\pi_0$-DROID | Fine-Tuning | $\pi_0$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/), faster inference than $\pi_0$-FAST-DROID, but may not follow language commands as well | `gs://openpi-assets/checkpoints/pi0_droid` |
|
| 72 |
+
| $\pi_0$-ALOHA-towel | Inference | $\pi_0$ model fine-tuned on internal ALOHA data, can fold diverse towels 0-shot on [ALOHA](https://tonyzhaozh.github.io/aloha/) robot platforms | `gs://openpi-assets/checkpoints/pi0_aloha_towel` |
|
| 73 |
+
| $\pi_0$-ALOHA-tupperware | Inference | $\pi_0$ model fine-tuned on internal ALOHA data, can unpack food from a tupperware container | `gs://openpi-assets/checkpoints/pi0_aloha_tupperware` |
|
| 74 |
+
| $\pi_0$-ALOHA-pen-uncap | Inference | $\pi_0$ model fine-tuned on [public ALOHA data](https://dit-policy.github.io/), can uncap a pen | `gs://openpi-assets/checkpoints/pi0_aloha_pen_uncap` |
|
| 75 |
|
| 76 |
|
| 77 |
+
By default, checkpoints are automatically downloaded from `gs://openpi-assets` and are cached in `~/.cache/openpi` when needed. You can overwrite the download path by setting the `OPENPI_DATA_HOME` environment variable.
|
| 78 |
|
| 79 |
|
| 80 |
|
|
|
|
| 88 |
from openpi.shared import download
|
| 89 |
|
| 90 |
config = config.get_config("pi0_fast_droid")
|
| 91 |
+
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_fast_droid")
|
| 92 |
|
| 93 |
# Create a trained policy.
|
| 94 |
policy = policy_config.create_trained_policy(config, checkpoint_dir)
|
docs/norm_stats.md
CHANGED
|
@@ -14,7 +14,7 @@ TrainConfig(
|
|
| 14 |
data=LeRobotAlohaDataConfig(
|
| 15 |
...
|
| 16 |
assets=AssetsConfig(
|
| 17 |
-
assets_dir="
|
| 18 |
asset_id="trossen",
|
| 19 |
),
|
| 20 |
),
|
|
@@ -30,7 +30,7 @@ For an example of a full training config that reloads normalization statistics,
|
|
| 30 |
|
| 31 |
## Provided Pre-training Normalization Statistics
|
| 32 |
|
| 33 |
-
Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `
|
| 34 |
| Robot | Description | Asset ID |
|
| 35 |
|-------|-------------|----------|
|
| 36 |
| ALOHA | 6-DoF dual arm robot with parallel grippers | trossen |
|
|
@@ -59,7 +59,7 @@ Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action
|
|
| 59 |
|
| 60 |
The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state.
|
| 61 |
|
| 62 |
-
For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action.
|
| 63 |
|
| 64 |
General info for Pi robots:
|
| 65 |
- Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details).
|
|
|
|
| 14 |
data=LeRobotAlohaDataConfig(
|
| 15 |
...
|
| 16 |
assets=AssetsConfig(
|
| 17 |
+
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
|
| 18 |
asset_id="trossen",
|
| 19 |
),
|
| 20 |
),
|
|
|
|
| 30 |
|
| 31 |
## Provided Pre-training Normalization Statistics
|
| 32 |
|
| 33 |
+
Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_fast_base/assets`.
|
| 34 |
| Robot | Description | Asset ID |
|
| 35 |
|-------|-------------|----------|
|
| 36 |
| ALOHA | 6-DoF dual arm robot with parallel grippers | trossen |
|
|
|
|
| 59 |
|
| 60 |
The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state.
|
| 61 |
|
| 62 |
+
For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action.
|
| 63 |
|
| 64 |
General info for Pi robots:
|
| 65 |
- Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details).
|
docs/remote_inference.md
CHANGED
|
@@ -14,7 +14,7 @@ uv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO]
|
|
| 14 |
The `env` argument specifies which $\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment):
|
| 15 |
|
| 16 |
```bash
|
| 17 |
-
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=
|
| 18 |
```
|
| 19 |
|
| 20 |
This will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000).
|
|
|
|
| 14 |
The `env` argument specifies which $\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment):
|
| 15 |
|
| 16 |
```bash
|
| 17 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
|
| 18 |
```
|
| 19 |
|
| 20 |
This will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000).
|
examples/aloha_real/README.md
CHANGED
|
@@ -53,74 +53,74 @@ While we’ve found the policies to work in unseen conditions across multiple AL
|
|
| 53 |
|
| 54 |
---
|
| 55 |
|
| 56 |
-
### **Toast Task**
|
| 57 |
|
| 58 |
-
This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
|
| 59 |
|
| 60 |
-
- **Checkpoint path**: `
|
| 61 |
- **Prompt**: "take the toast out of the toaster"
|
| 62 |
-
- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
|
| 63 |
-
- **Object Distribution**:
|
| 64 |
-
- Works on both real toast and rubber fake toast
|
| 65 |
-
- Compatible with standard 2-slice toasters
|
| 66 |
-
- Works with plates of varying colors
|
| 67 |
|
| 68 |
### **Scene Setup Guidelines**
|
| 69 |
<img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />
|
| 70 |
|
| 71 |
-
- The toaster should be positioned in the top-left quadrant of the workspace.
|
| 72 |
-
- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
|
| 73 |
-
- The plate should be placed roughly in the lower-center of the workspace.
|
| 74 |
-
- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
|
| 75 |
|
| 76 |
|
| 77 |
-
### **Towel Task**
|
| 78 |
|
| 79 |
This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
|
| 80 |
|
| 81 |
-
- **Checkpoint path**: `
|
| 82 |
-
- **Prompt**: "fold the towel"
|
| 83 |
-
- **Object Distribution**:
|
| 84 |
-
- Works on towels of varying solid colors
|
| 85 |
-
- Performance is worse on heavily textured or striped towels
|
| 86 |
|
| 87 |
-
### **Scene Setup Guidelines**
|
| 88 |
<img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />
|
| 89 |
|
| 90 |
-
- The towel should be flattened and roughly centered on the table.
|
| 91 |
-
- Choose a towel that does not blend in with the table surface.
|
| 92 |
|
| 93 |
|
| 94 |
-
### **Tupperware Task**
|
| 95 |
|
| 96 |
-
This task involves opening a tupperware filled with food and pouring the contents onto a plate.
|
| 97 |
|
| 98 |
-
- **Checkpoint path**: `
|
| 99 |
- **Prompt**: "open the tupperware and put the food on the plate"
|
| 100 |
-
- **Objects needed**: Tupperware, food (or food-like items), and a plate.
|
| 101 |
-
- **Object Distribution**:
|
| 102 |
-
- Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
|
| 103 |
-
- Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
|
| 104 |
-
- The policy has seen plates of varying solid colors.
|
| 105 |
|
| 106 |
-
### **Scene Setup Guidelines**
|
| 107 |
<img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />
|
| 108 |
|
| 109 |
-
- Best performance observed when both the tupperware and plate are roughly centered in the workspace.
|
| 110 |
-
- Positioning:
|
| 111 |
-
- Tupperware should be on the left.
|
| 112 |
-
- Plate should be on the right or bottom.
|
| 113 |
-
- The tupperware flap should point toward the plate.
|
| 114 |
|
| 115 |
## Training on your own Aloha dataset
|
| 116 |
|
| 117 |
-
1. Convert the dataset to the LeRobot dataset v2.0 format.
|
| 118 |
-
|
| 119 |
-
We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
|
| 120 |
|
|
|
|
| 121 |
|
| 122 |
-
|
|
|
|
| 123 |
|
| 124 |
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
|
| 125 |
-
|
| 126 |
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
|
|
|
|
| 53 |
|
| 54 |
---
|
| 55 |
|
| 56 |
+
### **Toast Task**
|
| 57 |
|
| 58 |
+
This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
|
| 59 |
|
| 60 |
+
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_base`
|
| 61 |
- **Prompt**: "take the toast out of the toaster"
|
| 62 |
+
- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
|
| 63 |
+
- **Object Distribution**:
|
| 64 |
+
- Works on both real toast and rubber fake toast
|
| 65 |
+
- Compatible with standard 2-slice toasters
|
| 66 |
+
- Works with plates of varying colors
|
| 67 |
|
| 68 |
### **Scene Setup Guidelines**
|
| 69 |
<img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />
|
| 70 |
|
| 71 |
+
- The toaster should be positioned in the top-left quadrant of the workspace.
|
| 72 |
+
- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
|
| 73 |
+
- The plate should be placed roughly in the lower-center of the workspace.
|
| 74 |
+
- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
|
| 75 |
|
| 76 |
|
| 77 |
+
### **Towel Task**
|
| 78 |
|
| 79 |
This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
|
| 80 |
|
| 81 |
+
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_towel`
|
| 82 |
+
- **Prompt**: "fold the towel"
|
| 83 |
+
- **Object Distribution**:
|
| 84 |
+
- Works on towels of varying solid colors
|
| 85 |
+
- Performance is worse on heavily textured or striped towels
|
| 86 |
|
| 87 |
+
### **Scene Setup Guidelines**
|
| 88 |
<img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />
|
| 89 |
|
| 90 |
+
- The towel should be flattened and roughly centered on the table.
|
| 91 |
+
- Choose a towel that does not blend in with the table surface.
|
| 92 |
|
| 93 |
|
| 94 |
+
### **Tupperware Task**
|
| 95 |
|
| 96 |
+
This task involves opening a tupperware filled with food and pouring the contents onto a plate.
|
| 97 |
|
| 98 |
+
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_tupperware`
|
| 99 |
- **Prompt**: "open the tupperware and put the food on the plate"
|
| 100 |
+
- **Objects needed**: Tupperware, food (or food-like items), and a plate.
|
| 101 |
+
- **Object Distribution**:
|
| 102 |
+
- Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
|
| 103 |
+
- Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
|
| 104 |
+
- The policy has seen plates of varying solid colors.
|
| 105 |
|
| 106 |
+
### **Scene Setup Guidelines**
|
| 107 |
<img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />
|
| 108 |
|
| 109 |
+
- Best performance observed when both the tupperware and plate are roughly centered in the workspace.
|
| 110 |
+
- Positioning:
|
| 111 |
+
- Tupperware should be on the left.
|
| 112 |
+
- Plate should be on the right or bottom.
|
| 113 |
+
- The tupperware flap should point toward the plate.
|
| 114 |
|
| 115 |
## Training on your own Aloha dataset
|
| 116 |
|
| 117 |
+
1. Convert the dataset to the LeRobot dataset v2.0 format.
|
|
|
|
|
|
|
| 118 |
|
| 119 |
+
We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
|
| 120 |
|
| 121 |
+
|
| 122 |
+
2. Define a training config that uses the custom dataset.
|
| 123 |
|
| 124 |
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
|
| 125 |
+
|
| 126 |
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
|
examples/droid/README.md
CHANGED
|
@@ -11,7 +11,7 @@ Since the DROID control laptop does not have a powerful GPU, we will start a rem
|
|
| 11 |
2. Start the OpenPI server via the following command:
|
| 12 |
|
| 13 |
```bash
|
| 14 |
-
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=
|
| 15 |
```
|
| 16 |
|
| 17 |
You can also run the equivalent command below:
|
|
|
|
| 11 |
2. Start the OpenPI server via the following command:
|
| 12 |
|
| 13 |
```bash
|
| 14 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
|
| 15 |
```
|
| 16 |
|
| 17 |
You can also run the equivalent command below:
|
examples/inference.ipynb
CHANGED
|
@@ -1,137 +1,137 @@
|
|
| 1 |
{
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import dataclasses\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"import jax\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"from openpi.models import model as _model\n",
|
| 14 |
+
"from openpi.policies import droid_policy\n",
|
| 15 |
+
"from openpi.policies import policy_config as _policy_config\n",
|
| 16 |
+
"from openpi.shared import download\n",
|
| 17 |
+
"from openpi.training import config as _config\n",
|
| 18 |
+
"from openpi.training import data_loader as _data_loader"
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "markdown",
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"source": [
|
| 25 |
+
"# Policy inference\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"cell_type": "code",
|
| 32 |
+
"execution_count": null,
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"outputs": [],
|
| 35 |
+
"source": [
|
| 36 |
+
"config = _config.get_config(\"pi0_fast_droid\")\n",
|
| 37 |
+
"checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_fast_droid\")\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"# Create a trained policy.\n",
|
| 40 |
+
"policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
|
| 43 |
+
"example = droid_policy.make_droid_example()\n",
|
| 44 |
+
"result = policy.infer(example)\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"# Delete the policy to free up memory.\n",
|
| 47 |
+
"del policy\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"print(\"Actions shape:\", result[\"actions\"].shape)"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "markdown",
|
| 54 |
+
"metadata": {},
|
| 55 |
+
"source": [
|
| 56 |
+
"# Working with a live model\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
|
| 60 |
+
]
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"cell_type": "code",
|
| 64 |
+
"execution_count": null,
|
| 65 |
+
"metadata": {},
|
| 66 |
+
"outputs": [],
|
| 67 |
+
"source": [
|
| 68 |
+
"config = _config.get_config(\"pi0_aloha_sim\")\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
|
| 71 |
+
"key = jax.random.key(0)\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"# Create a model from the checkpoint.\n",
|
| 74 |
+
"model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"# We can create fake observations and actions to test the model.\n",
|
| 77 |
+
"obs, act = config.model.fake_obs(), config.model.fake_act()\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"# Sample actions from the model.\n",
|
| 80 |
+
"loss = model.compute_loss(key, obs, act)\n",
|
| 81 |
+
"print(\"Loss shape:\", loss.shape)"
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "markdown",
|
| 86 |
+
"metadata": {},
|
| 87 |
+
"source": [
|
| 88 |
+
"Now, we are going to create a data loader and use a real batch of training data to compute the loss."
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "code",
|
| 93 |
+
"execution_count": null,
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"outputs": [],
|
| 96 |
+
"source": [
|
| 97 |
+
"# Reduce the batch size to reduce memory usage.\n",
|
| 98 |
+
"config = dataclasses.replace(config, batch_size=2)\n",
|
| 99 |
+
"\n",
|
| 100 |
+
"# Load a single batch of data. This is the same data that will be used during training.\n",
|
| 101 |
+
"# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
|
| 102 |
+
"# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
|
| 103 |
+
"loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
|
| 104 |
+
"obs, act = next(iter(loader))\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"# Sample actions from the model.\n",
|
| 107 |
+
"loss = model.compute_loss(key, obs, act)\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"# Delete the model to free up memory.\n",
|
| 110 |
+
"del model\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"print(\"Loss shape:\", loss.shape)"
|
| 113 |
+
]
|
| 114 |
+
}
|
| 115 |
+
],
|
| 116 |
+
"metadata": {
|
| 117 |
+
"kernelspec": {
|
| 118 |
+
"display_name": ".venv",
|
| 119 |
+
"language": "python",
|
| 120 |
+
"name": "python3"
|
| 121 |
+
},
|
| 122 |
+
"language_info": {
|
| 123 |
+
"codemirror_mode": {
|
| 124 |
+
"name": "ipython",
|
| 125 |
+
"version": 3
|
| 126 |
+
},
|
| 127 |
+
"file_extension": ".py",
|
| 128 |
+
"mimetype": "text/x-python",
|
| 129 |
+
"name": "python",
|
| 130 |
+
"nbconvert_exporter": "python",
|
| 131 |
+
"pygments_lexer": "ipython3",
|
| 132 |
+
"version": "3.11.9"
|
| 133 |
+
}
|
| 134 |
+
},
|
| 135 |
+
"nbformat": 4,
|
| 136 |
+
"nbformat_minor": 2
|
| 137 |
}
|
examples/ur5/README.md
CHANGED
|
@@ -129,7 +129,7 @@ TrainConfig(
|
|
| 129 |
# Reloading normalization stats can help transfer pre-trained models to new environments.
|
| 130 |
# See the [norm_stats.md](../docs/norm_stats.md) file for more details.
|
| 131 |
assets=AssetsConfig(
|
| 132 |
-
assets_dir="
|
| 133 |
asset_id="ur5e",
|
| 134 |
),
|
| 135 |
base_config=DataConfig(
|
|
@@ -139,7 +139,7 @@ TrainConfig(
|
|
| 139 |
),
|
| 140 |
),
|
| 141 |
# Load the pi0 base model checkpoint.
|
| 142 |
-
weight_loader=weight_loaders.CheckpointWeightLoader("
|
| 143 |
num_train_steps=30_000,
|
| 144 |
)
|
| 145 |
```
|
|
|
|
| 129 |
# Reloading normalization stats can help transfer pre-trained models to new environments.
|
| 130 |
# See the [norm_stats.md](../docs/norm_stats.md) file for more details.
|
| 131 |
assets=AssetsConfig(
|
| 132 |
+
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
|
| 133 |
asset_id="ur5e",
|
| 134 |
),
|
| 135 |
base_config=DataConfig(
|
|
|
|
| 139 |
),
|
| 140 |
),
|
| 141 |
# Load the pi0 base model checkpoint.
|
| 142 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
|
| 143 |
num_train_steps=30_000,
|
| 144 |
)
|
| 145 |
```
|
scripts/serve_policy.py
CHANGED
|
@@ -59,19 +59,19 @@ class Args:
|
|
| 59 |
DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
|
| 60 |
EnvMode.ALOHA: Checkpoint(
|
| 61 |
config="pi0_aloha",
|
| 62 |
-
dir="
|
| 63 |
),
|
| 64 |
EnvMode.ALOHA_SIM: Checkpoint(
|
| 65 |
config="pi0_aloha_sim",
|
| 66 |
-
dir="
|
| 67 |
),
|
| 68 |
EnvMode.DROID: Checkpoint(
|
| 69 |
config="pi0_fast_droid",
|
| 70 |
-
dir="
|
| 71 |
),
|
| 72 |
EnvMode.LIBERO: Checkpoint(
|
| 73 |
config="pi0_fast_libero",
|
| 74 |
-
dir="
|
| 75 |
),
|
| 76 |
}
|
| 77 |
|
|
|
|
| 59 |
DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
|
| 60 |
EnvMode.ALOHA: Checkpoint(
|
| 61 |
config="pi0_aloha",
|
| 62 |
+
dir="gs://openpi-assets/checkpoints/pi0_base",
|
| 63 |
),
|
| 64 |
EnvMode.ALOHA_SIM: Checkpoint(
|
| 65 |
config="pi0_aloha_sim",
|
| 66 |
+
dir="gs://openpi-assets/checkpoints/pi0_aloha_sim",
|
| 67 |
),
|
| 68 |
EnvMode.DROID: Checkpoint(
|
| 69 |
config="pi0_fast_droid",
|
| 70 |
+
dir="gs://openpi-assets/checkpoints/pi0_fast_droid",
|
| 71 |
),
|
| 72 |
EnvMode.LIBERO: Checkpoint(
|
| 73 |
config="pi0_fast_libero",
|
| 74 |
+
dir="gs://openpi-assets/checkpoints/pi0_fast_libero",
|
| 75 |
),
|
| 76 |
}
|
| 77 |
|
src/openpi/models/model_test.py
CHANGED
|
@@ -84,7 +84,7 @@ def test_model_restore():
|
|
| 84 |
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
| 85 |
|
| 86 |
model = config.load(
|
| 87 |
-
_model.restore_params(download.maybe_download("
|
| 88 |
)
|
| 89 |
|
| 90 |
loss = model.compute_loss(key, obs, act)
|
|
|
|
| 84 |
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
| 85 |
|
| 86 |
model = config.load(
|
| 87 |
+
_model.restore_params(download.maybe_download("gs://openpi-assets/checkpoints/pi0_base/params"))
|
| 88 |
)
|
| 89 |
|
| 90 |
loss = model.compute_loss(key, obs, act)
|
src/openpi/policies/policy_test.py
CHANGED
|
@@ -9,7 +9,7 @@ from openpi.training import config as _config
|
|
| 9 |
@pytest.mark.manual
|
| 10 |
def test_infer():
|
| 11 |
config = _config.get_config("pi0_aloha_sim")
|
| 12 |
-
policy = _policy_config.create_trained_policy(config, "
|
| 13 |
|
| 14 |
example = aloha_policy.make_aloha_example()
|
| 15 |
result = policy.infer(example)
|
|
@@ -20,7 +20,7 @@ def test_infer():
|
|
| 20 |
@pytest.mark.manual
|
| 21 |
def test_broker():
|
| 22 |
config = _config.get_config("pi0_aloha_sim")
|
| 23 |
-
policy = _policy_config.create_trained_policy(config, "
|
| 24 |
|
| 25 |
broker = action_chunk_broker.ActionChunkBroker(
|
| 26 |
policy,
|
|
|
|
| 9 |
@pytest.mark.manual
|
| 10 |
def test_infer():
|
| 11 |
config = _config.get_config("pi0_aloha_sim")
|
| 12 |
+
policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim")
|
| 13 |
|
| 14 |
example = aloha_policy.make_aloha_example()
|
| 15 |
result = policy.infer(example)
|
|
|
|
| 20 |
@pytest.mark.manual
|
| 21 |
def test_broker():
|
| 22 |
config = _config.get_config("pi0_aloha_sim")
|
| 23 |
+
policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim")
|
| 24 |
|
| 25 |
broker = action_chunk_broker.ActionChunkBroker(
|
| 26 |
policy,
|
src/openpi/shared/download.py
CHANGED
|
@@ -10,15 +10,10 @@ import stat
|
|
| 10 |
import time
|
| 11 |
import urllib.parse
|
| 12 |
|
| 13 |
-
import boto3
|
| 14 |
-
import boto3.s3.transfer as s3_transfer
|
| 15 |
-
import botocore
|
| 16 |
import filelock
|
| 17 |
import fsspec
|
| 18 |
import fsspec.generic
|
| 19 |
-
import s3transfer.futures as s3_transfer_futures
|
| 20 |
import tqdm_loggable.auto as tqdm
|
| 21 |
-
from types_boto3_s3.service_resource import ObjectSummary
|
| 22 |
|
| 23 |
# Environment variable to control cache directory path, ~/.cache/openpi will be used by default.
|
| 24 |
_OPENPI_DATA_HOME = "OPENPI_DATA_HOME"
|
|
@@ -92,22 +87,7 @@ def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathl
|
|
| 92 |
# Download the data to a local cache.
|
| 93 |
logger.info(f"Downloading {url} to {local_path}")
|
| 94 |
scratch_path = local_path.with_suffix(".partial")
|
| 95 |
-
|
| 96 |
-
if _is_openpi_url(url):
|
| 97 |
-
# Download without credentials.
|
| 98 |
-
_download_boto3(
|
| 99 |
-
url,
|
| 100 |
-
scratch_path,
|
| 101 |
-
boto_session=boto3.Session(
|
| 102 |
-
region_name="us-west-1",
|
| 103 |
-
),
|
| 104 |
-
botocore_config=botocore.config.Config(signature_version=botocore.UNSIGNED),
|
| 105 |
-
)
|
| 106 |
-
elif url.startswith("s3://"):
|
| 107 |
-
# Download with default boto3 credentials.
|
| 108 |
-
_download_boto3(url, scratch_path)
|
| 109 |
-
else:
|
| 110 |
-
_download_fsspec(url, scratch_path, **kwargs)
|
| 111 |
|
| 112 |
shutil.move(scratch_path, local_path)
|
| 113 |
_ensure_permissions(local_path)
|
|
@@ -126,7 +106,8 @@ def _download_fsspec(url: str, local_path: pathlib.Path, **kwargs) -> None:
|
|
| 126 |
"""Download a file from a remote filesystem to the local cache, and return the local path."""
|
| 127 |
fs, _ = fsspec.core.url_to_fs(url, **kwargs)
|
| 128 |
info = fs.info(url)
|
| 129 |
-
|
|
|
|
| 130 |
total_size = fs.du(url)
|
| 131 |
else:
|
| 132 |
total_size = info["size"]
|
|
@@ -140,114 +121,6 @@ def _download_fsspec(url: str, local_path: pathlib.Path, **kwargs) -> None:
|
|
| 140 |
pbar.update(total_size - pbar.n)
|
| 141 |
|
| 142 |
|
| 143 |
-
def _download_boto3(
|
| 144 |
-
url: str,
|
| 145 |
-
local_path: pathlib.Path,
|
| 146 |
-
*,
|
| 147 |
-
boto_session: boto3.Session | None = None,
|
| 148 |
-
botocore_config: botocore.config.Config | None = None,
|
| 149 |
-
workers: int = 16,
|
| 150 |
-
) -> None:
|
| 151 |
-
"""Download a file from the OpenPI S3 bucket using boto3. This is a more performant version of download but can
|
| 152 |
-
only handle s3 urls. In openpi repo, this is mainly used to access assets in S3 with higher throughput.
|
| 153 |
-
|
| 154 |
-
Input:
|
| 155 |
-
url: URL to openpi checkpoint path.
|
| 156 |
-
local_path: local path to the downloaded file.
|
| 157 |
-
boto_session: Optional boto3 session, will create by default if not provided.
|
| 158 |
-
botocore_config: Optional botocore config.
|
| 159 |
-
workers: number of workers for downloading.
|
| 160 |
-
"""
|
| 161 |
-
|
| 162 |
-
def validate_and_parse_url(maybe_s3_url: str) -> tuple[str, str]:
|
| 163 |
-
parsed = urllib.parse.urlparse(maybe_s3_url)
|
| 164 |
-
if parsed.scheme != "s3":
|
| 165 |
-
raise ValueError(f"URL must be an S3 URL (s3://), got: {maybe_s3_url}")
|
| 166 |
-
bucket_name = parsed.netloc
|
| 167 |
-
prefix = parsed.path.strip("/")
|
| 168 |
-
return bucket_name, prefix
|
| 169 |
-
|
| 170 |
-
bucket_name, prefix = validate_and_parse_url(url)
|
| 171 |
-
session = boto_session or boto3.Session()
|
| 172 |
-
|
| 173 |
-
s3api = session.resource("s3", config=botocore_config)
|
| 174 |
-
bucket = s3api.Bucket(bucket_name)
|
| 175 |
-
|
| 176 |
-
# Check if prefix points to an object and if not, assume that it's a directory and add a trailing slash.
|
| 177 |
-
try:
|
| 178 |
-
bucket.Object(prefix).load()
|
| 179 |
-
except botocore.exceptions.ClientError:
|
| 180 |
-
# Make sure to append a "/" to prevent getting objects from a different directory that shares the same prefix.
|
| 181 |
-
# For example, if we are downloading from s3://bucket/foo, we don't want to also download from s3://bucket/foobar.
|
| 182 |
-
if not prefix.endswith("/"):
|
| 183 |
-
prefix = prefix + "/"
|
| 184 |
-
|
| 185 |
-
# Get all candidate objects, filter out directories.
|
| 186 |
-
objects = [x for x in bucket.objects.filter(Prefix=prefix) if not x.key.endswith("/")]
|
| 187 |
-
if not objects:
|
| 188 |
-
raise FileNotFoundError(f"No objects found at {url}")
|
| 189 |
-
|
| 190 |
-
total_size = sum(obj.size for obj in objects)
|
| 191 |
-
|
| 192 |
-
s3t = _get_s3_transfer_manager(session, workers, botocore_config=botocore_config)
|
| 193 |
-
|
| 194 |
-
def transfer(
|
| 195 |
-
s3obj: ObjectSummary, dest_path: pathlib.Path, progress_func
|
| 196 |
-
) -> s3_transfer_futures.TransferFuture | None:
|
| 197 |
-
if dest_path.exists():
|
| 198 |
-
dest_stat = dest_path.stat()
|
| 199 |
-
if s3obj.size == dest_stat.st_size:
|
| 200 |
-
progress_func(s3obj.size)
|
| 201 |
-
return None
|
| 202 |
-
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
| 203 |
-
return s3t.download(
|
| 204 |
-
bucket_name,
|
| 205 |
-
s3obj.key,
|
| 206 |
-
str(dest_path),
|
| 207 |
-
subscribers=[
|
| 208 |
-
s3_transfer.ProgressCallbackInvoker(progress_func),
|
| 209 |
-
],
|
| 210 |
-
)
|
| 211 |
-
|
| 212 |
-
try:
|
| 213 |
-
with tqdm.tqdm(total=total_size, unit="iB", unit_scale=True, unit_divisor=1024) as pbar:
|
| 214 |
-
if os.getenv("IS_DOCKER", "false").lower() == "true":
|
| 215 |
-
# tqdm is bugged when using docker-compose. See https://github.com/tqdm/tqdm/issues/771
|
| 216 |
-
def update_progress(size: int) -> None:
|
| 217 |
-
pbar.update(size)
|
| 218 |
-
print(pbar)
|
| 219 |
-
else:
|
| 220 |
-
|
| 221 |
-
def update_progress(size: int) -> None:
|
| 222 |
-
pbar.update(size)
|
| 223 |
-
|
| 224 |
-
futures = []
|
| 225 |
-
for obj in objects:
|
| 226 |
-
relative_path = pathlib.Path(obj.key).relative_to(prefix)
|
| 227 |
-
dest_path = local_path / relative_path
|
| 228 |
-
if future := transfer(obj, dest_path, update_progress):
|
| 229 |
-
futures.append(future)
|
| 230 |
-
for future in futures:
|
| 231 |
-
future.result()
|
| 232 |
-
finally:
|
| 233 |
-
s3t.shutdown()
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
def _get_s3_transfer_manager(
|
| 237 |
-
session: boto3.Session, workers: int, botocore_config: botocore.config.Config | None = None
|
| 238 |
-
) -> s3_transfer.TransferManager:
|
| 239 |
-
# Add a few extra connections to prevent exceeding the pool size.
|
| 240 |
-
config = botocore.config.Config(max_pool_connections=workers + 2)
|
| 241 |
-
if botocore_config is not None:
|
| 242 |
-
config = config.merge(botocore_config)
|
| 243 |
-
s3client = session.client("s3", config=config)
|
| 244 |
-
transfer_config = s3_transfer.TransferConfig(
|
| 245 |
-
use_threads=True,
|
| 246 |
-
max_concurrency=workers,
|
| 247 |
-
)
|
| 248 |
-
return s3_transfer.create_transfer_manager(s3client, transfer_config)
|
| 249 |
-
|
| 250 |
-
|
| 251 |
def _set_permission(path: pathlib.Path, target_permission: int):
|
| 252 |
"""chmod requires executable permission to be set, so we skip if the permission is already match with the target."""
|
| 253 |
if path.stat().st_mode & target_permission == target_permission:
|
|
@@ -295,11 +168,6 @@ def _ensure_permissions(path: pathlib.Path) -> None:
|
|
| 295 |
_set_folder_permission(dir_path)
|
| 296 |
|
| 297 |
|
| 298 |
-
def _is_openpi_url(url: str) -> bool:
|
| 299 |
-
"""Check if the url is an OpenPI S3 bucket url."""
|
| 300 |
-
return url.startswith("s3://openpi-assets/")
|
| 301 |
-
|
| 302 |
-
|
| 303 |
def _get_mtime(year: int, month: int, day: int) -> float:
|
| 304 |
"""Get the mtime of a given date at midnight UTC."""
|
| 305 |
date = datetime.datetime(year, month, day, tzinfo=datetime.UTC)
|
|
|
|
| 10 |
import time
|
| 11 |
import urllib.parse
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
import filelock
|
| 14 |
import fsspec
|
| 15 |
import fsspec.generic
|
|
|
|
| 16 |
import tqdm_loggable.auto as tqdm
|
|
|
|
| 17 |
|
| 18 |
# Environment variable to control cache directory path, ~/.cache/openpi will be used by default.
|
| 19 |
_OPENPI_DATA_HOME = "OPENPI_DATA_HOME"
|
|
|
|
| 87 |
# Download the data to a local cache.
|
| 88 |
logger.info(f"Downloading {url} to {local_path}")
|
| 89 |
scratch_path = local_path.with_suffix(".partial")
|
| 90 |
+
_download_fsspec(url, scratch_path, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
shutil.move(scratch_path, local_path)
|
| 93 |
_ensure_permissions(local_path)
|
|
|
|
| 106 |
"""Download a file from a remote filesystem to the local cache, and return the local path."""
|
| 107 |
fs, _ = fsspec.core.url_to_fs(url, **kwargs)
|
| 108 |
info = fs.info(url)
|
| 109 |
+
# Folders are represented by 0-byte objects with a trailing forward slash.
|
| 110 |
+
if is_dir := (info["type"] == "directory" or (info["size"] == 0 and info["name"].endswith("/"))):
|
| 111 |
total_size = fs.du(url)
|
| 112 |
else:
|
| 113 |
total_size = info["size"]
|
|
|
|
| 121 |
pbar.update(total_size - pbar.n)
|
| 122 |
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
def _set_permission(path: pathlib.Path, target_permission: int):
|
| 125 |
"""chmod requires executable permission to be set, so we skip if the permission is already match with the target."""
|
| 126 |
if path.stat().st_mode & target_permission == target_permission:
|
|
|
|
| 168 |
_set_folder_permission(dir_path)
|
| 169 |
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
def _get_mtime(year: int, month: int, day: int) -> float:
|
| 172 |
"""Get the mtime of a given date at midnight UTC."""
|
| 173 |
date = datetime.datetime(year, month, day, tzinfo=datetime.UTC)
|
src/openpi/shared/download_test.py
CHANGED
|
@@ -24,8 +24,8 @@ def test_download_local(tmp_path: pathlib.Path):
|
|
| 24 |
download.maybe_download("bogus")
|
| 25 |
|
| 26 |
|
| 27 |
-
def
|
| 28 |
-
remote_path = "
|
| 29 |
|
| 30 |
local_path = download.maybe_download(remote_path)
|
| 31 |
assert local_path.exists()
|
|
@@ -34,8 +34,8 @@ def test_download_s3_dir():
|
|
| 34 |
assert new_local_path == local_path
|
| 35 |
|
| 36 |
|
| 37 |
-
def
|
| 38 |
-
remote_path = "
|
| 39 |
|
| 40 |
local_path = download.maybe_download(remote_path)
|
| 41 |
assert local_path.exists()
|
|
|
|
| 24 |
download.maybe_download("bogus")
|
| 25 |
|
| 26 |
|
| 27 |
+
def test_download_gs_dir():
|
| 28 |
+
remote_path = "gs://openpi-assets/testdata/random"
|
| 29 |
|
| 30 |
local_path = download.maybe_download(remote_path)
|
| 31 |
assert local_path.exists()
|
|
|
|
| 34 |
assert new_local_path == local_path
|
| 35 |
|
| 36 |
|
| 37 |
+
def test_download_gs():
|
| 38 |
+
remote_path = "gs://openpi-assets/testdata/random/random_512kb.bin"
|
| 39 |
|
| 40 |
local_path = download.maybe_download(remote_path)
|
| 41 |
assert local_path.exists()
|
src/openpi/training/config.py
CHANGED
|
@@ -44,7 +44,7 @@ class AssetsConfig:
|
|
| 44 |
|
| 45 |
```
|
| 46 |
AssetsConfig(
|
| 47 |
-
assets_dir="
|
| 48 |
asset_id="trossen",
|
| 49 |
)
|
| 50 |
```
|
|
@@ -556,7 +556,7 @@ _CONFIGS = [
|
|
| 556 |
),
|
| 557 |
# Here you define which pre-trained checkpoint you want to load to initialize the model.
|
| 558 |
# This should match the model config you chose above -- i.e. in this case we use the pi0 base model.
|
| 559 |
-
weight_loader=weight_loaders.CheckpointWeightLoader("
|
| 560 |
# Below you can define other hyperparameters like the learning rate, number of training steps, etc.
|
| 561 |
# Check the base TrainConfig class for a full list of available hyperparameters.
|
| 562 |
num_train_steps=30_000,
|
|
@@ -569,7 +569,7 @@ _CONFIGS = [
|
|
| 569 |
repo_id="physical-intelligence/libero",
|
| 570 |
base_config=DataConfig(prompt_from_task=True),
|
| 571 |
),
|
| 572 |
-
weight_loader=weight_loaders.CheckpointWeightLoader("
|
| 573 |
num_train_steps=30_000,
|
| 574 |
# The freeze filter defines which parameters should be frozen during training.
|
| 575 |
# We have a convenience function in the model config that returns the default freeze filter
|
|
@@ -599,7 +599,7 @@ _CONFIGS = [
|
|
| 599 |
base_config=DataConfig(prompt_from_task=True),
|
| 600 |
),
|
| 601 |
# Note that we load the pi0-FAST base model checkpoint here.
|
| 602 |
-
weight_loader=weight_loaders.CheckpointWeightLoader("
|
| 603 |
num_train_steps=30_000,
|
| 604 |
),
|
| 605 |
TrainConfig(
|
|
@@ -613,7 +613,7 @@ _CONFIGS = [
|
|
| 613 |
repo_id="physical-intelligence/libero",
|
| 614 |
base_config=DataConfig(prompt_from_task=True),
|
| 615 |
),
|
| 616 |
-
weight_loader=weight_loaders.CheckpointWeightLoader("
|
| 617 |
num_train_steps=30_000,
|
| 618 |
# Again, make sure to match the model config above when extracting the freeze filter
|
| 619 |
# that specifies which parameters should be frozen during LoRA finetuning.
|
|
@@ -634,7 +634,7 @@ _CONFIGS = [
|
|
| 634 |
data=LeRobotAlohaDataConfig(
|
| 635 |
repo_id="physical-intelligence/aloha_pen_uncap_diverse",
|
| 636 |
assets=AssetsConfig(
|
| 637 |
-
assets_dir="
|
| 638 |
asset_id="trossen",
|
| 639 |
),
|
| 640 |
default_prompt="uncap the pen",
|
|
@@ -654,7 +654,7 @@ _CONFIGS = [
|
|
| 654 |
]
|
| 655 |
),
|
| 656 |
),
|
| 657 |
-
weight_loader=weight_loaders.CheckpointWeightLoader("
|
| 658 |
num_train_steps=20_000,
|
| 659 |
),
|
| 660 |
#
|
|
@@ -698,7 +698,7 @@ _CONFIGS = [
|
|
| 698 |
default_prompt="Transfer cube",
|
| 699 |
use_delta_joint_actions=False,
|
| 700 |
),
|
| 701 |
-
weight_loader=weight_loaders.CheckpointWeightLoader("
|
| 702 |
num_train_steps=20_000,
|
| 703 |
),
|
| 704 |
#
|
|
|
|
| 44 |
|
| 45 |
```
|
| 46 |
AssetsConfig(
|
| 47 |
+
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
|
| 48 |
asset_id="trossen",
|
| 49 |
)
|
| 50 |
```
|
|
|
|
| 556 |
),
|
| 557 |
# Here you define which pre-trained checkpoint you want to load to initialize the model.
|
| 558 |
# This should match the model config you chose above -- i.e. in this case we use the pi0 base model.
|
| 559 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
|
| 560 |
# Below you can define other hyperparameters like the learning rate, number of training steps, etc.
|
| 561 |
# Check the base TrainConfig class for a full list of available hyperparameters.
|
| 562 |
num_train_steps=30_000,
|
|
|
|
| 569 |
repo_id="physical-intelligence/libero",
|
| 570 |
base_config=DataConfig(prompt_from_task=True),
|
| 571 |
),
|
| 572 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
|
| 573 |
num_train_steps=30_000,
|
| 574 |
# The freeze filter defines which parameters should be frozen during training.
|
| 575 |
# We have a convenience function in the model config that returns the default freeze filter
|
|
|
|
| 599 |
base_config=DataConfig(prompt_from_task=True),
|
| 600 |
),
|
| 601 |
# Note that we load the pi0-FAST base model checkpoint here.
|
| 602 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"),
|
| 603 |
num_train_steps=30_000,
|
| 604 |
),
|
| 605 |
TrainConfig(
|
|
|
|
| 613 |
repo_id="physical-intelligence/libero",
|
| 614 |
base_config=DataConfig(prompt_from_task=True),
|
| 615 |
),
|
| 616 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"),
|
| 617 |
num_train_steps=30_000,
|
| 618 |
# Again, make sure to match the model config above when extracting the freeze filter
|
| 619 |
# that specifies which parameters should be frozen during LoRA finetuning.
|
|
|
|
| 634 |
data=LeRobotAlohaDataConfig(
|
| 635 |
repo_id="physical-intelligence/aloha_pen_uncap_diverse",
|
| 636 |
assets=AssetsConfig(
|
| 637 |
+
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
|
| 638 |
asset_id="trossen",
|
| 639 |
),
|
| 640 |
default_prompt="uncap the pen",
|
|
|
|
| 654 |
]
|
| 655 |
),
|
| 656 |
),
|
| 657 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
|
| 658 |
num_train_steps=20_000,
|
| 659 |
),
|
| 660 |
#
|
|
|
|
| 698 |
default_prompt="Transfer cube",
|
| 699 |
use_delta_joint_actions=False,
|
| 700 |
),
|
| 701 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
|
| 702 |
num_train_steps=20_000,
|
| 703 |
),
|
| 704 |
#
|
src/openpi/training/weight_loaders.py
CHANGED
|
@@ -42,7 +42,7 @@ class CheckpointWeightLoader(WeightLoader):
|
|
| 42 |
trained checkpoints:
|
| 43 |
example: "./checkpoints/<config>/<exp>/<step>/params"
|
| 44 |
released checkpoints:
|
| 45 |
-
example: "
|
| 46 |
"""
|
| 47 |
|
| 48 |
params_path: str
|
|
|
|
| 42 |
trained checkpoints:
|
| 43 |
example: "./checkpoints/<config>/<exp>/<step>/params"
|
| 44 |
released checkpoints:
|
| 45 |
+
example: "gs://openpi-assets/checkpoints/<model>/params"
|
| 46 |
"""
|
| 47 |
|
| 48 |
params_path: str
|