| | --- |
| | license: apache-2.0 |
| | tags: |
| | - robotics |
| | - rdt |
| | - libero |
| | - diffusion |
| | - transformers |
| | --- |
| | |
| | # RDT-1B LIBERO Checkpoint |
| |
|
| | RDT-1B full fine-tuned on LIBERO-90 dataset (checkpoint-65000). Base model for all LIBERO tasks. |
| |
|
| | ## Model Information |
| | - Base Model: RDT-1B (Residual Diffusion Transformer) |
| | - Training Framework: DeepSpeed ZeRO Stage 2 |
| | - Precision: BF16 |
| |
|
| | ## Checkpoint Contents |
| |
|
| | This checkpoint includes: |
| |
|
| | ### For Inference |
| | - `ema/model.safetensors` - EMA model weights (recommended for inference) |
| | - `config.json` - Model configuration |
| |
|
| | ### For Training |
| | - `pytorch_model/` - DeepSpeed distributed training checkpoint |
| | - `bf16_zero_pp_rank_*_optim_states.pt` - Optimizer states (ZeRO Stage 2) |
| | - `mp_rank_00_model_states.pt` - Model states |
| | - `scheduler.bin` - Learning rate scheduler state |
| | - `random_states_*.pkl` - Random number generator states |
| | - `zero_to_fp32.py` - Utility to convert DeepSpeed checkpoint to FP32 |
| |
|
| | ## Usage |
| |
|
| | ### For Inference |
| |
|
| | ```python |
| | from transformers import AutoModel |
| | import torch |
| | |
| | # Load the EMA model for inference |
| | model = AutoModel.from_pretrained( |
| | "TJ-chen/RDT-1B-LIBERO-Base", |
| | subfolder="ema", |
| | trust_remote_code=True |
| | ) |
| | model.eval() |
| | ``` |
| |
|
| | ### For Continued Training |
| |
|
| | Download the complete checkpoint and use DeepSpeed to resume training: |
| |
|
| | ```bash |
| | # The checkpoint can be loaded with DeepSpeed ZeRO Stage 2 |
| | # Make sure your training script is configured with the same DeepSpeed settings |
| | ``` |
| |
|
| | ## Citation |
| |
|
| | If you use this model, please cite: |
| |
|
| | ```bibtex |
| | @article{rdt2024, |
| | title={Residual Diffusion Transformer for Robotic Manipulation}, |
| | author={Your Name}, |
| | journal={arXiv preprint}, |
| | year={2024} |
| | } |
| | ``` |
| |
|
| | ## License |
| |
|
| | Apache 2.0 |
| |
|