|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- vision-language-action |
|
|
- robotics |
|
|
- trajectory-prediction |
|
|
- diffusion-transformer |
|
|
- qwen2-vl |
|
|
datasets: |
|
|
- TESS-Computer/quickdraw-circles-delta |
|
|
--- |
|
|
|
|
|
# Qwen-DiT-Draw (Delta Version) |
|
|
|
|
|
Vision-Language-Action model for continuous mouse trajectory prediction using **delta (relative) movements**. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This is the DiT action head trained on top of frozen Qwen2.5-VL-3B. It predicts mouse trajectories as sequences of (dx, dy, state) deltas. |
|
|
|
|
|
- **Base VLM**: Qwen/Qwen2.5-VL-3B-Instruct (frozen) |
|
|
- **Action Head**: 6-layer DiT with 512 hidden size (~37M trainable params) |
|
|
- **Training**: Flow matching loss on 16-point chunks |
|
|
- **Coordinate Mode**: Delta (relative movements) |
|
|
|
|
|
## Training Details |
|
|
|
|
|
| Setting | Value | |
|
|
|---------|-------| |
|
|
| Dataset | TESS-Computer/quickdraw-circles-delta | |
|
|
| Samples | 21,207 chunks | |
|
|
| Epochs | 3 | |
|
|
| Final Loss | 0.111 | |
|
|
| Batch Size | 4 | |
|
|
| Learning Rate | 1e-4 | |
|
|
|
|
|
## Key Learning: VLA Models Need More Epochs! |
|
|
|
|
|
This model was trained for only 3 epochs as a proof-of-concept. Research from OpenVLA, GR00T, and pi0 shows that **VLA models need 20-30+ epochs** to learn good action patterns: |
|
|
|
|
|
> *"Typical LLM or VLM training runs complete at most one or two epochs... In contrast, we found it important for VLA training to iterate through the training dataset significantly more times."* — OpenVLA Paper |
|
|
|
|
|
The model produces partial arcs instead of complete circles because it hasn't seen enough training iterations. |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
from src.model import Qwen2_5_VL_Draw, TrajectoryConfig |
|
|
import torch |
|
|
|
|
|
# Load config |
|
|
config = TrajectoryConfig( |
|
|
chunk_size=16, |
|
|
dit_hidden_size=512, |
|
|
dit_num_layers=6, |
|
|
) |
|
|
|
|
|
# Create model and load weights |
|
|
model = Qwen2_5_VL_Draw( |
|
|
model_id="Qwen/Qwen2.5-VL-3B-Instruct", |
|
|
config=config, |
|
|
freeze_backbone=True, |
|
|
) |
|
|
model.trajectory_head.load_state_dict( |
|
|
torch.load("trajectory_head.pt", map_location="cpu") |
|
|
) |
|
|
``` |
|
|
|
|
|
## Delta vs Absolute |
|
|
|
|
|
This model uses **delta coordinates** (GR00T N1.6 style): |
|
|
- Chunk 0: First point is absolute start, rest are deltas |
|
|
- Chunk 1+: All points are deltas from previous chunk's last point |
|
|
|
|
|
To reconstruct absolute positions: |
|
|
```python |
|
|
# Chunk 0 |
|
|
abs_positions = np.cumsum(deltas, axis=0) |
|
|
|
|
|
# Chunk 1+ |
|
|
abs_positions = np.cumsum(deltas, axis=0) + prev_chunk_last_point |
|
|
``` |
|
|
|
|
|
## Links |
|
|
|
|
|
- [GitHub: qwen-dit-draw](https://github.com/TESS-Computer/qwen-dit-draw) |
|
|
- [Dataset: quickdraw-circles-delta](https://huggingface.co/datasets/TESS-Computer/quickdraw-circles-delta) |
|
|
- [Absolute coordinates version](https://huggingface.co/TESS-Computer/qwen-dit-draw) |
|
|
|
|
|
## License |
|
|
|
|
|
MIT |
|
|
|