--- license: mit language: - en library_name: transformers pipeline_tag: reinforcement-learning tags: - robotics - vla - vision-language-action - openvla - omnivla - robot - qwen - dinov2 - siglip datasets: - libero_90 - cast model-index: - name: openvla-micro results: [] --- # OpenVLA-Micro **A drop-in replacement for OmniVLA 7B that runs ~14× faster with 0.997 action cosine similarity.** OpenVLA-Micro is a compact Vision-Language-Action model that replaces OmniVLA 7B (DINOv2-L + SigLIP-so400m + Llama-2 7B) with a much smaller stack (DINOv2-S/14 + SigLIP-B/16 + Qwen2.5 0.5B) while preserving compatibility with OmniVLA's pretrained action head through a learned hidden-state shim (896→4096). ## Model Architecture | Component | Encoder | Output Dim | |---|---|---| | Vision (DINO) | `DINOv2-S/14` (facebook/dinov2-small) | 256 tokens × 384d → MLP → 8704 | | Vision (SigLIP) | `SigLIP-B/16` (google/siglip-base-patch16-224) | 196 tokens × 768d → MLP → 8704 | | Projector | Linear(8704→896) + GELU + Linear(896→896) | 452 tokens × 896d | | LLM | `Qwen2.5-0.5B` (24 layers, 896 hidden, 8 heads) | Variable × 896d | | Shim | Linear(896→2048) + GELU + Linear(2048→4096) | 32 action tokens × 4096d | | Action Head | OmniVLA's pretrained head (unchanged) | 8 chunks × 7-DoF | Total parameters: **~0.6B** (vs 7B for OmniVLA/OpenVLA). ## Performance ### vs OmniVLA 7B (teacher) | Metric | Value | Note | |---|---|---| | Hidden state cosine | **0.63** | Last-layer HS at action positions | | Action cosine | **0.997** | After OmniVLA action head | | Action MSE | ~0.001 | Effectively identical predictions | | Inference speed | **~14× faster** | 0.5B vs 7B LLM | Trained on 1000 CAST episodes (17k steps) via hidden-state distillation. The shim was trained for ~21k steps (plateaued at ~8k). ### vs OpenVLA 7B (original) OpenVLA-Micro is distilled from *OmniVLA*, which itself fine-tuned *OpenVLA* with 32 action tokens and a modified action head. Direct comparison with the original OpenVLA is not apples-to-apples due to different action tokenization, but the ~14× speedup and near-lossless action quality relative to OmniVLA apply similarly vs OpenVLA. ## Quick Start ```python from PIL import Image from modeling_openvla_micro import OpenVLAMicro model = OpenVLAMicro.from_pretrained("theguy21/openvla-micro", device="cuda") model.eval() image = Image.open("demo.jpg").convert("RGB") action = model.predict_action(image, "pick up the red block") print(action) # [0.12, -0.03, 0.45, -0.01, 0.22, 0.08, -0.15] ``` ### CLI — GPU ```bash python inference.py --image demo.jpg "pick up the red block" ``` ### CLI — CPU / Edge ```bash # Standard CPU (~6GB RAM, 3-5 sec/step) python inference_cpu.py --image demo.jpg "pick up the red block" # Low-RAM CPU (~2.5GB RAM, requires bitsandbytes) python inference_cpu.py --low-ram --image demo.jpg "pick up the red block" ``` ### As an OmniVLA drop-in replacement Use `OpenVLAMicroWrapper` (from `model_wrapper.py`) to expose the same forward interface as OmniVLA's `VLAForActionPrediction`: ```python from model_wrapper import OpenVLAMicroWrapper from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP ckpt = torch.load("openvla-micro-distill.pt", map_location="cpu") ve = DinoSigLIPEncoder() ve.load_state_dict(ckpt["model"]["vision_backbone"]) # ... (see model_wrapper.py for full example) output = vla(input_ids, attention_mask, pixel_values, labels=labels, output_hidden_states=True) actions_hidden_states = extract_actions(output.hidden_states[-1], labels) predicted_actions = omnivla_action_head.predict_action(actions_hidden_states, modality_id) ``` ## Architecture Diagram ``` Image (224×224) ├── DINOv2-S/14 → 256 patches × 384d → ShimMLP(384→8704) └── SigLIP-B/16 → 196 patches × 768d → ShimMLP(768→8704) └── Concat (452 tokens) → Linear(8704→896) → GELU → Linear(896→896) └── Qwen2.5 0.5B (24 layers, 896 hidden) └── Hidden State Shim (896→2048→4096) └── OmniVLA Action Head (pretrained, frozen) └── 8 chunks × 7-DoF actions ``` ## Files | File | Size | Description | |---|---|---| | `modeling_openvla_micro.py` | 15 KB | Model definitions | | `model_wrapper.py` | 11 KB | OmniVLA-compatible interface | | `inference.py` | 1.5 KB | GPU/CPU CLI inference | | `inference_cpu.py` | 2 KB | Edge device inference (with low-RAM mode) | | `train_shim.py` | 15 KB | Reference shim training script | | `config.json` | 1.2 KB | Model configuration | | `openvla-micro-merged.pt` | 1.6 GB | Base checkpoint (no shim, 896-dim output) | | `openvla-micro-distill.pt` | 1.6 GB | Full checkpoint (with baked-in shim, 4096-dim) | **Which checkpoint to use?** - `openvla-micro-distill.pt` — **recommended**. Outputs 4096-dim hidden states that plug directly into OmniVLA's action head. One-step inference. - `openvla-micro-merged.pt` — base model only (896-dim). Use if you want to train your own shim or action head. ## Requirements ``` torch>=2.0.0 torchvision>=0.15.0 transformers>=4.38.0 timm>=0.9.0 Pillow>=10.0.0 numpy>=1.24.0 ``` For low-RAM CPU: `bitsandbytes>=0.43.0` ## Training the Shim ```bash python train_shim.py \ --cache-dir ./teacher_cache \ --data-dir ./dataset \ --base-model openvla-micro-merged.pt \ --teacher-dim 4096 ``` See `train_shim.py` for full options. The script expects pre-cached teacher hidden states; adapt `DistillDataset` to your format. ## License MIT