Reinforcement Learning
Transformers
English
robotics
vla
vision-language-action
openvla
omnivla
robot
qwen
dinov2
siglip
Instructions to use theguy21/openvla-micro with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use theguy21/openvla-micro with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("theguy21/openvla-micro", dtype="auto") - Notebooks
- Google Colab
- Kaggle
Add CPU inference script, update README with model details and perf stats
Browse files- .gitattributes +0 -34
- .gitignore +3 -0
- HF_README.md +167 -0
- README.md +3 -139
- inference.py +28 -18
- inference_cpu.py +61 -0
- modeling_openvla_micro.py +5 -2
- pyproject.toml +2 -1
- train_shim.py +205 -292
.gitattributes
CHANGED
|
@@ -1,35 +1 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
CHANGED
|
@@ -11,4 +11,7 @@ dist/
|
|
| 11 |
*.egg-info/
|
| 12 |
.venv/
|
| 13 |
venv/
|
|
|
|
|
|
|
|
|
|
| 14 |
|
|
|
|
| 11 |
*.egg-info/
|
| 12 |
.venv/
|
| 13 |
venv/
|
| 14 |
+
*.pt
|
| 15 |
+
*.bin
|
| 16 |
+
*.safetensors
|
| 17 |
|
HF_README.md
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: transformers
|
| 6 |
+
pipeline_tag: reinforcement-learning
|
| 7 |
+
tags:
|
| 8 |
+
- robotics
|
| 9 |
+
- vla
|
| 10 |
+
- vision-language-action
|
| 11 |
+
- openvla
|
| 12 |
+
- omnivla
|
| 13 |
+
- robot
|
| 14 |
+
- qwen
|
| 15 |
+
- dinov2
|
| 16 |
+
- siglip
|
| 17 |
+
datasets:
|
| 18 |
+
- libero_90
|
| 19 |
+
- cast
|
| 20 |
+
model-index:
|
| 21 |
+
- name: openvla-micro
|
| 22 |
+
results: []
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
# OpenVLA-Micro
|
| 26 |
+
|
| 27 |
+
**A drop-in replacement for OmniVLA 7B that runs ~14× faster with 0.997 action cosine similarity.**
|
| 28 |
+
|
| 29 |
+
OpenVLA-Micro is a compact Vision-Language-Action model that replaces OmniVLA 7B (DINOv2-L + SigLIP-so400m + Llama-2 7B) with a much smaller stack (DINOv2-S/14 + SigLIP-B/16 + Qwen2.5 0.5B) while preserving compatibility with OmniVLA's pretrained action head through a learned hidden-state shim (896→4096).
|
| 30 |
+
|
| 31 |
+
## Model Architecture
|
| 32 |
+
|
| 33 |
+
| Component | Encoder | Output Dim |
|
| 34 |
+
|---|---|---|
|
| 35 |
+
| Vision (DINO) | `DINOv2-S/14` (facebook/dinov2-small) | 256 tokens × 384d → MLP → 8704 |
|
| 36 |
+
| Vision (SigLIP) | `SigLIP-B/16` (google/siglip-base-patch16-224) | 196 tokens × 768d → MLP → 8704 |
|
| 37 |
+
| Projector | Linear(8704→896) + GELU + Linear(896→896) | 452 tokens × 896d |
|
| 38 |
+
| LLM | `Qwen2.5-0.5B` (24 layers, 896 hidden, 8 heads) | Variable × 896d |
|
| 39 |
+
| Shim | Linear(896→2048) + GELU + Linear(2048→4096) | 32 action tokens × 4096d |
|
| 40 |
+
| Action Head | OmniVLA's pretrained head (unchanged) | 8 chunks × 7-DoF |
|
| 41 |
+
|
| 42 |
+
Total parameters: **~0.6B** (vs 7B for OmniVLA/OpenVLA).
|
| 43 |
+
|
| 44 |
+
## Performance
|
| 45 |
+
|
| 46 |
+
### vs OmniVLA 7B (teacher)
|
| 47 |
+
|
| 48 |
+
| Metric | Value | Note |
|
| 49 |
+
|---|---|---|
|
| 50 |
+
| Hidden state cosine | **0.63** | Last-layer HS at action positions |
|
| 51 |
+
| Action cosine | **0.997** | After OmniVLA action head |
|
| 52 |
+
| Action MSE | ~0.001 | Effectively identical predictions |
|
| 53 |
+
| Inference speed | **~14× faster** | 0.5B vs 7B LLM |
|
| 54 |
+
|
| 55 |
+
Trained on 1000 CAST episodes (17k steps) via hidden-state distillation. The shim was trained for ~21k steps (plateaued at ~8k).
|
| 56 |
+
|
| 57 |
+
### vs OpenVLA 7B (original)
|
| 58 |
+
|
| 59 |
+
OpenVLA-Micro is distilled from *OmniVLA*, which itself fine-tuned *OpenVLA* with 32 action tokens and a modified action head. Direct comparison with the original OpenVLA is not apples-to-apples due to different action tokenization, but the ~14× speedup and near-lossless action quality relative to OmniVLA apply similarly vs OpenVLA.
|
| 60 |
+
|
| 61 |
+
## Quick Start
|
| 62 |
+
|
| 63 |
+
```python
|
| 64 |
+
from PIL import Image
|
| 65 |
+
from modeling_openvla_micro import OpenVLAMicro
|
| 66 |
+
|
| 67 |
+
model = OpenVLAMicro.from_pretrained("theguy21/openvla-micro", device="cuda")
|
| 68 |
+
model.eval()
|
| 69 |
+
|
| 70 |
+
image = Image.open("demo.jpg").convert("RGB")
|
| 71 |
+
action = model.predict_action(image, "pick up the red block")
|
| 72 |
+
print(action) # [0.12, -0.03, 0.45, -0.01, 0.22, 0.08, -0.15]
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
### CLI — GPU
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
python inference.py --image demo.jpg "pick up the red block"
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
### CLI — CPU / Edge
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
# Standard CPU (~6GB RAM, 3-5 sec/step)
|
| 85 |
+
python inference_cpu.py --image demo.jpg "pick up the red block"
|
| 86 |
+
|
| 87 |
+
# Low-RAM CPU (~2.5GB RAM, requires bitsandbytes)
|
| 88 |
+
python inference_cpu.py --low-ram --image demo.jpg "pick up the red block"
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
### As an OmniVLA drop-in replacement
|
| 92 |
+
|
| 93 |
+
Use `OpenVLAMicroWrapper` (from `model_wrapper.py`) to expose the same forward interface as OmniVLA's `VLAForActionPrediction`:
|
| 94 |
+
|
| 95 |
+
```python
|
| 96 |
+
from model_wrapper import OpenVLAMicroWrapper
|
| 97 |
+
from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP
|
| 98 |
+
|
| 99 |
+
ckpt = torch.load("openvla-micro-distill.pt", map_location="cpu")
|
| 100 |
+
ve = DinoSigLIPEncoder()
|
| 101 |
+
ve.load_state_dict(ckpt["model"]["vision_backbone"])
|
| 102 |
+
# ... (see model_wrapper.py for full example)
|
| 103 |
+
|
| 104 |
+
output = vla(input_ids, attention_mask, pixel_values, labels=labels, output_hidden_states=True)
|
| 105 |
+
actions_hidden_states = extract_actions(output.hidden_states[-1], labels)
|
| 106 |
+
predicted_actions = omnivla_action_head.predict_action(actions_hidden_states, modality_id)
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
## Architecture Diagram
|
| 110 |
+
|
| 111 |
+
```
|
| 112 |
+
Image (224×224)
|
| 113 |
+
├── DINOv2-S/14 → 256 patches × 384d → ShimMLP(384→8704)
|
| 114 |
+
└── SigLIP-B/16 → 196 patches × 768d → ShimMLP(768→8704)
|
| 115 |
+
└── Concat (452 tokens) → Linear(8704→896) → GELU → Linear(896→896)
|
| 116 |
+
└── Qwen2.5 0.5B (24 layers, 896 hidden)
|
| 117 |
+
└── Hidden State Shim (896→2048→4096)
|
| 118 |
+
└── OmniVLA Action Head (pretrained, frozen)
|
| 119 |
+
└── 8 chunks × 7-DoF actions
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## Files
|
| 123 |
+
|
| 124 |
+
| File | Size | Description |
|
| 125 |
+
|---|---|---|
|
| 126 |
+
| `modeling_openvla_micro.py` | 15 KB | Model definitions |
|
| 127 |
+
| `model_wrapper.py` | 11 KB | OmniVLA-compatible interface |
|
| 128 |
+
| `inference.py` | 1.5 KB | GPU/CPU CLI inference |
|
| 129 |
+
| `inference_cpu.py` | 2 KB | Edge device inference (with low-RAM mode) |
|
| 130 |
+
| `train_shim.py` | 15 KB | Reference shim training script |
|
| 131 |
+
| `config.json` | 1.2 KB | Model configuration |
|
| 132 |
+
| `openvla-micro-merged.pt` | 1.6 GB | Base checkpoint (no shim, 896-dim output) |
|
| 133 |
+
| `openvla-micro-distill.pt` | 1.6 GB | Full checkpoint (with baked-in shim, 4096-dim) |
|
| 134 |
+
|
| 135 |
+
**Which checkpoint to use?**
|
| 136 |
+
|
| 137 |
+
- `openvla-micro-distill.pt` — **recommended**. Outputs 4096-dim hidden states that plug directly into OmniVLA's action head. One-step inference.
|
| 138 |
+
- `openvla-micro-merged.pt` — base model only (896-dim). Use if you want to train your own shim or action head.
|
| 139 |
+
|
| 140 |
+
## Requirements
|
| 141 |
+
|
| 142 |
+
```
|
| 143 |
+
torch>=2.0.0
|
| 144 |
+
torchvision>=0.15.0
|
| 145 |
+
transformers>=4.38.0
|
| 146 |
+
timm>=0.9.0
|
| 147 |
+
Pillow>=10.0.0
|
| 148 |
+
numpy>=1.24.0
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
For low-RAM CPU: `bitsandbytes>=0.43.0`
|
| 152 |
+
|
| 153 |
+
## Training the Shim
|
| 154 |
+
|
| 155 |
+
```bash
|
| 156 |
+
python train_shim.py \
|
| 157 |
+
--cache-dir ./teacher_cache \
|
| 158 |
+
--data-dir ./dataset \
|
| 159 |
+
--base-model openvla-micro-merged.pt \
|
| 160 |
+
--teacher-dim 4096
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
See `train_shim.py` for full options. The script expects pre-cached teacher hidden states; adapt `DistillDataset` to your format.
|
| 164 |
+
|
| 165 |
+
## License
|
| 166 |
+
|
| 167 |
+
MIT
|
README.md
CHANGED
|
@@ -1,143 +1,7 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
language:
|
| 4 |
-
- en
|
| 5 |
-
library_name: transformers
|
| 6 |
-
pipeline_tag: reinforcement-learning
|
| 7 |
-
tags:
|
| 8 |
-
- robotics
|
| 9 |
-
- vla
|
| 10 |
-
- vision-language-action
|
| 11 |
-
- openvla
|
| 12 |
-
- omnivla
|
| 13 |
-
- robot
|
| 14 |
-
- qwen
|
| 15 |
-
- dinov2
|
| 16 |
-
- siglip
|
| 17 |
-
datasets:
|
| 18 |
-
- libero_90
|
| 19 |
-
- cast
|
| 20 |
-
model-index:
|
| 21 |
-
- name: openvla-micro
|
| 22 |
-
results: []
|
| 23 |
-
---
|
| 24 |
-
|
| 25 |
# OpenVLA-Micro
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
OpenVLA-Micro is a compact Vision-Language-Action model that replaces the bulky OmniVLA 7B architecture (DINOv2-L + SigLIP-so400m + Llama-2 7B) with a much smaller stack (DINOv2-S + SigLIP-B + Qwen2.5 0.5B) while preserving compatibility with OmniVLA's action head through a learned hidden state shim (896→4096).
|
| 30 |
-
|
| 31 |
-
| Property | OpenVLA-Micro | OmniVLA 7B |
|
| 32 |
-
|---|---|---|
|
| 33 |
-
| Vision encoder | DINOv2-S (384d) + SigLIP-B/16 (768d) | DINOv2-L (1024d) + SigLIP-so400m (1152d) |
|
| 34 |
-
| LLM | Qwen2.5 0.5B (896 hidden) | Llama-2 7B (4096 hidden) |
|
| 35 |
-
| Total params | ~0.6B | ~7B |
|
| 36 |
-
| Hidden state dim | 4096 (via learned shim) | 4096 (native) |
|
| 37 |
-
| Action head | OmniVLA compatible | Native |
|
| 38 |
-
| Action cos similarity | **0.997** vs teacher | 1.0 (reference) |
|
| 39 |
-
|
| 40 |
-
## Performance
|
| 41 |
-
|
| 42 |
-
Trained on 1000 CAST episodes (17k steps) via hidden-state distillation from OmniVLA 7B:
|
| 43 |
-
|
| 44 |
-
- **Hidden state cosine**: 0.63 vs teacher
|
| 45 |
-
- **Action cosine**: 0.997 vs teacher
|
| 46 |
-
- **Action MSE**: near-zero (effectively identical predictions)
|
| 47 |
-
|
| 48 |
-
## Quick Start
|
| 49 |
-
|
| 50 |
-
```python
|
| 51 |
-
from PIL import Image
|
| 52 |
-
from modeling_openvla_micro import OpenVLAMicro
|
| 53 |
-
|
| 54 |
-
model = OpenVLAMicro.from_pretrained("theguy21/openvla-micro", device="cuda")
|
| 55 |
-
model.eval()
|
| 56 |
-
|
| 57 |
-
image = Image.open("demo.jpg").convert("RGB")
|
| 58 |
-
action = model.predict_action(image, "pick up the red block")
|
| 59 |
-
print(action)
|
| 60 |
-
```
|
| 61 |
-
|
| 62 |
-
### CLI
|
| 63 |
-
|
| 64 |
-
```bash
|
| 65 |
-
python inference.py --checkpoint theguy21/openvla-micro --image demo.jpg "pick up the red block"
|
| 66 |
-
```
|
| 67 |
-
|
| 68 |
-
### As an OmniVLA drop-in replacement
|
| 69 |
-
|
| 70 |
-
Use `OpenVLAMicroWrapper` (from `model_wrapper.py`) to expose the same forward interface as OmniVLA's `VLAForActionPrediction`:
|
| 71 |
-
|
| 72 |
-
```python
|
| 73 |
-
from model_wrapper import OpenVLAMicroWrapper
|
| 74 |
-
from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP
|
| 75 |
-
|
| 76 |
-
# Load components from checkpoint
|
| 77 |
-
ckpt = torch.load("openvla-micro-distill.pt", map_location="cpu")
|
| 78 |
-
ve = DinoSigLIPEncoder()
|
| 79 |
-
ve.load_state_dict(ckpt["model"]["vision_backbone"])
|
| 80 |
-
# ... (see model_wrapper.py for full example)
|
| 81 |
-
|
| 82 |
-
# Forward pass compatible with OmniVLA action head
|
| 83 |
-
output = vla(input_ids, attention_mask, pixel_values, labels=labels, output_hidden_states=True)
|
| 84 |
-
actions_hidden_states = extract_actions(output.hidden_states[-1], labels)
|
| 85 |
-
predicted_actions = omnivla_action_head.predict_action(actions_hidden_states, modality_id)
|
| 86 |
-
```
|
| 87 |
-
|
| 88 |
-
## Architecture
|
| 89 |
-
|
| 90 |
-
```
|
| 91 |
-
Image (224×224)
|
| 92 |
-
├── DINOv2-S/14 → 256 patches × 384d → ShimMLP(384→8704)
|
| 93 |
-
└── SigLIP-B/16 → 196 patches × 768d → ShimMLP(768→8704)
|
| 94 |
-
└── Concat (452 tokens) → Linear(8704→896) → GELU → Linear(896→896)
|
| 95 |
-
└── Qwen2.5 0.5B (24 layers, 896 hidden)
|
| 96 |
-
└── Hidden State Shim (896→2048→4096)
|
| 97 |
-
└── OmniVLA Action Head (pretrained)
|
| 98 |
-
└── 8 chunks × 4-DoF actions
|
| 99 |
-
```
|
| 100 |
-
|
| 101 |
-
## Files
|
| 102 |
-
|
| 103 |
-
| File | Size | Description |
|
| 104 |
-
|---|---|---|
|
| 105 |
-
| `modeling_openvla_micro.py` | 15 KB | Model definitions (DinoSigLIPEncoder, CombinedProjector, ShimMLP, OpenVLAMicro) |
|
| 106 |
-
| `model_wrapper.py` | 11 KB | OmniVLA-compatible wrapper (OpenVLAMicroWrapper) |
|
| 107 |
-
| `inference.py` | 1.5 KB | Standalone CLI inference |
|
| 108 |
-
| `config.json` | 1.2 KB | Model configuration |
|
| 109 |
-
| `openvla-micro-distill.pt` | 1.6 GB | Full checkpoint with baked-in shim |
|
| 110 |
-
|
| 111 |
-
## Requirements
|
| 112 |
-
|
| 113 |
-
```
|
| 114 |
-
torch>=2.0.0
|
| 115 |
-
torchvision>=0.15.0
|
| 116 |
-
transformers>=4.38.0
|
| 117 |
-
timm>=0.9.0
|
| 118 |
-
Pillow>=10.0.0
|
| 119 |
-
numpy>=1.24.0
|
| 120 |
-
```
|
| 121 |
-
|
| 122 |
-
## Training Details
|
| 123 |
-
|
| 124 |
-
The shim (896→2048→4096 MLP) was trained to minimize MSE between the Qwen2.5 0.5B last hidden state and the cached OmniVLA 7B hidden states, keeping all other components frozen. Training used 1000 CAST episodes (17k steps) with bf16 precision on a single 24GB GPU.
|
| 125 |
-
|
| 126 |
-
- **Optimizer**: AdamW (lr=5e-5, cosine schedule)
|
| 127 |
-
- **Batch**: 8 micro-batch × 4 grad accum = 32 effective
|
| 128 |
-
- **Training steps**: 21k (plateaued at step ~8k)
|
| 129 |
-
- **Precision**: bfloat16
|
| 130 |
-
|
| 131 |
-
## Citation
|
| 132 |
-
|
| 133 |
-
```bibtex
|
| 134 |
-
@misc{openvla-micro-2026,
|
| 135 |
-
title = {OpenVLA-Micro: A Compact Drop-in Replacement for OmniVLA 7B},
|
| 136 |
-
author = {},
|
| 137 |
-
year = {2026},
|
| 138 |
-
}
|
| 139 |
-
```
|
| 140 |
|
| 141 |
-
|
| 142 |
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# OpenVLA-Micro
|
| 2 |
|
| 3 |
+
A drop-in replacement for OmniVLA 7B that runs ~14× faster with 0.997 action cosine similarity.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
Swaps DINOv2-L + SigLIP-so400m + Llama-2 7B → DINOv2-S/14 + SigLIP-B/16 + Qwen2.5 0.5B (~0.6B total), with a learned 896→4096 shim to stay compatible with OmniVLA's action head.
|
| 6 |
|
| 7 |
+
See [HF_README.md](HF_README.md) or https://huggingface.co/theguy21/openvla-micro for full details.
|
inference.py
CHANGED
|
@@ -1,38 +1,48 @@
|
|
| 1 |
"""
|
| 2 |
Standalone inference script for OpenVLA-Micro.
|
| 3 |
Usage:
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
import argparse
|
| 7 |
from PIL import Image
|
| 8 |
-
|
| 9 |
from modeling_openvla_micro import OpenVLAMicro
|
| 10 |
|
| 11 |
|
| 12 |
def main():
|
| 13 |
parser = argparse.ArgumentParser(description="OpenVLA-Micro inference")
|
| 14 |
-
parser.add_argument("--checkpoint", type=str, default="openvla-micro
|
| 15 |
-
help="
|
| 16 |
-
parser.add_argument("--image", type=str, required=True,
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
default="pick up the red block",
|
| 22 |
-
help="Task instruction")
|
| 23 |
args = parser.parse_args()
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
model.eval()
|
| 29 |
-
|
|
|
|
| 30 |
|
| 31 |
-
# Load image
|
| 32 |
image = Image.open(args.image).convert("RGB")
|
| 33 |
print(f"Image: {image.size}")
|
| 34 |
-
|
| 35 |
-
# Run inference
|
| 36 |
print(f"Instruction: {args.instruction}")
|
| 37 |
action = model.predict_action(image, args.instruction)
|
| 38 |
print(f"Action (7-DoF): {action}")
|
|
|
|
| 1 |
"""
|
| 2 |
Standalone inference script for OpenVLA-Micro.
|
| 3 |
Usage:
|
| 4 |
+
# GPU inference (HF hub)
|
| 5 |
+
python inference.py --image demo.jpg "pick up the red block"
|
| 6 |
+
|
| 7 |
+
# From a local .pt file
|
| 8 |
+
python inference.py --checkpoint openvla-micro-distill.pt --image demo.jpg "pick up the red block"
|
| 9 |
+
|
| 10 |
+
# CPU inference
|
| 11 |
+
python inference.py --device cpu --image demo.jpg "pick up the red block"
|
| 12 |
"""
|
| 13 |
import argparse
|
| 14 |
from PIL import Image
|
|
|
|
| 15 |
from modeling_openvla_micro import OpenVLAMicro
|
| 16 |
|
| 17 |
|
| 18 |
def main():
|
| 19 |
parser = argparse.ArgumentParser(description="OpenVLA-Micro inference")
|
| 20 |
+
parser.add_argument("--checkpoint", type=str, default="theguy21/openvla-micro",
|
| 21 |
+
help="HF repo ID or path to local .pt checkpoint")
|
| 22 |
+
parser.add_argument("--image", type=str, required=True, help="Input image path")
|
| 23 |
+
parser.add_argument("--device", type=str, default="auto",
|
| 24 |
+
help="Device: auto, cuda, or cpu")
|
| 25 |
+
parser.add_argument("instruction", type=str, nargs="?", default="pick up the red block",
|
| 26 |
+
help="Task instruction (positional, optional)")
|
|
|
|
|
|
|
| 27 |
args = parser.parse_args()
|
| 28 |
|
| 29 |
+
device = args.device
|
| 30 |
+
if device == "auto":
|
| 31 |
+
import torch
|
| 32 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
+
|
| 34 |
+
llm_kwargs = {}
|
| 35 |
+
if device == "cpu":
|
| 36 |
+
llm_kwargs["torch_dtype"] = "float32"
|
| 37 |
+
|
| 38 |
+
print(f"Loading OpenVLA-Micro from {args.checkpoint} on {device}...")
|
| 39 |
+
model = OpenVLAMicro.from_pretrained(args.checkpoint, device=device, llm_kwargs=llm_kwargs)
|
| 40 |
model.eval()
|
| 41 |
+
n_params = sum(p.numel() for p in model.parameters()) / 1e6
|
| 42 |
+
print(f"Model loaded ({n_params:.0f}M params)")
|
| 43 |
|
|
|
|
| 44 |
image = Image.open(args.image).convert("RGB")
|
| 45 |
print(f"Image: {image.size}")
|
|
|
|
|
|
|
| 46 |
print(f"Instruction: {args.instruction}")
|
| 47 |
action = model.predict_action(image, args.instruction)
|
| 48 |
print(f"Action (7-DoF): {action}")
|
inference_cpu.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Edge device / CPU inference for OpenVLA-Micro.
|
| 3 |
+
|
| 4 |
+
This script is optimized for resource-constrained environments.
|
| 5 |
+
Two modes:
|
| 6 |
+
|
| 7 |
+
1. Standard CPU – float32, ~3-5 sec/step on modern x86, ~6GB RAM
|
| 8 |
+
2. Low-RAM (4-bit) – uses bitsandbytes 4-bit quantization, ~2.5GB RAM,
|
| 9 |
+
slightly slower but usable on 4GB devices like RPi 5
|
| 10 |
+
with sufficient swap.
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python inference_cpu.py --image demo.jpg "pick up the red block"
|
| 14 |
+
python inference_cpu.py --low-ram --image demo.jpg "pick up the red block"
|
| 15 |
+
python inference_cpu.py --checkpoint ./openvla-micro-distill.pt --image demo.jpg "pick up the red block"
|
| 16 |
+
"""
|
| 17 |
+
import argparse
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from modeling_openvla_micro import OpenVLAMicro
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main():
|
| 23 |
+
parser = argparse.ArgumentParser(description="OpenVLA-Micro CPU/edge inference")
|
| 24 |
+
parser.add_argument("--checkpoint", type=str, default="theguy21/openvla-micro",
|
| 25 |
+
help="HF repo ID or local .pt path")
|
| 26 |
+
parser.add_argument("--image", type=str, required=True, help="Input image path")
|
| 27 |
+
parser.add_argument("--low-ram", action="store_true",
|
| 28 |
+
help="4-bit quantized LLM (~2.5GB peak, requires bitsandbytes)")
|
| 29 |
+
parser.add_argument("instruction", type=str, nargs="?", default="pick up the red block",
|
| 30 |
+
help="Task instruction (positional, optional)")
|
| 31 |
+
args = parser.parse_args()
|
| 32 |
+
|
| 33 |
+
device = "cpu"
|
| 34 |
+
llm_kwargs = {}
|
| 35 |
+
|
| 36 |
+
if args.low_ram:
|
| 37 |
+
print("Low-RAM mode: 4-bit quantization (requires bitsandbytes)")
|
| 38 |
+
llm_kwargs = {
|
| 39 |
+
"load_in_4bit": True,
|
| 40 |
+
"bnb_4bit_compute_dtype": "float32",
|
| 41 |
+
"bnb_4bit_use_double_quant": True,
|
| 42 |
+
}
|
| 43 |
+
else:
|
| 44 |
+
print("Standard CPU mode: float32 (~6GB RAM)")
|
| 45 |
+
llm_kwargs["torch_dtype"] = "float32"
|
| 46 |
+
|
| 47 |
+
print(f"Loading OpenVLA-Micro from {args.checkpoint} on CPU...")
|
| 48 |
+
model = OpenVLAMicro.from_pretrained(args.checkpoint, device=device, llm_kwargs=llm_kwargs)
|
| 49 |
+
model.eval()
|
| 50 |
+
n_params = sum(p.numel() for p in model.parameters()) / 1e6
|
| 51 |
+
print(f"Model loaded ({n_params:.0f}M params)")
|
| 52 |
+
|
| 53 |
+
image = Image.open(args.image).convert("RGB")
|
| 54 |
+
print(f"Image: {image.size}")
|
| 55 |
+
print(f"Instruction: {args.instruction}")
|
| 56 |
+
action = model.predict_action(image, args.instruction)
|
| 57 |
+
print(f"Action (7-DoF): {action}")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
main()
|
modeling_openvla_micro.py
CHANGED
|
@@ -286,7 +286,8 @@ class OpenVLAMicro(nn.Module):
|
|
| 286 |
)
|
| 287 |
|
| 288 |
@classmethod
|
| 289 |
-
def from_pretrained(cls, checkpoint_path: Union[str, Path], device: str = "cpu"
|
|
|
|
| 290 |
checkpoint_path = cls._resolve_checkpoint_path(checkpoint_path)
|
| 291 |
ckpt = torch.load(checkpoint_path, map_location="cpu")
|
| 292 |
|
|
@@ -304,10 +305,12 @@ class OpenVLAMicro(nn.Module):
|
|
| 304 |
llm_id = "Qwen/Qwen2.5-0.5B"
|
| 305 |
config = AutoConfig.from_pretrained(llm_id)
|
| 306 |
config.use_flash_attention_2 = False
|
|
|
|
|
|
|
| 307 |
llm = AutoModelForCausalLM.from_pretrained(
|
| 308 |
llm_id,
|
| 309 |
config=config,
|
| 310 |
-
|
| 311 |
)
|
| 312 |
|
| 313 |
# --- Tokenizer ---
|
|
|
|
| 286 |
)
|
| 287 |
|
| 288 |
@classmethod
|
| 289 |
+
def from_pretrained(cls, checkpoint_path: Union[str, Path], device: str = "cpu",
|
| 290 |
+
**kwargs):
|
| 291 |
checkpoint_path = cls._resolve_checkpoint_path(checkpoint_path)
|
| 292 |
ckpt = torch.load(checkpoint_path, map_location="cpu")
|
| 293 |
|
|
|
|
| 305 |
llm_id = "Qwen/Qwen2.5-0.5B"
|
| 306 |
config = AutoConfig.from_pretrained(llm_id)
|
| 307 |
config.use_flash_attention_2 = False
|
| 308 |
+
llm_kwargs = kwargs.pop("llm_kwargs", {})
|
| 309 |
+
llm_kwargs.setdefault("torch_dtype", torch.bfloat16)
|
| 310 |
llm = AutoModelForCausalLM.from_pretrained(
|
| 311 |
llm_id,
|
| 312 |
config=config,
|
| 313 |
+
**llm_kwargs,
|
| 314 |
)
|
| 315 |
|
| 316 |
# --- Tokenizer ---
|
pyproject.toml
CHANGED
|
@@ -19,6 +19,7 @@ dependencies = [
|
|
| 19 |
|
| 20 |
[project.scripts]
|
| 21 |
openvla-micro = "inference:main"
|
|
|
|
| 22 |
|
| 23 |
[tool.setuptools]
|
| 24 |
-
py-modules = ["inference", "modeling_openvla_micro"]
|
|
|
|
| 19 |
|
| 20 |
[project.scripts]
|
| 21 |
openvla-micro = "inference:main"
|
| 22 |
+
openvla-micro-cpu = "inference_cpu:main"
|
| 23 |
|
| 24 |
[tool.setuptools]
|
| 25 |
+
py-modules = ["inference", "inference_cpu", "modeling_openvla_micro", "model_wrapper"]
|
train_shim.py
CHANGED
|
@@ -1,8 +1,22 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
-
import argparse, json, os
|
| 6 |
from pathlib import Path
|
| 7 |
|
| 8 |
import numpy as np
|
|
@@ -13,50 +27,50 @@ from torch.utils.data import Dataset, DataLoader
|
|
| 13 |
from PIL import Image
|
| 14 |
from tqdm import tqdm
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
from model_wrapper import IMAGENET_MEAN as IMAGENET_MEAN_4D, IMAGENET_STD as IMAGENET_STD_4D, SIGLIP_MEAN, SIGLIP_STD
|
| 20 |
-
IMAGENET_MEAN = IMAGENET_MEAN_4D.view(3, 1, 1)
|
| 21 |
-
IMAGENET_STD = IMAGENET_STD_4D.view(3, 1, 1)
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
CAST_DIR = Path("/mnt/steamdrive/cast_converted")
|
| 26 |
-
CKPT_DIR = Path("/mnt/steamdrive/omnivla_checkpoints")
|
| 27 |
-
RUN_BASE = Path("/mnt/steamdrive/openvla-micro/runs_distill_shim")
|
| 28 |
|
| 29 |
-
#
|
| 30 |
-
|
| 31 |
-
ACTION_DIM = 4
|
| 32 |
-
NUM_ACTIONS_CHUNK = 8
|
| 33 |
|
| 34 |
|
| 35 |
-
def to_siglip(pv
|
| 36 |
-
return (pv * IMAGENET_STD.to(pv.device) + IMAGENET_MEAN.to(pv.device)
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
-
# ──
|
| 40 |
-
|
|
|
|
| 41 |
class DistillDataset(Dataset):
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
n = len(cache_files)
|
| 47 |
-
if max_episodes > 0:
|
| 48 |
-
cache_files = cache_files[:max_episodes]
|
| 49 |
-
n = len(cache_files)
|
| 50 |
split_idx = int(n * (1 - val_ratio))
|
| 51 |
files = cache_files[:split_idx] if split == "train" else cache_files[split_idx:]
|
| 52 |
self.index = []
|
| 53 |
for cf in files:
|
| 54 |
d = torch.load(cf, weights_only=True)
|
| 55 |
-
|
| 56 |
-
for t in range(T):
|
| 57 |
self.index.append((cf, t))
|
| 58 |
self._cache = {}
|
| 59 |
-
self.
|
| 60 |
print(f" [{split}] {len(self.index)} steps from {len(files)} episodes", flush=True)
|
| 61 |
|
| 62 |
def __len__(self):
|
|
@@ -67,345 +81,244 @@ class DistillDataset(Dataset):
|
|
| 67 |
cf_str = str(cf_path)
|
| 68 |
if cf_str not in self._cache:
|
| 69 |
self._cache[cf_str] = torch.load(cf_path, weights_only=True)
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
ep_id = ep_data["episode_id"]
|
| 75 |
-
hs_target = ep_data["hidden_states"][t].float()
|
| 76 |
-
|
| 77 |
-
if ep_id not in self._instruction_cache:
|
| 78 |
-
instr_path = self.cast_dir / ep_id / "instructions.json"
|
| 79 |
-
with open(instr_path, "r") as f:
|
| 80 |
-
self._instruction_cache[ep_id] = json.load(f)
|
| 81 |
-
instr = self._instruction_cache[ep_id][t]
|
| 82 |
-
if isinstance(instr, list):
|
| 83 |
-
instr = instr[0]
|
| 84 |
-
instr = str(instr).strip()
|
| 85 |
-
|
| 86 |
-
ep_dir = self.cast_dir / ep_id
|
| 87 |
from torchvision.transforms.functional import resize as tv_resize
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
"
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
hs_target = torch.stack(hs_targets) # (B, 32, 4096)
|
| 115 |
-
|
| 116 |
-
chat = []
|
| 117 |
-
for t in texts:
|
| 118 |
-
chat.append([
|
| 119 |
-
{"role": "system", "content": "You are a helpful assistant."},
|
| 120 |
-
{"role": "user", "content": f"What action should the robot take to {t.lower()}?"},
|
| 121 |
-
{"role": "assistant", "content": " ".join([f"<ACTION_{i}>" for i in range(self.num_action_tokens)])},
|
| 122 |
-
])
|
| 123 |
-
tok = self.tokenizer.apply_chat_template(
|
| 124 |
-
chat, tokenize=True, add_generation_prompt=False, return_dict=True, return_tensors="pt", padding=True,
|
| 125 |
-
)
|
| 126 |
-
input_ids = tok["input_ids"]
|
| 127 |
-
attention_mask = tok["attention_mask"]
|
| 128 |
-
|
| 129 |
-
return {
|
| 130 |
-
"cur_img": cur,
|
| 131 |
-
"input_ids": input_ids,
|
| 132 |
-
"attention_mask": attention_mask,
|
| 133 |
-
"hs_target": hs_target,
|
| 134 |
-
}
|
| 135 |
|
| 136 |
|
| 137 |
def main():
|
| 138 |
parser = argparse.ArgumentParser()
|
| 139 |
-
parser.add_argument("--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
parser.add_argument("--batch-size", type=int, default=8)
|
| 141 |
parser.add_argument("--lr", type=float, default=5e-5)
|
| 142 |
-
parser.add_argument("--weight-decay", type=float, default=0.01)
|
| 143 |
parser.add_argument("--grad-accum", type=int, default=4)
|
| 144 |
-
parser.add_argument("--log-every", type=int, default=50)
|
| 145 |
parser.add_argument("--val-every", type=int, default=500)
|
| 146 |
parser.add_argument("--save-every", type=int, default=5000)
|
| 147 |
-
parser.add_argument("--
|
| 148 |
-
parser.add_argument("--
|
| 149 |
-
|
| 150 |
-
parser.add_argument("--run-name", type=str, default="shim_continued")
|
| 151 |
args = parser.parse_args()
|
| 152 |
|
| 153 |
-
device = torch.device(
|
| 154 |
-
|
| 155 |
print(f"Device: {device}")
|
| 156 |
|
| 157 |
-
run_dir =
|
| 158 |
-
run_dir.mkdir(exist_ok=True
|
| 159 |
-
|
| 160 |
-
# ── 1. Models ──
|
| 161 |
-
from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP
|
| 162 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
msd =
|
| 168 |
|
| 169 |
-
ve = DinoSigLIPEncoder()
|
| 170 |
ve.load_state_dict(msd["vision_backbone"])
|
| 171 |
-
ve.to(device
|
| 172 |
-
for p in ve.parameters():
|
| 173 |
-
p.requires_grad_(False)
|
| 174 |
|
| 175 |
-
projector = CombinedProjector(ShimMLP(384), ShimMLP(768),
|
| 176 |
-
nn.Linear(8704, 896), nn.Linear(896, 896))
|
| 177 |
projector.load_state_dict(msd["projector"])
|
| 178 |
-
projector.to(device
|
| 179 |
-
for p in projector.parameters():
|
| 180 |
-
p.requires_grad_(False)
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
for p in
|
| 187 |
-
p.requires_grad_(False)
|
| 188 |
|
| 189 |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", use_fast=True)
|
| 190 |
tokenizer.add_tokens([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
print("
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
| 198 |
print(f" Resumed from {args.resume}")
|
| 199 |
-
|
| 200 |
-
print(" Starting from scratch (random init)!")
|
| 201 |
-
shim.to(device=device, dtype=torch.bfloat16).train()
|
| 202 |
-
|
| 203 |
-
# Pose projector (trainable)
|
| 204 |
-
pose_proj = nn.Sequential(
|
| 205 |
-
nn.Linear(4, 896), nn.GELU(),
|
| 206 |
-
).to(device=device, dtype=torch.bfloat16).train()
|
| 207 |
|
| 208 |
# ── 3. Data ──
|
| 209 |
-
print("\n[3]
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
val_ds = DistillDataset(CACHE_DIR, CAST_DIR, split="val")
|
| 213 |
-
collator = DistillCollator(tokenizer, action_token_ids)
|
| 214 |
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
# ── 4. Optimizer ──
|
| 221 |
-
opt = torch.optim.AdamW(
|
| 222 |
-
{"params": shim.parameters(), "lr": args.lr},
|
| 223 |
-
{"params": pose_proj.parameters(), "lr": args.lr},
|
| 224 |
-
], weight_decay=args.weight_decay)
|
| 225 |
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.max_steps)
|
| 226 |
|
| 227 |
# ── 5. Training ──
|
| 228 |
-
print("\n[4] Training
|
| 229 |
-
|
| 230 |
-
dino = ve.dino_featurizer
|
| 231 |
-
siglip = ve.siglip_featurizer
|
| 232 |
-
dino.eval(); siglip.eval()
|
| 233 |
-
projector.eval(); qwen.eval()
|
| 234 |
|
| 235 |
def encode_image(cur):
|
| 236 |
-
"""Encode a single image → (B, 452, 896) vision features."""
|
| 237 |
with torch.no_grad():
|
| 238 |
-
|
| 239 |
-
if isinstance(
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
return p
|
| 252 |
-
combined = torch.cat([pad(dino_f, 384), pad(siglip_f, 768)], dim=1)
|
| 253 |
-
return projector(combined) # (B, 452, 896)
|
| 254 |
-
|
| 255 |
-
# Find action token offset in template
|
| 256 |
-
dummy_tok = tokenizer.apply_chat_template(
|
| 257 |
-
[{"role": "system", "content": "You are a helpful assistant."},
|
| 258 |
-
{"role": "user", "content": "test"},
|
| 259 |
-
{"role": "assistant", "content": " ".join([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])}],
|
| 260 |
-
tokenize=True, add_generation_prompt=False, return_dict=True, return_tensors="pt",
|
| 261 |
-
)
|
| 262 |
-
dummy_ids = dummy_tok["input_ids"].squeeze(0)
|
| 263 |
-
action_pos = torch.where(
|
| 264 |
-
(dummy_ids >= action_token_ids[0]) & (dummy_ids <= action_token_ids[-1])
|
| 265 |
-
)[0]
|
| 266 |
-
action_offset_in_text = action_pos[0].item()
|
| 267 |
-
NUM_VIS = 452
|
| 268 |
-
print(f" Action tokens start at position {action_offset_in_text} in text sequence")
|
| 269 |
-
print(f" Vision tokens: {NUM_VIS}")
|
| 270 |
-
|
| 271 |
global_step = 0
|
| 272 |
-
best_val_loss = float("inf")
|
| 273 |
train_iter = iter(train_loader)
|
| 274 |
-
|
| 275 |
pbar = tqdm(total=args.max_steps, desc="Train")
|
|
|
|
| 276 |
while global_step < args.max_steps:
|
| 277 |
shim.train()
|
| 278 |
-
pose_proj.train()
|
| 279 |
opt.zero_grad()
|
| 280 |
accum_loss = 0.0
|
| 281 |
|
| 282 |
-
for
|
| 283 |
try:
|
| 284 |
batch = next(train_iter)
|
| 285 |
except StopIteration:
|
| 286 |
train_iter = iter(train_loader)
|
| 287 |
batch = next(train_iter)
|
| 288 |
|
| 289 |
-
cur_img = batch["cur_img"].to(device, dtype=
|
| 290 |
inp = batch["input_ids"].to(device)
|
| 291 |
am = batch["attention_mask"].to(device)
|
| 292 |
-
hs_target = batch["hs_target"].to(device, dtype=
|
| 293 |
B = cur_img.shape[0]
|
| 294 |
|
| 295 |
-
# Encode image
|
| 296 |
vis = encode_image(cur_img)
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
mm_embeds = torch.cat([embed[:, :1, :], vis, embed[:, 1:, :]], dim=1)
|
| 303 |
-
mm_attn = torch.cat([am[:, :1],
|
| 304 |
-
torch.ones(B, vis.shape[1], dtype=am.dtype, device=device),
|
| 305 |
-
am[:, 1:]], dim=1)
|
| 306 |
-
|
| 307 |
-
# Action mask
|
| 308 |
-
act_start = 1 + NUM_VIS + action_offset_in_text - 1
|
| 309 |
-
action_mask_mm = torch.zeros(mm_embeds.shape[:2], dtype=torch.bool, device=device)
|
| 310 |
for i in range(B):
|
| 311 |
-
|
| 312 |
-
end =
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
mm_embeds = mm_embeds * ~action_mask_mm.unsqueeze(-1)
|
| 316 |
-
|
| 317 |
-
# Qwen forward
|
| 318 |
-
try:
|
| 319 |
-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 320 |
-
out = qwen(inputs_embeds=mm_embeds, attention_mask=mm_attn,
|
| 321 |
-
labels=None, output_hidden_states=True, return_dict=True)
|
| 322 |
-
except Exception as exc:
|
| 323 |
-
print(f"\n[train] forward failed at step {global_step} micro_step {micro_step}: {exc}", flush=True)
|
| 324 |
-
raise
|
| 325 |
|
|
|
|
|
|
|
| 326 |
hs_all = out.hidden_states[-1]
|
| 327 |
-
|
| 328 |
-
hs_shimmed = shim(
|
| 329 |
-
|
| 330 |
-
# State MSE (shim gradients)
|
| 331 |
loss = F.mse_loss(hs_shimmed, hs_target)
|
| 332 |
(loss / args.grad_accum).backward()
|
| 333 |
accum_loss += loss.item()
|
| 334 |
|
| 335 |
-
# Gradient clipping & step
|
| 336 |
torch.nn.utils.clip_grad_norm_(shim.parameters(), 1.0)
|
| 337 |
-
torch.nn.utils.clip_grad_norm_(pose_proj.parameters(), 1.0)
|
| 338 |
opt.step()
|
| 339 |
sched.step()
|
| 340 |
global_step += 1
|
| 341 |
|
| 342 |
-
|
| 343 |
-
if global_step % args.log_every == 0:
|
| 344 |
-
lr_cur = opt.param_groups[0]["lr"]
|
| 345 |
with torch.no_grad():
|
| 346 |
-
cos = F.cosine_similarity(hs_shimmed.float().reshape(-1,
|
| 347 |
-
hs_target.float().reshape(-1,
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
"loss": f"{accum_loss/args.grad_accum:.5f}",
|
| 351 |
-
"cos": f"{cos:.4f}",
|
| 352 |
-
"nd": f"{nd:.2f}",
|
| 353 |
-
"lr": f"{lr_cur:.1e}",
|
| 354 |
-
})
|
| 355 |
|
| 356 |
# Validation
|
| 357 |
if global_step % args.val_every == 0:
|
| 358 |
shim.eval()
|
| 359 |
-
|
| 360 |
with torch.no_grad():
|
| 361 |
-
for
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
am =
|
| 365 |
-
|
| 366 |
-
Bv =
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
am[:, 1:]], dim=1)
|
| 373 |
-
act_start_v = 1 + NUM_VIS + action_offset_in_text - 1
|
| 374 |
-
act_mask_v = torch.zeros(mm_embeds.shape[:2], dtype=torch.bool, device=device)
|
| 375 |
for i in range(Bv):
|
| 376 |
-
|
| 377 |
-
e =
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
val_loss += F.mse_loss(hs_shimmed, hs_target).item()
|
| 387 |
-
val_cos += F.cosine_similarity(hs_shimmed.float().reshape(-1, 4096),
|
| 388 |
-
hs_target.float().reshape(-1, 4096), dim=-1).mean().item()
|
| 389 |
-
val_nd += (hs_shimmed.float() - hs_target.float()).norm(dim=-1).mean().item()
|
| 390 |
nv += 1
|
| 391 |
-
|
| 392 |
-
print(f"\n─── Val @
|
| 393 |
-
if
|
| 394 |
-
|
| 395 |
torch.save(shim.state_dict(), run_dir / "shim_best.pt")
|
| 396 |
-
|
| 397 |
-
print(f" → New best saved (loss={val_loss:.5f})")
|
| 398 |
|
| 399 |
if global_step % args.save_every == 0:
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
torch.save(shim.state_dict(), ckpt_dir / "shim.pt")
|
| 403 |
-
torch.save(pose_proj.state_dict(), ckpt_dir / "pose_projector.pt")
|
| 404 |
-
|
| 405 |
-
pbar.update(1)
|
| 406 |
|
| 407 |
pbar.close()
|
| 408 |
-
print(f"\nDone! Best val loss: {
|
| 409 |
|
| 410 |
|
| 411 |
if __name__ == "__main__":
|
|
|
|
| 1 |
"""
|
| 2 |
+
Train the hidden state shim (896→4096) for OpenVLA-Micro.
|
| 3 |
+
|
| 4 |
+
The shim maps Qwen2.5 0.5B's 896-dim hidden states to match a teacher
|
| 5 |
+
LLM's 4096-dim space (e.g., Llama-2, Llama-3). This lets the small model
|
| 6 |
+
drive OmniVLA's pretrained action head with near-zero accuracy loss.
|
| 7 |
+
|
| 8 |
+
Workflow:
|
| 9 |
+
1. Cache your teacher's hidden states on your dataset
|
| 10 |
+
2. Run this script to train the shim
|
| 11 |
+
3. Bake the shim into the checkpoint with bake_shim.py
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python train_shim.py --cache-dir ./my_cache --base-model theguy21/openvla-micro
|
| 15 |
+
|
| 16 |
+
For the full training pipeline used in openvla-micro-distill, see:
|
| 17 |
+
https://huggingface.co/theguy21/openvla-micro
|
| 18 |
"""
|
| 19 |
+
import argparse, json, os
|
| 20 |
from pathlib import Path
|
| 21 |
|
| 22 |
import numpy as np
|
|
|
|
| 27 |
from PIL import Image
|
| 28 |
from tqdm import tqdm
|
| 29 |
|
| 30 |
+
from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP
|
| 31 |
+
from model_wrapper import IMAGENET_MEAN as IM4D, IMAGENET_STD as IS4D, SIGLIP_MEAN, SIGLIP_STD
|
| 32 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
IMAGENET_MEAN = IM4D.view(3, 1, 1)
|
| 35 |
+
IMAGENET_STD = IS4D.view(3, 1, 1)
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
NUM_ACTION_TOKENS = 32 # OmniVLA uses 8 chunks × 4 DoF
|
| 38 |
+
NUM_VIS = 452 # 256 dino patches + 196 siglip patches
|
|
|
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
+
def to_siglip(pv):
|
| 42 |
+
return (pv * IMAGENET_STD.to(pv.device) + IMAGENET_MEAN.to(pv.device)
|
| 43 |
+
- SIGLIP_MEAN.to(pv.device)) / SIGLIP_STD.to(pv.device)
|
| 44 |
|
| 45 |
|
| 46 |
+
# ─────────────────────────────────────────────────────────────
|
| 47 |
+
# Dataset — ADAPT THE IMAGE/INSTRUCTION LOGIC TO YOUR FORMAT
|
| 48 |
+
# ─────────────────────────────────────────────────────────────
|
| 49 |
class DistillDataset(Dataset):
|
| 50 |
+
"""
|
| 51 |
+
Each episode_*.pt is expected to contain:
|
| 52 |
+
episode_id: str
|
| 53 |
+
num_steps: int
|
| 54 |
+
hidden_states: Tensor[T, 32, teacher_dim]
|
| 55 |
+
(optional) instructions: list[str] of length T
|
| 56 |
+
|
| 57 |
+
Image paths are constructed as {data_dir}/{episode_id}/img/step_{t:04d}.png
|
| 58 |
+
Override _load_image / _get_instruction for custom formats.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self, cache_dir, data_dir, split="train", val_ratio=0.1):
|
| 62 |
+
self.data_dir = Path(data_dir)
|
| 63 |
+
cache_files = sorted(Path(cache_dir).glob("episode_*.pt"))
|
| 64 |
n = len(cache_files)
|
|
|
|
|
|
|
|
|
|
| 65 |
split_idx = int(n * (1 - val_ratio))
|
| 66 |
files = cache_files[:split_idx] if split == "train" else cache_files[split_idx:]
|
| 67 |
self.index = []
|
| 68 |
for cf in files:
|
| 69 |
d = torch.load(cf, weights_only=True)
|
| 70 |
+
for t in range(d["num_steps"]):
|
|
|
|
| 71 |
self.index.append((cf, t))
|
| 72 |
self._cache = {}
|
| 73 |
+
self._instr_cache = {}
|
| 74 |
print(f" [{split}] {len(self.index)} steps from {len(files)} episodes", flush=True)
|
| 75 |
|
| 76 |
def __len__(self):
|
|
|
|
| 81 |
cf_str = str(cf_path)
|
| 82 |
if cf_str not in self._cache:
|
| 83 |
self._cache[cf_str] = torch.load(cf_path, weights_only=True)
|
| 84 |
+
ep = self._cache[cf_str]
|
| 85 |
+
ep_id = ep["episode_id"]
|
| 86 |
+
|
| 87 |
+
# Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
from torchvision.transforms.functional import resize as tv_resize
|
| 89 |
+
img = tv_resize(Image.open(self.data_dir / ep_id / "img" / f"step_{t:04d}.png").convert("RGB"), 224)
|
| 90 |
+
img = torch.tensor(np.array(img, dtype=np.float32) / 255.0).permute(2, 0, 1)
|
| 91 |
+
img = (img - IMAGENET_MEAN) / IMAGENET_STD
|
| 92 |
+
|
| 93 |
+
# Instruction
|
| 94 |
+
if "instructions" in ep:
|
| 95 |
+
instr = ep["instructions"][t]
|
| 96 |
+
if isinstance(instr, list):
|
| 97 |
+
instr = instr[0]
|
| 98 |
+
else:
|
| 99 |
+
instr = "move forward"
|
| 100 |
+
|
| 101 |
+
return {"cur_img": img, "hs_target": ep["hidden_states"][t].float(), "instruction": str(instr).strip()}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def find_action_offset(tokenizer, action_token_ids):
|
| 105 |
+
"""Determine where action tokens start in the chat template."""
|
| 106 |
+
dummy = tokenizer.apply_chat_template(
|
| 107 |
+
[{"role": "system", "content": "You are a helpful assistant."},
|
| 108 |
+
{"role": "user", "content": "test"},
|
| 109 |
+
{"role": "assistant", "content": " ".join([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])}],
|
| 110 |
+
tokenize=True, add_generation_prompt=False, return_dict=True, return_tensors="pt",
|
| 111 |
+
)
|
| 112 |
+
ids = dummy["input_ids"].squeeze(0)
|
| 113 |
+
pos = torch.where((ids >= action_token_ids[0]) & (ids <= action_token_ids[-1]))[0]
|
| 114 |
+
return pos[0].item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
|
| 117 |
def main():
|
| 118 |
parser = argparse.ArgumentParser()
|
| 119 |
+
parser.add_argument("--cache-dir", type=str, required=True)
|
| 120 |
+
parser.add_argument("--data-dir", type=str, required=True,
|
| 121 |
+
help="Dataset root with {episode_id}/img/step_*.png")
|
| 122 |
+
parser.add_argument("--base-model", type=str, default="theguy21/openvla-micro")
|
| 123 |
+
parser.add_argument("--teacher-dim", type=int, default=4096)
|
| 124 |
+
parser.add_argument("--max-steps", type=int, default=10000)
|
| 125 |
parser.add_argument("--batch-size", type=int, default=8)
|
| 126 |
parser.add_argument("--lr", type=float, default=5e-5)
|
|
|
|
| 127 |
parser.add_argument("--grad-accum", type=int, default=4)
|
|
|
|
| 128 |
parser.add_argument("--val-every", type=int, default=500)
|
| 129 |
parser.add_argument("--save-every", type=int, default=5000)
|
| 130 |
+
parser.add_argument("--resume", type=str, default=None)
|
| 131 |
+
parser.add_argument("--run-name", type=str, default="shim_run")
|
| 132 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 133 |
args = parser.parse_args()
|
| 134 |
|
| 135 |
+
device = torch.device(args.device)
|
| 136 |
+
dtype = torch.bfloat16
|
| 137 |
print(f"Device: {device}")
|
| 138 |
|
| 139 |
+
run_dir = Path(args.run_name)
|
| 140 |
+
run_dir.mkdir(exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
+
# ── 1. Load base model ──
|
| 143 |
+
print("\n[1] Loading base model...")
|
| 144 |
+
ckpt = torch.load(os.path.expanduser(args.base_model), map_location="cpu", weights_only=False)
|
| 145 |
+
msd = ckpt["model"]
|
| 146 |
|
| 147 |
+
ve = DinoSigLIPEncoder().eval()
|
| 148 |
ve.load_state_dict(msd["vision_backbone"])
|
| 149 |
+
ve.to(device, dtype=dtype)
|
| 150 |
+
for p in ve.parameters(): p.requires_grad_(False)
|
|
|
|
| 151 |
|
| 152 |
+
projector = CombinedProjector(ShimMLP(384), ShimMLP(768), nn.Linear(8704, 896), nn.Linear(896, 896))
|
|
|
|
| 153 |
projector.load_state_dict(msd["projector"])
|
| 154 |
+
projector.to(device, dtype=dtype).eval()
|
| 155 |
+
for p in projector.parameters(): p.requires_grad_(False)
|
|
|
|
| 156 |
|
| 157 |
+
llm = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B", torch_dtype=dtype)
|
| 158 |
+
llm_sd = {k.replace("llm.", "", 1): v for k, v in msd["llm_backbone"].items()}
|
| 159 |
+
llm.load_state_dict(llm_sd)
|
| 160 |
+
llm.to(device, dtype=dtype).eval()
|
| 161 |
+
for p in llm.parameters(): p.requires_grad_(False)
|
|
|
|
| 162 |
|
| 163 |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", use_fast=True)
|
| 164 |
tokenizer.add_tokens([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])
|
| 165 |
+
action_token_ids = tokenizer.convert_tokens_to_ids([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])
|
| 166 |
+
action_offset = find_action_offset(tokenizer, action_token_ids)
|
| 167 |
+
print(f" Action tokens at position {action_offset}")
|
| 168 |
+
|
| 169 |
+
# ── 2. Shim ──
|
| 170 |
+
print("\n[2] Building shim...")
|
| 171 |
+
shim = nn.Sequential(nn.Linear(896, 2048), nn.GELU(), nn.Linear(2048, args.teacher_dim))
|
| 172 |
+
if args.resume:
|
| 173 |
+
shim.load_state_dict(torch.load(args.resume, map_location="cpu"))
|
| 174 |
print(f" Resumed from {args.resume}")
|
| 175 |
+
shim.to(device, dtype=dtype).train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
# ── 3. Data ──
|
| 178 |
+
print("\n[3] Loading data...")
|
| 179 |
+
train_ds = DistillDataset(args.cache_dir, args.data_dir, split="train")
|
| 180 |
+
val_ds = DistillDataset(args.cache_dir, args.data_dir, split="val")
|
|
|
|
|
|
|
| 181 |
|
| 182 |
+
def collate(batch):
|
| 183 |
+
from torchvision.transforms.functional import resize as tv_resize
|
| 184 |
+
texts, imgs, hs = [], [], []
|
| 185 |
+
for b in batch:
|
| 186 |
+
texts.append(b["instruction"])
|
| 187 |
+
imgs.append(b["cur_img"])
|
| 188 |
+
hs.append(b["hs_target"])
|
| 189 |
+
cur = torch.stack(imgs)
|
| 190 |
+
hs_target = torch.stack(hs)
|
| 191 |
+
chat = [[{"role": "system", "content": "You are a helpful assistant."},
|
| 192 |
+
{"role": "user", "content": f"What action should the robot take to {t.lower()}?"},
|
| 193 |
+
{"role": "assistant", "content": " ".join([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])}]
|
| 194 |
+
for t in texts]
|
| 195 |
+
tok = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False,
|
| 196 |
+
return_dict=True, return_tensors="pt", padding=True)
|
| 197 |
+
return {"cur_img": cur, "input_ids": tok["input_ids"], "attention_mask": tok["attention_mask"],
|
| 198 |
+
"hs_target": hs_target}
|
| 199 |
+
|
| 200 |
+
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate, num_workers=0)
|
| 201 |
+
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate, num_workers=0)
|
| 202 |
|
| 203 |
# ── 4. Optimizer ──
|
| 204 |
+
opt = torch.optim.AdamW(shim.parameters(), lr=args.lr, weight_decay=0.01)
|
|
|
|
|
|
|
|
|
|
| 205 |
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.max_steps)
|
| 206 |
|
| 207 |
# ── 5. Training ──
|
| 208 |
+
print(f"\n[4] Training...")
|
| 209 |
+
dino, siglip = ve.dino_featurizer, ve.siglip_featurizer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
def encode_image(cur):
|
|
|
|
| 212 |
with torch.no_grad():
|
| 213 |
+
df = dino(cur)
|
| 214 |
+
if isinstance(df, (list, tuple)): df = df[0]
|
| 215 |
+
df = df[:, 1:]
|
| 216 |
+
sf = siglip(to_siglip(cur))
|
| 217 |
+
if isinstance(sf, (list, tuple)): sf = sf[0]
|
| 218 |
+
sf = sf[:, 1:]
|
| 219 |
+
B = cur.shape[0]; D = 1152
|
| 220 |
+
def pad(f, ed):
|
| 221 |
+
p = torch.zeros(B, f.shape[1], D, device=device, dtype=dtype)
|
| 222 |
+
p[..., :ed] = f[..., :ed]; return p
|
| 223 |
+
return projector(torch.cat([pad(df, 384), pad(sf, 768)], dim=1))
|
| 224 |
+
|
| 225 |
+
best_loss = float("inf")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
global_step = 0
|
|
|
|
| 227 |
train_iter = iter(train_loader)
|
|
|
|
| 228 |
pbar = tqdm(total=args.max_steps, desc="Train")
|
| 229 |
+
|
| 230 |
while global_step < args.max_steps:
|
| 231 |
shim.train()
|
|
|
|
| 232 |
opt.zero_grad()
|
| 233 |
accum_loss = 0.0
|
| 234 |
|
| 235 |
+
for _ in range(args.grad_accum):
|
| 236 |
try:
|
| 237 |
batch = next(train_iter)
|
| 238 |
except StopIteration:
|
| 239 |
train_iter = iter(train_loader)
|
| 240 |
batch = next(train_iter)
|
| 241 |
|
| 242 |
+
cur_img = batch["cur_img"].to(device, dtype=dtype)
|
| 243 |
inp = batch["input_ids"].to(device)
|
| 244 |
am = batch["attention_mask"].to(device)
|
| 245 |
+
hs_target = batch["hs_target"].to(device, dtype=dtype)
|
| 246 |
B = cur_img.shape[0]
|
| 247 |
|
|
|
|
| 248 |
vis = encode_image(cur_img)
|
| 249 |
+
embed = llm.get_input_embeddings()(inp)
|
| 250 |
+
mm = torch.cat([embed[:, :1, :], vis, embed[:, 1:, :]], dim=1)
|
| 251 |
+
mm_attn = torch.cat([am[:, :1], torch.ones(B, NUM_VIS, dtype=am.dtype, device=device), am[:, 1:]], dim=1)
|
| 252 |
+
act_start = 1 + NUM_VIS + action_offset - 1
|
| 253 |
+
mask = torch.zeros(B, mm.shape[1], dtype=torch.bool, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
for i in range(B):
|
| 255 |
+
end = act_start + NUM_ACTION_TOKENS
|
| 256 |
+
if end <= mm.shape[1]:
|
| 257 |
+
mask[i, act_start:end] = True
|
| 258 |
+
mm = mm * ~mask.unsqueeze(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
+
with torch.autocast(device_type=device.type, dtype=dtype):
|
| 261 |
+
out = llm(inputs_embeds=mm, attention_mask=mm_attn, labels=None, output_hidden_states=True, return_dict=True)
|
| 262 |
hs_all = out.hidden_states[-1]
|
| 263 |
+
hs_act = torch.stack([hs_all[i, mask[i]] for i in range(B)], dim=0)
|
| 264 |
+
hs_shimmed = shim(hs_act)
|
|
|
|
|
|
|
| 265 |
loss = F.mse_loss(hs_shimmed, hs_target)
|
| 266 |
(loss / args.grad_accum).backward()
|
| 267 |
accum_loss += loss.item()
|
| 268 |
|
|
|
|
| 269 |
torch.nn.utils.clip_grad_norm_(shim.parameters(), 1.0)
|
|
|
|
| 270 |
opt.step()
|
| 271 |
sched.step()
|
| 272 |
global_step += 1
|
| 273 |
|
| 274 |
+
if global_step % 100 == 0:
|
|
|
|
|
|
|
| 275 |
with torch.no_grad():
|
| 276 |
+
cos = F.cosine_similarity(hs_shimmed.float().reshape(-1, args.teacher_dim),
|
| 277 |
+
hs_target.float().reshape(-1, args.teacher_dim), dim=-1).mean().item()
|
| 278 |
+
pbar.set_postfix({"loss": f"{accum_loss/args.grad_accum:.5f}", "cos": f"{cos:.4f}"})
|
| 279 |
+
pbar.update(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
# Validation
|
| 282 |
if global_step % args.val_every == 0:
|
| 283 |
shim.eval()
|
| 284 |
+
v_loss, v_cos, nv = 0.0, 0.0, 0
|
| 285 |
with torch.no_grad():
|
| 286 |
+
for vb in val_loader:
|
| 287 |
+
ci = vb["cur_img"].to(device, dtype=dtype)
|
| 288 |
+
ip = vb["input_ids"].to(device)
|
| 289 |
+
am = vb["attention_mask"].to(device)
|
| 290 |
+
ht = vb["hs_target"].to(device, dtype=dtype)
|
| 291 |
+
Bv = ci.shape[0]
|
| 292 |
+
vi = encode_image(ci)
|
| 293 |
+
em = llm.get_input_embeddings()(ip)
|
| 294 |
+
mm = torch.cat([em[:, :1, :], vi, em[:, 1:, :]], dim=1)
|
| 295 |
+
ma = torch.cat([am[:, :1], torch.ones(Bv, NUM_VIS, dtype=am.dtype, device=device), am[:, 1:]], dim=1)
|
| 296 |
+
mk = torch.zeros(Bv, mm.shape[1], dtype=torch.bool, device=device)
|
|
|
|
|
|
|
|
|
|
| 297 |
for i in range(Bv):
|
| 298 |
+
e = 1 + NUM_VIS + action_offset - 1 + NUM_ACTION_TOKENS
|
| 299 |
+
if e <= mm.shape[1]:
|
| 300 |
+
mk[i, 1 + NUM_VIS + action_offset - 1:e] = True
|
| 301 |
+
mm = mm * ~mk.unsqueeze(-1)
|
| 302 |
+
o = llm(inputs_embeds=mm, attention_mask=ma, labels=None, output_hidden_states=True, return_dict=True)
|
| 303 |
+
ha = torch.stack([o.hidden_states[-1][i, mk[i]] for i in range(Bv)], dim=0)
|
| 304 |
+
hs = shim(ha)
|
| 305 |
+
v_loss += F.mse_loss(hs, ht).item()
|
| 306 |
+
v_cos += F.cosine_similarity(hs.float().reshape(-1, args.teacher_dim),
|
| 307 |
+
ht.float().reshape(-1, args.teacher_dim), dim=-1).mean().item()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
nv += 1
|
| 309 |
+
v_loss /= nv; v_cos /= nv
|
| 310 |
+
print(f"\n─── Val @ {global_step}: loss={v_loss:.5f} cos={v_cos:.4f} ───", flush=True)
|
| 311 |
+
if v_loss < best_loss:
|
| 312 |
+
best_loss = v_loss
|
| 313 |
torch.save(shim.state_dict(), run_dir / "shim_best.pt")
|
| 314 |
+
print(f" → Saved best (loss={v_loss:.5f})")
|
|
|
|
| 315 |
|
| 316 |
if global_step % args.save_every == 0:
|
| 317 |
+
d = run_dir / f"step_{global_step}"; d.mkdir(exist_ok=True)
|
| 318 |
+
torch.save(shim.state_dict(), d / "shim.pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
pbar.close()
|
| 321 |
+
print(f"\nDone! Best val loss: {best_loss:.5f}")
|
| 322 |
|
| 323 |
|
| 324 |
if __name__ == "__main__":
|