theguy21 commited on
Commit
2e0542d
Β·
verified Β·
1 Parent(s): bd89217

Set proper README with YAML metadata

Browse files
Files changed (1) hide show
  1. README.md +163 -3
README.md CHANGED
@@ -1,7 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # OpenVLA-Micro
2
 
3
- A drop-in replacement for OmniVLA 7B that runs ~14Γ— faster with 0.997 action cosine similarity.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- Swaps DINOv2-L + SigLIP-so400m + Llama-2 7B β†’ DINOv2-S/14 + SigLIP-B/16 + Qwen2.5 0.5B (~0.6B total), with a learned 896β†’4096 shim to stay compatible with OmniVLA's action head.
6
 
7
- See [HF_README.md](HF_README.md) or https://huggingface.co/theguy21/openvla-micro for full details.
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ library_name: transformers
6
+ pipeline_tag: reinforcement-learning
7
+ tags:
8
+ - robotics
9
+ - vla
10
+ - vision-language-action
11
+ - openvla
12
+ - omnivla
13
+ - robot
14
+ - qwen
15
+ - dinov2
16
+ - siglip
17
+ datasets:
18
+ - libero_90
19
+ - cast
20
+ model-index:
21
+ - name: openvla-micro
22
+ results: []
23
+ ---
24
+
25
  # OpenVLA-Micro
26
 
