Update README.md
Browse files
README.md
CHANGED
|
@@ -19,9 +19,9 @@ The models released are the following:
|
|
| 19 |
|
| 20 |
| Benchmark | Embedding dimension | Diffusion timestep |
|
| 21 |
|------|------|------|
|
| 22 |
-
| [RLBench (PerAct)]() | 120 | 100 |
|
| 23 |
-
| [RLBench (GNFactor)]() | 120| 100 |
|
| 24 |
-
| [CALVIN]() | 192 | 25 |
|
| 25 |
|
| 26 |
### Model Description
|
| 27 |
|
|
@@ -46,13 +46,53 @@ The models released are the following:
|
|
| 46 |
|
| 47 |
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 48 |
|
| 49 |
-
|
|
|
|
| 50 |
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
|
|
|
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
## Evaluation
|
|
|
|
| 19 |
|
| 20 |
| Benchmark | Embedding dimension | Diffusion timestep |
|
| 21 |
|------|------|------|
|
| 22 |
+
| [RLBench (PerAct)](https://huggingface.co/katefgroup/3d_diffuser_actor/blob/main/diffuser_actor_peract.pth) | 120 | 100 |
|
| 23 |
+
| [RLBench (GNFactor)](https://huggingface.co/katefgroup/3d_diffuser_actor/blob/main/diffuser_actor_gnfactor.pth) | 120| 100 |
|
| 24 |
+
| [CALVIN](https://huggingface.co/katefgroup/3d_diffuser_actor/blob/main/diffuser_actor_calvin.pth) | 192 | 25 |
|
| 25 |
|
| 26 |
### Model Description
|
| 27 |
|
|
|
|
| 46 |
|
| 47 |
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 48 |
|
| 49 |
+
### Input format
|
| 50 |
+
3D Diffuser Actor takes the following inputs:
|
| 51 |
|
| 52 |
+
1. `RGB observations`: a tensor of shape (batch_size, num_cameras, 3, H, W). The pixel values are in the range of [0, 1]
|
| 53 |
+
2. `Point cloud observation`: a tensor of shape (batch_size, num_cameras, 3, H, W).
|
| 54 |
+
3. `Instruction encodings`: a tensor of shape (batch_size, max_instruction_length, C). In this code base, the embedding dimension `C` is set to 512.
|
| 55 |
+
4. `curr_gripper`: a tensor of shape (batch_size, history_length, 7), where the last channel denotes xyz-action (3D) and quarternion (4D).
|
| 56 |
+
5. `trajectory_mask`: a tensor of shape (batch_size, trajectory_length), which is only used to indicate the length of each trajectory. To predict keyposes, we just need to set its shape to (batch_size, 1).
|
| 57 |
+
6. `gt_trajectory`: a tensor of shape (batch_size, trajectory_length, 7), where the last channel denotes xyz-action (3D) and quarternion (4D). The input is only used during training.
|
| 58 |
|
| 59 |
+
### Output format
|
| 60 |
+
The model returns the diffusion loss, when `run_inference=False`, otherwise, it returns pose trajectory of shape (batch_size, trajectory_length, 8) when `run_inference=True`.
|
| 61 |
|
| 62 |
+
### Usage
|
| 63 |
+
For training, forward 3D Diffuser Actor with `run_inference=False`
|
| 64 |
+
```
|
| 65 |
+
> loss = model.forward(gt_trajectory,
|
| 66 |
+
trajectory_mask,
|
| 67 |
+
rgb_obs,
|
| 68 |
+
pcd_obs,
|
| 69 |
+
instruction,
|
| 70 |
+
curr_gripper,
|
| 71 |
+
run_inference=False)
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
For evaluation, forward 3D Diffuser Actor with `run_inference=True`
|
| 75 |
+
```
|
| 76 |
+
> fake_gt_trajectory = torch.full((1, trajectory_length, 7), 0).to(device)
|
| 77 |
+
> trajectory_mask = torch.full((1, trajectory_length), False).to(device)
|
| 78 |
+
> trajectory = model.forward(fake_gt_trajectory,
|
| 79 |
+
trajectory_mask,
|
| 80 |
+
rgb_obs,
|
| 81 |
+
pcd_obs,
|
| 82 |
+
instruction,
|
| 83 |
+
curr_gripper,
|
| 84 |
+
run_inference=True)
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
Or you can forward the model with `compute_trajectory` function
|
| 88 |
+
```
|
| 89 |
+
> trajectory_mask = torch.full((1, trajectory_length), False).to(device)
|
| 90 |
+
> trajectory = model.compute_trajectory(trajectory_mask,
|
| 91 |
+
rgb_obs,
|
| 92 |
+
pcd_obs,
|
| 93 |
+
instruction,
|
| 94 |
+
curr_gripper)
|
| 95 |
+
```
|
| 96 |
|
| 97 |
|
| 98 |
## Evaluation
|