Reinforcement Learning
Transformers
English
robotics
vla
vision-language-action
openvla
omnivla
robot
qwen
dinov2
siglip
Instructions to use theguy21/openvla-micro with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use theguy21/openvla-micro with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("theguy21/openvla-micro", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Fast(er) model loader for OpenVLA-Micro on CPU. | |
| Uses the cached distill checkpoint directly (~21s load). | |
| """ | |
| import os, time | |
| from pathlib import Path | |
| import torch | |
| def _find_cached_distill(): | |
| """Find the cached distill checkpoint in HF hub cache.""" | |
| import glob | |
| pattern = os.path.expanduser("~/.cache/huggingface/hub/models--theguy21--openvla-micro/snapshots/*/openvla-micro-distill.pt") | |
| matches = sorted(glob.glob(pattern)) | |
| return matches[-1] if matches else None | |
| def load_openvla_micro(device: str = "cpu"): | |
| from modeling_openvla_micro import DinoSigLIPEncoder, CombinedProjector, ShimMLP, OpenVLAMicro | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| ckpt_path = _find_cached_distill() | |
| if ckpt_path is None: | |
| print("Distill checkpoint not cached. Downloading from HF...", flush=True) | |
| from huggingface_hub import hf_hub_download | |
| ckpt_path = hf_hub_download("theguy21/openvla-micro", "openvla-micro-distill.pt") | |
| t0 = time.time() | |
| print(f"Loading checkpoint...", flush=True) | |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| msd = ckpt["model"] | |
| print(f"Building vision encoder...", flush=True) | |
| ve = DinoSigLIPEncoder().eval() | |
| ve.load_state_dict(msd["vision_backbone"]) | |
| proj = CombinedProjector(ShimMLP(384), ShimMLP(768), torch.nn.Linear(8704, 896), torch.nn.Linear(896, 896)) | |
| proj.load_state_dict(msd["projector"]) | |
| print(f"Loading Qwen...", flush=True) | |
| llm = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B", dtype=torch.float32) | |
| llm_sd = {k.replace("llm.", "", 1): v for k, v in msd["llm_backbone"].items()} | |
| llm.load_state_dict(llm_sd, strict=True) | |
| tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") | |
| model = OpenVLAMicro(ve, proj, llm, tok, norm_stats=ckpt.get("norm_stats")).eval() | |
| print(f"Loaded in {time.time()-t0:.0f}s", flush=True) | |
| return model | |