Reinforcement Learning
Transformers
English
robotics
vla
vision-language-action
openvla
omnivla
robot
qwen
dinov2
siglip
Instructions to use theguy21/openvla-micro with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use theguy21/openvla-micro with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("theguy21/openvla-micro", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 5,621 Bytes
2e0542d dd9b4af 2e0542d dd9b4af 2e0542d dd9b4af 2e0542d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | ---
license: mit
language:
- en
library_name: transformers
pipeline_tag: reinforcement-learning
tags:
- robotics
- vla
- vision-language-action
- openvla
- omnivla
- robot
- qwen
- dinov2
- siglip
datasets:
- libero_90
- cast
model-index:
- name: openvla-micro
results: []
---
# OpenVLA-Micro
**A drop-in replacement for OmniVLA 7B that runs ~14Γ faster with 0.997 action cosine similarity.**
OpenVLA-Micro is a compact Vision-Language-Action model that replaces OmniVLA 7B (DINOv2-L + SigLIP-so400m + Llama-2 7B) with a much smaller stack (DINOv2-S/14 + SigLIP-B/16 + Qwen2.5 0.5B) while preserving compatibility with OmniVLA's pretrained action head through a learned hidden-state shim (896β4096).
## Model Architecture
| Component | Encoder | Output Dim |
|---|---|---|
| Vision (DINO) | `DINOv2-S/14` (facebook/dinov2-small) | 256 tokens Γ 384d β MLP β 8704 |
| Vision (SigLIP) | `SigLIP-B/16` (google/siglip-base-patch16-224) | 196 tokens Γ 768d β MLP β 8704 |
| Projector | Linear(8704β896) + GELU + Linear(896β896) | 452 tokens Γ 896d |
| LLM | `Qwen2.5-0.5B` (24 layers, 896 hidden, 8 heads) | Variable Γ 896d |
| Shim | Linear(896β2048) + GELU + Linear(2048β4096) | 32 action tokens Γ 4096d |
| Action Head | OmniVLA's pretrained head (unchanged) | 8 chunks Γ 7-DoF |
Total parameters: **~0.6B** (vs 7B for OmniVLA/OpenVLA).
## Performance
### vs OmniVLA 7B (teacher)
| Metric | Value | Note |
|---|---|---|
| Hidden state cosine | **0.63** | Last-layer HS at action positions |
| Action cosine | **0.997** | After OmniVLA action head |
| Action MSE | ~0.001 | Effectively identical predictions |
| Inference speed | **~14Γ faster** | 0.5B vs 7B LLM |
Trained on 1000 CAST episodes (17k steps) via hidden-state distillation. The shim was trained for ~21k steps (plateaued at ~8k).
### vs OpenVLA 7B (original)
OpenVLA-Micro is distilled from *OmniVLA*, which itself fine-tuned *OpenVLA* with 32 action tokens and a modified action head. Direct comparison with the original OpenVLA is not apples-to-apples due to different action tokenization, but the ~14Γ speedup and near-lossless action quality relative to OmniVLA apply similarly vs OpenVLA.
## Quick Start
```python
from PIL import Image
from modeling_openvla_micro import OpenVLAMicro
model = OpenVLAMicro.from_pretrained("theguy21/openvla-micro", device="cuda")
model.eval()
image = Image.open("demo.jpg").convert("RGB")
action = model.predict_action(image, "pick up the red block")
print(action) # [0.12, -0.03, 0.45, -0.01, 0.22, 0.08, -0.15]
```
### CLI β GPU
```bash
python inference.py --image demo.jpg "pick up the red block"
```
### CLI β CPU / Edge
```bash
# Standard CPU (~6GB RAM, 3-5 sec/step)
python inference_cpu.py --image demo.jpg "pick up the red block"
# Low-RAM CPU (~2.5GB RAM, requires bitsandbytes)
python inference_cpu.py --low-ram --image demo.jpg "pick up the red block"
```
### As an OmniVLA drop-in replacement
Use `OpenVLAMicroWrapper` (from `model_wrapper.py`) to expose the same forward interface as OmniVLA's `VLAForActionPrediction`:
```python
from model_wrapper import OpenVLAMicroWrapper
from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP
ckpt = torch.load("openvla-micro-distill.pt", map_location="cpu")
ve = DinoSigLIPEncoder()
ve.load_state_dict(ckpt["model"]["vision_backbone"])
# ... (see model_wrapper.py for full example)
output = vla(input_ids, attention_mask, pixel_values, labels=labels, output_hidden_states=True)
actions_hidden_states = extract_actions(output.hidden_states[-1], labels)
predicted_actions = omnivla_action_head.predict_action(actions_hidden_states, modality_id)
```
## Architecture Diagram
```
Image (224Γ224)
βββ DINOv2-S/14 β 256 patches Γ 384d β ShimMLP(384β8704)
βββ SigLIP-B/16 β 196 patches Γ 768d β ShimMLP(768β8704)
βββ Concat (452 tokens) β Linear(8704β896) β GELU β Linear(896β896)
βββ Qwen2.5 0.5B (24 layers, 896 hidden)
βββ Hidden State Shim (896β2048β4096)
βββ OmniVLA Action Head (pretrained, frozen)
βββ 8 chunks Γ 7-DoF actions
```
## Files
| File | Size | Description |
|---|---|---|
| `modeling_openvla_micro.py` | 15 KB | Model definitions |
| `model_wrapper.py` | 11 KB | OmniVLA-compatible interface |
| `inference.py` | 1.5 KB | GPU/CPU CLI inference |
| `inference_cpu.py` | 2 KB | Edge device inference (with low-RAM mode) |
| `train_shim.py` | 15 KB | Reference shim training script |
| `config.json` | 1.2 KB | Model configuration |
| `openvla-micro-merged.pt` | 1.6 GB | Base checkpoint (no shim, 896-dim output) |
| `openvla-micro-distill.pt` | 1.6 GB | Full checkpoint (with baked-in shim, 4096-dim) |
**Which checkpoint to use?**
- `openvla-micro-distill.pt` β **recommended**. Outputs 4096-dim hidden states that plug directly into OmniVLA's action head. One-step inference.
- `openvla-micro-merged.pt` β base model only (896-dim). Use if you want to train your own shim or action head.
## Requirements
```
torch>=2.0.0
torchvision>=0.15.0
transformers>=4.38.0
timm>=0.9.0
Pillow>=10.0.0
numpy>=1.24.0
```
For low-RAM CPU: `bitsandbytes>=0.43.0`
## Training the Shim
```bash
python train_shim.py \
--cache-dir ./teacher_cache \
--data-dir ./dataset \
--base-model openvla-micro-merged.pt \
--teacher-dim 4096
```
See `train_shim.py` for full options. The script expects pre-cached teacher hidden states; adapt `DistillDataset` to your format.
## License
MIT
|