openvla-micro / model_wrapper.py
theguy21's picture
Initial upload: base + distill checkpoints, model code, train_shim.py
dd9b4af verified
Raw
History Blame Contribute Delete
11.2 kB
import sys
sys.path.insert(0, "/mnt/steamdrive/openvla-micro")
sys.path.insert(0, "/home/the_one1/OmniVLA")
import math
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from timm.models.vision_transformer import VisionTransformer
from torchvision.transforms import Compose, Resize, Normalize
from transformers.modeling_outputs import CausalLMOutputWithPast
from modeling_openvla_micro import (
DinoSigLIPEncoder,
CombinedProjector,
ShimMLP,
unpack_tuple,
monkey_patch_featurizer,
_build_timm_transform,
)
class VisionBackboneWrapper:
"""Wrapper exposing API the training loop expects, using real encoder shapes."""
def __init__(self, encoder: DinoSigLIPEncoder):
self._encoder = encoder
self._num_images_in_input = 2
# Infer patch count from encoder output
device = next(encoder.parameters()).device
with torch.inference_mode():
dummy = torch.zeros(1, 3, 224, 224, device=device, dtype=next(encoder.parameters()).dtype)
dino_out = encoder.dino_featurizer(dummy)
if isinstance(dino_out, (list, tuple)):
dino_out = dino_out[0]
siglip_out = encoder.siglip_featurizer(dummy)
if isinstance(siglip_out, (list, tuple)):
siglip_out = siglip_out[0]
self._patches_per_img = dino_out.shape[1] + siglip_out.shape[1]
print(f"[VisionBackboneWrapper] patches_per_img = {self._patches_per_img}")
def get_num_patches(self) -> int:
return self._patches_per_img
def get_num_images_in_input(self) -> int:
return self._num_images_in_input
def set_num_images_in_input(self, n: int) -> None:
self._num_images_in_input = n
# SigLIP stats
SIGLIP_MEAN = torch.tensor([0.5, 0.5, 0.5]).reshape(1, 3, 1, 1)
SIGLIP_STD = torch.tensor([0.5, 0.5, 0.5]).reshape(1, 3, 1, 1)
# ImageNet stats (used by OpenVLA processor and DINOv2)
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1)
def _convert_pixel_values_for_siglip(pixel_values: torch.Tensor) -> torch.Tensor:
"""Undo ImageNet norm, apply SigLIP norm."""
device = pixel_values.device
m = IMAGENET_MEAN.to(device)
s = IMAGENET_STD.to(device)
sm = SIGLIP_MEAN.to(device)
ss = SIGLIP_STD.to(device)
return (pixel_values * s + m - sm) / ss
class OpenVLAMicroWrapper(nn.Module):
"""
Wraps openvla-micro (DinoSigLIPEncoder + CombinedProjector + Qwen2.5)
+ MLP hidden shim (896→2048→4096) into OmniVLA's forward interface.
Forward signature matches ``PrismaticForConditionalGeneration_MMNv1``
so that the existing training loop and ``run_forward_pass`` work as-is.
"""
def __init__(
self,
vision_encoder: DinoSigLIPEncoder,
projector: CombinedProjector,
llm: nn.Module,
hidden_shim: nn.Module,
tokenizer,
pad_token_id: int,
action_token_begin_idx: int = 151679,
):
super().__init__()
self.vision_encoder = vision_encoder
self.projector = projector
self.llm = llm
self.hidden_shim = hidden_shim
self.tokenizer = tokenizer
self.pad_token_id = pad_token_id
self.action_token_begin_idx = action_token_begin_idx
self.vision_backbone = VisionBackboneWrapper(vision_encoder)
self.llm_dim = 4096 # after hidden shim
self.vocab_size = llm.config.vocab_size
def get_input_embeddings(self) -> nn.Module:
if hasattr(self.llm, 'get_input_embeddings'):
return self.llm.get_input_embeddings()
if hasattr(self.llm, 'base_model') and hasattr(self.llm.base_model, 'model'):
return self.llm.base_model.model.embed_tokens
return self.llm.model.embed_tokens
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# MoE/novelty projection (unused but accepted for compat)
proprio=None,
proprio_projector=None,
noisy_actions=None,
noisy_action_projector=None,
diffusion_timestep_embeddings=None,
use_film: bool = False,
# MMNv1-specific
attention_mask_label=None,
pixel_values_goal=None,
img_hist=None,
map_images=None,
obs_img=None,
modality_id=None,
goal_pose=None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
_ = (attention_mask_label, pixel_values_goal, img_hist, map_images, obs_img, modality_id, goal_pose,
noisy_actions, noisy_action_projector, diffusion_timestep_embeddings, use_film)
output_hidden_states = output_hidden_states or False
return_dict = return_dict if return_dict is not None else True
use_cache = use_cache and not self.training
B = input_ids.shape[0]
# ---------- 1. Input embeddings ----------
embed_fn = self.get_input_embeddings()
input_embeds = embed_fn(input_ids) # (B, seq_len, 896)
input_embeds = input_embeds.to(self.llm.dtype)
# ---------- 2. Action mask ----------
if labels is not None:
all_actions_mask = self._action_mask(labels)
all_actions_mask_3d = all_actions_mask.unsqueeze(-1)
input_embeds = input_embeds * ~all_actions_mask_3d
else:
all_actions_mask = None
# ---------- 3. Vision features ----------
if pixel_values is not None:
num_patches_per_img = self.vision_backbone.get_num_patches()
num_imgs = pixel_values.shape[1] // 3 # 6 channels → 2 images
patch_feats = []
for i in range(num_imgs):
img = pixel_values[:, i*3:(i+1)*3, :, :] # (B, 3, 224, 224)
feats = self._encode_image(img)
patch_feats.append(feats)
projected = torch.cat(patch_feats, dim=1) # (B, N*num_imgs, 896)
# Optionally append proprio
if proprio_projector is not None and proprio is not None:
proprio = proprio.reshape(B, -1)
proprio_feats = proprio_projector(proprio).unsqueeze(1)
proprio_feats = proprio_feats.to(self.llm.dtype)
projected = torch.cat([projected, proprio_feats], dim=1)
else:
projected = None
# Cast projected features to LLM dtype
if projected is not None:
projected = projected.to(self.llm.dtype)
# ---------- 4. Build multimodal embeddings ----------
if projected is not None:
multimodal_embeds = torch.cat(
[input_embeds[:, :1, :], projected, input_embeds[:, 1:, :]], dim=1
)
# Build attention mask
if attention_mask is not None:
vis_mask = torch.full(
(B, projected.shape[1]), True,
dtype=attention_mask.dtype, device=attention_mask.device,
)
multimodal_attn_mask = torch.cat(
[attention_mask[:, :1], vis_mask, attention_mask[:, 1:]], dim=1
)
else:
multimodal_attn_mask = None
else:
multimodal_embeds = input_embeds
multimodal_attn_mask = attention_mask
# ---------- 5. Run LLM ----------
llm_out = self.llm(
input_ids=None,
attention_mask=multimodal_attn_mask,
inputs_embeds=multimodal_embeds,
labels=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
# ---------- 6. Apply hidden shim ----------
if output_hidden_states and llm_out.hidden_states is not None:
shimmed = tuple(
self.hidden_shim(hs) if i == len(llm_out.hidden_states) - 1 else hs
for i, hs in enumerate(llm_out.hidden_states)
)
hidden = shimmed
else:
hidden = llm_out.hidden_states
# ---------- 7. Return ----------
if not return_dict:
return llm_out.logits, hidden, llm_out.past_key_values
return CausalLMOutputWithPast(
loss=llm_out.loss,
logits=llm_out.logits,
past_key_values=llm_out.past_key_values,
hidden_states=hidden,
attentions=llm_out.attentions,
)
def _action_mask(self, labels: torch.Tensor) -> torch.Tensor:
"""Same logic as PrismaticForConditionalGeneration_MMNv1."""
BEGIN = self.action_token_begin_idx
current_mask = (labels >= BEGIN) & (labels < BEGIN + 256)
next_mask = torch.zeros_like(labels, dtype=torch.bool)
for b in range(labels.shape[0]):
action_positions = torch.where(current_mask[b])[0]
if len(action_positions) == 0:
continue
first, last = action_positions[0], action_positions[-1]
if last + 1 < labels.shape[1] and labels[b, last + 1] < BEGIN:
next_mask[b, last + 1] = True
return current_mask | next_mask
def _encode_image(self, img: torch.Tensor) -> torch.Tensor:
"""
Encode a single (B, 3, 224, 224) image through DinoSigLIPEncoder.
Pixel values are assumed to be normalized with ImageNet stats.
"""
B = img.shape[0]
device = img.device
dtype = img.dtype
dino_out = self.vision_encoder.dino_featurizer(img)
if isinstance(dino_out, (list, tuple)):
dino_out = dino_out[0]
# SigLIP needs different normalization
siglip_input = _convert_pixel_values_for_siglip(img)
siglip_out = self.vision_encoder.siglip_featurizer(siglip_input)
if isinstance(siglip_out, (list, tuple)):
siglip_out = siglip_out[0]
# Zero-pad to joint dim and concat along sequence
D_total = self.vision_encoder.total_embed_dim # 1152
dino_padded = torch.zeros(B, dino_out.shape[1], D_total, device=device, dtype=dtype)
dino_padded[:, :, :dino_out.shape[-1]] = dino_out
siglip_padded = torch.zeros(B, siglip_out.shape[1], D_total, device=device, dtype=dtype)
siglip_padded[:, :, :siglip_out.shape[-1]] = siglip_out
combined = torch.cat([dino_padded, siglip_padded], dim=1) # (B, 458, 1152)
projected = self.projector(combined) # (B, 458, 896)
return projected