| |
| |
| |
| """ |
| QwenLatent History Naive Baseline |
| |
| Ablation / baseline variant of QwenLatent_history. Instead of using a |
| dedicated action encoder (Qwen3-based transformer) to compress history |
| action+state sequences into a compact latent embedding, this model projects |
| each history timestep directly into the VLM token space via two lightweight |
| MLP projectors: |
| |
| - history_action_projector : R^{action_size} -> R^{llm_hidden_size} |
| - history_state_projector : R^{state_size} -> R^{llm_hidden_size} |
| |
| The resulting per-step tokens are interleaved as |
| [a_0, s_0, a_1, s_1, ..., a_{T-1}, s_{T-1}] |
| and prepended to the VLM context (after the dataset soft-prompt, before the |
| visual/language tokens). |
| |
| This preserves the identical training objective, loss weights, and dual-branch |
| (no-history / with-history) structure as QwenLatent_history, so results are |
| directly comparable. The only difference is how history information is |
| encoded: here we use a flat MLP projection instead of the action encoder. |
| """ |
|
|
| import sys |
| sys.path.append("/mnt/data/fangyu/code/rewardmodel") |
|
|
| from typing import List, Optional |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from omegaconf import OmegaConf |
| from PIL import Image |
|
|
| from starVLA.training.trainer_utils import initialize_overwatch |
| from starVLA.model.framework.base_framework import baseframework |
| from starVLA.model.modules.vlm import get_vlm_model |
| from starVLA.model.modules.action_model.ActionModel_FM import ActionModelFM |
| from starVLA.model.modules.action_model.configuration_actionmodel import ActionModelConfig |
| from starVLA.dataloader.gr00t_lerobot.datasets import ACTION_REPRESENTATION_SLICES |
| from starVLA.training.trainer_utils.trainer_tools import resize_images |
| from starVLA.model.tools import FRAMEWORK_REGISTRY |
|
|
| logger = initialize_overwatch(__name__) |
|
|
| IGNORE_INDEX = -100 |
|
|
|
|
| @FRAMEWORK_REGISTRY.register("QwenLatent_history_naive") |
| class QwenLatentHistoryNaive(baseframework): |
| """ |
| Naive history baseline: project each history step's action and state |
| independently via MLP projectors and append the resulting token sequence |
| to the VLM context. |
| |
| Architecture overview |
| --------------------- |
| Input (with history):: |
| |
| [ds_embed | hist_action_0 | hist_state_0 | ... | |
| hist_action_{T-1} | hist_state_{T-1} | |
| VL_tokens | query_token] |
| |
| Compared to QwenLatent_history, the action-encoder is only used as the |
| flow-matching *decoder* here — its *encoder* path is bypassed for history |
| encoding. The action model itself is still used for: |
| - Computing GT action embeddings (for align loss) |
| - Decoding predicted embeddings to actions during inference |
| """ |
|
|
| |
| |
| |
| @staticmethod |
| def _get_last_nonpad_indices(attention_mask: torch.Tensor) -> torch.Tensor: |
| if attention_mask is None: |
| raise ValueError("attention_mask cannot be None") |
| if attention_mask.dim() != 2: |
| raise ValueError( |
| f"attention_mask must be 2D [B,T], got shape {tuple(attention_mask.shape)}" |
| ) |
| mask = attention_mask.to(dtype=torch.long) |
| rev_first_one = torch.flip(mask, dims=[1]).argmax(dim=1) |
| last_nonpad = mask.size(1) - 1 - rev_first_one |
| return last_nonpad |
|
|
| |
| |
| |
| def __init__(self, config: Optional[dict] = None, **kwargs) -> None: |
| super().__init__() |
| self.config = config |
| self.qwen_vl_interface = get_vlm_model(config=self.config) |
|
|
| num_vl_layers, llm_hidden_size = 36, self.qwen_vl_interface.model.config.hidden_size |
| self.llm_hidden_size = llm_hidden_size |
| self.config.framework.qwenvl.vl_hidden_dim = llm_hidden_size |
| self.config.framework.qwenvl.num_vl_layers = num_vl_layers |
|
|
| |
| action_model_cfg = getattr(self.config.framework, "action_model", None) |
| if action_model_cfg is not None: |
| action_model_kwargs = OmegaConf.to_container(action_model_cfg, resolve=True) |
| print(f"{action_model_kwargs=}") |
| self.action_model = ActionModelFM(ActionModelConfig(**action_model_kwargs)) |
| else: |
| self.action_model = ActionModelFM(ActionModelConfig()) |
|
|
| ckpt_path = getattr(self.config.framework.action_model, "ckpt_path", None) |
| if ckpt_path: |
| self.action_model.load_state_dict( |
| torch.load(ckpt_path, map_location="cpu"), strict=True |
| ) |
| print(f"✅ loaded action model from {ckpt_path}") |
| print(f"action model loss mode: {self.action_model.config.loss_mode}") |
|
|
| |
| self.dataset_vocab_size = getattr( |
| self.config.framework.action_model, "dataset_vocab_size", 256 |
| ) |
| self.num_data_tokens = getattr(self.config.framework.qwenvl, "num_data_tokens", 32) |
| self.dataset_embed = nn.Embedding( |
| self.dataset_vocab_size, |
| llm_hidden_size * self.num_data_tokens, |
| ) |
|
|
| |
| self.query_token = nn.Parameter(torch.randn(1, 1, llm_hidden_size)) |
|
|
| |
| action_hidden_size = self.action_model.config.hidden_size |
| self.action_embed_projector = nn.Sequential( |
| nn.Linear(llm_hidden_size, llm_hidden_size), |
| nn.GELU(), |
| nn.Linear(llm_hidden_size, action_hidden_size), |
| ) |
|
|
| |
| self.total_action_chunk_size = self.config.datasets.vla_data.chunk_size |
| self.num_history_steps = self.config.datasets.vla_data.num_history_steps |
| print(f"num_history_steps: {self.num_history_steps}") |
| self.chunk_size = self.total_action_chunk_size - self.num_history_steps |
| self.use_state = self.action_model.use_state |
|
|
| |
| |
| |
| |
| |
| action_size = self.action_model.config.action_size |
| state_size = self.action_model.config.state_size |
|
|
| self.history_action_projector = nn.Sequential( |
| nn.Linear(action_size, llm_hidden_size), |
| nn.GELU(), |
| nn.Linear(llm_hidden_size, llm_hidden_size), |
| ) |
|
|
| if self.use_state and state_size > 0: |
| self.history_state_projector = nn.Sequential( |
| nn.Linear(state_size, llm_hidden_size), |
| nn.GELU(), |
| nn.Linear(llm_hidden_size, llm_hidden_size), |
| ) |
| else: |
| self.history_state_projector = None |
|
|
| |
| |
| |
| def _maybe_log_align_stats( |
| self, |
| predicted_action_embeddings: torch.Tensor, |
| gt_action_embeddings: torch.Tensor, |
| ) -> None: |
| if getattr(self, "_align_stats_logged", False): |
| return |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| if torch.distributed.get_rank() != 0: |
| return |
| with torch.no_grad(): |
| pred = predicted_action_embeddings.float() |
| gt = gt_action_embeddings.float() |
| logger.info( |
| "Align stats: pred(mean=%.4f,std=%.4f,avg_norm=%.4f) " |
| "gt(mean=%.4f,std=%.4f,avg_norm=%.4f)", |
| pred.mean().item(), |
| pred.std().item(), |
| pred.norm(dim=-1).mean().item(), |
| gt.mean().item(), |
| gt.std().item(), |
| gt.norm(dim=-1).mean().item(), |
| ) |
| self._align_stats_logged = True |
|
|
| def _encode_history_tokens( |
| self, |
| history_actions: torch.Tensor, |
| history_states: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| """ |
| Project raw history actions (and optionally states) into VLM token space. |
| |
| Args: |
| history_actions : [B, T_hist, action_size] float32 |
| history_states : [B, T_hist, state_size] float32 or None |
| |
| Returns: |
| history_tokens : [B, T_hist * (1 or 2), llm_hidden_size] |
| Interleaved as [a_0, s_0, a_1, s_1, ...] when |
| state is available, otherwise [a_0, a_1, ...]. |
| """ |
| B, T, _ = history_actions.shape |
|
|
| |
| proj_dtype = self.history_action_projector[0].weight.dtype |
| act = history_actions.to(proj_dtype) |
|
|
| act_tokens = self.history_action_projector(act) |
|
|
| if self.history_state_projector is not None and history_states is not None: |
| sta = history_states.to(proj_dtype) |
| sta_tokens = self.history_state_projector(sta) |
| |
| |
| interleaved = torch.stack([act_tokens, sta_tokens], dim=2) |
| history_tokens = interleaved.view(B, T * 2, self.llm_hidden_size) |
| else: |
| history_tokens = act_tokens |
|
|
| return history_tokens |
|
|
| def _build_qwen_inputs( |
| self, |
| images: List, |
| instructions: List[str], |
| dataset_ids: List[int], |
| extra_prefix_embeds: Optional[torch.Tensor] = None, |
| ) -> dict: |
| qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs( |
| images=images, |
| instructions=instructions, |
| ) |
|
|
| if "input_ids" in qwen_inputs: |
| dataset_ids_tensor = torch.tensor( |
| dataset_ids, |
| device=qwen_inputs["input_ids"].device, |
| dtype=torch.long, |
| ) |
| ds_embeds = self.dataset_embed(dataset_ids_tensor).view( |
| len(dataset_ids), self.num_data_tokens, self.llm_hidden_size |
| ) |
| token_embeds = self.qwen_vl_interface.model.get_input_embeddings()( |
| qwen_inputs["input_ids"] |
| ) |
| query_embeds = self.query_token.expand(len(dataset_ids), -1, -1) |
|
|
| embed_parts = [ds_embeds, token_embeds] |
| if extra_prefix_embeds is not None: |
| embed_parts.append(extra_prefix_embeds) |
| embed_parts.append(query_embeds) |
| qwen_inputs["inputs_embeds"] = torch.cat(embed_parts, dim=1) |
| qwen_inputs.pop("input_ids") |
|
|
| if "attention_mask" in qwen_inputs: |
| prefix_mask = torch.ones( |
| (qwen_inputs["attention_mask"].shape[0], self.num_data_tokens), |
| device=qwen_inputs["attention_mask"].device, |
| dtype=qwen_inputs["attention_mask"].dtype, |
| ) |
| mask_parts = [prefix_mask, qwen_inputs["attention_mask"]] |
| if extra_prefix_embeds is not None: |
| history_mask = torch.ones( |
| ( |
| qwen_inputs["attention_mask"].shape[0], |
| extra_prefix_embeds.shape[1], |
| ), |
| device=qwen_inputs["attention_mask"].device, |
| dtype=qwen_inputs["attention_mask"].dtype, |
| ) |
| mask_parts.append(history_mask) |
| query_mask = torch.ones( |
| (qwen_inputs["attention_mask"].shape[0], 1), |
| device=qwen_inputs["attention_mask"].device, |
| dtype=qwen_inputs["attention_mask"].dtype, |
| ) |
| mask_parts.append(query_mask) |
| qwen_inputs["attention_mask"] = torch.cat(mask_parts, dim=1) |
|
|
| if "position_ids" in qwen_inputs: |
| extra_prefix_len = ( |
| 0 if extra_prefix_embeds is None else extra_prefix_embeds.shape[1] |
| ) |
| prefix_total_len = self.num_data_tokens + extra_prefix_len |
| prefix_pos = ( |
| torch.arange( |
| prefix_total_len, |
| device=qwen_inputs["position_ids"].device, |
| dtype=qwen_inputs["position_ids"].dtype, |
| ) |
| .unsqueeze(0) |
| .expand(qwen_inputs["position_ids"].shape[0], -1) |
| ) |
| query_pos = torch.full( |
| (qwen_inputs["position_ids"].shape[0], 1), |
| qwen_inputs["position_ids"].shape[1] + prefix_total_len, |
| device=qwen_inputs["position_ids"].device, |
| dtype=qwen_inputs["position_ids"].dtype, |
| ) |
| qwen_inputs["position_ids"] = torch.cat( |
| ( |
| prefix_pos, |
| qwen_inputs["position_ids"] + prefix_total_len, |
| query_pos, |
| ), |
| dim=1, |
| ) |
| return qwen_inputs |
|
|
| def _encode_vlm_action_embedding(self, qwen_inputs: dict) -> torch.Tensor: |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| qwenvl_outputs = self.qwen_vl_interface( |
| **qwen_inputs, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| last_hidden_states = qwenvl_outputs.hidden_states[-1] |
|
|
| if "attention_mask" in qwen_inputs: |
| last_token_indices = self._get_last_nonpad_indices( |
| qwen_inputs["attention_mask"] |
| ) |
| batch_indices = torch.arange( |
| last_hidden_states.shape[0], device=last_hidden_states.device |
| ) |
| action_token_hidden = last_hidden_states[batch_indices, last_token_indices] |
| else: |
| action_token_hidden = last_hidden_states[:, -1, :] |
|
|
| predicted_action_embeddings = self.action_embed_projector( |
| action_token_hidden |
| ).float() |
| predicted_action_embeddings = F.normalize( |
| predicted_action_embeddings, p=2, dim=-1 |
| ) |
| return predicted_action_embeddings |
|
|
| def _compute_branch_losses( |
| self, |
| predicted_action_embeddings: torch.Tensor, |
| actions_target: torch.Tensor, |
| states_target: Optional[torch.Tensor], |
| dataset_ids: List[int], |
| ) -> dict: |
| loss_mode = getattr(self.action_model.config, "loss_mode", "full") |
| with torch.autocast("cuda", dtype=torch.float32): |
| B = actions_target.shape[0] |
| t = self.action_model._sample_fm_time( |
| B, device=actions_target.device, dtype=actions_target.dtype |
| ) |
| noise = torch.randn_like(actions_target) |
|
|
| if loss_mode == "predict_only": |
| align_loss = None |
| recon_loss = None |
| predict_loss = self.action_model.recon_loss_from_embedding( |
| actions=actions_target, |
| action_embedding=predicted_action_embeddings, |
| t=t, |
| noise=noise, |
| ) |
| else: |
| gt_action_embeddings = self.action_model.encode_actions( |
| actions=actions_target, |
| dataset_ids=dataset_ids, |
| state=states_target, |
| ) |
| self._maybe_log_align_stats(predicted_action_embeddings, gt_action_embeddings) |
|
|
| align_loss = F.l1_loss( |
| predicted_action_embeddings, |
| gt_action_embeddings.float().detach(), |
| ) |
| recon_loss = self.action_model.recon_loss_from_embedding( |
| actions=actions_target, |
| action_embedding=gt_action_embeddings, |
| t=t, |
| noise=noise, |
| ) |
| predict_loss = self.action_model.recon_loss_from_embedding( |
| actions=actions_target, |
| action_embedding=predicted_action_embeddings, |
| t=t, |
| noise=noise, |
| ) |
| return { |
| "align_loss": align_loss, |
| "recon_loss": recon_loss, |
| "predict_loss": predict_loss, |
| } |
|
|
| |
| |
| |
| def forward(self, examples: List[dict] = None, **kwargs): |
| """ |
| Dual-branch forward (mirrors QwenLatent_history exactly): |
| |
| Branch 1 — no history: |
| Image at step 0, predict actions[0 : chunk_size]. |
| |
| Branch 2 — with history (only when num_history_steps > 0): |
| mid_image (image at step num_history_steps), with naive MLP |
| projection of history actions/states prepended to VLM context. |
| Predict actions[num_history_steps : total_chunk_size]. |
| |
| Returns combined losses (average of both branches). |
| """ |
| batch_images = [ex["image"] for ex in examples] |
| instructions = [ex["lang"] for ex in examples] |
| actions = [ex["action"] for ex in examples] |
| states = [ex["state"] for ex in examples] if self.use_state else None |
| dataset_ids = [ex.get("dataset_id", 0) for ex in examples] |
|
|
| device = self.query_token.device |
| actions_full = torch.as_tensor( |
| np.array(actions), device=device, dtype=torch.float32 |
| ) |
| assert actions_full.shape[1] == self.total_action_chunk_size |
|
|
| states_full = None |
| if self.use_state: |
| states_full = torch.as_tensor( |
| np.array(states), device=device, dtype=torch.float32 |
| ) |
| assert states_full.shape[1] == self.total_action_chunk_size |
|
|
| |
| no_hist_qwen_inputs = self._build_qwen_inputs( |
| images=batch_images, |
| instructions=instructions, |
| dataset_ids=dataset_ids, |
| extra_prefix_embeds=None, |
| ) |
| no_hist_pred_emb = self._encode_vlm_action_embedding(no_hist_qwen_inputs) |
| no_hist_losses = self._compute_branch_losses( |
| predicted_action_embeddings=no_hist_pred_emb, |
| actions_target=actions_full[:, : self.chunk_size], |
| states_target=( |
| states_full[:, : self.chunk_size] if states_full is not None else None |
| ), |
| dataset_ids=dataset_ids, |
| ) |
|
|
| if self.num_history_steps <= 0: |
| return no_hist_losses |
|
|
| |
| if not all("mid_image" in ex for ex in examples): |
| raise ValueError("num_history_steps > 0 but `mid_image` is missing in examples.") |
| mid_images = [ex["mid_image"] for ex in examples] |
|
|
| history_actions = actions_full[:, : self.num_history_steps] |
| history_states = ( |
| states_full[:, : self.num_history_steps] |
| if states_full is not None |
| else None |
| ) |
|
|
| |
| history_tokens = self._encode_history_tokens(history_actions, history_states) |
|
|
| hist_qwen_inputs = self._build_qwen_inputs( |
| images=mid_images, |
| instructions=instructions, |
| dataset_ids=dataset_ids, |
| extra_prefix_embeds=history_tokens, |
| ) |
| hist_pred_emb = self._encode_vlm_action_embedding(hist_qwen_inputs) |
| hist_losses = self._compute_branch_losses( |
| predicted_action_embeddings=hist_pred_emb, |
| actions_target=actions_full[:, self.num_history_steps :], |
| states_target=( |
| states_full[:, self.num_history_steps :] |
| if states_full is not None |
| else None |
| ), |
| dataset_ids=dataset_ids, |
| ) |
|
|
| return { |
| "align_loss": 0.5 * hist_losses["align_loss"] |
| + 0.5 * no_hist_losses["align_loss"], |
| "recon_loss": 0.5 * hist_losses["recon_loss"] |
| + 0.5 * no_hist_losses["recon_loss"], |
| "predict_loss": 0.5 * hist_losses["predict_loss"] |
| + 0.5 * no_hist_losses["predict_loss"], |
| } |
|
|
| |
| |
| |
| @torch.inference_mode() |
| def predict_action( |
| self, |
| examples: List[dict] = None, |
| embodiment_tag: Optional[str] = None, |
| use_history: bool = True, |
| **kwargs, |
| ) -> dict: |
| """ |
| Inference counterpart. |
| |
| When ``use_history=True`` and ``num_history_steps > 0``, uses |
| ``mid_image`` together with naive MLP projections of history |
| actions/states. Otherwise falls back to the no-history branch. |
| """ |
| from deployment.model_server.tools.image_tools import to_pil_preserve |
|
|
| instructions = [ex["lang"] for ex in examples] |
| dataset_ids = [ex.get("dataset_id", 0) for ex in examples] |
|
|
| batch_images = [to_pil_preserve(ex["image"]) for ex in examples] |
| extra_prefix_embeds = None |
|
|
| if self.num_history_steps > 0 and use_history: |
| if not all(("mid_image" in ex and "action" in ex) for ex in examples): |
| raise ValueError( |
| "num_history_steps > 0 requires `mid_image` and `action` in each " |
| "example for history inference." |
| ) |
| batch_images = [to_pil_preserve(ex["mid_image"]) for ex in examples] |
|
|
| proj_dtype = self.history_action_projector[0].weight.dtype |
| history_actions_np = np.array( |
| [ex["action"][: self.num_history_steps] for ex in examples] |
| ) |
| history_actions = torch.as_tensor( |
| history_actions_np, |
| device=self.query_token.device, |
| dtype=proj_dtype, |
| ) |
|
|
| history_states = None |
| if self.use_state and all("state" in ex for ex in examples): |
| history_states_np = np.array( |
| [ex["state"][: self.num_history_steps] for ex in examples] |
| ) |
| history_states = torch.as_tensor( |
| history_states_np, |
| device=self.query_token.device, |
| dtype=proj_dtype, |
| ) |
|
|
| extra_prefix_embeds = self._encode_history_tokens( |
| history_actions, history_states |
| ) |
|
|
| train_obs_image_size = getattr( |
| self.config.datasets.vla_data, "image_size", None |
| ) |
| if train_obs_image_size: |
| batch_images = resize_images(batch_images, target_size=train_obs_image_size) |
|
|
| qwen_inputs = self._build_qwen_inputs( |
| images=batch_images, |
| instructions=instructions, |
| dataset_ids=dataset_ids, |
| extra_prefix_embeds=extra_prefix_embeds, |
| ) |
| predicted_action_embeddings = self._encode_vlm_action_embedding(qwen_inputs) |
|
|
| with torch.autocast("cuda", dtype=torch.float32): |
| pred_actions = self.action_model.decode_actions( |
| predicted_action_embeddings, |
| chunk_size=self.chunk_size, |
| ) |
|
|
| normalized_actions = pred_actions.detach().cpu().numpy() |
|
|
| if embodiment_tag is not None: |
| if embodiment_tag not in ACTION_REPRESENTATION_SLICES: |
| raise ValueError( |
| f"Unknown embodiment tag '{embodiment_tag}'. " |
| f"Known tags: {sorted(ACTION_REPRESENTATION_SLICES.keys())}" |
| ) |
| target_slice = ACTION_REPRESENTATION_SLICES[embodiment_tag] |
| normalized_actions = normalized_actions[..., target_slice] |
|
|
| return {"normalized_actions": normalized_actions} |
|
|