# Copyright 2025 starVLA community. All rights reserved. # Licensed under the MIT License, Version 1.0 (the "License"); # Implemented by Jinhui YE / HKUST University in [2025]. """ QwenLatent History Naive Baseline Ablation / baseline variant of QwenLatent_history. Instead of using a dedicated action encoder (Qwen3-based transformer) to compress history action+state sequences into a compact latent embedding, this model projects each history timestep directly into the VLM token space via two lightweight MLP projectors: - history_action_projector : R^{action_size} -> R^{llm_hidden_size} - history_state_projector : R^{state_size} -> R^{llm_hidden_size} The resulting per-step tokens are interleaved as [a_0, s_0, a_1, s_1, ..., a_{T-1}, s_{T-1}] and prepended to the VLM context (after the dataset soft-prompt, before the visual/language tokens). This preserves the identical training objective, loss weights, and dual-branch (no-history / with-history) structure as QwenLatent_history, so results are directly comparable. The only difference is how history information is encoded: here we use a flat MLP projection instead of the action encoder. """ import sys sys.path.append("/mnt/data/fangyu/code/rewardmodel") from typing import List, Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from omegaconf import OmegaConf from PIL import Image from starVLA.training.trainer_utils import initialize_overwatch from starVLA.model.framework.base_framework import baseframework from starVLA.model.modules.vlm import get_vlm_model from starVLA.model.modules.action_model.ActionModel_FM import ActionModelFM from starVLA.model.modules.action_model.configuration_actionmodel import ActionModelConfig from starVLA.dataloader.gr00t_lerobot.datasets import ACTION_REPRESENTATION_SLICES from starVLA.training.trainer_utils.trainer_tools import resize_images from starVLA.model.tools import FRAMEWORK_REGISTRY logger = initialize_overwatch(__name__) IGNORE_INDEX = -100 @FRAMEWORK_REGISTRY.register("QwenLatent_history_naive") class QwenLatentHistoryNaive(baseframework): """ Naive history baseline: project each history step's action and state independently via MLP projectors and append the resulting token sequence to the VLM context. Architecture overview --------------------- Input (with history):: [ds_embed | hist_action_0 | hist_state_0 | ... | hist_action_{T-1} | hist_state_{T-1} | VL_tokens | query_token] Compared to QwenLatent_history, the action-encoder is only used as the flow-matching *decoder* here — its *encoder* path is bypassed for history encoding. The action model itself is still used for: - Computing GT action embeddings (for align loss) - Decoding predicted embeddings to actions during inference """ # ------------------------------------------------------------------ # Helper: last non-pad token index # ------------------------------------------------------------------ @staticmethod def _get_last_nonpad_indices(attention_mask: torch.Tensor) -> torch.Tensor: if attention_mask is None: raise ValueError("attention_mask cannot be None") if attention_mask.dim() != 2: raise ValueError( f"attention_mask must be 2D [B,T], got shape {tuple(attention_mask.shape)}" ) mask = attention_mask.to(dtype=torch.long) rev_first_one = torch.flip(mask, dims=[1]).argmax(dim=1) last_nonpad = mask.size(1) - 1 - rev_first_one return last_nonpad # ------------------------------------------------------------------ # Construction # ------------------------------------------------------------------ def __init__(self, config: Optional[dict] = None, **kwargs) -> None: super().__init__() self.config = config self.qwen_vl_interface = get_vlm_model(config=self.config) num_vl_layers, llm_hidden_size = 36, self.qwen_vl_interface.model.config.hidden_size self.llm_hidden_size = llm_hidden_size self.config.framework.qwenvl.vl_hidden_dim = llm_hidden_size self.config.framework.qwenvl.num_vl_layers = num_vl_layers # Action model (used only as flow-matching decoder + GT encoder for loss) action_model_cfg = getattr(self.config.framework, "action_model", None) if action_model_cfg is not None: action_model_kwargs = OmegaConf.to_container(action_model_cfg, resolve=True) print(f"{action_model_kwargs=}") self.action_model = ActionModelFM(ActionModelConfig(**action_model_kwargs)) else: self.action_model = ActionModelFM(ActionModelConfig()) ckpt_path = getattr(self.config.framework.action_model, "ckpt_path", None) if ckpt_path: self.action_model.load_state_dict( torch.load(ckpt_path, map_location="cpu"), strict=True ) print(f"✅ loaded action model from {ckpt_path}") print(f"action model loss mode: {self.action_model.config.loss_mode}") # Dataset soft-prompt embedding self.dataset_vocab_size = getattr( self.config.framework.action_model, "dataset_vocab_size", 256 ) self.num_data_tokens = getattr(self.config.framework.qwenvl, "num_data_tokens", 32) self.dataset_embed = nn.Embedding( self.dataset_vocab_size, llm_hidden_size * self.num_data_tokens, ) # Learnable query token (VLM output token used for action prediction) self.query_token = nn.Parameter(torch.randn(1, 1, llm_hidden_size)) # VLM → action-space projector (query token hidden → action embedding) action_hidden_size = self.action_model.config.hidden_size self.action_embed_projector = nn.Sequential( nn.Linear(llm_hidden_size, llm_hidden_size), nn.GELU(), nn.Linear(llm_hidden_size, action_hidden_size), ) # Chunk / history book-keeping self.total_action_chunk_size = self.config.datasets.vla_data.chunk_size self.num_history_steps = self.config.datasets.vla_data.num_history_steps print(f"num_history_steps: {self.num_history_steps}") self.chunk_size = self.total_action_chunk_size - self.num_history_steps self.use_state = self.action_model.use_state # ------------------------------------------------------------------ # Naive history projectors # Each history timestep's raw action / state is projected to a single # VLM-dimension token via a two-layer MLP. # ------------------------------------------------------------------ action_size = self.action_model.config.action_size state_size = self.action_model.config.state_size self.history_action_projector = nn.Sequential( nn.Linear(action_size, llm_hidden_size), nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size), ) if self.use_state and state_size > 0: self.history_state_projector = nn.Sequential( nn.Linear(state_size, llm_hidden_size), nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size), ) else: self.history_state_projector = None # ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ def _maybe_log_align_stats( self, predicted_action_embeddings: torch.Tensor, gt_action_embeddings: torch.Tensor, ) -> None: if getattr(self, "_align_stats_logged", False): return if torch.distributed.is_available() and torch.distributed.is_initialized(): if torch.distributed.get_rank() != 0: return with torch.no_grad(): pred = predicted_action_embeddings.float() gt = gt_action_embeddings.float() logger.info( "Align stats: pred(mean=%.4f,std=%.4f,avg_norm=%.4f) " "gt(mean=%.4f,std=%.4f,avg_norm=%.4f)", pred.mean().item(), pred.std().item(), pred.norm(dim=-1).mean().item(), gt.mean().item(), gt.std().item(), gt.norm(dim=-1).mean().item(), ) self._align_stats_logged = True def _encode_history_tokens( self, history_actions: torch.Tensor, history_states: Optional[torch.Tensor], ) -> torch.Tensor: """ Project raw history actions (and optionally states) into VLM token space. Args: history_actions : [B, T_hist, action_size] float32 history_states : [B, T_hist, state_size] float32 or None Returns: history_tokens : [B, T_hist * (1 or 2), llm_hidden_size] Interleaved as [a_0, s_0, a_1, s_1, ...] when state is available, otherwise [a_0, a_1, ...]. """ B, T, _ = history_actions.shape # Cast to model dtype for the projectors proj_dtype = self.history_action_projector[0].weight.dtype act = history_actions.to(proj_dtype) act_tokens = self.history_action_projector(act) # [B, T, llm_hidden_size] if self.history_state_projector is not None and history_states is not None: sta = history_states.to(proj_dtype) sta_tokens = self.history_state_projector(sta) # [B, T, llm_hidden_size] # Interleave: [a_0, s_0, a_1, s_1, ...] # Stack along a new dim then reshape: [B, T, 2, H] → [B, 2T, H] interleaved = torch.stack([act_tokens, sta_tokens], dim=2) # [B, T, 2, H] history_tokens = interleaved.view(B, T * 2, self.llm_hidden_size) else: history_tokens = act_tokens # [B, T, llm_hidden_size] return history_tokens def _build_qwen_inputs( self, images: List, instructions: List[str], dataset_ids: List[int], extra_prefix_embeds: Optional[torch.Tensor] = None, ) -> dict: qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs( images=images, instructions=instructions, ) if "input_ids" in qwen_inputs: dataset_ids_tensor = torch.tensor( dataset_ids, device=qwen_inputs["input_ids"].device, dtype=torch.long, ) ds_embeds = self.dataset_embed(dataset_ids_tensor).view( len(dataset_ids), self.num_data_tokens, self.llm_hidden_size ) token_embeds = self.qwen_vl_interface.model.get_input_embeddings()( qwen_inputs["input_ids"] ) query_embeds = self.query_token.expand(len(dataset_ids), -1, -1) embed_parts = [ds_embeds, token_embeds] if extra_prefix_embeds is not None: embed_parts.append(extra_prefix_embeds) embed_parts.append(query_embeds) qwen_inputs["inputs_embeds"] = torch.cat(embed_parts, dim=1) qwen_inputs.pop("input_ids") if "attention_mask" in qwen_inputs: prefix_mask = torch.ones( (qwen_inputs["attention_mask"].shape[0], self.num_data_tokens), device=qwen_inputs["attention_mask"].device, dtype=qwen_inputs["attention_mask"].dtype, ) mask_parts = [prefix_mask, qwen_inputs["attention_mask"]] if extra_prefix_embeds is not None: history_mask = torch.ones( ( qwen_inputs["attention_mask"].shape[0], extra_prefix_embeds.shape[1], ), device=qwen_inputs["attention_mask"].device, dtype=qwen_inputs["attention_mask"].dtype, ) mask_parts.append(history_mask) query_mask = torch.ones( (qwen_inputs["attention_mask"].shape[0], 1), device=qwen_inputs["attention_mask"].device, dtype=qwen_inputs["attention_mask"].dtype, ) mask_parts.append(query_mask) qwen_inputs["attention_mask"] = torch.cat(mask_parts, dim=1) if "position_ids" in qwen_inputs: extra_prefix_len = ( 0 if extra_prefix_embeds is None else extra_prefix_embeds.shape[1] ) prefix_total_len = self.num_data_tokens + extra_prefix_len prefix_pos = ( torch.arange( prefix_total_len, device=qwen_inputs["position_ids"].device, dtype=qwen_inputs["position_ids"].dtype, ) .unsqueeze(0) .expand(qwen_inputs["position_ids"].shape[0], -1) ) query_pos = torch.full( (qwen_inputs["position_ids"].shape[0], 1), qwen_inputs["position_ids"].shape[1] + prefix_total_len, device=qwen_inputs["position_ids"].device, dtype=qwen_inputs["position_ids"].dtype, ) qwen_inputs["position_ids"] = torch.cat( ( prefix_pos, qwen_inputs["position_ids"] + prefix_total_len, query_pos, ), dim=1, ) return qwen_inputs def _encode_vlm_action_embedding(self, qwen_inputs: dict) -> torch.Tensor: with torch.autocast("cuda", dtype=torch.bfloat16): qwenvl_outputs = self.qwen_vl_interface( **qwen_inputs, output_attentions=False, output_hidden_states=True, return_dict=True, ) last_hidden_states = qwenvl_outputs.hidden_states[-1] if "attention_mask" in qwen_inputs: last_token_indices = self._get_last_nonpad_indices( qwen_inputs["attention_mask"] ) batch_indices = torch.arange( last_hidden_states.shape[0], device=last_hidden_states.device ) action_token_hidden = last_hidden_states[batch_indices, last_token_indices] else: action_token_hidden = last_hidden_states[:, -1, :] predicted_action_embeddings = self.action_embed_projector( action_token_hidden ).float() predicted_action_embeddings = F.normalize( predicted_action_embeddings, p=2, dim=-1 ) return predicted_action_embeddings def _compute_branch_losses( self, predicted_action_embeddings: torch.Tensor, actions_target: torch.Tensor, states_target: Optional[torch.Tensor], dataset_ids: List[int], ) -> dict: loss_mode = getattr(self.action_model.config, "loss_mode", "full") with torch.autocast("cuda", dtype=torch.float32): B = actions_target.shape[0] t = self.action_model._sample_fm_time( B, device=actions_target.device, dtype=actions_target.dtype ) noise = torch.randn_like(actions_target) if loss_mode == "predict_only": align_loss = None recon_loss = None predict_loss = self.action_model.recon_loss_from_embedding( actions=actions_target, action_embedding=predicted_action_embeddings, t=t, noise=noise, ) else: gt_action_embeddings = self.action_model.encode_actions( actions=actions_target, dataset_ids=dataset_ids, state=states_target, ) self._maybe_log_align_stats(predicted_action_embeddings, gt_action_embeddings) align_loss = F.l1_loss( predicted_action_embeddings, gt_action_embeddings.float().detach(), ) recon_loss = self.action_model.recon_loss_from_embedding( actions=actions_target, action_embedding=gt_action_embeddings, t=t, noise=noise, ) predict_loss = self.action_model.recon_loss_from_embedding( actions=actions_target, action_embedding=predicted_action_embeddings, t=t, noise=noise, ) return { "align_loss": align_loss, "recon_loss": recon_loss, "predict_loss": predict_loss, } # ------------------------------------------------------------------ # Forward (training) # ------------------------------------------------------------------ def forward(self, examples: List[dict] = None, **kwargs): """ Dual-branch forward (mirrors QwenLatent_history exactly): Branch 1 — no history: Image at step 0, predict actions[0 : chunk_size]. Branch 2 — with history (only when num_history_steps > 0): mid_image (image at step num_history_steps), with naive MLP projection of history actions/states prepended to VLM context. Predict actions[num_history_steps : total_chunk_size]. Returns combined losses (average of both branches). """ batch_images = [ex["image"] for ex in examples] instructions = [ex["lang"] for ex in examples] actions = [ex["action"] for ex in examples] states = [ex["state"] for ex in examples] if self.use_state else None dataset_ids = [ex.get("dataset_id", 0) for ex in examples] device = self.query_token.device actions_full = torch.as_tensor( np.array(actions), device=device, dtype=torch.float32 ) assert actions_full.shape[1] == self.total_action_chunk_size states_full = None if self.use_state: states_full = torch.as_tensor( np.array(states), device=device, dtype=torch.float32 ) assert states_full.shape[1] == self.total_action_chunk_size # ---------- Branch 1: no history ---------- no_hist_qwen_inputs = self._build_qwen_inputs( images=batch_images, instructions=instructions, dataset_ids=dataset_ids, extra_prefix_embeds=None, ) no_hist_pred_emb = self._encode_vlm_action_embedding(no_hist_qwen_inputs) no_hist_losses = self._compute_branch_losses( predicted_action_embeddings=no_hist_pred_emb, actions_target=actions_full[:, : self.chunk_size], states_target=( states_full[:, : self.chunk_size] if states_full is not None else None ), dataset_ids=dataset_ids, ) if self.num_history_steps <= 0: return no_hist_losses # ---------- Branch 2: naive history ---------- if not all("mid_image" in ex for ex in examples): raise ValueError("num_history_steps > 0 but `mid_image` is missing in examples.") mid_images = [ex["mid_image"] for ex in examples] history_actions = actions_full[:, : self.num_history_steps] history_states = ( states_full[:, : self.num_history_steps] if states_full is not None else None ) # Project raw history tokens via naive MLP (the key difference) history_tokens = self._encode_history_tokens(history_actions, history_states) hist_qwen_inputs = self._build_qwen_inputs( images=mid_images, instructions=instructions, dataset_ids=dataset_ids, extra_prefix_embeds=history_tokens, ) hist_pred_emb = self._encode_vlm_action_embedding(hist_qwen_inputs) hist_losses = self._compute_branch_losses( predicted_action_embeddings=hist_pred_emb, actions_target=actions_full[:, self.num_history_steps :], states_target=( states_full[:, self.num_history_steps :] if states_full is not None else None ), dataset_ids=dataset_ids, ) return { "align_loss": 0.5 * hist_losses["align_loss"] + 0.5 * no_hist_losses["align_loss"], "recon_loss": 0.5 * hist_losses["recon_loss"] + 0.5 * no_hist_losses["recon_loss"], "predict_loss": 0.5 * hist_losses["predict_loss"] + 0.5 * no_hist_losses["predict_loss"], } # ------------------------------------------------------------------ # Inference # ------------------------------------------------------------------ @torch.inference_mode() def predict_action( self, examples: List[dict] = None, embodiment_tag: Optional[str] = None, use_history: bool = True, **kwargs, ) -> dict: """ Inference counterpart. When ``use_history=True`` and ``num_history_steps > 0``, uses ``mid_image`` together with naive MLP projections of history actions/states. Otherwise falls back to the no-history branch. """ from deployment.model_server.tools.image_tools import to_pil_preserve instructions = [ex["lang"] for ex in examples] dataset_ids = [ex.get("dataset_id", 0) for ex in examples] batch_images = [to_pil_preserve(ex["image"]) for ex in examples] extra_prefix_embeds = None if self.num_history_steps > 0 and use_history: if not all(("mid_image" in ex and "action" in ex) for ex in examples): raise ValueError( "num_history_steps > 0 requires `mid_image` and `action` in each " "example for history inference." ) batch_images = [to_pil_preserve(ex["mid_image"]) for ex in examples] proj_dtype = self.history_action_projector[0].weight.dtype history_actions_np = np.array( [ex["action"][: self.num_history_steps] for ex in examples] ) history_actions = torch.as_tensor( history_actions_np, device=self.query_token.device, dtype=proj_dtype, ) history_states = None if self.use_state and all("state" in ex for ex in examples): history_states_np = np.array( [ex["state"][: self.num_history_steps] for ex in examples] ) history_states = torch.as_tensor( history_states_np, device=self.query_token.device, dtype=proj_dtype, ) extra_prefix_embeds = self._encode_history_tokens( history_actions, history_states ) train_obs_image_size = getattr( self.config.datasets.vla_data, "image_size", None ) if train_obs_image_size: batch_images = resize_images(batch_images, target_size=train_obs_image_size) qwen_inputs = self._build_qwen_inputs( images=batch_images, instructions=instructions, dataset_ids=dataset_ids, extra_prefix_embeds=extra_prefix_embeds, ) predicted_action_embeddings = self._encode_vlm_action_embedding(qwen_inputs) with torch.autocast("cuda", dtype=torch.float32): pred_actions = self.action_model.decode_actions( predicted_action_embeddings, chunk_size=self.chunk_size, ) normalized_actions = pred_actions.detach().cpu().numpy() if embodiment_tag is not None: if embodiment_tag not in ACTION_REPRESENTATION_SLICES: raise ValueError( f"Unknown embodiment tag '{embodiment_tag}'. " f"Known tags: {sorted(ACTION_REPRESENTATION_SLICES.keys())}" ) target_slice = ACTION_REPRESENTATION_SLICES[embodiment_tag] normalized_actions = normalized_actions[..., target_slice] return {"normalized_actions": normalized_actions}