theguy21 commited on
Commit
bd89217
·
verified ·
1 Parent(s): 0021e33

Add CPU inference script, update README with model details and perf stats

Browse files
Files changed (9) hide show
  1. .gitattributes +0 -34
  2. .gitignore +3 -0
  3. HF_README.md +167 -0
  4. README.md +3 -139
  5. inference.py +28 -18
  6. inference_cpu.py +61 -0
  7. modeling_openvla_micro.py +5 -2
  8. pyproject.toml +2 -1
  9. train_shim.py +205 -292
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore CHANGED
@@ -11,4 +11,7 @@ dist/
11
  *.egg-info/
12
  .venv/
13
  venv/
 
 
 
14
 
 
11
  *.egg-info/
12
  .venv/
13
  venv/
14
+ *.pt
15
+ *.bin
16
+ *.safetensors
17
 
HF_README.md ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ library_name: transformers
6
+ pipeline_tag: reinforcement-learning
7
+ tags:
8
+ - robotics
9
+ - vla
10
+ - vision-language-action
11
+ - openvla
12
+ - omnivla
13
+ - robot
14
+ - qwen
15
+ - dinov2
16
+ - siglip
17
+ datasets:
18
+ - libero_90
19
+ - cast
20
+ model-index:
21
+ - name: openvla-micro
22
+ results: []
23
+ ---
24
+
25
+ # OpenVLA-Micro
26
+
27
+ **A drop-in replacement for OmniVLA 7B that runs ~14× faster with 0.997 action cosine similarity.**
28
+
29
+ OpenVLA-Micro is a compact Vision-Language-Action model that replaces OmniVLA 7B (DINOv2-L + SigLIP-so400m + Llama-2 7B) with a much smaller stack (DINOv2-S/14 + SigLIP-B/16 + Qwen2.5 0.5B) while preserving compatibility with OmniVLA's pretrained action head through a learned hidden-state shim (896→4096).
30
+
31
+ ## Model Architecture
32
+
33
+ | Component | Encoder | Output Dim |
34
+ |---|---|---|
35
+ | Vision (DINO) | `DINOv2-S/14` (facebook/dinov2-small) | 256 tokens × 384d → MLP → 8704 |
36
+ | Vision (SigLIP) | `SigLIP-B/16` (google/siglip-base-patch16-224) | 196 tokens × 768d → MLP → 8704 |
37
+ | Projector | Linear(8704→896) + GELU + Linear(896→896) | 452 tokens × 896d |
38
+ | LLM | `Qwen2.5-0.5B` (24 layers, 896 hidden, 8 heads) | Variable × 896d |
39
+ | Shim | Linear(896→2048) + GELU + Linear(2048→4096) | 32 action tokens × 4096d |
40
+ | Action Head | OmniVLA's pretrained head (unchanged) | 8 chunks × 7-DoF |
41
+
42
+ Total parameters: **~0.6B** (vs 7B for OmniVLA/OpenVLA).
43
+
44
+ ## Performance
45
+
46
+ ### vs OmniVLA 7B (teacher)
47
+
48
+ | Metric | Value | Note |
49
+ |---|---|---|
50
+ | Hidden state cosine | **0.63** | Last-layer HS at action positions |
51
+ | Action cosine | **0.997** | After OmniVLA action head |
52
+ | Action MSE | ~0.001 | Effectively identical predictions |
53
+ | Inference speed | **~14× faster** | 0.5B vs 7B LLM |
54
+
55
+ Trained on 1000 CAST episodes (17k steps) via hidden-state distillation. The shim was trained for ~21k steps (plateaued at ~8k).
56
+
57
+ ### vs OpenVLA 7B (original)
58
+
59
+ OpenVLA-Micro is distilled from *OmniVLA*, which itself fine-tuned *OpenVLA* with 32 action tokens and a modified action head. Direct comparison with the original OpenVLA is not apples-to-apples due to different action tokenization, but the ~14× speedup and near-lossless action quality relative to OmniVLA apply similarly vs OpenVLA.
60
+
61
+ ## Quick Start
62
+
63
+ ```python
64
+ from PIL import Image
65
+ from modeling_openvla_micro import OpenVLAMicro
66
+
67
+ model = OpenVLAMicro.from_pretrained("theguy21/openvla-micro", device="cuda")
68
+ model.eval()
69
+
70
+ image = Image.open("demo.jpg").convert("RGB")
71
+ action = model.predict_action(image, "pick up the red block")
72
+ print(action) # [0.12, -0.03, 0.45, -0.01, 0.22, 0.08, -0.15]
73
+ ```
74
+
75
+ ### CLI — GPU
76
+
77
+ ```bash
78
+ python inference.py --image demo.jpg "pick up the red block"
79
+ ```
80
+
81
+ ### CLI — CPU / Edge
82
+
83
+ ```bash
84
+ # Standard CPU (~6GB RAM, 3-5 sec/step)
85
+ python inference_cpu.py --image demo.jpg "pick up the red block"
86
+
87
+ # Low-RAM CPU (~2.5GB RAM, requires bitsandbytes)
88
+ python inference_cpu.py --low-ram --image demo.jpg "pick up the red block"
89
+ ```
90
+
91
+ ### As an OmniVLA drop-in replacement
92
+
93
+ Use `OpenVLAMicroWrapper` (from `model_wrapper.py`) to expose the same forward interface as OmniVLA's `VLAForActionPrediction`:
94
+
95
+ ```python
96
+ from model_wrapper import OpenVLAMicroWrapper
97
+ from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP
98
+
99
+ ckpt = torch.load("openvla-micro-distill.pt", map_location="cpu")
100
+ ve = DinoSigLIPEncoder()
101
+ ve.load_state_dict(ckpt["model"]["vision_backbone"])
102
+ # ... (see model_wrapper.py for full example)
103
+
104
+ output = vla(input_ids, attention_mask, pixel_values, labels=labels, output_hidden_states=True)
105
+ actions_hidden_states = extract_actions(output.hidden_states[-1], labels)
106
+ predicted_actions = omnivla_action_head.predict_action(actions_hidden_states, modality_id)
107
+ ```
108
+
109
+ ## Architecture Diagram
110
+
111
+ ```
112
+ Image (224×224)
113
+ ├── DINOv2-S/14 → 256 patches × 384d → ShimMLP(384→8704)
114
+ └── SigLIP-B/16 → 196 patches × 768d → ShimMLP(768→8704)
115
+ └── Concat (452 tokens) → Linear(8704→896) → GELU → Linear(896→896)
116
+ └── Qwen2.5 0.5B (24 layers, 896 hidden)
117
+ └── Hidden State Shim (896→2048→4096)
118
+ └── OmniVLA Action Head (pretrained, frozen)
119
+ └── 8 chunks × 7-DoF actions
120
+ ```
121
+
122
+ ## Files
123
+
124
+ | File | Size | Description |
125
+ |---|---|---|
126
+ | `modeling_openvla_micro.py` | 15 KB | Model definitions |
127
+ | `model_wrapper.py` | 11 KB | OmniVLA-compatible interface |
128
+ | `inference.py` | 1.5 KB | GPU/CPU CLI inference |
129
+ | `inference_cpu.py` | 2 KB | Edge device inference (with low-RAM mode) |
130
+ | `train_shim.py` | 15 KB | Reference shim training script |
131
+ | `config.json` | 1.2 KB | Model configuration |
132
+ | `openvla-micro-merged.pt` | 1.6 GB | Base checkpoint (no shim, 896-dim output) |
133
+ | `openvla-micro-distill.pt` | 1.6 GB | Full checkpoint (with baked-in shim, 4096-dim) |
134
+
135
+ **Which checkpoint to use?**
136
+
137
+ - `openvla-micro-distill.pt` — **recommended**. Outputs 4096-dim hidden states that plug directly into OmniVLA's action head. One-step inference.
138
+ - `openvla-micro-merged.pt` — base model only (896-dim). Use if you want to train your own shim or action head.
139
+
140
+ ## Requirements
141
+
142
+ ```
143
+ torch>=2.0.0
144
+ torchvision>=0.15.0
145
+ transformers>=4.38.0
146
+ timm>=0.9.0
147
+ Pillow>=10.0.0
148
+ numpy>=1.24.0
149
+ ```
150
+
151
+ For low-RAM CPU: `bitsandbytes>=0.43.0`
152
+
153
+ ## Training the Shim
154
+
155
+ ```bash
156
+ python train_shim.py \
157
+ --cache-dir ./teacher_cache \
158
+ --data-dir ./dataset \
159
+ --base-model openvla-micro-merged.pt \
160
+ --teacher-dim 4096
161
+ ```
162
+
163
+ See `train_shim.py` for full options. The script expects pre-cached teacher hidden states; adapt `DistillDataset` to your format.
164
+
165
+ ## License
166
+
167
+ MIT
README.md CHANGED
@@ -1,143 +1,7 @@
1
- ---
2
- license: mit
3
- language:
4
- - en
5
- library_name: transformers
6
- pipeline_tag: reinforcement-learning
7
- tags:
8
- - robotics
9
- - vla
10
- - vision-language-action
11
- - openvla
12
- - omnivla
13
- - robot
14
- - qwen
15
- - dinov2
16
- - siglip
17
- datasets:
18
- - libero_90
19
- - cast
20
- model-index:
21
- - name: openvla-micro
22
- results: []
23
- ---
24
-
25
  # OpenVLA-Micro
