Reinforcement Learning
Transformers
English
robotics
vla
vision-language-action
openvla
omnivla
robot
qwen
dinov2
siglip
Instructions to use theguy21/openvla-micro with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use theguy21/openvla-micro with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("theguy21/openvla-micro", dtype="auto") - Notebooks
- Google Colab
- Kaggle
Add fast model loader (cached distill, ~21s load)
Browse files- fast_load.py +48 -0
fast_load.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fast(er) model loader for OpenVLA-Micro on CPU.
|
| 3 |
+
Uses the cached distill checkpoint directly (~21s load).
|
| 4 |
+
"""
|
| 5 |
+
import os, time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _find_cached_distill():
|
| 12 |
+
"""Find the cached distill checkpoint in HF hub cache."""
|
| 13 |
+
import glob
|
| 14 |
+
pattern = os.path.expanduser("~/.cache/huggingface/hub/models--theguy21--openvla-micro/snapshots/*/openvla-micro-distill.pt")
|
| 15 |
+
matches = sorted(glob.glob(pattern))
|
| 16 |
+
return matches[-1] if matches else None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_openvla_micro(device: str = "cpu"):
|
| 20 |
+
from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP, OpenVLAMicro
|
| 21 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 22 |
+
|
| 23 |
+
ckpt_path = _find_cached_distill()
|
| 24 |
+
if ckpt_path is None:
|
| 25 |
+
print("Distill checkpoint not cached. Downloading from HF...", flush=True)
|
| 26 |
+
from huggingface_hub import hf_hub_download
|
| 27 |
+
ckpt_path = hf_hub_download("theguy21/openvla-micro", "openvla-micro-distill.pt")
|
| 28 |
+
|
| 29 |
+
t0 = time.time()
|
| 30 |
+
print(f"Loading checkpoint...", flush=True)
|
| 31 |
+
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 32 |
+
msd = ckpt["model"]
|
| 33 |
+
|
| 34 |
+
print(f"Building vision encoder...", flush=True)
|
| 35 |
+
ve = DinoSigLIPEncoder().eval()
|
| 36 |
+
ve.load_state_dict(msd["vision_backbone"])
|
| 37 |
+
proj = CombinedProjector(ShimMLP(384), ShimMLP(768), torch.nn.Linear(8704, 896), torch.nn.Linear(896, 896))
|
| 38 |
+
proj.load_state_dict(msd["projector"])
|
| 39 |
+
|
| 40 |
+
print(f"Loading Qwen...", flush=True)
|
| 41 |
+
llm = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B", dtype=torch.float32)
|
| 42 |
+
llm_sd = {k.replace("llm.", "", 1): v for k, v in msd["llm_backbone"].items()}
|
| 43 |
+
llm.load_state_dict(llm_sd, strict=True)
|
| 44 |
+
tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
| 45 |
+
|
| 46 |
+
model = OpenVLAMicro(ve, proj, llm, tok, norm_stats=ckpt.get("norm_stats")).eval()
|
| 47 |
+
print(f"Loaded in {time.time()-t0:.0f}s", flush=True)
|
| 48 |
+
return model
|