File size: 2,632 Bytes
7ffbf30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
---
license: mit
tags:
- vision-language-action
- robotics
- trajectory-prediction
- diffusion-transformer
- qwen2-vl
datasets:
- TESS-Computer/quickdraw-circles-delta
---
# Qwen-DiT-Draw (Delta Version)
Vision-Language-Action model for continuous mouse trajectory prediction using **delta (relative) movements**.
## Model Description
This is the DiT action head trained on top of frozen Qwen2.5-VL-3B. It predicts mouse trajectories as sequences of (dx, dy, state) deltas.
- **Base VLM**: Qwen/Qwen2.5-VL-3B-Instruct (frozen)
- **Action Head**: 6-layer DiT with 512 hidden size (~37M trainable params)
- **Training**: Flow matching loss on 16-point chunks
- **Coordinate Mode**: Delta (relative movements)
## Training Details
| Setting | Value |
|---------|-------|
| Dataset | TESS-Computer/quickdraw-circles-delta |
| Samples | 21,207 chunks |
| Epochs | 3 |
| Final Loss | 0.111 |
| Batch Size | 4 |
| Learning Rate | 1e-4 |
## Key Learning: VLA Models Need More Epochs!
This model was trained for only 3 epochs as a proof-of-concept. Research from OpenVLA, GR00T, and pi0 shows that **VLA models need 20-30+ epochs** to learn good action patterns:
> *"Typical LLM or VLM training runs complete at most one or two epochs... In contrast, we found it important for VLA training to iterate through the training dataset significantly more times."* — OpenVLA Paper
The model produces partial arcs instead of complete circles because it hasn't seen enough training iterations.
## Usage
```python
from src.model import Qwen2_5_VL_Draw, TrajectoryConfig
import torch
# Load config
config = TrajectoryConfig(
chunk_size=16,
dit_hidden_size=512,
dit_num_layers=6,
)
# Create model and load weights
model = Qwen2_5_VL_Draw(
model_id="Qwen/Qwen2.5-VL-3B-Instruct",
config=config,
freeze_backbone=True,
)
model.trajectory_head.load_state_dict(
torch.load("trajectory_head.pt", map_location="cpu")
)
```
## Delta vs Absolute
This model uses **delta coordinates** (GR00T N1.6 style):
- Chunk 0: First point is absolute start, rest are deltas
- Chunk 1+: All points are deltas from previous chunk's last point
To reconstruct absolute positions:
```python
# Chunk 0
abs_positions = np.cumsum(deltas, axis=0)
# Chunk 1+
abs_positions = np.cumsum(deltas, axis=0) + prev_chunk_last_point
```
## Links
- [GitHub: qwen-dit-draw](https://github.com/TESS-Computer/qwen-dit-draw)
- [Dataset: quickdraw-circles-delta](https://huggingface.co/datasets/TESS-Computer/quickdraw-circles-delta)
- [Absolute coordinates version](https://huggingface.co/TESS-Computer/qwen-dit-draw)
## License
MIT
|