26
 
27
- **A drop-in replacement for OmniVLA 7B that runs ~14× faster with 0.997 action cosine similarity.**
28
-
29
- OpenVLA-Micro is a compact Vision-Language-Action model that replaces the bulky OmniVLA 7B architecture (DINOv2-L + SigLIP-so400m + Llama-2 7B) with a much smaller stack (DINOv2-S + SigLIP-B + Qwen2.5 0.5B) while preserving compatibility with OmniVLA's action head through a learned hidden state shim (896→4096).
30
-
31
- | Property | OpenVLA-Micro | OmniVLA 7B |
32
- |---|---|---|
33
- | Vision encoder | DINOv2-S (384d) + SigLIP-B/16 (768d) | DINOv2-L (1024d) + SigLIP-so400m (1152d) |
34
- | LLM | Qwen2.5 0.5B (896 hidden) | Llama-2 7B (4096 hidden) |
35
- | Total params | ~0.6B | ~7B |
36
- | Hidden state dim | 4096 (via learned shim) | 4096 (native) |
37
- | Action head | OmniVLA compatible | Native |
38
- | Action cos similarity | **0.997** vs teacher | 1.0 (reference) |
39
-
40
- ## Performance
41
-
42
- Trained on 1000 CAST episodes (17k steps) via hidden-state distillation from OmniVLA 7B:
43
-
44
- - **Hidden state cosine**: 0.63 vs teacher
45
- - **Action cosine**: 0.997 vs teacher
46
- - **Action MSE**: near-zero (effectively identical predictions)
47
-
48
- ## Quick Start
49
-
50
- ```python
51
- from PIL import Image
52
- from modeling_openvla_micro import OpenVLAMicro
53
-
54
- model = OpenVLAMicro.from_pretrained("theguy21/openvla-micro", device="cuda")
55
- model.eval()
56
-
57
- image = Image.open("demo.jpg").convert("RGB")
58
- action = model.predict_action(image, "pick up the red block")
59
- print(action)
60
- ```
61
-
62
- ### CLI
63
-
64
- ```bash
65
- python inference.py --checkpoint theguy21/openvla-micro --image demo.jpg "pick up the red block"
66
- ```
67
-
68
- ### As an OmniVLA drop-in replacement
69
-
70
- Use `OpenVLAMicroWrapper` (from `model_wrapper.py`) to expose the same forward interface as OmniVLA's `VLAForActionPrediction`:
71
-
72
- ```python
73
- from model_wrapper import OpenVLAMicroWrapper
74
- from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP
75
-
76
- # Load components from checkpoint
77
- ckpt = torch.load("openvla-micro-distill.pt", map_location="cpu")
78
- ve = DinoSigLIPEncoder()
79
- ve.load_state_dict(ckpt["model"]["vision_backbone"])
80
- # ... (see model_wrapper.py for full example)
81
-
82
- # Forward pass compatible with OmniVLA action head
83
- output = vla(input_ids, attention_mask, pixel_values, labels=labels, output_hidden_states=True)
84
- actions_hidden_states = extract_actions(output.hidden_states[-1], labels)
85
- predicted_actions = omnivla_action_head.predict_action(actions_hidden_states, modality_id)
86
- ```
87
-
88
- ## Architecture
89
-
90
- ```
91
- Image (224×224)
92
- ├── DINOv2-S/14 → 256 patches × 384d → ShimMLP(384→8704)
93
- └── SigLIP-B/16 → 196 patches × 768d → ShimMLP(768→8704)
94
- └── Concat (452 tokens) → Linear(8704→896) → GELU → Linear(896→896)
95
- └── Qwen2.5 0.5B (24 layers, 896 hidden)
96
- └── Hidden State Shim (896→2048→4096)
97
- └── OmniVLA Action Head (pretrained)
98
- └── 8 chunks × 4-DoF actions
99
- ```
100
-
101
- ## Files
102
-
103
- | File | Size | Description |
104
- |---|---|---|
105
- | `modeling_openvla_micro.py` | 15 KB | Model definitions (DinoSigLIPEncoder, CombinedProjector, ShimMLP, OpenVLAMicro) |
106
- | `model_wrapper.py` | 11 KB | OmniVLA-compatible wrapper (OpenVLAMicroWrapper) |
107
- | `inference.py` | 1.5 KB | Standalone CLI inference |
108
- | `config.json` | 1.2 KB | Model configuration |
109
- | `openvla-micro-distill.pt` | 1.6 GB | Full checkpoint with baked-in shim |
110
-
111
- ## Requirements
112
-
113
- ```
114
- torch>=2.0.0
115
- torchvision>=0.15.0
116
- transformers>=4.38.0
117
- timm>=0.9.0
118
- Pillow>=10.0.0
119
- numpy>=1.24.0
120
- ```
121
-
122
- ## Training Details
123
-
124
- The shim (896→2048→4096 MLP) was trained to minimize MSE between the Qwen2.5 0.5B last hidden state and the cached OmniVLA 7B hidden states, keeping all other components frozen. Training used 1000 CAST episodes (17k steps) with bf16 precision on a single 24GB GPU.
125
-
126
- - **Optimizer**: AdamW (lr=5e-5, cosine schedule)
127
- - **Batch**: 8 micro-batch × 4 grad accum = 32 effective
128
- - **Training steps**: 21k (plateaued at step ~8k)
129
- - **Precision**: bfloat16
130
-
131
- ## Citation
132
-
133
- ```bibtex
134
- @misc{openvla-micro-2026,
135
- title = {OpenVLA-Micro: A Compact Drop-in Replacement for OmniVLA 7B},
136
- author = {},
137
- year = {2026},
138
- }
139
- ```
140
 
141
- ## License
142
 
143
- MIT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # OpenVLA-Micro
2
 
3
+ A drop-in replacement for OmniVLA 7B that runs ~14× faster with 0.997 action cosine similarity.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ Swaps DINOv2-L + SigLIP-so400m + Llama-2 7B → DINOv2-S/14 + SigLIP-B/16 + Qwen2.5 0.5B (~0.6B total), with a learned 896→4096 shim to stay compatible with OmniVLA's action head.
6
 
