cross13tasks / code /model /framework /QwenLatent_history_naive.py
Timsty's picture
Upload folder using huggingface_hub
e94400c verified
# Copyright 2025 starVLA community. All rights reserved.
# Licensed under the MIT License, Version 1.0 (the "License");
# Implemented by Jinhui YE / HKUST University in [2025].
"""
QwenLatent History Naive Baseline
Ablation / baseline variant of QwenLatent_history. Instead of using a
dedicated action encoder (Qwen3-based transformer) to compress history
action+state sequences into a compact latent embedding, this model projects
each history timestep directly into the VLM token space via two lightweight
MLP projectors:
- history_action_projector : R^{action_size} -> R^{llm_hidden_size}
- history_state_projector : R^{state_size} -> R^{llm_hidden_size}
The resulting per-step tokens are interleaved as
[a_0, s_0, a_1, s_1, ..., a_{T-1}, s_{T-1}]
and prepended to the VLM context (after the dataset soft-prompt, before the
visual/language tokens).
This preserves the identical training objective, loss weights, and dual-branch
(no-history / with-history) structure as QwenLatent_history, so results are
directly comparable. The only difference is how history information is
encoded: here we use a flat MLP projection instead of the action encoder.
"""
import sys
sys.path.append("/mnt/data/fangyu/code/rewardmodel")
from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from PIL import Image
from starVLA.training.trainer_utils import initialize_overwatch
from starVLA.model.framework.base_framework import baseframework
from starVLA.model.modules.vlm import get_vlm_model
from starVLA.model.modules.action_model.ActionModel_FM import ActionModelFM
from starVLA.model.modules.action_model.configuration_actionmodel import ActionModelConfig
from starVLA.dataloader.gr00t_lerobot.datasets import ACTION_REPRESENTATION_SLICES
from starVLA.training.trainer_utils.trainer_tools import resize_images
from starVLA.model.tools import FRAMEWORK_REGISTRY
logger = initialize_overwatch(__name__)
IGNORE_INDEX = -100
@FRAMEWORK_REGISTRY.register("QwenLatent_history_naive")
class QwenLatentHistoryNaive(baseframework):
"""
Naive history baseline: project each history step's action and state
independently via MLP projectors and append the resulting token sequence
to the VLM context.
Architecture overview
---------------------
Input (with history)::
[ds_embed | hist_action_0 | hist_state_0 | ... |
hist_action_{T-1} | hist_state_{T-1} |
VL_tokens | query_token]
Compared to QwenLatent_history, the action-encoder is only used as the
flow-matching *decoder* here — its *encoder* path is bypassed for history
encoding. The action model itself is still used for:
- Computing GT action embeddings (for align loss)
- Decoding predicted embeddings to actions during inference
"""
# ------------------------------------------------------------------
# Helper: last non-pad token index
# ------------------------------------------------------------------
@staticmethod
def _get_last_nonpad_indices(attention_mask: torch.Tensor) -> torch.Tensor:
if attention_mask is None:
raise ValueError("attention_mask cannot be None")
if attention_mask.dim() != 2:
raise ValueError(
f"attention_mask must be 2D [B,T], got shape {tuple(attention_mask.shape)}"
)
mask = attention_mask.to(dtype=torch.long)
rev_first_one = torch.flip(mask, dims=[1]).argmax(dim=1)
last_nonpad = mask.size(1) - 1 - rev_first_one
return last_nonpad
# ------------------------------------------------------------------
# Construction
# ------------------------------------------------------------------
def __init__(self, config: Optional[dict] = None, **kwargs) -> None:
super().__init__()
self.config = config
self.qwen_vl_interface = get_vlm_model(config=self.config)
num_vl_layers, llm_hidden_size = 36, self.qwen_vl_interface.model.config.hidden_size
self.llm_hidden_size = llm_hidden_size
self.config.framework.qwenvl.vl_hidden_dim = llm_hidden_size
self.config.framework.qwenvl.num_vl_layers = num_vl_layers
# Action model (used only as flow-matching decoder + GT encoder for loss)
action_model_cfg = getattr(self.config.framework, "action_model", None)
if action_model_cfg is not None:
action_model_kwargs = OmegaConf.to_container(action_model_cfg, resolve=True)
print(f"{action_model_kwargs=}")
self.action_model = ActionModelFM(ActionModelConfig(**action_model_kwargs))
else:
self.action_model = ActionModelFM(ActionModelConfig())
ckpt_path = getattr(self.config.framework.action_model, "ckpt_path", None)
if ckpt_path:
self.action_model.load_state_dict(
torch.load(ckpt_path, map_location="cpu"), strict=True
)
print(f"✅ loaded action model from {ckpt_path}")
print(f"action model loss mode: {self.action_model.config.loss_mode}")
# Dataset soft-prompt embedding
self.dataset_vocab_size = getattr(
self.config.framework.action_model, "dataset_vocab_size", 256
)
self.num_data_tokens = getattr(self.config.framework.qwenvl, "num_data_tokens", 32)
self.dataset_embed = nn.Embedding(
self.dataset_vocab_size,
llm_hidden_size * self.num_data_tokens,
)
# Learnable query token (VLM output token used for action prediction)
self.query_token = nn.Parameter(torch.randn(1, 1, llm_hidden_size))
# VLM → action-space projector (query token hidden → action embedding)
action_hidden_size = self.action_model.config.hidden_size
self.action_embed_projector = nn.Sequential(
nn.Linear(llm_hidden_size, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, action_hidden_size),
)
# Chunk / history book-keeping
self.total_action_chunk_size = self.config.datasets.vla_data.chunk_size
self.num_history_steps = self.config.datasets.vla_data.num_history_steps
print(f"num_history_steps: {self.num_history_steps}")
self.chunk_size = self.total_action_chunk_size - self.num_history_steps
self.use_state = self.action_model.use_state
# ------------------------------------------------------------------
# Naive history projectors
# Each history timestep's raw action / state is projected to a single
# VLM-dimension token via a two-layer MLP.
# ------------------------------------------------------------------
action_size = self.action_model.config.action_size
state_size = self.action_model.config.state_size
self.history_action_projector = nn.Sequential(
nn.Linear(action_size, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size),
)
if self.use_state and state_size > 0:
self.history_state_projector = nn.Sequential(
nn.Linear(state_size, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size),
)
else:
self.history_state_projector = None
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _maybe_log_align_stats(
self,
predicted_action_embeddings: torch.Tensor,
gt_action_embeddings: torch.Tensor,
) -> None:
if getattr(self, "_align_stats_logged", False):
return
if torch.distributed.is_available() and torch.distributed.is_initialized():
if torch.distributed.get_rank() != 0:
return
with torch.no_grad():
pred = predicted_action_embeddings.float()
gt = gt_action_embeddings.float()
logger.info(
"Align stats: pred(mean=%.4f,std=%.4f,avg_norm=%.4f) "
"gt(mean=%.4f,std=%.4f,avg_norm=%.4f)",
pred.mean().item(),
pred.std().item(),
pred.norm(dim=-1).mean().item(),
gt.mean().item(),
gt.std().item(),
gt.norm(dim=-1).mean().item(),
)
self._align_stats_logged = True
def _encode_history_tokens(
self,
history_actions: torch.Tensor,
history_states: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Project raw history actions (and optionally states) into VLM token space.
Args:
history_actions : [B, T_hist, action_size] float32
history_states : [B, T_hist, state_size] float32 or None
Returns:
history_tokens : [B, T_hist * (1 or 2), llm_hidden_size]
Interleaved as [a_0, s_0, a_1, s_1, ...] when
state is available, otherwise [a_0, a_1, ...].
"""
B, T, _ = history_actions.shape
# Cast to model dtype for the projectors
proj_dtype = self.history_action_projector[0].weight.dtype
act = history_actions.to(proj_dtype)
act_tokens = self.history_action_projector(act) # [B, T, llm_hidden_size]
if self.history_state_projector is not None and history_states is not None:
sta = history_states.to(proj_dtype)
sta_tokens = self.history_state_projector(sta) # [B, T, llm_hidden_size]
# Interleave: [a_0, s_0, a_1, s_1, ...]
# Stack along a new dim then reshape: [B, T, 2, H] → [B, 2T, H]
interleaved = torch.stack([act_tokens, sta_tokens], dim=2) # [B, T, 2, H]
history_tokens = interleaved.view(B, T * 2, self.llm_hidden_size)
else:
history_tokens = act_tokens # [B, T, llm_hidden_size]
return history_tokens
def _build_qwen_inputs(
self,
images: List,
instructions: List[str],
dataset_ids: List[int],
extra_prefix_embeds: Optional[torch.Tensor] = None,
) -> dict:
qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(
images=images,
instructions=instructions,
)
if "input_ids" in qwen_inputs:
dataset_ids_tensor = torch.tensor(
dataset_ids,
device=qwen_inputs["input_ids"].device,
dtype=torch.long,
)
ds_embeds = self.dataset_embed(dataset_ids_tensor).view(
len(dataset_ids), self.num_data_tokens, self.llm_hidden_size
)
token_embeds = self.qwen_vl_interface.model.get_input_embeddings()(
qwen_inputs["input_ids"]
)
query_embeds = self.query_token.expand(len(dataset_ids), -1, -1)
embed_parts = [ds_embeds, token_embeds]
if extra_prefix_embeds is not None:
embed_parts.append(extra_prefix_embeds)
embed_parts.append(query_embeds)
qwen_inputs["inputs_embeds"] = torch.cat(embed_parts, dim=1)
qwen_inputs.pop("input_ids")
if "attention_mask" in qwen_inputs:
prefix_mask = torch.ones(
(qwen_inputs["attention_mask"].shape[0], self.num_data_tokens),
device=qwen_inputs["attention_mask"].device,
dtype=qwen_inputs["attention_mask"].dtype,
)
mask_parts = [prefix_mask, qwen_inputs["attention_mask"]]
if extra_prefix_embeds is not None:
history_mask = torch.ones(
(
qwen_inputs["attention_mask"].shape[0],
extra_prefix_embeds.shape[1],
),
device=qwen_inputs["attention_mask"].device,
dtype=qwen_inputs["attention_mask"].dtype,
)
mask_parts.append(history_mask)
query_mask = torch.ones(
(qwen_inputs["attention_mask"].shape[0], 1),
device=qwen_inputs["attention_mask"].device,
dtype=qwen_inputs["attention_mask"].dtype,
)
mask_parts.append(query_mask)
qwen_inputs["attention_mask"] = torch.cat(mask_parts, dim=1)
if "position_ids" in qwen_inputs:
extra_prefix_len = (
0 if extra_prefix_embeds is None else extra_prefix_embeds.shape[1]
)
prefix_total_len = self.num_data_tokens + extra_prefix_len
prefix_pos = (
torch.arange(
prefix_total_len,
device=qwen_inputs["position_ids"].device,
dtype=qwen_inputs["position_ids"].dtype,
)
.unsqueeze(0)
.expand(qwen_inputs["position_ids"].shape[0], -1)
)
query_pos = torch.full(
(qwen_inputs["position_ids"].shape[0], 1),
qwen_inputs["position_ids"].shape[1] + prefix_total_len,
device=qwen_inputs["position_ids"].device,
dtype=qwen_inputs["position_ids"].dtype,
)
qwen_inputs["position_ids"] = torch.cat(
(
prefix_pos,
qwen_inputs["position_ids"] + prefix_total_len,
query_pos,
),
dim=1,
)
return qwen_inputs
def _encode_vlm_action_embedding(self, qwen_inputs: dict) -> torch.Tensor:
with torch.autocast("cuda", dtype=torch.bfloat16):
qwenvl_outputs = self.qwen_vl_interface(
**qwen_inputs,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
last_hidden_states = qwenvl_outputs.hidden_states[-1]
if "attention_mask" in qwen_inputs:
last_token_indices = self._get_last_nonpad_indices(
qwen_inputs["attention_mask"]
)
batch_indices = torch.arange(
last_hidden_states.shape[0], device=last_hidden_states.device
)
action_token_hidden = last_hidden_states[batch_indices, last_token_indices]
else:
action_token_hidden = last_hidden_states[:, -1, :]
predicted_action_embeddings = self.action_embed_projector(
action_token_hidden
).float()
predicted_action_embeddings = F.normalize(
predicted_action_embeddings, p=2, dim=-1
)
return predicted_action_embeddings
def _compute_branch_losses(
self,
predicted_action_embeddings: torch.Tensor,
actions_target: torch.Tensor,
states_target: Optional[torch.Tensor],
dataset_ids: List[int],
) -> dict:
loss_mode = getattr(self.action_model.config, "loss_mode", "full")
with torch.autocast("cuda", dtype=torch.float32):
B = actions_target.shape[0]
t = self.action_model._sample_fm_time(
B, device=actions_target.device, dtype=actions_target.dtype
)
noise = torch.randn_like(actions_target)
if loss_mode == "predict_only":
align_loss = None
recon_loss = None
predict_loss = self.action_model.recon_loss_from_embedding(
actions=actions_target,
action_embedding=predicted_action_embeddings,
t=t,
noise=noise,
)
else:
gt_action_embeddings = self.action_model.encode_actions(
actions=actions_target,
dataset_ids=dataset_ids,
state=states_target,
)
self._maybe_log_align_stats(predicted_action_embeddings, gt_action_embeddings)
align_loss = F.l1_loss(
predicted_action_embeddings,
gt_action_embeddings.float().detach(),
)
recon_loss = self.action_model.recon_loss_from_embedding(
actions=actions_target,
action_embedding=gt_action_embeddings,
t=t,
noise=noise,
)
predict_loss = self.action_model.recon_loss_from_embedding(
actions=actions_target,
action_embedding=predicted_action_embeddings,
t=t,
noise=noise,
)
return {
"align_loss": align_loss,
"recon_loss": recon_loss,
"predict_loss": predict_loss,
}
# ------------------------------------------------------------------
# Forward (training)
# ------------------------------------------------------------------
def forward(self, examples: List[dict] = None, **kwargs):
"""
Dual-branch forward (mirrors QwenLatent_history exactly):
Branch 1 — no history:
Image at step 0, predict actions[0 : chunk_size].
Branch 2 — with history (only when num_history_steps > 0):
mid_image (image at step num_history_steps), with naive MLP
projection of history actions/states prepended to VLM context.
Predict actions[num_history_steps : total_chunk_size].
Returns combined losses (average of both branches).
"""
batch_images = [ex["image"] for ex in examples]
instructions = [ex["lang"] for ex in examples]
actions = [ex["action"] for ex in examples]
states = [ex["state"] for ex in examples] if self.use_state else None
dataset_ids = [ex.get("dataset_id", 0) for ex in examples]
device = self.query_token.device
actions_full = torch.as_tensor(
np.array(actions), device=device, dtype=torch.float32
)
assert actions_full.shape[1] == self.total_action_chunk_size
states_full = None
if self.use_state:
states_full = torch.as_tensor(
np.array(states), device=device, dtype=torch.float32
)
assert states_full.shape[1] == self.total_action_chunk_size
# ---------- Branch 1: no history ----------
no_hist_qwen_inputs = self._build_qwen_inputs(
images=batch_images,
instructions=instructions,
dataset_ids=dataset_ids,
extra_prefix_embeds=None,
)
no_hist_pred_emb = self._encode_vlm_action_embedding(no_hist_qwen_inputs)
no_hist_losses = self._compute_branch_losses(
predicted_action_embeddings=no_hist_pred_emb,
actions_target=actions_full[:, : self.chunk_size],
states_target=(
states_full[:, : self.chunk_size] if states_full is not None else None
),
dataset_ids=dataset_ids,
)
if self.num_history_steps <= 0:
return no_hist_losses
# ---------- Branch 2: naive history ----------
if not all("mid_image" in ex for ex in examples):
raise ValueError("num_history_steps > 0 but `mid_image` is missing in examples.")
mid_images = [ex["mid_image"] for ex in examples]
history_actions = actions_full[:, : self.num_history_steps]
history_states = (
states_full[:, : self.num_history_steps]
if states_full is not None
else None
)
# Project raw history tokens via naive MLP (the key difference)
history_tokens = self._encode_history_tokens(history_actions, history_states)
hist_qwen_inputs = self._build_qwen_inputs(
images=mid_images,
instructions=instructions,
dataset_ids=dataset_ids,
extra_prefix_embeds=history_tokens,
)
hist_pred_emb = self._encode_vlm_action_embedding(hist_qwen_inputs)
hist_losses = self._compute_branch_losses(
predicted_action_embeddings=hist_pred_emb,
actions_target=actions_full[:, self.num_history_steps :],
states_target=(
states_full[:, self.num_history_steps :]
if states_full is not None
else None
),
dataset_ids=dataset_ids,
)
return {
"align_loss": 0.5 * hist_losses["align_loss"]
+ 0.5 * no_hist_losses["align_loss"],
"recon_loss": 0.5 * hist_losses["recon_loss"]
+ 0.5 * no_hist_losses["recon_loss"],
"predict_loss": 0.5 * hist_losses["predict_loss"]
+ 0.5 * no_hist_losses["predict_loss"],
}
# ------------------------------------------------------------------
# Inference
# ------------------------------------------------------------------
@torch.inference_mode()
def predict_action(
self,
examples: List[dict] = None,
embodiment_tag: Optional[str] = None,
use_history: bool = True,
**kwargs,
) -> dict:
"""
Inference counterpart.
When ``use_history=True`` and ``num_history_steps > 0``, uses
``mid_image`` together with naive MLP projections of history
actions/states. Otherwise falls back to the no-history branch.
"""
from deployment.model_server.tools.image_tools import to_pil_preserve
instructions = [ex["lang"] for ex in examples]
dataset_ids = [ex.get("dataset_id", 0) for ex in examples]
batch_images = [to_pil_preserve(ex["image"]) for ex in examples]
extra_prefix_embeds = None
if self.num_history_steps > 0 and use_history:
if not all(("mid_image" in ex and "action" in ex) for ex in examples):
raise ValueError(
"num_history_steps > 0 requires `mid_image` and `action` in each "
"example for history inference."
)
batch_images = [to_pil_preserve(ex["mid_image"]) for ex in examples]
proj_dtype = self.history_action_projector[0].weight.dtype
history_actions_np = np.array(
[ex["action"][: self.num_history_steps] for ex in examples]
)
history_actions = torch.as_tensor(
history_actions_np,
device=self.query_token.device,
dtype=proj_dtype,
)
history_states = None
if self.use_state and all("state" in ex for ex in examples):
history_states_np = np.array(
[ex["state"][: self.num_history_steps] for ex in examples]
)
history_states = torch.as_tensor(
history_states_np,
device=self.query_token.device,
dtype=proj_dtype,
)
extra_prefix_embeds = self._encode_history_tokens(
history_actions, history_states
)
train_obs_image_size = getattr(
self.config.datasets.vla_data, "image_size", None
)
if train_obs_image_size:
batch_images = resize_images(batch_images, target_size=train_obs_image_size)
qwen_inputs = self._build_qwen_inputs(
images=batch_images,
instructions=instructions,
dataset_ids=dataset_ids,
extra_prefix_embeds=extra_prefix_embeds,
)
predicted_action_embeddings = self._encode_vlm_action_embedding(qwen_inputs)
with torch.autocast("cuda", dtype=torch.float32):
pred_actions = self.action_model.decode_actions(
predicted_action_embeddings,
chunk_size=self.chunk_size,
)
normalized_actions = pred_actions.detach().cpu().numpy()
if embodiment_tag is not None:
if embodiment_tag not in ACTION_REPRESENTATION_SLICES:
raise ValueError(
f"Unknown embodiment tag '{embodiment_tag}'. "
f"Known tags: {sorted(ACTION_REPRESENTATION_SLICES.keys())}"
)
target_slice = ACTION_REPRESENTATION_SLICES[embodiment_tag]
normalized_actions = normalized_actions[..., target_slice]
return {"normalized_actions": normalized_actions}