# Copyright 2025 starVLA community. All rights reserved. # Licensed under the MIT License, Version 1.0 (the "License"); # Implemented by [Junqiu YU / Fudan University] in [2025]. # Design and Merged by [Jinhui YE / HKUST University] in [2025]. """ Qwen-GR00T Framework Qwen-VL + Flow-matching head to directly predict continuous actions LangForceV5: (1) Assert language span consistency between prior/post branches (token-level exact match) (2) Hard-token LLR + Shortcut gate (3) Optional detach of prior condition to avoid pushing backbone to vision-only shortcut """ import sys from pathlib import Path # Add workspace root to Python path if not already there _workspace_root = Path(__file__).parent.parent.parent.parent if str(_workspace_root) not in sys.path: sys.path.insert(0, str(_workspace_root)) from typing import List, Optional, Tuple, Set from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from PIL import Image from starVLA.training.trainer_utils import initialize_overwatch from deployment.model_server.tools.image_tools import to_pil_preserve logger = initialize_overwatch(__name__) # HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) IGNORE_INDEX = -100 # ===== Qwen special tokens (you confirmed) ===== VISION_START_TOKEN_INDEX = 151652 # <|vision_start|> VISION_END_TOKEN_INDEX = 151654 # <|vision_end|> IMAGE_TOKEN_INDEX = 151655 # <|image_pad|> VIDEO_TOKEN_INDEX = 151656 # <|video_pad|> IM_START_TOKEN_INDEX = 151644 # <|im_start|> IM_END_TOKEN_INDEX = 151645 # <|im_end|> from starVLA.model.framework.base_framework import baseframework from starVLA.model.modules.vlm import get_vlm_model from starVLA.model.modules.action_model.GR00T_ActionHeader import get_action_model, FlowmatchingActionHead from starVLA.training.trainer_utils.trainer_tools import resize_images from starVLA.model.tools import FRAMEWORK_REGISTRY @FRAMEWORK_REGISTRY.register("LangForce") class LangForce(baseframework): """ LangForce: Bayesian Decomposition of Vision Language Action Models via Latent Action Queries (arxiv 2601.15197) Dual-branch VLA with: - Prior branch: (V + A + L) => proposal-like p(a|v) head - Posterior branch: (V + L + A) => pi(a|v,l) - LLR regularizer: maximize log p(L|V,A_prior) - sg(log p(L|V)) with: * Hard-token LLR (top-k hardest tokens under post) * Shortcut gate (down-weight LLR when log p(L|V) is already very low) - Optional detach prior cond (protect backbone from vision-only drift) Additionally: - Training-time assertion: extracted language spans in prior/post must match exactly (token-level). If mismatch => raise AssertionError with decoded spans. """ def __init__( self, config: Optional[dict] = None, **kwargs, ) -> None: super().__init__() self.config = config self.qwen_vl_interface = get_vlm_model(config=self.config) # align dims --> should go into config ideally self.config.framework.action_model.diffusion_model_cfg.cross_attention_dim = ( self.qwen_vl_interface.model.config.hidden_size ) self.num_latent_action_query = self.config.framework.qwenvl.get("num_latent_action_query", 32) self.latent_action_query = "".join([f"<|action_{i}|>" for i in range(self.num_latent_action_query)]) self.action_token_ids = None # cached {'first','last'} self.action_model: FlowmatchingActionHead = get_action_model(config=self.config) self.future_action_window_size = config.framework.action_model.future_action_window_size self.past_action_window_size = config.framework.action_model.past_action_window_size self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size # ===== Loss weights ===== self.kl_weight = float(self.config.framework.get("kl_weight", 0.1)) # maximize LLR via -kl_weight * kl_loss self.prior_loss_weight = float(self.config.framework.get("prior_loss_weight", 0.3)) # ===== (0) training assert switch ===== self.assert_lang_span_match = bool(self.config.framework.get("assert_lang_span_match", True)) # ===== (1) detach prior cond switch ===== self.detach_prior_cond = bool(self.config.framework.get("detach_prior_cond", True)) # ===== (2) Hard-token LLR ===== self.use_hard_token_llr = bool(self.config.framework.get("use_hard_token_llr", True)) self.hard_token_k = int(self.config.framework.get("hard_token_k", 16)) assert self.hard_token_k > 0 # ===== (3) Shortcut gate ===== # gate computed from posterior language-span NLL: high NLL => log p(L|V) low => gate small self.use_kl_gate = bool(self.config.framework.get("use_kl_gate", True)) self.kl_gate_momentum = float(self.config.framework.get("kl_gate_momentum", 0.99)) self.kl_gate_temp = float(self.config.framework.get("kl_gate_temp", 0.5)) self.kl_gate_tau_scale = float(self.config.framework.get("kl_gate_tau_scale", 0.7)) # scale EMA threshold self.kl_gate_min = float(self.config.framework.get("kl_gate_min", 0.0)) self.kl_gate_max = float(self.config.framework.get("kl_gate_max", 1.0)) # cache some special token ids from tokenizer lazily self._im_end_id = None # EMA buffer for posterior language-span NLL self.register_buffer("post_nll_ema", torch.tensor(0.0, dtype=torch.float32)) self.register_buffer("post_nll_ema_inited", torch.tensor(0, dtype=torch.uint8)) # --------------------------------------------------------------------- # Token id helpers # --------------------------------------------------------------------- def _ensure_action_token_ids(self, tokenizer): if self.action_token_ids is None: self.action_token_ids = { "first": tokenizer.convert_tokens_to_ids("<|action_0|>"), "last": tokenizer.convert_tokens_to_ids(f"<|action_{self.num_latent_action_query-1}|>"), } def _ensure_im_end_id(self, tokenizer): if self._im_end_id is None: self._im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") def _find_last_pos(self, seq_1d: torch.Tensor, token_id: int) -> int: idx = (seq_1d == int(token_id)).nonzero(as_tuple=True)[0] if idx.numel() == 0: return -1 return int(idx[-1].item()) def _find_first_pos_after(self, seq_1d: torch.Tensor, token_id: int, start: int) -> int: if start < 0: start = 0 sub = seq_1d[start:] idx = (sub == int(token_id)).nonzero(as_tuple=True)[0] if idx.numel() == 0: return -1 return int(start + idx[0].item()) # --------------------------------------------------------------------- # Action block helpers # --------------------------------------------------------------------- def _get_action_block_start(self, input_ids_1d: torch.Tensor, tokenizer) -> int: self._ensure_action_token_ids(tokenizer) first_id = self.action_token_ids["first"] last_id = self.action_token_ids["last"] pos = (input_ids_1d == int(first_id)).nonzero(as_tuple=True)[0] if pos.numel() == 0: return -1 start = int(pos[0].item()) end = start + self.num_latent_action_query if end > input_ids_1d.shape[0]: return -1 if int(input_ids_1d[end - 1].item()) != int(last_id): return -1 return start def _extract_action_query_hidden_states( self, hidden_states: torch.Tensor, # [B, S, H] input_ids: torch.Tensor, # [B, S] tokenizer, return_starts: bool = False, ): self._ensure_action_token_ids(tokenizer) B = hidden_states.shape[0] out = [] starts = [] for b in range(B): start = self._get_action_block_start(input_ids[b], tokenizer) assert start != -1, "No valid contiguous action token block found in the sequence." end = start + self.num_latent_action_query out.append(hidden_states[b, start:end, :]) starts.append(start) out = torch.stack(out, dim=0) # [B, K, H] if return_starts: return out, torch.tensor(starts, device=input_ids.device, dtype=torch.long) return out # --------------------------------------------------------------------- # SHIFT-correct token-level NLL span # --------------------------------------------------------------------- def _token_nll_span( self, logits_1d: torch.Tensor, # [S, V] input_ids_1d: torch.Tensor, # [S] start: int, end: int, ignore_ids: Optional[Set[int]] = None, ): """ Return (nll_vec, target_ids_vec) for tokens in [start,end), using next-token alignment: token at position j is scored by logits[j-1] (requires j>0). """ if end <= start: return None, None S = int(input_ids_1d.shape[0]) start = max(0, int(start)) end = min(S, int(end)) if end <= start: return None, None j = torch.arange(start, end, device=input_ids_1d.device, dtype=torch.long) j = j[j > 0] if j.numel() == 0: return None, None targets = input_ids_1d[j].long() if ignore_ids is not None and len(ignore_ids) > 0: keep = torch.ones_like(targets, dtype=torch.bool) for tid in ignore_ids: keep &= (targets != int(tid)) j = j[keep] if j.numel() == 0: return None, None targets = input_ids_1d[j].long() pred_pos = j - 1 pred_logits = logits_1d[pred_pos].float() # [T, V] nll = F.cross_entropy(pred_logits, targets, reduction="none") # [T] return nll, targets # --------------------------------------------------------------------- # Compute LLR with: # - strict span equality assertion (training) # - hard-token LLR (top-k) # - shortcut gate based on posterior NLL # --------------------------------------------------------------------- def _compute_language_llr_from_boundaries( self, priori_logits: torch.Tensor, # [B, S, V] posteriori_logits: torch.Tensor, # [B, S, V] (detached) priori_input_ids: torch.Tensor, # [B, S] posteriori_input_ids: torch.Tensor, # [B, S] priori_action_starts: torch.Tensor, # [B] posteriori_action_starts: torch.Tensor, # [B] ) -> torch.Tensor: tokenizer = self.qwen_vl_interface.processor.tokenizer self._ensure_im_end_id(tokenizer) pad_id = tokenizer.pad_token_id ignore_ids: Set[int] = set() if pad_id is not None: ignore_ids.add(int(pad_id)) ignore_ids.add(int(IMAGE_TOKEN_INDEX)) ignore_ids.add(int(VIDEO_TOKEN_INDEX)) ignore_ids.add(int(VISION_START_TOKEN_INDEX)) ignore_ids.add(int(VISION_END_TOKEN_INDEX)) ignore_ids.add(int(IM_START_TOKEN_INDEX)) ignore_ids.add(int(IM_END_TOKEN_INDEX)) B = int(priori_input_ids.shape[0]) K = self.num_latent_action_query llr_vals = [] post_nll_means = [] for b in range(B): ids_prior = priori_input_ids[b] ids_post = posteriori_input_ids[b] a_start_prior = int(priori_action_starts[b].item()) a_start_post = int(posteriori_action_starts[b].item()) # ===== prior language span: [action_end : im_end) ===== lang_start_prior = a_start_prior + K if lang_start_prior >= ids_prior.shape[0]: continue im_end = self._find_first_pos_after(ids_prior, self._im_end_id, lang_start_prior) lang_end_prior = im_end if im_end != -1 else int(ids_prior.shape[0]) if lang_end_prior <= lang_start_prior: continue # ===== post language span: [last(vision_end)+1 : action_start) ===== v_end_post = self._find_last_pos(ids_post, VISION_END_TOKEN_INDEX) if v_end_post == -1: continue lang_start_post = v_end_post + 1 lang_end_post = a_start_post if lang_end_post <= lang_start_post: continue # ===== (1) strict assertion: token-level equality ===== if self.training and self.assert_lang_span_match: prior_span_ids = ids_prior[lang_start_prior:lang_end_prior] post_span_ids = ids_post[lang_start_post:lang_end_post] if (prior_span_ids.numel() != post_span_ids.numel()) or (not torch.equal(prior_span_ids, post_span_ids)): # decode for human-readable debugging prior_text = tokenizer.decode(prior_span_ids.tolist()) post_text = tokenizer.decode(post_span_ids.tolist()) raise AssertionError( "\n[LangForceV5] Language span mismatch detected!\n" f"Sample b={b}\n" f"PRIOR span idx: [{lang_start_prior}:{lang_end_prior}] (len={prior_span_ids.numel()})\n" f"POST span idx: [{lang_start_post}:{lang_end_post}] (len={post_span_ids.numel()})\n" f"PRIOR span: {repr(prior_text)}\n" f"POST span: {repr(post_text)}\n" f"PRIOR token ids (first 50): {prior_span_ids[:50].tolist()}\n" f"POST token ids (first 50): {post_span_ids[:50].tolist()}\n" "This indicates your boundary-based language extraction is inconsistent (likely prompt/template issue)." ) # ===== (2) hard-token LLR needs token-level aligned targets ===== nll_prior, tok_prior = self._token_nll_span( logits_1d=priori_logits[b], input_ids_1d=ids_prior, start=lang_start_prior, end=lang_end_prior, ignore_ids=ignore_ids, ) nll_post, tok_post = self._token_nll_span( logits_1d=posteriori_logits[b], input_ids_1d=ids_post, start=lang_start_post, end=lang_end_post, ignore_ids=ignore_ids, ) if nll_prior is None or nll_post is None: continue # record post nll mean for gate post_nll_mean = nll_post.mean().detach() post_nll_means.append(post_nll_mean) # logp_prior - logp_post = (-nll_prior) - (-nll_post) = nll_post - nll_prior if self.use_hard_token_llr: # require same target token sequence if tok_prior is None or tok_post is None or tok_prior.shape != tok_post.shape or (not torch.equal(tok_prior, tok_post)): # This should not happen if your spans match, but keep safe fallback. llr = (nll_post.mean() - nll_prior.mean()) else: k = min(self.hard_token_k, int(nll_post.numel())) if k <= 0: continue idx = torch.topk(nll_post.detach(), k=k, largest=True).indices llr = (nll_post[idx] - nll_prior[idx]).mean() else: llr = (nll_post.mean() - nll_prior.mean()) llr_vals.append(llr) if len(llr_vals) == 0: return torch.tensor(0.0, device=priori_logits.device, dtype=torch.float32) llr_vals_t = torch.stack(llr_vals).float() # [M] post_nll_means_t = torch.stack(post_nll_means).float() # [M] # ===== (2) shortcut gate: update EMA threshold ===== if self.use_kl_gate and self.training: batch_mean = post_nll_means_t.mean().detach() with torch.no_grad(): if int(self.post_nll_ema_inited.item()) == 0: self.post_nll_ema.copy_(batch_mean) self.post_nll_ema_inited.fill_(1) else: m = self.kl_gate_momentum self.post_nll_ema.copy_(m * self.post_nll_ema + (1.0 - m) * batch_mean) # ===== gate computation ===== if self.use_kl_gate: tau = (self.post_nll_ema.detach() * float(self.kl_gate_tau_scale)) temp = max(float(self.kl_gate_temp), 1e-6) # high nll => log p(L|V) low => gate small g = torch.sigmoid((tau - post_nll_means_t) / temp) # optional clamp/scale if self.kl_gate_min != 0.0 or self.kl_gate_max != 1.0: g = float(self.kl_gate_min) + (float(self.kl_gate_max) - float(self.kl_gate_min)) * g else: g = torch.ones_like(post_nll_means_t) # weighted LLR return (g * llr_vals_t).mean() # --------------------------------------------------------------------- # Forward # --------------------------------------------------------------------- def forward( self, examples: List[dict] = None, **kwargs, ) -> dict: batch_images = [example["image"] for example in examples] # [B, [PIL...]] instructions_priori = [self.latent_action_query + example["lang"] for example in examples] # A + L instructions_posteriori = [example["lang"] + self.latent_action_query for example in examples] # L + A actions = [example["action"] for example in examples] state = [example["state"] for example in examples] if "state" in examples[0] else None # ===== Step 1: Priori Branch (V + A + L) ===== qwen_inputs_priori = self.qwen_vl_interface.build_qwenvl_inputs( images=batch_images, instructions=instructions_priori ) with torch.autocast("cuda", dtype=torch.bfloat16): qwenvl_outputs_priori = self.qwen_vl_interface( **qwen_inputs_priori, output_attentions=False, output_hidden_states=True, return_dict=True, use_cache=False, ) priori_last_hidden = qwenvl_outputs_priori.hidden_states[-1] # [B, S, H] priori_action_hidden, priori_action_starts = self._extract_action_query_hidden_states( priori_last_hidden, qwen_inputs_priori["input_ids"], self.qwen_vl_interface.processor.tokenizer, return_starts=True ) # [B, K, H], [B] priori_logits = qwenvl_outputs_priori.logits # [B, S, V] # ===== Step 2: Posteriori Branch (V + L + A) ===== qwen_inputs_posteriori = self.qwen_vl_interface.build_qwenvl_inputs( images=batch_images, instructions=instructions_posteriori ) with torch.autocast("cuda", dtype=torch.bfloat16): qwenvl_outputs_posteriori = self.qwen_vl_interface( **qwen_inputs_posteriori, output_attentions=False, output_hidden_states=True, return_dict=True, use_cache=False, ) posteriori_last_hidden = qwenvl_outputs_posteriori.hidden_states[-1] # [B, S, H] posteriori_action_hidden, posteriori_action_starts = self._extract_action_query_hidden_states( posteriori_last_hidden, qwen_inputs_posteriori["input_ids"], self.qwen_vl_interface.processor.tokenizer, return_starts=True ) # [B, K, H], [B] # detach baseline logits: do not allow worsening log p(L|V) to inflate LLR posteriori_logits = qwenvl_outputs_posteriori.logits.detach() # [B, S, V] # ===== Step 3: LLR loss (Hard-token + Gate + Assert) ===== kl_loss = self._compute_language_llr_from_boundaries( priori_logits=priori_logits, posteriori_logits=posteriori_logits, priori_input_ids=qwen_inputs_priori["input_ids"], posteriori_input_ids=qwen_inputs_posteriori["input_ids"], priori_action_starts=priori_action_starts, posteriori_action_starts=posteriori_action_starts, ) # ===== Step 4: Action head losses ===== with torch.autocast("cuda", dtype=torch.float32): actions_t = torch.tensor( np.array(actions), device=priori_action_hidden.device, dtype=priori_action_hidden.dtype ) actions_target = actions_t[:, -(self.future_action_window_size + 1):, :] # [B, chunk_len, action_dim] repeated_diffusion_steps = ( self.config.trainer.get("repeated_diffusion_steps", 4) if self.config and self.config.trainer else 4 ) state_tensor = None if state is not None: state_tensor = torch.tensor( np.array(state), device=priori_action_hidden.device, dtype=priori_action_hidden.dtype ) actions_target_repeated = actions_target.repeat(repeated_diffusion_steps, 1, 1) # (3) detach prior condition switch if self.detach_prior_cond: priori_cond_base = priori_action_hidden.detach() else: priori_cond_base = priori_action_hidden priori_cond = priori_cond_base.repeat(repeated_diffusion_steps, 1, 1).float() posteriori_cond = posteriori_action_hidden.repeat(repeated_diffusion_steps, 1, 1).float() state_repeated = state_tensor.repeat(repeated_diffusion_steps, 1, 1) if state_tensor is not None else None prior_loss = self.action_model(priori_cond, actions_target_repeated, state_repeated) main_loss = self.action_model(posteriori_cond, actions_target_repeated, state_repeated) # ===== Step 5: Total loss (keep your preferred convex mixture) ===== total_loss = ( (1.0 - self.prior_loss_weight) * main_loss + self.prior_loss_weight * prior_loss - self.kl_weight * kl_loss ) return { "action_loss": total_loss, # optional logs: "main_loss": main_loss.detach(), "prior_loss": prior_loss.detach(), "kl_loss": kl_loss.detach(), } # --------------------------------------------------------------------- # Inference # --------------------------------------------------------------------- @torch.inference_mode() def predict_action( self, examples: List[dict], **kwargs: str, ) -> dict: """ Inference uses Posteriori branch: (V + L + action_query) """ if type(examples) is not list: examples = [examples] # robustly preserve PIL for each view batch_images = [] for ex in examples: imgs = ex["image"] if isinstance(imgs, list): batch_images.append([to_pil_preserve(im) for im in imgs]) else: batch_images.append([to_pil_preserve(imgs)]) instructions_posteriori = [ex["lang"] + self.latent_action_query for ex in examples] state = [ex["state"] for ex in examples] if "state" in examples[0] else None train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None) if train_obs_image_size: batch_images = resize_images(batch_images, target_size=train_obs_image_size) qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs( images=batch_images, instructions=instructions_posteriori ) with torch.autocast("cuda", dtype=torch.bfloat16): qwenvl_outputs = self.qwen_vl_interface( **qwen_inputs, output_attentions=False, output_hidden_states=True, return_dict=True, use_cache=False, ) last_hidden = qwenvl_outputs.hidden_states[-1] action_hidden = self._extract_action_query_hidden_states( last_hidden, qwen_inputs["input_ids"], self.qwen_vl_interface.processor.tokenizer, return_starts=False ) # [B, K, H] state_tensor = None if state is not None: state_tensor = torch.from_numpy(np.array(state)).to(action_hidden.device, dtype=action_hidden.dtype) with torch.autocast("cuda", dtype=torch.float32): pred_actions = self.action_model.predict_action(action_hidden, state_tensor) return {"normalized_actions": pred_actions.detach().cpu().numpy()} if __name__ == "__main__": from omegaconf import OmegaConf import debugpy import argparse parser = argparse.ArgumentParser() parser.add_argument("--config_yaml", type=str, default="./examples/Robotwin/train_files/starvla_cotrain_robotwin.yaml") args, clipargs = parser.parse_known_args() debugpy.listen(("0.0.0.0", 10092)) print("🔍 Rank 0 waiting for debugger attach on port 10092...") debugpy.wait_for_client() args.config_yaml = "examples/MultiRobot/train_files/starvla_cotrain_multiRobot.yaml" cfg = OmegaConf.load(args.config_yaml) model: LangForce = LangForce(cfg) print(model) image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) sample = { "action": np.random.uniform(-1, 1, size=(16, 7)).astype(np.float16), "image": [image], "lang": "Put all the toys in the child's room ... inside the toy box.", } sample2 = { "action": np.random.uniform(-1, 1, size=(16, 7)).astype(np.float16), "image": [image], "lang": "Put all the toys in the child's room ... inside the toy box.", } batch = [sample, sample2] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) out = model(batch) print("Action Loss:", out["action_loss"].item(), "KL Loss:", out["kl_loss"].item()) pred = model.predict_action([sample]) print("Pred shape:", pred["normalized_actions"].shape) # optional dataloader test vla_dataset_cfg = cfg.datasets.vla_data from torch.utils.data import DataLoader from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn cfg.datasets.vla_data.include_state = "False" dataset = get_vla_dataset(data_cfg=vla_dataset_cfg) train_dataloader = DataLoader( dataset, batch_size=2, num_workers=1, collate_fn=collate_fn, ) for batch in tqdm(train_dataloader, desc="Processing Batches"): model(batch) break print("Finished")