7
+ See [HF_README.md](HF_README.md) or https://huggingface.co/theguy21/openvla-micro for full details.
inference.py CHANGED
@@ -1,38 +1,48 @@
1
  """
2
  Standalone inference script for OpenVLA-Micro.
3
  Usage:
4
- python inference.py --checkpoint openvla-micro-merged.pt --image demo.jpg "pick up the red block"
 
 
 
 
 
 
 
5
  """
6
  import argparse
7
  from PIL import Image
8
-
9
  from modeling_openvla_micro import OpenVLAMicro
10
 
11
 
12
  def main():
13
  parser = argparse.ArgumentParser(description="OpenVLA-Micro inference")
14
- parser.add_argument("--checkpoint", type=str, default="openvla-micro-merged.pt",
15
- help="Path to a local checkpoint or a Hugging Face repo ID")
16
- parser.add_argument("--image", type=str, required=True,
17
- help="Path to input image")
18
- parser.add_argument("--device", type=str, default="cpu",
19
- help="Device: cpu or cuda")
20
- parser.add_argument("instruction", type=str, nargs="?",
21
- default="pick up the red block",
22
- help="Task instruction")
23
  args = parser.parse_args()
24
 
25
- # Load model
26
- print(f"Loading OpenVLA-Micro from {args.checkpoint}...")
27
- model = OpenVLAMicro.from_pretrained(args.checkpoint, device=args.device)
 
 
 
 
 
 
 
 
28
  model.eval()
29
- print(f"Model loaded on {args.device}")
 
30
 
31
- # Load image
32
  image = Image.open(args.image).convert("RGB")
33
  print(f"Image: {image.size}")
34
-
35
- # Run inference
36
  print(f"Instruction: {args.instruction}")
37
  action = model.predict_action(image, args.instruction)
38
  print(f"Action (7-DoF): {action}")
 
1
  """
2
  Standalone inference script for OpenVLA-Micro.
3
  Usage:
4
+ # GPU inference (HF hub)
5
+ python inference.py --image demo.jpg "pick up the red block"
6
+
7
+ # From a local .pt file
8
+ python inference.py --checkpoint openvla-micro-distill.pt --image demo.jpg "pick up the red block"
9
+
10
+ # CPU inference
11
+ python inference.py --device cpu --image demo.jpg "pick up the red block"
12
  """
13
  import argparse
14
  from PIL import Image
 
15
  from modeling_openvla_micro import OpenVLAMicro
16
 
17
 
18
  def main():
19
  parser = argparse.ArgumentParser(description="OpenVLA-Micro inference")
20
+ parser.add_argument("--checkpoint", type=str, default="theguy21/openvla-micro",
21
+ help="HF repo ID or path to local .pt checkpoint")
22
+ parser.add_argument("--image", type=str, required=True, help="Input image path")
23
+ parser.add_argument("--device", type=str, default="auto",
24
+ help="Device: auto, cuda, or cpu")
25
+ parser.add_argument("instruction", type=str, nargs="?", default="pick up the red block",
26
+ help="Task instruction (positional, optional)")
 
 
27
  args = parser.parse_args()
28
 
29
+ device = args.device
30
+ if device == "auto":
31
+ import torch
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+
34
+ llm_kwargs = {}
35
+ if device == "cpu":
36
+ llm_kwargs["torch_dtype"] = "float32"
37
+
38
+ print(f"Loading OpenVLA-Micro from {args.checkpoint} on {device}...")
39
+ model = OpenVLAMicro.from_pretrained(args.checkpoint, device=device, llm_kwargs=llm_kwargs)
40
  model.eval()
41
+ n_params = sum(p.numel() for p in model.parameters()) / 1e6
42
+ print(f"Model loaded ({n_params:.0f}M params)")
43
 
 
44
  image = Image.open(args.image).convert("RGB")
45
  print(f"Image: {image.size}")
 
 
46
  print(f"Instruction: {args.instruction}")
47
  action = model.predict_action(image, args.instruction)
48
  print(f"Action (7-DoF): {action}")
inference_cpu.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Edge device / CPU inference for OpenVLA-Micro.
3
+
4
+ This script is optimized for resource-constrained environments.
5
+ Two modes:
6
+
7
+ 1. Standard CPU – float32, ~3-5 sec/step on modern x86, ~6GB RAM
8
+ 2. Low-RAM (4-bit) – uses bitsandbytes 4-bit quantization, ~2.5GB RAM,
9
+ slightly slower but usable on 4GB devices like RPi 5
10
+ with sufficient swap.
11
+
12
+ Usage:
13
+ python inference_cpu.py --image demo.jpg "pick up the red block"
14
+ python inference_cpu.py --low-ram --image demo.jpg "pick up the red block"
15
+ python inference_cpu.py --checkpoint ./openvla-micro-distill.pt --image demo.jpg "pick up the red block"
16
+ """
17
+ import argparse
18
+ from PIL import Image
19
+ from modeling_openvla_micro import OpenVLAMicro
20
+
21
+
22
+ def main():
23
+ parser = argparse.ArgumentParser(description="OpenVLA-Micro CPU/edge inference")
24
+ parser.add_argument("--checkpoint", type=str, default="theguy21/openvla-micro",
25
+ help="HF repo ID or local .pt path")
26
+ parser.add_argument("--image", type=str, required=True, help="Input image path")
27
+ parser.add_argument("--low-ram", action="store_true",
28
+ help="4-bit quantized LLM (~2.5GB peak, requires bitsandbytes)")
29
+ parser.add_argument("instruction", type=str, nargs="?", default="pick up the red block",
30
+ help="Task instruction (positional, optional)")
31
+ args = parser.parse_args()
32
+
33
+ device = "cpu"
34
+ llm_kwargs = {}
35
+
36
+ if args.low_ram:
37
+ print("Low-RAM mode: 4-bit quantization (requires bitsandbytes)")
38
+ llm_kwargs = {
39
+ "load_in_4bit": True,
40
+ "bnb_4bit_compute_dtype": "float32",
41
+ "bnb_4bit_use_double_quant": True,
42
+ }
43
+ else:
44
+ print("Standard CPU mode: float32 (~6GB RAM)")
45
+ llm_kwargs["torch_dtype"] = "float32"
46
+
47
+ print(f"Loading OpenVLA-Micro from {args.checkpoint} on CPU...")
48
+ model = OpenVLAMicro.from_pretrained(args.checkpoint, device=device, llm_kwargs=llm_kwargs)
49
+ model.eval()
50
+ n_params = sum(p.numel() for p in model.parameters()) / 1e6
51
+ print(f"Model loaded ({n_params:.0f}M params)")
52
+
53
+ image = Image.open(args.image).convert("RGB")
54
+ print(f"Image: {image.size}")
55
+ print(f"Instruction: {args.instruction}")
56
+ action = model.predict_action(image, args.instruction)
57
+ print(f"Action (7-DoF): {action}")
58
+
59
+
60
+ if __name__ == "__main__":
61
+ main()
modeling_openvla_micro.py CHANGED
@@ -286,7 +286,8 @@ class OpenVLAMicro(nn.Module):
286
  )
287
 
288
  @classmethod
289
- def from_pretrained(cls, checkpoint_path: Union[str, Path], device: str = "cpu"):
 
290
  checkpoint_path = cls._resolve_checkpoint_path(checkpoint_path)
291
  ckpt = torch.load(checkpoint_path, map_location="cpu")
292
 
@@ -304,10 +305,12 @@ class OpenVLAMicro(nn.Module):
304
  llm_id = "Qwen/Qwen2.5-0.5B"
305
  config = AutoConfig.from_pretrained(llm_id)
306
  config.use_flash_attention_2 = False
 
 
307
  llm = AutoModelForCausalLM.from_pretrained(
308
  llm_id,
309
  config=config,
310
- torch_dtype=torch.bfloat16,
311
  )
312
 
313
  # --- Tokenizer ---
 
286
  )
287
 
288
  @classmethod
289
+ def from_pretrained(cls, checkpoint_path: Union[str, Path], device: str = "cpu",
290
+ **kwargs):
291
  checkpoint_path = cls._resolve_checkpoint_path(checkpoint_path)
292
  ckpt = torch.load(checkpoint_path, map_location="cpu")
293
 
 
305
  llm_id = "Qwen/Qwen2.5-0.5B"
306
  config = AutoConfig.from_pretrained(llm_id)
307
  config.use_flash_attention_2 = False
308
+ llm_kwargs = kwargs.pop("llm_kwargs", {})
309
+ llm_kwargs.setdefault("torch_dtype", torch.bfloat16)
310
  llm = AutoModelForCausalLM.from_pretrained(
311
  llm_id,
312
  config=config,
313
+ **llm_kwargs,
314
  )
315
 
316
  # --- Tokenizer ---
pyproject.toml CHANGED
@@ -19,6 +19,7 @@ dependencies = [
19
 
20
  [project.scripts]
21
  openvla-micro = "inference:main"
 
22
 
23
  [tool.setuptools]
24
- py-modules = ["inference", "modeling_openvla_micro"]
 
19
 
20
  [project.scripts]
21
  openvla-micro = "inference:main"
22
+ openvla-micro-cpu = "inference_cpu:main"
23
 
24
  [tool.setuptools]
25
+ py-modules = ["inference", "inference_cpu", "modeling_openvla_micro", "model_wrapper"]
train_shim.py CHANGED
@@ -1,8 +1,22 @@
1
  """