27
+ **A drop-in replacement for OmniVLA 7B that runs ~14Γ— faster with 0.997 action cosine similarity.**
28
+
29
+ OpenVLA-Micro is a compact Vision-Language-Action model that replaces OmniVLA 7B (DINOv2-L + SigLIP-so400m + Llama-2 7B) with a much smaller stack (DINOv2-S/14 + SigLIP-B/16 + Qwen2.5 0.5B) while preserving compatibility with OmniVLA's pretrained action head through a learned hidden-state shim (896β†’4096).
30
+
31
+ ## Model Architecture
32
+
33
+ | Component | Encoder | Output Dim |
34
+ |---|---|---|
35
+ | Vision (DINO) | `DINOv2-S/14` (facebook/dinov2-small) | 256 tokens Γ— 384d β†’ MLP β†’ 8704 |
36
+ | Vision (SigLIP) | `SigLIP-B/16` (google/siglip-base-patch16-224) | 196 tokens Γ— 768d β†’ MLP β†’ 8704 |
37
+ | Projector | Linear(8704β†’896) + GELU + Linear(896β†’896) | 452 tokens Γ— 896d |
38
+ | LLM | `Qwen2.5-0.5B` (24 layers, 896 hidden, 8 heads) | Variable Γ— 896d |
39
+ | Shim | Linear(896β†’2048) + GELU + Linear(2048β†’4096) | 32 action tokens Γ— 4096d |
40
+ | Action Head | OmniVLA's pretrained head (unchanged) | 8 chunks Γ— 7-DoF |
41
+
42
+ Total parameters: **~0.6B** (vs 7B for OmniVLA/OpenVLA).
43
+
44
+ ## Performance
45
+
46
+ ### vs OmniVLA 7B (teacher)
47
+
48
+ | Metric | Value | Note |
49
+ |---|---|---|
50
+ | Hidden state cosine | **0.63** | Last-layer HS at action positions |
51
+ | Action cosine | **0.997** | After OmniVLA action head |
52
+ | Action MSE | ~0.001 | Effectively identical predictions |
53
+ | Inference speed | **~14Γ— faster** | 0.5B vs 7B LLM |
54
+
55
+ Trained on 1000 CAST episodes (17k steps) via hidden-state distillation. The shim was trained for ~21k steps (plateaued at ~8k).
56
+
57
+ ### vs OpenVLA 7B (original)
58
+
59
+ OpenVLA-Micro is distilled from *OmniVLA*, which itself fine-tuned *OpenVLA* with 32 action tokens and a modified action head. Direct comparison with the original OpenVLA is not apples-to-apples due to different action tokenization, but the ~14Γ— speedup and near-lossless action quality relative to OmniVLA apply similarly vs OpenVLA.
60
+
61
+ ## Quick Start
62
+
63
+ ```python
64
+ from PIL import Image
65
+ from modeling_openvla_micro import OpenVLAMicro
66
+
67
+ model = OpenVLAMicro.from_pretrained("theguy21/openvla-micro", device="cuda")
68
+ model.eval()
69
+
70
+ image = Image.open("demo.jpg").convert("RGB")
71
+ action = model.predict_action(image, "pick up the red block")
72
+ print(action) # [0.12, -0.03, 0.45, -0.01, 0.22, 0.08, -0.15]
73
+ ```
74
+
75
+ ### CLI β€” GPU
76
+
77
+ ```bash
78
+ python inference.py --image demo.jpg "pick up the red block"
79
+ ```
80
+
81
+ ### CLI β€” CPU / Edge
82
+
83
+ ```bash
84
+ # Standard CPU (~6GB RAM, 3-5 sec/step)
85
+ python inference_cpu.py --image demo.jpg "pick up the red block"
86
+
87
+ # Low-RAM CPU (~2.5GB RAM, requires bitsandbytes)
88
+ python inference_cpu.py --low-ram --image demo.jpg "pick up the red block"
89
+ ```
90
+
91
+ ### As an OmniVLA drop-in replacement
92
+
93
+ Use `OpenVLAMicroWrapper` (from `model_wrapper.py`) to expose the same forward interface as OmniVLA's `VLAForActionPrediction`:
94
+
95
+ ```python
96
+ from model_wrapper import OpenVLAMicroWrapper
97
+ from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP
98
+
99
+ ckpt = torch.load("openvla-micro-distill.pt", map_location="cpu")
100
+ ve = DinoSigLIPEncoder()
101
+ ve.load_state_dict(ckpt["model"]["vision_backbone"])
102
+ # ... (see model_wrapper.py for full example)
103
+
104
+ output = vla(input_ids, attention_mask, pixel_values, labels=labels, output_hidden_states=True)
105
+ actions_hidden_states = extract_actions(output.hidden_states[-1], labels)
106
+ predicted_actions = omnivla_action_head.predict_action(actions_hidden_states, modality_id)
107
+ ```
108
+
109
+ ## Architecture Diagram
110
+
111
+ ```
112
+ Image (224Γ—224)
113
+ β”œβ”€β”€ DINOv2-S/14 β†’ 256 patches Γ— 384d β†’ ShimMLP(384β†’8704)
114
+ └── SigLIP-B/16 β†’ 196 patches Γ— 768d β†’ ShimMLP(768β†’8704)
115
+ └── Concat (452 tokens) β†’ Linear(8704β†’896) β†’ GELU β†’ Linear(896β†’896)
116
+ └── Qwen2.5 0.5B (24 layers, 896 hidden)
117
+ └── Hidden State Shim (896β†’2048β†’4096)
118
+ └── OmniVLA Action Head (pretrained, frozen)
119
+ └── 8 chunks Γ— 7-DoF actions
120
+ ```
121
+
122
+ ## Files
123
+
124
+ | File | Size | Description |
125
+ |---|---|---|
126
+ | `modeling_openvla_micro.py` | 15 KB | Model definitions |
127
+ | `model_wrapper.py` | 11 KB | OmniVLA-compatible interface |
128
+ | `inference.py` | 1.5 KB | GPU/CPU CLI inference |
129
+ | `inference_cpu.py` | 2 KB | Edge device inference (with low-RAM mode) |
130
+ | `train_shim.py` | 15 KB | Reference shim training script |
131
+ | `config.json` | 1.2 KB | Model configuration |
132
+ | `openvla-micro-merged.pt` | 1.6 GB | Base checkpoint (no shim, 896-dim output) |
133
+ | `openvla-micro-distill.pt` | 1.6 GB | Full checkpoint (with baked-in shim, 4096-dim) |
134
+
135
+ **Which checkpoint to use?**
136
+
137
+ - `openvla-micro-distill.pt` β€” **recommended**. Outputs 4096-dim hidden states that plug directly into OmniVLA's action head. One-step inference.
138
+ - `openvla-micro-merged.pt` β€” base model only (896-dim). Use if you want to train your own shim or action head.
139
+
140
+ ## Requirements
141
+
142
+ ```
143
+ torch>=2.0.0
144
+ torchvision>=0.15.0
145
+ transformers>=4.38.0
146
+ timm>=0.9.0
147
+ Pillow>=10.0.0
148
+ numpy>=1.24.0
149
+ ```
150
+
151
+ For low-RAM CPU: `bitsandbytes>=0.43.0`
152
+
153
+ ## Training the Shim
154
+
155
+ ```bash
156
+ python train_shim.py \
157
+ --cache-dir ./teacher_cache \
158
+ --data-dir ./dataset \
159
+ --base-model openvla-micro-merged.pt \
160
+ --teacher-dim 4096
161
+ ```
162
+
163
+ See `train_shim.py` for full options. The script expects pre-cached teacher hidden states; adapt `DistillDataset` to your format.
164
 
165
+ ## License
166
 
167
+ MIT