File size: 5,621 Bytes
2e0542d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd9b4af
 
2e0542d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd9b4af
2e0542d
dd9b4af
2e0542d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
---
license: mit
language:
- en
library_name: transformers
pipeline_tag: reinforcement-learning
tags:
- robotics
- vla
- vision-language-action
- openvla
- omnivla
- robot
- qwen
- dinov2
- siglip
datasets:
- libero_90
- cast
model-index:
- name: openvla-micro
  results: []
---

# OpenVLA-Micro

**A drop-in replacement for OmniVLA 7B that runs ~14Γ— faster with 0.997 action cosine similarity.**

OpenVLA-Micro is a compact Vision-Language-Action model that replaces OmniVLA 7B (DINOv2-L + SigLIP-so400m + Llama-2 7B) with a much smaller stack (DINOv2-S/14 + SigLIP-B/16 + Qwen2.5 0.5B) while preserving compatibility with OmniVLA's pretrained action head through a learned hidden-state shim (896β†’4096).

## Model Architecture

| Component | Encoder | Output Dim |
|---|---|---|
| Vision (DINO) | `DINOv2-S/14` (facebook/dinov2-small) | 256 tokens Γ— 384d β†’ MLP β†’ 8704 |
| Vision (SigLIP) | `SigLIP-B/16` (google/siglip-base-patch16-224) | 196 tokens Γ— 768d β†’ MLP β†’ 8704 |
| Projector | Linear(8704β†’896) + GELU + Linear(896β†’896) | 452 tokens Γ— 896d |
| LLM | `Qwen2.5-0.5B` (24 layers, 896 hidden, 8 heads) | Variable Γ— 896d |
| Shim | Linear(896β†’2048) + GELU + Linear(2048β†’4096) | 32 action tokens Γ— 4096d |
| Action Head | OmniVLA's pretrained head (unchanged) | 8 chunks Γ— 7-DoF |

Total parameters: **~0.6B** (vs 7B for OmniVLA/OpenVLA).

## Performance

### vs OmniVLA 7B (teacher)

| Metric | Value | Note |
|---|---|---|
| Hidden state cosine | **0.63** | Last-layer HS at action positions |
| Action cosine | **0.997** | After OmniVLA action head |
| Action MSE | ~0.001 | Effectively identical predictions |
| Inference speed | **~14Γ— faster** | 0.5B vs 7B LLM |

Trained on 1000 CAST episodes (17k steps) via hidden-state distillation. The shim was trained for ~21k steps (plateaued at ~8k).

### vs OpenVLA 7B (original)

OpenVLA-Micro is distilled from *OmniVLA*, which itself fine-tuned *OpenVLA* with 32 action tokens and a modified action head. Direct comparison with the original OpenVLA is not apples-to-apples due to different action tokenization, but the ~14Γ— speedup and near-lossless action quality relative to OmniVLA apply similarly vs OpenVLA.

## Quick Start

```python
from PIL import Image
from modeling_openvla_micro import OpenVLAMicro

model = OpenVLAMicro.from_pretrained("theguy21/openvla-micro", device="cuda")
model.eval()

image = Image.open("demo.jpg").convert("RGB")
action = model.predict_action(image, "pick up the red block")
print(action)  # [0.12, -0.03, 0.45, -0.01, 0.22, 0.08, -0.15]
```

### CLI β€” GPU

```bash
python inference.py --image demo.jpg "pick up the red block"
```

### CLI β€” CPU / Edge

```bash
# Standard CPU (~6GB RAM, 3-5 sec/step)
python inference_cpu.py --image demo.jpg "pick up the red block"

# Low-RAM CPU (~2.5GB RAM, requires bitsandbytes)
python inference_cpu.py --low-ram --image demo.jpg "pick up the red block"
```

### As an OmniVLA drop-in replacement

Use `OpenVLAMicroWrapper` (from `model_wrapper.py`) to expose the same forward interface as OmniVLA's `VLAForActionPrediction`:

```python
from model_wrapper import OpenVLAMicroWrapper
from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP

ckpt = torch.load("openvla-micro-distill.pt", map_location="cpu")
ve = DinoSigLIPEncoder()
ve.load_state_dict(ckpt["model"]["vision_backbone"])
# ... (see model_wrapper.py for full example)

output = vla(input_ids, attention_mask, pixel_values, labels=labels, output_hidden_states=True)
actions_hidden_states = extract_actions(output.hidden_states[-1], labels)
predicted_actions = omnivla_action_head.predict_action(actions_hidden_states, modality_id)
```

## Architecture Diagram

```
Image (224Γ—224)
  β”œβ”€β”€ DINOv2-S/14 β†’ 256 patches Γ— 384d β†’ ShimMLP(384β†’8704)
  └── SigLIP-B/16 β†’ 196 patches Γ— 768d β†’ ShimMLP(768β†’8704)
       └── Concat (452 tokens) β†’ Linear(8704β†’896) β†’ GELU β†’ Linear(896β†’896)
              └── Qwen2.5 0.5B (24 layers, 896 hidden)
                     └── Hidden State Shim (896β†’2048β†’4096)
                            └── OmniVLA Action Head (pretrained, frozen)
                                   └── 8 chunks Γ— 7-DoF actions
```

## Files

| File | Size | Description |
|---|---|---|
| `modeling_openvla_micro.py` | 15 KB | Model definitions |
| `model_wrapper.py` | 11 KB | OmniVLA-compatible interface |
| `inference.py` | 1.5 KB | GPU/CPU CLI inference |
| `inference_cpu.py` | 2 KB | Edge device inference (with low-RAM mode) |
| `train_shim.py` | 15 KB | Reference shim training script |
| `config.json` | 1.2 KB | Model configuration |
| `openvla-micro-merged.pt` | 1.6 GB | Base checkpoint (no shim, 896-dim output) |
| `openvla-micro-distill.pt` | 1.6 GB | Full checkpoint (with baked-in shim, 4096-dim) |

**Which checkpoint to use?**

- `openvla-micro-distill.pt` β€” **recommended**. Outputs 4096-dim hidden states that plug directly into OmniVLA's action head. One-step inference.
- `openvla-micro-merged.pt` β€” base model only (896-dim). Use if you want to train your own shim or action head.

## Requirements

```
torch>=2.0.0
torchvision>=0.15.0
transformers>=4.38.0
timm>=0.9.0
Pillow>=10.0.0
numpy>=1.24.0
```

For low-RAM CPU: `bitsandbytes>=0.43.0`

## Training the Shim

```bash
python train_shim.py \
    --cache-dir ./teacher_cache \
    --data-dir ./dataset \
    --base-model openvla-micro-merged.pt \
    --teacher-dim 4096
```

See `train_shim.py` for full options. The script expects pre-cached teacher hidden states; adapt `DistillDataset` to your format.

## License

MIT