2
- Continue shim-only training (no action head) on cached distill data.
3
- Resumes from the previous best shim checkpoint.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
- import argparse, json, os, sys
6
  from pathlib import Path
7
 
8
  import numpy as np
@@ -13,50 +27,50 @@ from torch.utils.data import Dataset, DataLoader
13
  from PIL import Image
14
  from tqdm import tqdm
15
 
16
- sys.path.insert(0, os.path.expanduser("~/openvla-micro"))
17
- sys.path.insert(0, "/mnt/steamdrive/openvla-micro")
18
-
19
- from model_wrapper import IMAGENET_MEAN as IMAGENET_MEAN_4D, IMAGENET_STD as IMAGENET_STD_4D, SIGLIP_MEAN, SIGLIP_STD
20
- IMAGENET_MEAN = IMAGENET_MEAN_4D.view(3, 1, 1)
21
- IMAGENET_STD = IMAGENET_STD_4D.view(3, 1, 1)
22
 
23
- # ── Paths ──
24
- CACHE_DIR = Path("/mnt/steamdrive/openvla_cache")
25
- CAST_DIR = Path("/mnt/steamdrive/cast_converted")
26
- CKPT_DIR = Path("/mnt/steamdrive/omnivla_checkpoints")
27
- RUN_BASE = Path("/mnt/steamdrive/openvla-micro/runs_distill_shim")
28
 
29
- # ── Constants ──
30
- NUM_ACTION_TOKENS = 32
31
- ACTION_DIM = 4
32
- NUM_ACTIONS_CHUNK = 8
33
 
34
 
35
- def to_siglip(pv: torch.Tensor) -> torch.Tensor:
36
- return (pv * IMAGENET_STD.to(pv.device) + IMAGENET_MEAN.to(pv.device) - SIGLIP_MEAN.to(pv.device)) / SIGLIP_STD.to(pv.device)
 
37
 
38
 
39
- # ── Dataset (same as original) ──
40
-
 
41
  class DistillDataset(Dataset):
42
- def __init__(self, cache_dir: Path, cast_dir: Path, split="train", val_ratio=0.1, max_episodes=0):
43
- self.cast_dir = cast_dir
44
- self.cache_dir = cache_dir
45
- cache_files = sorted(cache_dir.glob("episode_*.pt"))
 
 
 
 
 
 
 
 
 
 
46
  n = len(cache_files)
47
- if max_episodes > 0:
48
- cache_files = cache_files[:max_episodes]
49
- n = len(cache_files)
50
  split_idx = int(n * (1 - val_ratio))
51
  files = cache_files[:split_idx] if split == "train" else cache_files[split_idx:]
52
  self.index = []
53
  for cf in files:
54
  d = torch.load(cf, weights_only=True)
55
- T = d["num_steps"]
56
- for t in range(T):
57
  self.index.append((cf, t))
58
  self._cache = {}
59
- self._instruction_cache = {}
60
  print(f" [{split}] {len(self.index)} steps from {len(files)} episodes", flush=True)
61
 
62
  def __len__(self):
@@ -67,345 +81,244 @@ class DistillDataset(Dataset):
67
  cf_str = str(cf_path)
68
  if cf_str not in self._cache:
69
  self._cache[cf_str] = torch.load(cf_path, weights_only=True)
70
- if len(self._cache) > 2:
71
- oldest = next(iter(self._cache))
72
- del self._cache[oldest]
73
- ep_data = self._cache[cf_str]
74
- ep_id = ep_data["episode_id"]
75
- hs_target = ep_data["hidden_states"][t].float()
76
-
77
- if ep_id not in self._instruction_cache:
78
- instr_path = self.cast_dir / ep_id / "instructions.json"
79
- with open(instr_path, "r") as f:
80
- self._instruction_cache[ep_id] = json.load(f)
81
- instr = self._instruction_cache[ep_id][t]
82
- if isinstance(instr, list):
83
- instr = instr[0]
84
- instr = str(instr).strip()
85
-
86
- ep_dir = self.cast_dir / ep_id
87
  from torchvision.transforms.functional import resize as tv_resize
88
- cur = tv_resize(Image.open(ep_dir / "img" / f"step_{t:04d}.png").convert("RGB"), 224)
89
- cur = torch.tensor(np.array(cur, dtype=np.float32) / 255.0).permute(2, 0, 1)
90
- cur = (cur - IMAGENET_MEAN) / IMAGENET_STD
91
- return {
92
- "cur_img": cur,
93
- "hs_target": hs_target,
94
- "instruction": instr,
95
- }
96
-
97
-
98
- class DistillCollator:
99
- def __init__(self, tokenizer, action_token_ids, num_action_tokens=NUM_ACTION_TOKENS):
100
- self.tokenizer = tokenizer
101
- self.action_token_ids = action_token_ids
102
- self.num_action_tokens = num_action_tokens
103
-
104
- def __call__(self, batch):
105
- texts = []
106
- cur_imgs = []
107
- hs_targets = []
108
- for item in batch:
109
- texts.append(item["instruction"])
110
- cur_imgs.append(item["cur_img"])
111
- hs_targets.append(item["hs_target"])
112
-
113
- cur = torch.stack(cur_imgs) # (B, 3, 224, 224)
114
- hs_target = torch.stack(hs_targets) # (B, 32, 4096)
115
-
116
- chat = []
117
- for t in texts:
118
- chat.append([
119
- {"role": "system", "content": "You are a helpful assistant."},
120
- {"role": "user", "content": f"What action should the robot take to {t.lower()}?"},
121
- {"role": "assistant", "content": " ".join([f"<ACTION_{i}>" for i in range(self.num_action_tokens)])},
122
- ])
123
- tok = self.tokenizer.apply_chat_template(
124
- chat, tokenize=True, add_generation_prompt=False, return_dict=True, return_tensors="pt", padding=True,
125
- )
126
- input_ids = tok["input_ids"]
127
- attention_mask = tok["attention_mask"]
128
-
129
- return {
130
- "cur_img": cur,
131
- "input_ids": input_ids,
132
- "attention_mask": attention_mask,
133
- "hs_target": hs_target,
134
- }
135
 
