Reinforcement Learning
Transformers
English
robotics
vla
vision-language-action
openvla
omnivla
robot
qwen
dinov2
siglip
Instructions to use theguy21/openvla-micro with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use theguy21/openvla-micro with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("theguy21/openvla-micro", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import sys | |
| sys.path.insert(0, "/mnt/steamdrive/openvla-micro") | |
| sys.path.insert(0, "/home/the_one1/OmniVLA") | |
| import math | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import timm | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from timm.models.vision_transformer import VisionTransformer | |
| from torchvision.transforms import Compose, Resize, Normalize | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from modeling_openvla_micro import ( | |
| DinoSigLIPEncoder, | |
| CombinedProjector, | |
| ShimMLP, | |
| unpack_tuple, | |
| monkey_patch_featurizer, | |
| _build_timm_transform, | |
| ) | |
| class VisionBackboneWrapper: | |
| """Wrapper exposing API the training loop expects, using real encoder shapes.""" | |
| def __init__(self, encoder: DinoSigLIPEncoder): | |
| self._encoder = encoder | |
| self._num_images_in_input = 2 | |
| # Infer patch count from encoder output | |
| device = next(encoder.parameters()).device | |
| with torch.inference_mode(): | |
| dummy = torch.zeros(1, 3, 224, 224, device=device, dtype=next(encoder.parameters()).dtype) | |
| dino_out = encoder.dino_featurizer(dummy) | |
| if isinstance(dino_out, (list, tuple)): | |
| dino_out = dino_out[0] | |
| siglip_out = encoder.siglip_featurizer(dummy) | |
| if isinstance(siglip_out, (list, tuple)): | |
| siglip_out = siglip_out[0] | |
| self._patches_per_img = dino_out.shape[1] + siglip_out.shape[1] | |
| print(f"[VisionBackboneWrapper] patches_per_img = {self._patches_per_img}") | |
| def get_num_patches(self) -> int: | |
| return self._patches_per_img | |
| def get_num_images_in_input(self) -> int: | |
| return self._num_images_in_input | |
| def set_num_images_in_input(self, n: int) -> None: | |
| self._num_images_in_input = n | |
| # SigLIP stats | |
| SIGLIP_MEAN = torch.tensor([0.5, 0.5, 0.5]).reshape(1, 3, 1, 1) | |
| SIGLIP_STD = torch.tensor([0.5, 0.5, 0.5]).reshape(1, 3, 1, 1) | |
| # ImageNet stats (used by OpenVLA processor and DINOv2) | |
| IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) | |
| IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1) | |
| def _convert_pixel_values_for_siglip(pixel_values: torch.Tensor) -> torch.Tensor: | |
| """Undo ImageNet norm, apply SigLIP norm.""" | |
| device = pixel_values.device | |
| m = IMAGENET_MEAN.to(device) | |
| s = IMAGENET_STD.to(device) | |
| sm = SIGLIP_MEAN.to(device) | |
| ss = SIGLIP_STD.to(device) | |
| return (pixel_values * s + m - sm) / ss | |
| class OpenVLAMicroWrapper(nn.Module): | |
| """ | |
| Wraps openvla-micro (DinoSigLIPEncoder + CombinedProjector + Qwen2.5) | |
| + MLP hidden shim (896→2048→4096) into OmniVLA's forward interface. | |
| Forward signature matches ``PrismaticForConditionalGeneration_MMNv1`` | |
| so that the existing training loop and ``run_forward_pass`` work as-is. | |
| """ | |
| def __init__( | |
| self, | |
| vision_encoder: DinoSigLIPEncoder, | |
| projector: CombinedProjector, | |
| llm: nn.Module, | |
| hidden_shim: nn.Module, | |
| tokenizer, | |
| pad_token_id: int, | |
| action_token_begin_idx: int = 151679, | |
| ): | |
| super().__init__() | |
| self.vision_encoder = vision_encoder | |
| self.projector = projector | |
| self.llm = llm | |
| self.hidden_shim = hidden_shim | |
| self.tokenizer = tokenizer | |
| self.pad_token_id = pad_token_id | |
| self.action_token_begin_idx = action_token_begin_idx | |
| self.vision_backbone = VisionBackboneWrapper(vision_encoder) | |
| self.llm_dim = 4096 # after hidden shim | |
| self.vocab_size = llm.config.vocab_size | |
| def get_input_embeddings(self) -> nn.Module: | |
| if hasattr(self.llm, 'get_input_embeddings'): | |
| return self.llm.get_input_embeddings() | |
| if hasattr(self.llm, 'base_model') and hasattr(self.llm.base_model, 'model'): | |
| return self.llm.base_model.model.embed_tokens | |
| return self.llm.model.embed_tokens | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| # MoE/novelty projection (unused but accepted for compat) | |
| proprio=None, | |
| proprio_projector=None, | |
| noisy_actions=None, | |
| noisy_action_projector=None, | |
| diffusion_timestep_embeddings=None, | |
| use_film: bool = False, | |
| # MMNv1-specific | |
| attention_mask_label=None, | |
| pixel_values_goal=None, | |
| img_hist=None, | |
| map_images=None, | |
| obs_img=None, | |
| modality_id=None, | |
| goal_pose=None, | |
| **kwargs, | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| _ = (attention_mask_label, pixel_values_goal, img_hist, map_images, obs_img, modality_id, goal_pose, | |
| noisy_actions, noisy_action_projector, diffusion_timestep_embeddings, use_film) | |
| output_hidden_states = output_hidden_states or False | |
| return_dict = return_dict if return_dict is not None else True | |
| use_cache = use_cache and not self.training | |
| B = input_ids.shape[0] | |
| # ---------- 1. Input embeddings ---------- | |
| embed_fn = self.get_input_embeddings() | |
| input_embeds = embed_fn(input_ids) # (B, seq_len, 896) | |
| input_embeds = input_embeds.to(self.llm.dtype) | |
| # ---------- 2. Action mask ---------- | |
| if labels is not None: | |
| all_actions_mask = self._action_mask(labels) | |
| all_actions_mask_3d = all_actions_mask.unsqueeze(-1) | |
| input_embeds = input_embeds * ~all_actions_mask_3d | |
| else: | |
| all_actions_mask = None | |
| # ---------- 3. Vision features ---------- | |
| if pixel_values is not None: | |
| num_patches_per_img = self.vision_backbone.get_num_patches() | |
| num_imgs = pixel_values.shape[1] // 3 # 6 channels → 2 images | |
| patch_feats = [] | |
| for i in range(num_imgs): | |
| img = pixel_values[:, i*3:(i+1)*3, :, :] # (B, 3, 224, 224) | |
| feats = self._encode_image(img) | |
| patch_feats.append(feats) | |
| projected = torch.cat(patch_feats, dim=1) # (B, N*num_imgs, 896) | |
| # Optionally append proprio | |
| if proprio_projector is not None and proprio is not None: | |
| proprio = proprio.reshape(B, -1) | |
| proprio_feats = proprio_projector(proprio).unsqueeze(1) | |
| proprio_feats = proprio_feats.to(self.llm.dtype) | |
| projected = torch.cat([projected, proprio_feats], dim=1) | |
| else: | |
| projected = None | |
| # Cast projected features to LLM dtype | |
| if projected is not None: | |
| projected = projected.to(self.llm.dtype) | |
| # ---------- 4. Build multimodal embeddings ---------- | |
| if projected is not None: | |
| multimodal_embeds = torch.cat( | |
| [input_embeds[:, :1, :], projected, input_embeds[:, 1:, :]], dim=1 | |
| ) | |
| # Build attention mask | |
| if attention_mask is not None: | |
| vis_mask = torch.full( | |
| (B, projected.shape[1]), True, | |
| dtype=attention_mask.dtype, device=attention_mask.device, | |
| ) | |
| multimodal_attn_mask = torch.cat( | |
| [attention_mask[:, :1], vis_mask, attention_mask[:, 1:]], dim=1 | |
| ) | |
| else: | |
| multimodal_attn_mask = None | |
| else: | |
| multimodal_embeds = input_embeds | |
| multimodal_attn_mask = attention_mask | |
| # ---------- 5. Run LLM ---------- | |
| llm_out = self.llm( | |
| input_ids=None, | |
| attention_mask=multimodal_attn_mask, | |
| inputs_embeds=multimodal_embeds, | |
| labels=None, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=True, | |
| ) | |
| # ---------- 6. Apply hidden shim ---------- | |
| if output_hidden_states and llm_out.hidden_states is not None: | |
| shimmed = tuple( | |
| self.hidden_shim(hs) if i == len(llm_out.hidden_states) - 1 else hs | |
| for i, hs in enumerate(llm_out.hidden_states) | |
| ) | |
| hidden = shimmed | |
| else: | |
| hidden = llm_out.hidden_states | |
| # ---------- 7. Return ---------- | |
| if not return_dict: | |
| return llm_out.logits, hidden, llm_out.past_key_values | |
| return CausalLMOutputWithPast( | |
| loss=llm_out.loss, | |
| logits=llm_out.logits, | |
| past_key_values=llm_out.past_key_values, | |
| hidden_states=hidden, | |
| attentions=llm_out.attentions, | |
| ) | |
| def _action_mask(self, labels: torch.Tensor) -> torch.Tensor: | |
| """Same logic as PrismaticForConditionalGeneration_MMNv1.""" | |
| BEGIN = self.action_token_begin_idx | |
| current_mask = (labels >= BEGIN) & (labels < BEGIN + 256) | |
| next_mask = torch.zeros_like(labels, dtype=torch.bool) | |
| for b in range(labels.shape[0]): | |
| action_positions = torch.where(current_mask[b])[0] | |
| if len(action_positions) == 0: | |
| continue | |
| first, last = action_positions[0], action_positions[-1] | |
| if last + 1 < labels.shape[1] and labels[b, last + 1] < BEGIN: | |
| next_mask[b, last + 1] = True | |
| return current_mask | next_mask | |
| def _encode_image(self, img: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Encode a single (B, 3, 224, 224) image through DinoSigLIPEncoder. | |
| Pixel values are assumed to be normalized with ImageNet stats. | |
| """ | |
| B = img.shape[0] | |
| device = img.device | |
| dtype = img.dtype | |
| dino_out = self.vision_encoder.dino_featurizer(img) | |
| if isinstance(dino_out, (list, tuple)): | |
| dino_out = dino_out[0] | |
| # SigLIP needs different normalization | |
| siglip_input = _convert_pixel_values_for_siglip(img) | |
| siglip_out = self.vision_encoder.siglip_featurizer(siglip_input) | |
| if isinstance(siglip_out, (list, tuple)): | |
| siglip_out = siglip_out[0] | |
| # Zero-pad to joint dim and concat along sequence | |
| D_total = self.vision_encoder.total_embed_dim # 1152 | |
| dino_padded = torch.zeros(B, dino_out.shape[1], D_total, device=device, dtype=dtype) | |
| dino_padded[:, :, :dino_out.shape[-1]] = dino_out | |
| siglip_padded = torch.zeros(B, siglip_out.shape[1], D_total, device=device, dtype=dtype) | |
| siglip_padded[:, :, :siglip_out.shape[-1]] = siglip_out | |
| combined = torch.cat([dino_padded, siglip_padded], dim=1) # (B, 458, 1152) | |
| projected = self.projector(combined) # (B, 458, 896) | |
| return projected | |