Reinforcement Learning
Transformers
English
robotics
vla
vision-language-action
openvla
omnivla
robot
qwen
dinov2
siglip
Instructions to use theguy21/openvla-micro with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use theguy21/openvla-micro with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("theguy21/openvla-micro", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| license: mit | |
| language: | |
| - en | |
| library_name: transformers | |
| pipeline_tag: reinforcement-learning | |
| tags: | |
| - robotics | |
| - vla | |
| - vision-language-action | |
| - openvla | |
| - omnivla | |
| - robot | |
| - qwen | |
| - dinov2 | |
| - siglip | |
| datasets: | |
| - libero_90 | |
| - cast | |
| model-index: | |
| - name: openvla-micro | |
| results: [] | |
| # OpenVLA-Micro | |
| **A drop-in replacement for OmniVLA 7B that runs ~14Γ faster with 0.997 action cosine similarity.** | |
| OpenVLA-Micro is a compact Vision-Language-Action model that replaces OmniVLA 7B (DINOv2-L + SigLIP-so400m + Llama-2 7B) with a much smaller stack (DINOv2-S/14 + SigLIP-B/16 + Qwen2.5 0.5B) while preserving compatibility with OmniVLA's pretrained action head through a learned hidden-state shim (896β4096). | |
| ## Model Architecture | |
| | Component | Encoder | Output Dim | | |
| |---|---|---| | |
| | Vision (DINO) | `DINOv2-S/14` (facebook/dinov2-small) | 256 tokens Γ 384d β MLP β 8704 | | |
| | Vision (SigLIP) | `SigLIP-B/16` (google/siglip-base-patch16-224) | 196 tokens Γ 768d β MLP β 8704 | | |
| | Projector | Linear(8704β896) + GELU + Linear(896β896) | 452 tokens Γ 896d | | |
| | LLM | `Qwen2.5-0.5B` (24 layers, 896 hidden, 8 heads) | Variable Γ 896d | | |
| | Shim | Linear(896β2048) + GELU + Linear(2048β4096) | 32 action tokens Γ 4096d | | |
| | Action Head | OmniVLA's pretrained head (unchanged) | 8 chunks Γ 7-DoF | | |
| Total parameters: **~0.6B** (vs 7B for OmniVLA/OpenVLA). | |
| ## Performance | |
| ### vs OmniVLA 7B (teacher) | |
| | Metric | Value | Note | | |
| |---|---|---| | |
| | Hidden state cosine | **0.63** | Last-layer HS at action positions | | |
| | Action cosine | **0.997** | After OmniVLA action head | | |
| | Action MSE | ~0.001 | Effectively identical predictions | | |
| | Inference speed | **~14Γ faster** | 0.5B vs 7B LLM | | |
| Trained on 1000 CAST episodes (17k steps) via hidden-state distillation. The shim was trained for ~21k steps (plateaued at ~8k). | |
| ### vs OpenVLA 7B (original) | |
| OpenVLA-Micro is distilled from *OmniVLA*, which itself fine-tuned *OpenVLA* with 32 action tokens and a modified action head. Direct comparison with the original OpenVLA is not apples-to-apples due to different action tokenization, but the ~14Γ speedup and near-lossless action quality relative to OmniVLA apply similarly vs OpenVLA. | |
| ## Quick Start | |
| ```python | |
| from PIL import Image | |
| from modeling_openvla_micro import OpenVLAMicro | |
| model = OpenVLAMicro.from_pretrained("theguy21/openvla-micro", device="cuda") | |
| model.eval() | |
| image = Image.open("demo.jpg").convert("RGB") | |
| action = model.predict_action(image, "pick up the red block") | |
| print(action) # [0.12, -0.03, 0.45, -0.01, 0.22, 0.08, -0.15] | |
| ``` | |
| ### CLI β GPU | |
| ```bash | |
| python inference.py --image demo.jpg "pick up the red block" | |
| ``` | |
| ### CLI β CPU / Edge | |
| ```bash | |
| # Standard CPU (~6GB RAM, 3-5 sec/step) | |
| python inference_cpu.py --image demo.jpg "pick up the red block" | |
| # Low-RAM CPU (~2.5GB RAM, requires bitsandbytes) | |
| python inference_cpu.py --low-ram --image demo.jpg "pick up the red block" | |
| ``` | |
| ### As an OmniVLA drop-in replacement | |
| Use `OpenVLAMicroWrapper` (from `model_wrapper.py`) to expose the same forward interface as OmniVLA's `VLAForActionPrediction`: | |
| ```python | |
| from model_wrapper import OpenVLAMicroWrapper | |
| from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP | |
| ckpt = torch.load("openvla-micro-distill.pt", map_location="cpu") | |
| ve = DinoSigLIPEncoder() | |
| ve.load_state_dict(ckpt["model"]["vision_backbone"]) | |
| # ... (see model_wrapper.py for full example) | |
| output = vla(input_ids, attention_mask, pixel_values, labels=labels, output_hidden_states=True) | |
| actions_hidden_states = extract_actions(output.hidden_states[-1], labels) | |
| predicted_actions = omnivla_action_head.predict_action(actions_hidden_states, modality_id) | |
| ``` | |
| ## Architecture Diagram | |
| ``` | |
| Image (224Γ224) | |
| βββ DINOv2-S/14 β 256 patches Γ 384d β ShimMLP(384β8704) | |
| βββ SigLIP-B/16 β 196 patches Γ 768d β ShimMLP(768β8704) | |
| βββ Concat (452 tokens) β Linear(8704β896) β GELU β Linear(896β896) | |
| βββ Qwen2.5 0.5B (24 layers, 896 hidden) | |
| βββ Hidden State Shim (896β2048β4096) | |
| βββ OmniVLA Action Head (pretrained, frozen) | |
| βββ 8 chunks Γ 7-DoF actions | |
| ``` | |
| ## Files | |
| | File | Size | Description | | |
| |---|---|---| | |
| | `modeling_openvla_micro.py` | 15 KB | Model definitions | | |
| | `model_wrapper.py` | 11 KB | OmniVLA-compatible interface | | |
| | `inference.py` | 1.5 KB | GPU/CPU CLI inference | | |
| | `inference_cpu.py` | 2 KB | Edge device inference (with low-RAM mode) | | |
| | `train_shim.py` | 15 KB | Reference shim training script | | |
| | `config.json` | 1.2 KB | Model configuration | | |
| | `openvla-micro-merged.pt` | 1.6 GB | Base checkpoint (no shim, 896-dim output) | | |
| | `openvla-micro-distill.pt` | 1.6 GB | Full checkpoint (with baked-in shim, 4096-dim) | | |
| **Which checkpoint to use?** | |
| - `openvla-micro-distill.pt` β **recommended**. Outputs 4096-dim hidden states that plug directly into OmniVLA's action head. One-step inference. | |
| - `openvla-micro-merged.pt` β base model only (896-dim). Use if you want to train your own shim or action head. | |
| ## Requirements | |
| ``` | |
| torch>=2.0.0 | |
| torchvision>=0.15.0 | |
| transformers>=4.38.0 | |
| timm>=0.9.0 | |
| Pillow>=10.0.0 | |
| numpy>=1.24.0 | |
| ``` | |
| For low-RAM CPU: `bitsandbytes>=0.43.0` | |
| ## Training the Shim | |
| ```bash | |
| python train_shim.py \ | |
| --cache-dir ./teacher_cache \ | |
| --data-dir ./dataset \ | |
| --base-model openvla-micro-merged.pt \ | |
| --teacher-dim 4096 | |
| ``` | |
| See `train_shim.py` for full options. The script expects pre-cached teacher hidden states; adapt `DistillDataset` to your format. | |
| ## License | |
| MIT | |