136
 
137
  def main():
138
  parser = argparse.ArgumentParser()
139
- parser.add_argument("--max-steps", type=int, default=50000)
 
 
 
 
 
140
  parser.add_argument("--batch-size", type=int, default=8)
141
  parser.add_argument("--lr", type=float, default=5e-5)
142
- parser.add_argument("--weight-decay", type=float, default=0.01)
143
  parser.add_argument("--grad-accum", type=int, default=4)
144
- parser.add_argument("--log-every", type=int, default=50)
145
  parser.add_argument("--val-every", type=int, default=500)
146
  parser.add_argument("--save-every", type=int, default=5000)
147
- parser.add_argument("--num-workers", type=int, default=0)
148
- parser.add_argument("--resume", type=str,
149
- default=str(RUN_BASE / "shim_best.pt"))
150
- parser.add_argument("--run-name", type=str, default="shim_continued")
151
  args = parser.parse_args()
152
 
153
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
- assert device.type == "cuda"
155
  print(f"Device: {device}")
156
 
157
- run_dir = RUN_BASE.parent / args.run_name
158
- run_dir.mkdir(exist_ok=True, parents=True)
159
-
160
- # ── 1. Models ──
161
- from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP
162
- from transformers import AutoModelForCausalLM, AutoTokenizer
163
 
164
- print("\n[1] Loading base model (vision encoder, projector, Qwen)...")
165
- distill_ckpt = torch.load(os.path.expanduser("~/openvla-micro/openvla-micro-distill.pt"),
166
- map_location="cpu", weights_only=False)
167
- msd = distill_ckpt["model"]
168
 
169
- ve = DinoSigLIPEncoder()
170
  ve.load_state_dict(msd["vision_backbone"])
171
- ve.to(device=device, dtype=torch.bfloat16).eval()
172
- for p in ve.parameters():
173
- p.requires_grad_(False)
174
 
175
- projector = CombinedProjector(ShimMLP(384), ShimMLP(768),
176
- nn.Linear(8704, 896), nn.Linear(896, 896))
177
  projector.load_state_dict(msd["projector"])
178
- projector.to(device=device, dtype=torch.bfloat16).eval()
179
- for p in projector.parameters():
180
- p.requires_grad_(False)
181
 
182
- qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B", torch_dtype=torch.bfloat16)
183
- qwen_sd = {k.replace("llm.", "", 1): v for k, v in msd["llm_backbone"].items()}
184
- qwen.load_state_dict(qwen_sd)
185
- qwen.to(device=device, dtype=torch.bfloat16).eval()
186
- for p in qwen.parameters():
187
- p.requires_grad_(False)
188
 
189
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", use_fast=True)
190
  tokenizer.add_tokens([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])
191
-
192
- # ── 2. Shim (trainable) ──
193
- print("\n[2] Loading shim...")
194
- shim = nn.Sequential(nn.Linear(896, 2048), nn.GELU(), nn.Linear(2048, 4096))
195
- if args.resume and os.path.exists(args.resume):
196
- shim_sd = torch.load(args.resume, map_location="cpu", weights_only=True)
197
- shim.load_state_dict(shim_sd)
 
 
198
  print(f" Resumed from {args.resume}")
199
- else:
200
- print(" Starting from scratch (random init)!")
201
- shim.to(device=device, dtype=torch.bfloat16).train()
202
-
203
- # Pose projector (trainable)
204
- pose_proj = nn.Sequential(
205
- nn.Linear(4, 896), nn.GELU(),
206
- ).to(device=device, dtype=torch.bfloat16).train()
207
 
208
  # ── 3. Data ──
209
- print("\n[3] Setting up data...")
210
- action_token_ids = tokenizer.convert_tokens_to_ids([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])
211
- train_ds = DistillDataset(CACHE_DIR, CAST_DIR, split="train")
212
- val_ds = DistillDataset(CACHE_DIR, CAST_DIR, split="val")
213
- collator = DistillCollator(tokenizer, action_token_ids)
214
 
215
- train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
216
- collate_fn=collator, num_workers=args.num_workers, pin_memory=True)
217
- val_loader = DataLoader(val_ds, batch_size=1, shuffle=False,
218
- collate_fn=collator, num_workers=0, pin_memory=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  # ── 4. Optimizer ──
221
- opt = torch.optim.AdamW([
222
- {"params": shim.parameters(), "lr": args.lr},
223
- {"params": pose_proj.parameters(), "lr": args.lr},
224
- ], weight_decay=args.weight_decay)
225
  sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.max_steps)
226
 
227
  # ── 5. Training ──
228
- print("\n[4] Training (shim only, no action head)...")
229
-
230
- dino = ve.dino_featurizer
231
- siglip = ve.siglip_featurizer
232
- dino.eval(); siglip.eval()
233
- projector.eval(); qwen.eval()
234
 
235
  def encode_image(cur):
236
- """Encode a single image → (B, 452, 896) vision features."""
237
  with torch.no_grad():
238
- dino_f = dino(cur.to(dtype=torch.bfloat16))
239
- if isinstance(dino_f, (list, tuple)):
240
- dino_f = dino_f[0]
241
- dino_f = dino_f[:, 1:] # drop cls
242
- siglip_f = siglip(to_siglip(cur).to(dtype=torch.bfloat16))
243
- if isinstance(siglip_f, (list, tuple)):
244
- siglip_f = siglip_f[0]
245
- siglip_f = siglip_f[:, 1:] # drop cls
246
- B = cur.shape[0]
247
- D = 1152
248
- def pad(feat, ed):
249
- p = torch.zeros(B, feat.shape[1], D, device=device, dtype=torch.bfloat16)
250
- p[..., :ed] = feat[..., :ed]
251
- return p
252
- combined = torch.cat([pad(dino_f, 384), pad(siglip_f, 768)], dim=1)
253
- return projector(combined) # (B, 452, 896)
254
-
255
- # Find action token offset in template
256
- dummy_tok = tokenizer.apply_chat_template(
257
- [{"role": "system", "content": "You are a helpful assistant."},
258
- {"role": "user", "content": "test"},
259
- {"role": "assistant", "content": " ".join([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])}],
260
- tokenize=True, add_generation_prompt=False, return_dict=True, return_tensors="pt",
261
- )
262
- dummy_ids = dummy_tok["input_ids"].squeeze(0)
263
- action_pos = torch.where(
264
- (dummy_ids >= action_token_ids[0]) & (dummy_ids <= action_token_ids[-1])
265
- )[0]
266
- action_offset_in_text = action_pos[0].item()
267
- NUM_VIS = 452
268
- print(f" Action tokens start at position {action_offset_in_text} in text sequence")
269
- print(f" Vision tokens: {NUM_VIS}")
270
-
271
  global_step = 0
272
- best_val_loss = float("inf")
273
  train_iter = iter(train_loader)
274
-
275
  pbar = tqdm(total=args.max_steps, desc="Train")
 
276
  while global_step < args.max_steps:
277
  shim.train()
278
- pose_proj.train()
279
  opt.zero_grad()
280
  accum_loss = 0.0
281
 
282
- for micro_step in range(args.grad_accum):
283
  try:
284
  batch = next(train_iter)
285
  except StopIteration:
286
  train_iter = iter(train_loader)
287
  batch = next(train_iter)
288
 
289
- cur_img = batch["cur_img"].to(device, dtype=torch.bfloat16)
290
  inp = batch["input_ids"].to(device)
291
  am = batch["attention_mask"].to(device)
