|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- robotics |
|
|
- act |
|
|
- vla |
|
|
- gearbox-assembly |
|
|
- imitation-learning |
|
|
metrics: |
|
|
- l1_loss |
|
|
- kl_divergence |
|
|
--- |
|
|
|
|
|
# ACT Policy for Gearbox Assembly (Filtered Demos) |
|
|
|
|
|
This model is an **Action Chunking with Transformers (ACT)** policy trained to perform gearbox assembly tasks. It was trained using behavior cloning on a combination of `rocochallenge2025` and additional collected datasets. |
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Model Type**: ACT (Action Chunking with Transformers) |
|
|
- **Policy Class**: ACT |
|
|
- **Backbone**: ResNet-18 |
|
|
- **Training Dataset**: Integrated dataset (`rocochallenge2025` + `temp_new_dataset`) containing **241 episodes**. |
|
|
- **Episode Length**: Fixed to **12,600 steps** (padded/truncated) to handle variable length recordings. |
|
|
|
|
|
## Training Configuration |
|
|
- **Task Name**: `sim_gearbox_assembly_demos_filtered` |
|
|
- **Batch Size**: 32 |
|
|
- **Chunk Size (Action Horizon)**: 100 |
|
|
- **KL Weight**: 10 |
|
|
- **Hidden Dimension**: 512 |
|
|
- **Feedforward Dimension**: 3200 |
|
|
- **Learning Rate**: 1e-5 |
|
|
- **Num Epochs**: ~9500 (Early stopped/Interrupted) |
|
|
- **Seed**: 0 |
|
|
|
|
|
## Inputs and Outputs |
|
|
- **Observations**: |
|
|
- `head_rgb` (240x320) |
|
|
- `left_hand_rgb` (240x320) |
|
|
- `right_hand_rgb` (240x320) |
|
|
- `qpos` (14-dim joint positions) |
|
|
- **Actions**: |
|
|
- 14-dim combined action vector (7-dim left arm + 7-dim right arm) |
|
|
|
|
|
## Usages |
|
|
|
|
|
This model can be loaded using the `ACTPolicy` class. Ensure `dataset_stats.pkl` is loaded to normalize/unnormalize observations and actions correctly. |
|
|
|
|
|
```python |
|
|
from policy import ACTPolicy |
|
|
import pickle |
|
|
|
|
|
# Load stats |
|
|
with open('dataset_stats.pkl', 'rb') as f: |
|
|
stats = pickle.load(f) |
|
|
|
|
|
# Load policy |
|
|
policy = ACTPolicy(config) |
|
|
policy.load_state_dict(torch.load('policy_best.ckpt')) |
|
|
policy.cuda() |
|
|
policy.eval() |
|
|
``` |
|
|
|