Reinforcement Learning
Transformers
English
robotics
vla
vision-language-action
openvla
omnivla
robot
qwen
dinov2
siglip
Instructions to use theguy21/openvla-micro with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use theguy21/openvla-micro with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("theguy21/openvla-micro", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| OpenVLA-Micro: A small-vision VLA for CPU robot deployment. | |
| =========================================================== | |
| Architecture: DINOv2-S (384-dim) + SigLIP-B/16 (768-dim) | |
| → ShimMLPs (→ 8704-dim each) → Concat → Linear(896) → GELU → Linear(896) | |
| → Qwen2.5-0.5B LLM → 7-DoF action. | |
| Trained by freezing DINOv2, SigLIP, Qwen2.5, and lm_head, and only training | |
| the ShimMLPs + LoRA adapters on the projector (38.1M params). | |
| This module provides a self-contained ``OpenVLAMicro`` class that loads the | |
| merged checkpoint and runs ``predict_action(image, instruction)``. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from functools import partial | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import timm | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from timm.models.vision_transformer import VisionTransformer | |
| from torchvision.transforms import Compose, Resize | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| PreTrainedModel, | |
| Qwen2TokenizerFast, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def unpack_tuple(fn): | |
| def wrapper(*args, **kwargs): | |
| result = fn(*args, **kwargs) | |
| return result[0] if isinstance(result, tuple) else result | |
| return wrapper | |
| def monkey_patch_featurizer(vit: VisionTransformer) -> None: | |
| """Patch a TIMM ViT to return penultimate-layer patch features.""" | |
| vit.forward = unpack_tuple( | |
| partial(vit.get_intermediate_layers, n={len(vit.blocks) - 2}) | |
| ) | |
| def _build_timm_transform(timm_path: str, img_size: int = 224) -> Compose: | |
| """Build a resize-to-224 image transform for a given TIMM model.""" | |
| model = timm.create_model(timm_path, pretrained=False, num_classes=0) | |
| cfg = timm.data.resolve_model_data_config(model) | |
| cfg["input_size"] = (3, img_size, img_size) | |
| default = timm.data.create_transform(**cfg, is_training=False) | |
| assert isinstance(default, Compose) | |
| assert isinstance(default.transforms[0], Resize) | |
| return Compose([ | |
| Resize((img_size, img_size), interpolation=default.transforms[0].interpolation), | |
| *default.transforms[1:], | |
| ]) | |
| # --------------------------------------------------------------------------- | |
| # Components | |
| # --------------------------------------------------------------------------- | |
| class ShimMLP(nn.Module): | |
| """Maps native vision dim (384 or 768) → 8704 (the original projector's intermediate dim).""" | |
| def __init__(self, in_dim: int, hidden: int = 2048, out_dim: int = 8704): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(in_dim, hidden), | |
| nn.GELU(), | |
| nn.Linear(hidden, out_dim), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.net(x) | |
| class CombinedProjector(nn.Module): | |
| """ | |
| Splits the zero-padded vision-backbone output (458 tokens × 1152 dims), | |
| runs per-backbone ShimMLPs, concatenates, then projects to LLM dim (896). | |
| """ | |
| def __init__(self, dino_mlp: ShimMLP, siglip_mlp: ShimMLP, | |
| proj2: nn.Linear, proj4: nn.Linear): | |
| super().__init__() | |
| self.dino_mlp = dino_mlp | |
| self.siglip_mlp = siglip_mlp | |
| self.proj2 = proj2 | |
| self.proj4 = proj4 | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| dino_feats = x[:, :256, :384] | |
| siglip_feats = x[:, 256:, :768] | |
| dino_out = self.dino_mlp(dino_feats) | |
| siglip_out = self.siglip_mlp(siglip_feats) | |
| combined = torch.cat([dino_out, siglip_out], dim=1) | |
| h = self.proj2(combined) | |
| h = nn.functional.gelu(h) | |
| h = self.proj4(h) | |
| return h | |
| # --------------------------------------------------------------------------- | |
| # Vision Encoder | |
| # --------------------------------------------------------------------------- | |
| DINOSigLIP_REGISTRY = { | |
| "dinosiglip-vit-s-b-224px": { | |
| "dino": "vit_small_patch14_reg4_dinov2.lvd142m", | |
| "siglip": "vit_base_patch16_siglip_224", | |
| }, | |
| } | |
| class DinoSigLIPEncoder(nn.Module): | |
| """ | |
| Loads DINOv2-S + SigLIP-B/16 from TIMM, runs both, | |
| zero-pads each to 1152 dims, concatenates along sequence dim. | |
| """ | |
| def __init__(self, variant: str = "dinosiglip-vit-s-b-224px", img_size: int = 224): | |
| super().__init__() | |
| spec = DINOSigLIP_REGISTRY[variant] | |
| self.dino_featurizer: VisionTransformer = timm.create_model( | |
| spec["dino"], pretrained=True, num_classes=0, img_size=img_size | |
| ) | |
| self.dino_featurizer.eval() | |
| monkey_patch_featurizer(self.dino_featurizer) | |
| self.siglip_featurizer: VisionTransformer = timm.create_model( | |
| spec["siglip"], pretrained=True, num_classes=0, img_size=img_size | |
| ) | |
| self.siglip_featurizer.eval() | |
| monkey_patch_featurizer(self.siglip_featurizer) | |
| self.dino_transform = _build_timm_transform(spec["dino"], img_size) | |
| self.siglip_transform = _build_timm_transform(spec["siglip"], img_size) | |
| self.total_embed_dim = ( | |
| self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim | |
| ) | |
| def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor: | |
| dino_out = self.dino_featurizer(pixel_values["dino"]) | |
| siglip_out = self.siglip_featurizer(pixel_values["siglip"]) | |
| if isinstance(dino_out, (list, tuple)): | |
| dino_out = dino_out[0] | |
| if isinstance(siglip_out, (list, tuple)): | |
| siglip_out = siglip_out[0] | |
| B, D_total = dino_out.shape[0], self.total_embed_dim | |
| dino_padded = torch.zeros( | |
| B, dino_out.shape[1], D_total, device=dino_out.device, dtype=dino_out.dtype | |
| ) | |
| dino_padded[:, :, : dino_out.shape[-1]] = dino_out | |
| siglip_padded = torch.zeros( | |
| B, siglip_out.shape[1], D_total, device=siglip_out.device, dtype=siglip_out.dtype | |
| ) | |
| siglip_padded[:, :, : siglip_out.shape[-1]] = siglip_out | |
| return torch.cat([dino_padded, siglip_padded], dim=1) | |
| def get_image_transform(self): | |
| """Return a callable that takes a PIL Image → dict of tensors.""" | |
| def transform(img: Image.Image) -> Dict[str, torch.Tensor]: | |
| return { | |
| "dino": self.dino_transform(img), | |
| "siglip": self.siglip_transform(img), | |
| } | |
| return transform | |
| # --------------------------------------------------------------------------- | |
| # Action De-Tokenization | |
| # --------------------------------------------------------------------------- | |
| class ActionTokenizer: | |
| """Minimal action tokenizer: decodes token IDs → normalized continuous actions.""" | |
| def __init__(self, tokenizer, use_extra: bool = True): | |
| self.tokenizer = tokenizer | |
| self.n_bins = 256 | |
| self.min_action = -1 | |
| self.max_action = 1 | |
| self.bins = np.linspace(self.min_action, self.max_action, self.n_bins) | |
| self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 | |
| tokenizer_len = len(tokenizer) if use_extra else tokenizer.vocab_size | |
| self.action_token_begin_idx = int(tokenizer_len - (self.n_bins + 1)) | |
| self.action_token_end_idx = int(tokenizer_len) | |
| self.tokenizer_len = tokenizer_len | |
| def decode_token_ids_to_actions(self, action_token_ids: np.ndarray) -> np.ndarray: | |
| discretized = self.tokenizer_len - action_token_ids | |
| discretized = np.clip(discretized - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) | |
| return self.bin_centers[discretized] | |
| def unnormalize_actions( | |
| normalized_actions: np.ndarray, norm_stats: dict, unnorm_key: str | |
| ) -> np.ndarray: | |
| stats = norm_stats[unnorm_key]["action"] | |
| mask = stats.get("mask", np.ones_like(stats["q01"], dtype=bool)) | |
| high, low = np.array(stats["q99"]), np.array(stats["q01"]) | |
| return np.where( | |
| mask, | |
| 0.5 * (normalized_actions + 1) * (high - low) + low, | |
| normalized_actions, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # OpenVLA-Micro | |
| # --------------------------------------------------------------------------- | |
| class OpenVLAMicro(nn.Module): | |
| """ | |
| Self-contained OpenVLA-Micro model. | |
| Usage:: | |
| model = OpenVLAMicro.from_pretrained("/path/to/openvla-micro-merged.pt") | |
| model.to("cuda") | |
| action = model.predict_action(pil_image, "pick up the red block") | |
| """ | |
| def __init__( | |
| self, | |
| vision_encoder: DinoSigLIPEncoder, | |
| projector: CombinedProjector, | |
| llm: PreTrainedModel, | |
| tokenizer, | |
| norm_stats: dict, | |
| unnorm_key: str = "libero_90", | |
| ): | |
| super().__init__() | |
| self.vision_encoder = vision_encoder | |
| self.projector = projector | |
| self.llm = llm | |
| self.tokenizer = tokenizer | |
| self.norm_stats = norm_stats | |
| self.unnorm_key = unnorm_key | |
| self.action_dim = 7 | |
| self.image_transform = vision_encoder.get_image_transform() | |
| self.action_tokenizer = ActionTokenizer(tokenizer, use_extra=True) | |
| self.device = next(self.llm.parameters()).device | |
| def _resolve_checkpoint_path(cls, checkpoint_path: Union[str, Path]) -> Path: | |
| path = Path(checkpoint_path) | |
| if path.exists(): | |
| return path | |
| # Treat a non-path input as a Hugging Face repo ID and fetch the default artifact. | |
| for filename in ("openvla-micro-distill.pt", "openvla-micro-merged.pt"): | |
| try: | |
| return Path(hf_hub_download(repo_id=str(checkpoint_path), filename=filename)) | |
| except Exception: | |
| continue | |
| raise FileNotFoundError( | |
| f"Could not resolve checkpoint '{checkpoint_path}'. " | |
| "Pass a local .pt file or a Hugging Face repo ID containing " | |
| "'openvla-micro-merged.pt' or 'openvla-micro-distill.pt'." | |
| ) | |
| def from_pretrained(cls, checkpoint_path: Union[str, Path], device: str = "cpu", | |
| **kwargs): | |
| checkpoint_path = cls._resolve_checkpoint_path(checkpoint_path) | |
| ckpt = torch.load(checkpoint_path, map_location="cpu") | |
| # --- Build vision encoder --- | |
| vision_encoder = DinoSigLIPEncoder() | |
| # --- Build projector --- | |
| dino_mlp = ShimMLP(384) | |
| siglip_mlp = ShimMLP(768) | |
| proj2 = nn.Linear(8704, 896, bias=True) | |
| proj4 = nn.Linear(896, 896, bias=True) | |
| projector = CombinedProjector(dino_mlp, siglip_mlp, proj2, proj4) | |
| # --- Build LLM --- | |
| llm_id = "Qwen/Qwen2.5-0.5B" | |
| config = AutoConfig.from_pretrained(llm_id) | |
| config.use_flash_attention_2 = False | |
| llm_kwargs = kwargs.pop("llm_kwargs", {}) | |
| llm_kwargs.setdefault("torch_dtype", torch.bfloat16) | |
| llm = AutoModelForCausalLM.from_pretrained( | |
| llm_id, | |
| config=config, | |
| **llm_kwargs, | |
| ) | |
| # --- Tokenizer --- | |
| tokenizer = AutoTokenizer.from_pretrained(llm_id, use_fast=True) | |
| tokenizer.add_tokens([f"<|extra_{i}|>" for i in range(256)]) | |
| # --- Load weights --- | |
| model_sd = ckpt["model"] | |
| vision_encoder.load_state_dict(model_sd["vision_backbone"]) | |
| projector.load_state_dict(model_sd["projector"]) | |
| llm_raw_sd = model_sd["llm_backbone"] | |
| llm_clean_sd = {k.replace("llm.", "", 1): v for k, v in llm_raw_sd.items()} | |
| llm.load_state_dict(llm_clean_sd) | |
| norm_stats = ckpt.get("norm_stats", {}) | |
| model = cls(vision_encoder, projector, llm, tokenizer, norm_stats) | |
| return model.to(device) | |
| def to(self, device): | |
| self.device = device | |
| return super().to(device) | |
| def _build_prompt(self, instruction: str) -> str: | |
| """Match the prompt format used during training (QwenPromptBuilder, openvla family).""" | |
| system = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." | |
| return ( | |
| f"<|im_start|>system\n{system}<|im_end|>\n" | |
| f"<|im_start|>user\nWhat action should the robot take to {instruction.lower()}?<|im_end|>\n" | |
| f"<|im_start|>assistant\n" | |
| ) | |
| def predict_action( | |
| self, | |
| image: Image.Image, | |
| instruction: str, | |
| unnorm_key: Optional[str] = None, | |
| ) -> np.ndarray: | |
| """Run inference: image + instruction → 7-DoF action (unnormalized).""" | |
| if unnorm_key is None: | |
| unnorm_key = self.unnorm_key | |
| # --- Build prompt --- | |
| prompt_text = self._build_prompt(instruction) | |
| input_ids = self.tokenizer(prompt_text, return_tensors="pt").input_ids.to(self.device) | |
| # --- Transform image --- | |
| pixel_values = self.image_transform(image) | |
| pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()} | |
| # --- Encode vision --- | |
| patch_features = self.vision_encoder(pixel_values) | |
| projected_patches = self.projector(patch_features) | |
| # Match LLM dtype (bfloat16) for mixed-precision compatibility | |
| llm_dtype = self.llm.dtype | |
| projected_patches = projected_patches.to(dtype=llm_dtype) | |
| # --- Build multimodal input embeddings --- | |
| input_embeds = self.llm.model.embed_tokens(input_ids) | |
| multimodal_embeds = torch.cat( | |
| [input_embeds[:, :1, :], projected_patches, input_embeds[:, 1:, :]], dim=1 | |
| ) | |
| # --- Auto-regressive decode 7 action tokens --- | |
| past_key_values = None | |
| generated_ids = [] | |
| outputs = self.llm(inputs_embeds=multimodal_embeds, use_cache=True, past_key_values=None) | |
| next_token = outputs.logits[:, -1:, :].argmax(dim=-1) | |
| generated_ids.append(next_token) | |
| past_key_values = outputs.past_key_values | |
| for _ in range(self.action_dim - 1): | |
| outputs = self.llm(input_ids=next_token, use_cache=True, past_key_values=past_key_values) | |
| next_token = outputs.logits[:, -1:, :].argmax(dim=-1) | |
| generated_ids.append(next_token) | |
| past_key_values = outputs.past_key_values | |
| action_token_ids = torch.cat(generated_ids, dim=1) | |
| # --- Decode to continuous actions --- | |
| normalized = self.action_tokenizer.decode_token_ids_to_actions( | |
| action_token_ids[0].cpu().numpy() | |
| ) | |
| actions = unnormalize_actions(normalized, self.norm_stats, unnorm_key) | |
| return actions | |