292
- hs_target = batch["hs_target"].to(device, dtype=torch.bfloat16)
293
  B = cur_img.shape[0]
294
 
295
- # Encode image
296
  vis = encode_image(cur_img)
297
-
298
- # Text embeddings
299
- embed = qwen.get_input_embeddings()(inp)
300
-
301
- # Multimodal input
302
- mm_embeds = torch.cat([embed[:, :1, :], vis, embed[:, 1:, :]], dim=1)
303
- mm_attn = torch.cat([am[:, :1],
304
- torch.ones(B, vis.shape[1], dtype=am.dtype, device=device),
305
- am[:, 1:]], dim=1)
306
-
307
- # Action mask
308
- act_start = 1 + NUM_VIS + action_offset_in_text - 1
309
- action_mask_mm = torch.zeros(mm_embeds.shape[:2], dtype=torch.bool, device=device)
310
  for i in range(B):
311
- start = act_start
312
- end = start + NUM_ACTION_TOKENS
313
- if end <= mm_embeds.shape[1]:
314
- action_mask_mm[i, start:end] = True
315
- mm_embeds = mm_embeds * ~action_mask_mm.unsqueeze(-1)
316
-
317
- # Qwen forward
318
- try:
319
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
320
- out = qwen(inputs_embeds=mm_embeds, attention_mask=mm_attn,
321
- labels=None, output_hidden_states=True, return_dict=True)
322
- except Exception as exc:
323
- print(f"\n[train] forward failed at step {global_step} micro_step {micro_step}: {exc}", flush=True)
324
- raise
325
 
 
 
326
  hs_all = out.hidden_states[-1]
327
- hs_act_qwen = torch.stack([hs_all[i, action_mask_mm[i]] for i in range(B)], dim=0)
328
- hs_shimmed = shim(hs_act_qwen) # (B, 32, 4096)
329
-
330
- # State MSE (shim gradients)
331
  loss = F.mse_loss(hs_shimmed, hs_target)
332
  (loss / args.grad_accum).backward()
333
  accum_loss += loss.item()
334
 
335
- # Gradient clipping & step
336
  torch.nn.utils.clip_grad_norm_(shim.parameters(), 1.0)
337
- torch.nn.utils.clip_grad_norm_(pose_proj.parameters(), 1.0)
338
  opt.step()
339
  sched.step()
340
  global_step += 1
341
 
342
- # Logging
343
- if global_step % args.log_every == 0:
344
- lr_cur = opt.param_groups[0]["lr"]
345
  with torch.no_grad():
346
- cos = F.cosine_similarity(hs_shimmed.float().reshape(-1, 4096),
347
- hs_target.float().reshape(-1, 4096), dim=-1).mean().item()
348
- nd = (hs_shimmed.float() - hs_target.float()).norm(dim=-1).mean().item()
349
- pbar.set_postfix({
350
- "loss": f"{accum_loss/args.grad_accum:.5f}",
351
- "cos": f"{cos:.4f}",
352
- "nd": f"{nd:.2f}",
353
- "lr": f"{lr_cur:.1e}",
354
- })
355
 
356
  # Validation
357
  if global_step % args.val_every == 0:
358
  shim.eval()
359
- val_loss, val_cos, val_nd, nv = 0.0, 0.0, 0.0, 0
360
  with torch.no_grad():
361
- for vbatch in val_loader:
362
- cur_img = vbatch["cur_img"].to(device, dtype=torch.bfloat16)
363
- inp = vbatch["input_ids"].to(device)
364
- am = vbatch["attention_mask"].to(device)
365
- hs_target = vbatch["hs_target"].to(device, dtype=torch.bfloat16)
366
- Bv = cur_img.shape[0]
367
- vis = encode_image(cur_img)
368
- embed = qwen.get_input_embeddings()(inp)
369
- mm_embeds = torch.cat([embed[:, :1, :], vis, embed[:, 1:, :]], dim=1)
370
- mm_attn = torch.cat([am[:, :1],
371
- torch.ones(Bv, NUM_VIS, dtype=am.dtype, device=device),
372
- am[:, 1:]], dim=1)
373
- act_start_v = 1 + NUM_VIS + action_offset_in_text - 1
374
- act_mask_v = torch.zeros(mm_embeds.shape[:2], dtype=torch.bool, device=device)
375
  for i in range(Bv):
376
- s = act_start_v
377
- e = s + NUM_ACTION_TOKENS
378
- if e <= mm_embeds.shape[1]:
379
- act_mask_v[i, s:e] = True
380
- mm_embeds = mm_embeds * ~act_mask_v.unsqueeze(-1)
381
- out = qwen(inputs_embeds=mm_embeds, attention_mask=mm_attn,
382
- labels=None, output_hidden_states=True, return_dict=True)
383
- hs_all = out.hidden_states[-1]
384
- hs_act = torch.stack([hs_all[i, act_mask_v[i]] for i in range(Bv)], dim=0)
385
- hs_shimmed = shim(hs_act)
386
- val_loss += F.mse_loss(hs_shimmed, hs_target).item()
387
- val_cos += F.cosine_similarity(hs_shimmed.float().reshape(-1, 4096),
388
- hs_target.float().reshape(-1, 4096), dim=-1).mean().item()
389
- val_nd += (hs_shimmed.float() - hs_target.float()).norm(dim=-1).mean().item()
390
  nv += 1
391
- val_loss /= nv; val_cos /= nv; val_nd /= nv
392
- print(f"\n─── Val @ step {global_step}: loss={val_loss:.5f} cos={val_cos:.4f} nd={val_nd:.2f} ───", flush=True)
393
- if val_loss < best_val_loss:
394
- best_val_loss = val_loss
395
  torch.save(shim.state_dict(), run_dir / "shim_best.pt")
396
- torch.save(pose_proj.state_dict(), run_dir / "pose_projector_best.pt")
397
- print(f" → New best saved (loss={val_loss:.5f})")
398
 
399
  if global_step % args.save_every == 0:
400
- ckpt_dir = run_dir / f"step_{global_step}"
401
- ckpt_dir.mkdir(exist_ok=True)
402
- torch.save(shim.state_dict(), ckpt_dir / "shim.pt")
403
- torch.save(pose_proj.state_dict(), ckpt_dir / "pose_projector.pt")
404
-
405
- pbar.update(1)
406
 
407
  pbar.close()
408
- print(f"\nDone! Best val loss: {best_val_loss:.5f}")
409
 
410
 
411
  if __name__ == "__main__":
 
1
  """
2
+ Train the hidden state shim (896→4096) for OpenVLA-Micro.
3
+
4
+ The shim maps Qwen2.5 0.5B's 896-dim hidden states to match a teacher
5
+ LLM's 4096-dim space (e.g., Llama-2, Llama-3). This lets the small model
6
+ drive OmniVLA's pretrained action head with near-zero accuracy loss.
7
+
8
+ Workflow:
9
+ 1. Cache your teacher's hidden states on your dataset
10
+ 2. Run this script to train the shim
11
+ 3. Bake the shim into the checkpoint with bake_shim.py
12
+
13
+ Usage:
14
+ python train_shim.py --cache-dir ./my_cache --base-model theguy21/openvla-micro
15
+
16
+ For the full training pipeline used in openvla-micro-distill, see:
17
+ https://huggingface.co/theguy21/openvla-micro
18
  """
19
+ import argparse, json, os
20
  from pathlib import Path
21
 
22
  import numpy as np
 
27
  from PIL import Image
28
  from tqdm import tqdm
29
 
30
+ from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP
31
+ from model_wrapper import IMAGENET_MEAN as IM4D, IMAGENET_STD as IS4D, SIGLIP_MEAN, SIGLIP_STD
32
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
33
 
