|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- robotics |
|
|
- rdt |
|
|
- libero |
|
|
- diffusion |
|
|
- transformers |
|
|
--- |
|
|
|
|
|
# RDT-1B LIBERO Checkpoint |
|
|
|
|
|
RDT-1B fine-tuned on LIBERO Spatial benchmark. Best performing checkpoint. |
|
|
|
|
|
## Model Information |
|
|
- Base Model: RDT-1B (Residual Diffusion Transformer) |
|
|
- Training Framework: DeepSpeed ZeRO Stage 2 |
|
|
- Precision: BF16 |
|
|
|
|
|
## Checkpoint Contents |
|
|
|
|
|
This checkpoint includes: |
|
|
|
|
|
### For Inference |
|
|
- `ema/model.safetensors` - EMA model weights (recommended for inference) |
|
|
- `config.json` - Model configuration |
|
|
|
|
|
### For Training |
|
|
- `pytorch_model/` - DeepSpeed distributed training checkpoint |
|
|
- `bf16_zero_pp_rank_*_optim_states.pt` - Optimizer states (ZeRO Stage 2) |
|
|
- `mp_rank_00_model_states.pt` - Model states |
|
|
- `scheduler.bin` - Learning rate scheduler state |
|
|
- `random_states_*.pkl` - Random number generator states |
|
|
- `zero_to_fp32.py` - Utility to convert DeepSpeed checkpoint to FP32 |
|
|
|
|
|
## Usage |
|
|
|
|
|
### For Inference |
|
|
|
|
|
```python |
|
|
from transformers import AutoModel |
|
|
import torch |
|
|
|
|
|
# Load the EMA model for inference |
|
|
model = AutoModel.from_pretrained( |
|
|
"TJ-chen/RDT-1B-LIBERO-Spatial", |
|
|
subfolder="ema", |
|
|
trust_remote_code=True |
|
|
) |
|
|
model.eval() |
|
|
``` |
|
|
|
|
|
### For Continued Training |
|
|
|
|
|
Download the complete checkpoint and use DeepSpeed to resume training: |
|
|
|
|
|
```bash |
|
|
# The checkpoint can be loaded with DeepSpeed ZeRO Stage 2 |
|
|
# Make sure your training script is configured with the same DeepSpeed settings |
|
|
``` |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
```bibtex |
|
|
@article{rdt2024, |
|
|
title={Residual Diffusion Transformer for Robotic Manipulation}, |
|
|
author={Your Name}, |
|
|
journal={arXiv preprint}, |
|
|
year={2024} |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
Apache 2.0 |
|
|
|