import sys sys.path.insert(0, "/mnt/steamdrive/openvla-micro") sys.path.insert(0, "/home/the_one1/OmniVLA") import math from typing import Dict, List, Optional, Tuple, Union import numpy as np import timm import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from timm.models.vision_transformer import VisionTransformer from torchvision.transforms import Compose, Resize, Normalize from transformers.modeling_outputs import CausalLMOutputWithPast from modeling_openvla_micro import ( DinoSigLIPEncoder, CombinedProjector, ShimMLP, unpack_tuple, monkey_patch_featurizer, _build_timm_transform, ) class VisionBackboneWrapper: """Wrapper exposing API the training loop expects, using real encoder shapes.""" def __init__(self, encoder: DinoSigLIPEncoder): self._encoder = encoder self._num_images_in_input = 2 # Infer patch count from encoder output device = next(encoder.parameters()).device with torch.inference_mode(): dummy = torch.zeros(1, 3, 224, 224, device=device, dtype=next(encoder.parameters()).dtype) dino_out = encoder.dino_featurizer(dummy) if isinstance(dino_out, (list, tuple)): dino_out = dino_out[0] siglip_out = encoder.siglip_featurizer(dummy) if isinstance(siglip_out, (list, tuple)): siglip_out = siglip_out[0] self._patches_per_img = dino_out.shape[1] + siglip_out.shape[1] print(f"[VisionBackboneWrapper] patches_per_img = {self._patches_per_img}") def get_num_patches(self) -> int: return self._patches_per_img def get_num_images_in_input(self) -> int: return self._num_images_in_input def set_num_images_in_input(self, n: int) -> None: self._num_images_in_input = n # SigLIP stats SIGLIP_MEAN = torch.tensor([0.5, 0.5, 0.5]).reshape(1, 3, 1, 1) SIGLIP_STD = torch.tensor([0.5, 0.5, 0.5]).reshape(1, 3, 1, 1) # ImageNet stats (used by OpenVLA processor and DINOv2) IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1) def _convert_pixel_values_for_siglip(pixel_values: torch.Tensor) -> torch.Tensor: """Undo ImageNet norm, apply SigLIP norm.""" device = pixel_values.device m = IMAGENET_MEAN.to(device) s = IMAGENET_STD.to(device) sm = SIGLIP_MEAN.to(device) ss = SIGLIP_STD.to(device) return (pixel_values * s + m - sm) / ss class OpenVLAMicroWrapper(nn.Module): """ Wraps openvla-micro (DinoSigLIPEncoder + CombinedProjector + Qwen2.5) + MLP hidden shim (896→2048→4096) into OmniVLA's forward interface. Forward signature matches ``PrismaticForConditionalGeneration_MMNv1`` so that the existing training loop and ``run_forward_pass`` work as-is. """ def __init__( self, vision_encoder: DinoSigLIPEncoder, projector: CombinedProjector, llm: nn.Module, hidden_shim: nn.Module, tokenizer, pad_token_id: int, action_token_begin_idx: int = 151679, ): super().__init__() self.vision_encoder = vision_encoder self.projector = projector self.llm = llm self.hidden_shim = hidden_shim self.tokenizer = tokenizer self.pad_token_id = pad_token_id self.action_token_begin_idx = action_token_begin_idx self.vision_backbone = VisionBackboneWrapper(vision_encoder) self.llm_dim = 4096 # after hidden shim self.vocab_size = llm.config.vocab_size def get_input_embeddings(self) -> nn.Module: if hasattr(self.llm, 'get_input_embeddings'): return self.llm.get_input_embeddings() if hasattr(self.llm, 'base_model') and hasattr(self.llm.base_model, 'model'): return self.llm.base_model.model.embed_tokens return self.llm.model.embed_tokens def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, pixel_values: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, # MoE/novelty projection (unused but accepted for compat) proprio=None, proprio_projector=None, noisy_actions=None, noisy_action_projector=None, diffusion_timestep_embeddings=None, use_film: bool = False, # MMNv1-specific attention_mask_label=None, pixel_values_goal=None, img_hist=None, map_images=None, obs_img=None, modality_id=None, goal_pose=None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: _ = (attention_mask_label, pixel_values_goal, img_hist, map_images, obs_img, modality_id, goal_pose, noisy_actions, noisy_action_projector, diffusion_timestep_embeddings, use_film) output_hidden_states = output_hidden_states or False return_dict = return_dict if return_dict is not None else True use_cache = use_cache and not self.training B = input_ids.shape[0] # ---------- 1. Input embeddings ---------- embed_fn = self.get_input_embeddings() input_embeds = embed_fn(input_ids) # (B, seq_len, 896) input_embeds = input_embeds.to(self.llm.dtype) # ---------- 2. Action mask ---------- if labels is not None: all_actions_mask = self._action_mask(labels) all_actions_mask_3d = all_actions_mask.unsqueeze(-1) input_embeds = input_embeds * ~all_actions_mask_3d else: all_actions_mask = None # ---------- 3. Vision features ---------- if pixel_values is not None: num_patches_per_img = self.vision_backbone.get_num_patches() num_imgs = pixel_values.shape[1] // 3 # 6 channels → 2 images patch_feats = [] for i in range(num_imgs): img = pixel_values[:, i*3:(i+1)*3, :, :] # (B, 3, 224, 224) feats = self._encode_image(img) patch_feats.append(feats) projected = torch.cat(patch_feats, dim=1) # (B, N*num_imgs, 896) # Optionally append proprio if proprio_projector is not None and proprio is not None: proprio = proprio.reshape(B, -1) proprio_feats = proprio_projector(proprio).unsqueeze(1) proprio_feats = proprio_feats.to(self.llm.dtype) projected = torch.cat([projected, proprio_feats], dim=1) else: projected = None # Cast projected features to LLM dtype if projected is not None: projected = projected.to(self.llm.dtype) # ---------- 4. Build multimodal embeddings ---------- if projected is not None: multimodal_embeds = torch.cat( [input_embeds[:, :1, :], projected, input_embeds[:, 1:, :]], dim=1 ) # Build attention mask if attention_mask is not None: vis_mask = torch.full( (B, projected.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device, ) multimodal_attn_mask = torch.cat( [attention_mask[:, :1], vis_mask, attention_mask[:, 1:]], dim=1 ) else: multimodal_attn_mask = None else: multimodal_embeds = input_embeds multimodal_attn_mask = attention_mask # ---------- 5. Run LLM ---------- llm_out = self.llm( input_ids=None, attention_mask=multimodal_attn_mask, inputs_embeds=multimodal_embeds, labels=None, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) # ---------- 6. Apply hidden shim ---------- if output_hidden_states and llm_out.hidden_states is not None: shimmed = tuple( self.hidden_shim(hs) if i == len(llm_out.hidden_states) - 1 else hs for i, hs in enumerate(llm_out.hidden_states) ) hidden = shimmed else: hidden = llm_out.hidden_states # ---------- 7. Return ---------- if not return_dict: return llm_out.logits, hidden, llm_out.past_key_values return CausalLMOutputWithPast( loss=llm_out.loss, logits=llm_out.logits, past_key_values=llm_out.past_key_values, hidden_states=hidden, attentions=llm_out.attentions, ) def _action_mask(self, labels: torch.Tensor) -> torch.Tensor: """Same logic as PrismaticForConditionalGeneration_MMNv1.""" BEGIN = self.action_token_begin_idx current_mask = (labels >= BEGIN) & (labels < BEGIN + 256) next_mask = torch.zeros_like(labels, dtype=torch.bool) for b in range(labels.shape[0]): action_positions = torch.where(current_mask[b])[0] if len(action_positions) == 0: continue first, last = action_positions[0], action_positions[-1] if last + 1 < labels.shape[1] and labels[b, last + 1] < BEGIN: next_mask[b, last + 1] = True return current_mask | next_mask def _encode_image(self, img: torch.Tensor) -> torch.Tensor: """ Encode a single (B, 3, 224, 224) image through DinoSigLIPEncoder. Pixel values are assumed to be normalized with ImageNet stats. """ B = img.shape[0] device = img.device dtype = img.dtype dino_out = self.vision_encoder.dino_featurizer(img) if isinstance(dino_out, (list, tuple)): dino_out = dino_out[0] # SigLIP needs different normalization siglip_input = _convert_pixel_values_for_siglip(img) siglip_out = self.vision_encoder.siglip_featurizer(siglip_input) if isinstance(siglip_out, (list, tuple)): siglip_out = siglip_out[0] # Zero-pad to joint dim and concat along sequence D_total = self.vision_encoder.total_embed_dim # 1152 dino_padded = torch.zeros(B, dino_out.shape[1], D_total, device=device, dtype=dtype) dino_padded[:, :, :dino_out.shape[-1]] = dino_out siglip_padded = torch.zeros(B, siglip_out.shape[1], D_total, device=device, dtype=dtype) siglip_padded[:, :, :siglip_out.shape[-1]] = siglip_out combined = torch.cat([dino_padded, siglip_padded], dim=1) # (B, 458, 1152) projected = self.projector(combined) # (B, 458, 896) return projected