34
+ IMAGENET_MEAN = IM4D.view(3, 1, 1)
35
+ IMAGENET_STD = IS4D.view(3, 1, 1)
 
 
 
36
 
37
+ NUM_ACTION_TOKENS = 32 # OmniVLA uses 8 chunks × 4 DoF
38
+ NUM_VIS = 452 # 256 dino patches + 196 siglip patches
 
 
39
 
40
 
41
+ def to_siglip(pv):
42
+ return (pv * IMAGENET_STD.to(pv.device) + IMAGENET_MEAN.to(pv.device)
43
+ - SIGLIP_MEAN.to(pv.device)) / SIGLIP_STD.to(pv.device)
44
 
45
 
46
+ # ─────────────────────────────────────────────────────────────
47
+ # Dataset — ADAPT THE IMAGE/INSTRUCTION LOGIC TO YOUR FORMAT
48
+ # ─────────────────────────────────────────────────────────────
49
  class DistillDataset(Dataset):
50
+ """
51
+ Each episode_*.pt is expected to contain:
52
+ episode_id: str
53
+ num_steps: int
54
+ hidden_states: Tensor[T, 32, teacher_dim]
55
+ (optional) instructions: list[str] of length T
56
+
57
+ Image paths are constructed as {data_dir}/{episode_id}/img/step_{t:04d}.png
58
+ Override _load_image / _get_instruction for custom formats.
59
+ """
60
+
61
+ def __init__(self, cache_dir, data_dir, split="train", val_ratio=0.1):
62
+ self.data_dir = Path(data_dir)
63
+ cache_files = sorted(Path(cache_dir).glob("episode_*.pt"))
64
  n = len(cache_files)
 
 
 
65
  split_idx = int(n * (1 - val_ratio))
66
  files = cache_files[:split_idx] if split == "train" else cache_files[split_idx:]
67
  self.index = []
68
  for cf in files:
69
  d = torch.load(cf, weights_only=True)
70
+ for t in range(d["num_steps"]):
 
71
  self.index.append((cf, t))
72
  self._cache = {}
73
+ self._instr_cache = {}
74
  print(f" [{split}] {len(self.index)} steps from {len(files)} episodes", flush=True)
75
 
76
  def __len__(self):
 
81
  cf_str = str(cf_path)
82
  if cf_str not in self._cache:
83
  self._cache[cf_str] = torch.load(cf_path, weights_only=True)
84
+ ep = self._cache[cf_str]
85
+ ep_id = ep["episode_id"]
86
+
87
+ # Image
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  from torchvision.transforms.functional import resize as tv_resize
89
+ img = tv_resize(Image.open(self.data_dir / ep_id / "img" / f"step_{t:04d}.png").convert("RGB"), 224)
90
+ img = torch.tensor(np.array(img, dtype=np.float32) / 255.0).permute(2, 0, 1)
91
+ img = (img - IMAGENET_MEAN) / IMAGENET_STD
92
+
93
+ # Instruction
94
+ if "instructions" in ep:
95
+ instr = ep["instructions"][t]
96
+ if isinstance(instr, list):
97
+ instr = instr[0]
98
+ else:
99
+ instr = "move forward"
100
+
101
+ return {"cur_img": img, "hs_target": ep["hidden_states"][t].float(), "instruction": str(instr).strip()}
102
+
103
+
104
+ def find_action_offset(tokenizer, action_token_ids):
105
+ """Determine where action tokens start in the chat template."""
106
+ dummy = tokenizer.apply_chat_template(
107
+ [{"role": "system", "content": "You are a helpful assistant."},
108
+ {"role": "user", "content": "test"},
109
+ {"role": "assistant", "content": " ".join([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])}],
110
+ tokenize=True, add_generation_prompt=False, return_dict=True, return_tensors="pt",
111
+ )
112
+ ids = dummy["input_ids"].squeeze(0)
113
+ pos = torch.where((ids >= action_token_ids[0]) & (ids <= action_token_ids[-1]))[0]
114
+ return pos[0].item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
 
117
  def main():
118
  parser = argparse.ArgumentParser()
119
+ parser.add_argument("--cache-dir", type=str, required=True)
120
+ parser.add_argument("--data-dir", type=str, required=True,
121
+ help="Dataset root with {episode_id}/img/step_*.png")
122
+ parser.add_argument("--base-model", type=str, default="theguy21/openvla-micro")
123
+ parser.add_argument("--teacher-dim", type=int, default=4096)
124
+ parser.add_argument("--max-steps", type=int, default=10000)
125
  parser.add_argument("--batch-size", type=int, default=8)
126
  parser.add_argument("--lr", type=float, default=5e-5)
 
127
  parser.add_argument("--grad-accum", type=int, default=4)
 
128
  parser.add_argument("--val-every", type=int, default=500)
129
  parser.add_argument("--save-every", type=int, default=5000)
130
+ parser.add_argument("--resume", type=str, default=None)
131
+ parser.add_argument("--run-name", type=str, default="shim_run")
132
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
 
133
  args = parser.parse_args()
134
 
135
+ device = torch.device(args.device)
136
+ dtype = torch.bfloat16
137
  print(f"Device: {device}")
138
 
139
+ run_dir = Path(args.run_name)
140
+ run_dir.mkdir(exist_ok=True)
 
 
 
 
141
 
142
+ # ── 1. Load base model ──
143
+ print("\n[1] Loading base model...")
144
+ ckpt = torch.load(os.path.expanduser(args.base_model), map_location="cpu", weights_only=False)
145
+ msd = ckpt["model"]
146
 
147
+ ve = DinoSigLIPEncoder().eval()
148
  ve.load_state_dict(msd["vision_backbone"])
149
+ ve.to(device, dtype=dtype)
150
+ for p in ve.parameters(): p.requires_grad_(False)
 
151
 
152
+ projector = CombinedProjector(ShimMLP(384), ShimMLP(768), nn.Linear(8704, 896), nn.Linear(896, 896))
 
153
  projector.load_state_dict(msd["projector"])
154
+ projector.to(device, dtype=dtype).eval()
155
+ for p in projector.parameters(): p.requires_grad_(False)
 
156
 
157
+ llm = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B", torch_dtype=dtype)
158
+ llm_sd = {k.replace("llm.", "", 1): v for k, v in msd["llm_backbone"].items()}
159
+ llm.load_state_dict(llm_sd)
160
+ llm.to(device, dtype=dtype).eval()
161
+ for p in llm.parameters(): p.requires_grad_(False)
 
162
 
163
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", use_fast=True)
164
  tokenizer.add_tokens([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])
165
+ action_token_ids = tokenizer.convert_tokens_to_ids([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])
166
+ action_offset = find_action_offset(tokenizer, action_token_ids)
167
+ print(f" Action tokens at position {action_offset}")
168
+
169
+ # ── 2. Shim ──
170
+ print("\n[2] Building shim...")
171
+ shim = nn.Sequential(nn.Linear(896, 2048), nn.GELU(), nn.Linear(2048, args.teacher_dim))
172
+ if args.resume:
173
+ shim.load_state_dict(torch.load(args.resume, map_location="cpu"))
174
  print(f" Resumed from {args.resume}")
175
+ shim.to(device, dtype=dtype).train()
 
 
 
 
 
 
 
176
 
177
  # ── 3. Data ──
178
+ print("\n[3] Loading data...")
179
+ train_ds = DistillDataset(args.cache_dir, args.data_dir, split="train")
180
+ val_ds = DistillDataset(args.cache_dir, args.data_dir, split="val")
 
 
181
 
