# Copyright 2025 InternVLA-M1. All rights reserved. # Modified by [Jinhui YE/ HKUST University] in [2025]. # Modification: [add fake sample and predict_action to match with starVLA]. """ InternVLA M1 framework: Vision-Language-Action diffusion model integrating: - Qwen2.5 vision-language backbone - Layer-wise QFormer aggregation - DINO multi-view visual encoder - DiT diffusion head for future action sequence prediction Primary goal: predict continuous future actions conditioned on multi-view images + instruction. """ from typing import List from tqdm import tqdm from typing import List, Optional, Tuple import torch import torch.nn as nn import numpy as np from PIL import Image from qwen_vl_utils import process_vision_info from starVLA.training.trainer_utils import initialize_overwatch from starVLA.model.tools import FRAMEWORK_REGISTRY logger = initialize_overwatch(__name__) # HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) IGNORE_INDEX = -100 from starVLA.model.framework.base_framework import baseframework from starVLA.model.modules.vlm import get_vlm_model from starVLA.model.modules.projector.QFormer import get_layerwise_qformer from starVLA.model.modules.action_model.DiTActionHeader import get_action_model from starVLA.model.modules.dino_model.dino import get_dino_model from starVLA.training.trainer_utils.trainer_tools import resize_images @FRAMEWORK_REGISTRY.register("InternVLA-M1") class InternVLA_M1(baseframework): """ Multimodal vision-language-action model. Components: - Qwen2.5 VL interface for fused language/vision token embeddings - Layer-wise QFormer for multi-layer feature aggregation - DINO encoder for dense multi-view spatial tokens - DiT diffusion head for future action sequence modeling Focus: Predict future continuous actions conditioned on images + instruction. """ def __init__( self, config: Optional[dict] = None, **kwargs, ) -> None: """ Construct all submodules and cache key configuration values. Args: config: Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections. **kwargs: Reserved for future overrides (unused). """ super().__init__() self.config = config self.qwen_vl_interface = get_vlm_model(config=self.config) self.layer_qformer = get_layerwise_qformer(config=self.config) self.action_model = get_action_model(config=self.config) self.dino_encoder = get_dino_model( backone_name=getattr(self.config.framework.dino, "dino_backbone", "dinov2_vits14") ) self.dino_pro = nn.Linear( in_features=self.dino_encoder.num_channels, out_features=self.qwen_vl_interface.model.config.hidden_size ) self.future_action_window_size = config.framework.action_model.future_action_window_size self.past_action_window_size = config.framework.action_model.past_action_window_size def forward( self, examples: List[dict] = None, **kwargs, ) -> Tuple: """ Forward pass for training (diffusion objective). Flow: 1. Build QwenVL inputs (images + instruction tokens) 2. Extract hidden states from configured layer range 3. Encode images with DINO, flatten multi-view tokens and project 4. Concatenate per-layer language tokens with visual tokens 5. Fuse via layer-wise QFormer -> action condition embeddings 6. Prepare repeated future action windows (for diffusion efficiency) 7. Predict noise and compute diffusion loss Args: examples: List[dict], each dict requires: - image: List[PIL.Image] (multi-view) - lang: str instruction - action: np.ndarray or list shaped [T, action_dim] **kwargs: Reserved. Returns: dict: action_loss (torch.Tensor): Scalar diffusion noise prediction loss. """ batch_images = [example["image"] for example in examples] # [B,[PLT]] instructions = [example["lang"] for example in examples] # [B, str] actions = [example["action"] for example in examples] # label [B, len, 7] # Step 1: QWenVL input format qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions) with torch.autocast("cuda", dtype=torch.bfloat16): qwenvl_outputs = self.qwen_vl_interface( **qwen_inputs, output_attentions=False, output_hidden_states=True, return_dict=True, ) pass # Step 2: DINO Forward image_tensors = self.dino_encoder.prepare_dino_input(batch_images) # B = len(batch_images) dino_features = self.dino_encoder(image_tensors) # DINO output is [B*num_view, token, dim] dino_encoded_features = dino_features.reshape(B, -1, dino_features.shape[-1]) # [B, num_view * token, dim] dino_encoded_features = self.dino_pro(dino_encoded_features) # [B, num_view * token, hidden_size] # Step 3: aggregation condition for Action expert start_layer = self.config.framework.layer_qformer.qformer_start_layer end_layer = self.config.framework.layer_qformer.qformer_end_layer condition_features = qwenvl_outputs.hidden_states[start_layer:end_layer] cat_conditions = [] for layer_index in range(len(condition_features)): layer_features = condition_features[layer_index] # [B, n_qformer_token, D] layer_features = torch.cat( [layer_features, dino_encoded_features], dim=1 ) # [B, n_qformer_token + num_view * token, D] cat_conditions.append(layer_features) action_condition = self.layer_qformer(cat_conditions) # [B, 64, D_action] # Step 4: Action Expert Forward and Loss with torch.autocast("cuda", dtype=torch.float32): # here is a tips to accelerate training speed, by repeating each sample for several times @ref to CogACT actions = torch.tensor(np.array(actions), device=action_condition.device) # [B, chunk, 7] actions_future = actions[:, -(self.future_action_window_size + 1) :, :] # tips: Repeat 'actions' 'repeated_diffusion_steps' times, resulting in [repeated_diffusion_steps*B, T, D] repeated_diffusion_steps = ( self.config.trainer.get("repeated_diffusion_steps", 4) if self.config and self.config.trainer else 4 ) actions_repeated = actions_future.repeat(repeated_diffusion_steps, 1, 1) action_condition = action_condition.repeat( repeated_diffusion_steps, 1, 1 ) # [repeated_diffusion_steps*B, T, D_action] # DiT noise add and predict noise_pred, noise, timestep = self.action_model(actions_repeated, action_condition) # perdition loss action_loss = self.action_model.loss(noise_pred, noise) return {"action_loss": action_loss} @torch.inference_mode() def predict_action( self, batch_images: List[List[Image.Image]], # B * List of PIL Image as [view1, view2] instructions: List[str], cfg_scale: float = 1.5, use_ddim: bool = True, num_ddim_steps: int = 5, resize_image = [224, 224], **kwargs: str, ) -> np.ndarray: """ Inference: generate future normalized action sequence via diffusion sampling. Steps: 1. Resize images to training resolution (if specified) 2. Encode with QwenVL (hidden states retained) 3. Extract DINO tokens and project to vlm hidden size 4. Build multi-layer fused QwenVL and DINO features via QFormer 5. Run diffusion sampling (DDIM optional, CFG optional) 6. Return normalized action trajectory Args: batch_images: List of samples; each sample is List[PIL.Image] (multi-view). instructions: List[str] natural language task instructions. cfg_scale: >1 enables classifier-free guidance (scales conditional vs unconditional). use_ddim: Whether to use DDIM deterministic sampling. num_ddim_steps: Number of DDIM steps if enabled. **kwargs: Reserved. Returns: dict: normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions. """ # align obs and lang # is policy's duty to make sure the image size? train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None) if train_obs_image_size: batch_images = resize_images(batch_images, target_size=train_obs_image_size) instructions = [instruction.lower() for instruction in instructions] inferface_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions) qwen_inputs = inferface_inputs with torch.autocast("cuda", dtype=torch.bfloat16): qwenvl_outputs = self.qwen_vl_interface( **qwen_inputs, output_hidden_states=True, return_dict=True, ) B = len(batch_images) # dino don't have smart resize in processing image_tensors = self.dino_encoder.prepare_dino_input(batch_images) dino_features = self.dino_encoder(image_tensors) B = len(batch_images) dino_encoded_features = dino_features.reshape(B, -1, dino_features.shape[-1]) # [B, num_view * token, dim] dino_encoded_features = self.dino_pro(dino_encoded_features) # [B, 256, D] with torch.autocast("cuda", dtype=torch.bfloat16): start_layer = self.config.framework.layer_qformer.qformer_start_layer end_layer = self.config.framework.layer_qformer.qformer_end_layer condition_features = qwenvl_outputs.hidden_states[start_layer:end_layer] cat_conditions = [] for layer_index in range(len(condition_features)): layer_features = condition_features[layer_index] # [B, n_qformer_token, D] layer_features = torch.cat( [layer_features, dino_encoded_features], dim=1 ) # [B, n_qformer_token + num_view * token, D] cat_conditions.append(layer_features) action_condition_feature = self.layer_qformer(cat_conditions) # [B, 64, D_action] using_cfg = cfg_scale > 1.0 model_dtype = next(self.action_model.net.parameters()).dtype B = action_condition_feature.shape[0] # Sample random noise noise = torch.randn( B, self.future_action_window_size + 1, self.action_model.in_channels, device=action_condition_feature.device, ).to( model_dtype ) # [B, T, D] # Setup classifier-free guidance: if using_cfg: noise = torch.cat([noise, noise], 0) # [2,16,7] uncondition = self.action_model.net.z_embedder.uncondition # [64, 768] uncondition_shape = uncondition.shape uncondition = uncondition.unsqueeze(0) # [1, 64, D] uncondition = uncondition.expand( B, uncondition_shape[0], uncondition_shape[1] ) # [B, n_qformer_token, D] z = torch.cat([action_condition_feature, uncondition], 0) # [2, 64, 768] cfg_scale = cfg_scale model_kwargs = dict(z=z, cfg_scale=cfg_scale) sample_fn = self.action_model.net.forward_with_cfg else: model_kwargs = dict(z=action_condition_feature) sample_fn = self.action_model.net.forward # DDIM Sampling if use_ddim and num_ddim_steps is not None: if self.action_model.ddim_diffusion is None: self.action_model.create_ddim(ddim_step=num_ddim_steps) samples = self.action_model.ddim_diffusion.ddim_sample_loop( sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=action_condition_feature.device, eta=0.0, ) if using_cfg: samples, _ = samples.chunk(2, dim=0) # Remove null class samples normalized_actions = samples.cpu().numpy() raw_actions = None return {"normalized_actions": normalized_actions} # [B, T, action_dim] @torch.inference_mode() def chat_with_M1( self, image: Image.Image, text: str, max_new_tokens: int = 128, device: Optional[str] = "cuda", ) -> List[str]: processor = getattr(self.qwen_vl_interface, "processor", None) model = getattr(self.qwen_vl_interface, "model", None) # if processor is None or model is None: # raise RuntimeError("qwen_vl_interface 缺少 processor 或 model。") messages0 = [ { "role": "user", "content": [ { "type": "image", "image": image, }, {"type": "text", "text": text}, ], } ] messages = [messages0] # text info texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] # visual info image_inputs, video_inputs = process_vision_info(messages) # tokenizer inputs = processor( text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to(device) model.eval() generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens) generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] outputs = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return outputs if __name__ == "__main__": from omegaconf import OmegaConf import debugpy import argparse parser = argparse.ArgumentParser() parser.add_argument("--config_yaml", type=str, default="./starVLA/config/training/starvla_cotrain_oxe.yaml", help="Path to YAML config") args, clipargs = parser.parse_known_args() debugpy.listen(("0.0.0.0", 10092)) print("🔍 Rank 0 waiting for debugger attach on port 10092...") debugpy.wait_for_client() cfg = OmegaConf.load(args.config_yaml) # try get model model = InternVLA_M1(cfg) print(model) # fake sample image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) # Create a sample sample = { "action": np.random.uniform(-1, 1, size=(16, 7)).astype(np.float16), # action_chunk, action_dim "image": [image, image], # two views "lang": "This is a fake instruction for testing.", # "state" : np.random.uniform(-1, 1, size=(1, 7)).astype(np.float16), # chunk, state_dim } batch = [sample, sample] # batch size 2 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) forward_output = model(batch) action_loss = forward_output['action_loss'] print(f"Action Loss: {action_loss.item()}") # test predict action predict_output = model.predict_action(batch_images=[batch[0]["image"]], instructions=[batch[0]["lang"]]) normalized_actions = predict_output['normalized_actions'] print(f"Unnormalized Action: {normalized_actions}") # model_path = "./results/Checkpoints/1_need/0906_bestvla_retrain_sota2/checkpoints/steps_50000_pytorch_model.pt" # state_dict = torch.load(model_path, map_location="cpu") # model.load_state_dict(state_dict, strict=True) # # try forward model # # can be fake sample, but here get from dataloader for simpler # from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn # vla_dataset_cfg = cfg.datasets.vla_data # dataset = get_vla_dataset(data_cfg=vla_dataset_cfg) # from torch.utils.data import DataLoader # train_dataloader = DataLoader( # dataset, # batch_size=2, # num_workers=1, # For Debug # collate_fn=collate_fn, # ) # for batch in tqdm(train_dataloader, desc="Processing Batches"): # batch # break # # try get model # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model = model.to(device) # model(batch)