182
+ def collate(batch):
183
+ from torchvision.transforms.functional import resize as tv_resize
184
+ texts, imgs, hs = [], [], []
185
+ for b in batch:
186
+ texts.append(b["instruction"])
187
+ imgs.append(b["cur_img"])
188
+ hs.append(b["hs_target"])
189
+ cur = torch.stack(imgs)
190
+ hs_target = torch.stack(hs)
191
+ chat = [[{"role": "system", "content": "You are a helpful assistant."},
192
+ {"role": "user", "content": f"What action should the robot take to {t.lower()}?"},
193
+ {"role": "assistant", "content": " ".join([f"<ACTION_{i}>" for i in range(NUM_ACTION_TOKENS)])}]
194
+ for t in texts]
195
+ tok = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False,
196
+ return_dict=True, return_tensors="pt", padding=True)
197
+ return {"cur_img": cur, "input_ids": tok["input_ids"], "attention_mask": tok["attention_mask"],
198
+ "hs_target": hs_target}
199
+
200
+ train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate, num_workers=0)
201
+ val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate, num_workers=0)
202
 
203
  # ── 4. Optimizer ──
204
+ opt = torch.optim.AdamW(shim.parameters(), lr=args.lr, weight_decay=0.01)
 
 
 
205
  sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.max_steps)
206
 
207
  # ── 5. Training ──
208
+ print(f"\n[4] Training...")
209
+ dino, siglip = ve.dino_featurizer, ve.siglip_featurizer
 
 
 
 
210
 
211
  def encode_image(cur):
 
212
  with torch.no_grad():
213
+ df = dino(cur)
214
+ if isinstance(df, (list, tuple)): df = df[0]
215
+ df = df[:, 1:]
216
+ sf = siglip(to_siglip(cur))
217
+ if isinstance(sf, (list, tuple)): sf = sf[0]
218
+ sf = sf[:, 1:]
219
+ B = cur.shape[0]; D = 1152
220
+ def pad(f, ed):
221
+ p = torch.zeros(B, f.shape[1], D, device=device, dtype=dtype)
222
+ p[..., :ed] = f[..., :ed]; return p
223
+ return projector(torch.cat([pad(df, 384), pad(sf, 768)], dim=1))
224
+
225
+ best_loss = float("inf")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  global_step = 0
 
227
  train_iter = iter(train_loader)
 
228
  pbar = tqdm(total=args.max_steps, desc="Train")
229
+
230
  while global_step < args.max_steps:
231
  shim.train()
 
232
  opt.zero_grad()
233
  accum_loss = 0.0
234
 
235
+ for _ in range(args.grad_accum):
236
  try:
237
  batch = next(train_iter)
238
  except StopIteration:
239
  train_iter = iter(train_loader)
240
  batch = next(train_iter)
241
 
242
+ cur_img = batch["cur_img"].to(device, dtype=dtype)
243
  inp = batch["input_ids"].to(device)
244
  am = batch["attention_mask"].to(device)
245
+ hs_target = batch["hs_target"].to(device, dtype=dtype)
246
  B = cur_img.shape[0]
247
 
 
248
  vis = encode_image(cur_img)
249
+ embed = llm.get_input_embeddings()(inp)
250
+ mm = torch.cat([embed[:, :1, :], vis, embed[:, 1:, :]], dim=1)
251
+ mm_attn = torch.cat([am[:, :1], torch.ones(B, NUM_VIS, dtype=am.dtype, device=device), am[:, 1:]], dim=1)
252
+ act_start = 1 + NUM_VIS + action_offset - 1
253
+ mask = torch.zeros(B, mm.shape[1], dtype=torch.bool, device=device)
 
 
 
 
 
 
 
 
254
  for i in range(B):
255
+ end = act_start + NUM_ACTION_TOKENS
256
+ if end <= mm.shape[1]:
257
+ mask[i, act_start:end] = True
258
+ mm = mm * ~mask.unsqueeze(-1)
 
 
 
 
 
 
 
 
 
 
259
 
260
+ with torch.autocast(device_type=device.type, dtype=dtype):
261
+ out = llm(inputs_embeds=mm, attention_mask=mm_attn, labels=None, output_hidden_states=True, return_dict=True)
262
  hs_all = out.hidden_states[-1]
263
+ hs_act = torch.stack([hs_all[i, mask[i]] for i in range(B)], dim=0)
264
+ hs_shimmed = shim(hs_act)
 
 
265
  loss = F.mse_loss(hs_shimmed, hs_target)
266
  (loss / args.grad_accum).backward()
267
  accum_loss += loss.item()
268
 
 
269
  torch.nn.utils.clip_grad_norm_(shim.parameters(), 1.0)
 
270
  opt.step()
271
  sched.step()
272
  global_step += 1
273
 
274
+ if global_step % 100 == 0:
 
 
275
  with torch.no_grad():
276
+ cos = F.cosine_similarity(hs_shimmed.float().reshape(-1, args.teacher_dim),
277
+ hs_target.float().reshape(-1, args.teacher_dim), dim=-1).mean().item()
278
+ pbar.set_postfix({"loss": f"{accum_loss/args.grad_accum:.5f}", "cos": f"{cos:.4f}"})
279
+ pbar.update(1)
 
 
 
 
 
280
 
281
  # Validation
282
  if global_step % args.val_every == 0:
283
  shim.eval()
284
+ v_loss, v_cos, nv = 0.0, 0.0, 0
285
  with torch.no_grad():
286
+ for vb in val_loader:
287
+ ci = vb["cur_img"].to(device, dtype=dtype)
288
+ ip = vb["input_ids"].to(device)
289
+ am = vb["attention_mask"].to(device)
290
+ ht = vb["hs_target"].to(device, dtype=dtype)
291
+ Bv = ci.shape[0]
292
+ vi = encode_image(ci)
293
+ em = llm.get_input_embeddings()(ip)
294
+ mm = torch.cat([em[:, :1, :], vi, em[:, 1:, :]], dim=1)
295
+ ma = torch.cat([am[:, :1], torch.ones(Bv, NUM_VIS, dtype=am.dtype, device=device), am[:, 1:]], dim=1)
296
+ mk = torch.zeros(Bv, mm.shape[1], dtype=torch.bool, device=device)
 
 
 
297
  for i in range(Bv):
298
+ e = 1 + NUM_VIS + action_offset - 1 + NUM_ACTION_TOKENS
299
+ if e <= mm.shape[1]:
300
+ mk[i, 1 + NUM_VIS + action_offset - 1:e] = True
301
+ mm = mm * ~mk.unsqueeze(-1)
302
+ o = llm(inputs_embeds=mm, attention_mask=ma, labels=None, output_hidden_states=True, return_dict=True)
303
+ ha = torch.stack([o.hidden_states[-1][i, mk[i]] for i in range(Bv)], dim=0)
304
+ hs = shim(ha)
305
+ v_loss += F.mse_loss(hs, ht).item()
306
+ v_cos += F.cosine_similarity(hs.float().reshape(-1, args.teacher_dim),
307
+ ht.float().reshape(-1, args.teacher_dim), dim=-1).mean().item()
 
 
 
 
308
  nv += 1
309
+ v_loss /= nv; v_cos /= nv
310
+ print(f"\n─── Val @ {global_step}: loss={v_loss:.5f} cos={v_cos:.4f} ───", flush=True)
311
+ if v_loss < best_loss:
312
+ best_loss = v_loss
313
  torch.save(shim.state_dict(), run_dir / "shim_best.pt")
314
+ print(f" → Saved best (loss={v_loss:.5f})")
 
315
 
316
  if global_step % args.save_every == 0:
317
+ d = run_dir / f"step_{global_step}"; d.mkdir(exist_ok=True)
318
+ torch.save(shim.state_dict(), d / "shim.pt")
 
 
 
 
319
 
320
  pbar.close()
321
+ print(f"\nDone! Best val loss: {best_loss:.5f}")
322
 
323
 
324
  if __name__ == "__main__":