| |
| |
| |
| |
| """ |
| Qwen-GR00T Framework |
| Qwen-VL + Flow-matching head to directly predict continuous actions |
| |
| LangForceV5: |
| (1) Assert language span consistency between prior/post branches (token-level exact match) |
| (2) Hard-token LLR + Shortcut gate |
| (3) Optional detach of prior condition to avoid pushing backbone to vision-only shortcut |
| """ |
| import sys |
| from pathlib import Path |
|
|
| |
| _workspace_root = Path(__file__).parent.parent.parent.parent |
| if str(_workspace_root) not in sys.path: |
| sys.path.insert(0, str(_workspace_root)) |
|
|
| from typing import List, Optional, Tuple, Set |
| from tqdm import tqdm |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from PIL import Image |
|
|
| from starVLA.training.trainer_utils import initialize_overwatch |
| from deployment.model_server.tools.image_tools import to_pil_preserve |
|
|
| logger = initialize_overwatch(__name__) |
|
|
| |
| IGNORE_INDEX = -100 |
|
|
| |
| VISION_START_TOKEN_INDEX = 151652 |
| VISION_END_TOKEN_INDEX = 151654 |
| IMAGE_TOKEN_INDEX = 151655 |
| VIDEO_TOKEN_INDEX = 151656 |
| IM_START_TOKEN_INDEX = 151644 |
| IM_END_TOKEN_INDEX = 151645 |
|
|
| from starVLA.model.framework.base_framework import baseframework |
| from starVLA.model.modules.vlm import get_vlm_model |
| from starVLA.model.modules.action_model.GR00T_ActionHeader import get_action_model, FlowmatchingActionHead |
| from starVLA.training.trainer_utils.trainer_tools import resize_images |
| from starVLA.model.tools import FRAMEWORK_REGISTRY |
|
|
|
|
| @FRAMEWORK_REGISTRY.register("LangForce") |
| class LangForce(baseframework): |
| """ |
| LangForce: Bayesian Decomposition of Vision Language Action Models via Latent Action Queries (arxiv 2601.15197) |
| |
| Dual-branch VLA with: |
| - Prior branch: (V + A + L) => proposal-like p(a|v) head |
| - Posterior branch: (V + L + A) => pi(a|v,l) |
| - LLR regularizer: maximize log p(L|V,A_prior) - sg(log p(L|V)) |
| with: |
| * Hard-token LLR (top-k hardest tokens under post) |
| * Shortcut gate (down-weight LLR when log p(L|V) is already very low) |
| - Optional detach prior cond (protect backbone from vision-only drift) |
| |
| Additionally: |
| - Training-time assertion: extracted language spans in prior/post must match exactly (token-level). |
| If mismatch => raise AssertionError with decoded spans. |
| """ |
|
|
| def __init__( |
| self, |
| config: Optional[dict] = None, |
| **kwargs, |
| ) -> None: |
| super().__init__() |
| self.config = config |
| self.qwen_vl_interface = get_vlm_model(config=self.config) |
|
|
| |
| self.config.framework.action_model.diffusion_model_cfg.cross_attention_dim = ( |
| self.qwen_vl_interface.model.config.hidden_size |
| ) |
|
|
| self.num_latent_action_query = self.config.framework.qwenvl.get("num_latent_action_query", 32) |
| self.latent_action_query = "".join([f"<|action_{i}|>" for i in range(self.num_latent_action_query)]) |
| self.action_token_ids = None |
|
|
| self.action_model: FlowmatchingActionHead = get_action_model(config=self.config) |
|
|
| self.future_action_window_size = config.framework.action_model.future_action_window_size |
| self.past_action_window_size = config.framework.action_model.past_action_window_size |
| self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size |
|
|
| |
| self.kl_weight = float(self.config.framework.get("kl_weight", 0.1)) |
| self.prior_loss_weight = float(self.config.framework.get("prior_loss_weight", 0.3)) |
|
|
| |
| self.assert_lang_span_match = bool(self.config.framework.get("assert_lang_span_match", True)) |
|
|
| |
| self.detach_prior_cond = bool(self.config.framework.get("detach_prior_cond", True)) |
|
|
| |
| self.use_hard_token_llr = bool(self.config.framework.get("use_hard_token_llr", True)) |
| self.hard_token_k = int(self.config.framework.get("hard_token_k", 16)) |
| assert self.hard_token_k > 0 |
|
|
| |
| |
| self.use_kl_gate = bool(self.config.framework.get("use_kl_gate", True)) |
| self.kl_gate_momentum = float(self.config.framework.get("kl_gate_momentum", 0.99)) |
| self.kl_gate_temp = float(self.config.framework.get("kl_gate_temp", 0.5)) |
| self.kl_gate_tau_scale = float(self.config.framework.get("kl_gate_tau_scale", 0.7)) |
| self.kl_gate_min = float(self.config.framework.get("kl_gate_min", 0.0)) |
| self.kl_gate_max = float(self.config.framework.get("kl_gate_max", 1.0)) |
|
|
| |
| self._im_end_id = None |
|
|
| |
| self.register_buffer("post_nll_ema", torch.tensor(0.0, dtype=torch.float32)) |
| self.register_buffer("post_nll_ema_inited", torch.tensor(0, dtype=torch.uint8)) |
|
|
| |
| |
| |
| def _ensure_action_token_ids(self, tokenizer): |
| if self.action_token_ids is None: |
| self.action_token_ids = { |
| "first": tokenizer.convert_tokens_to_ids("<|action_0|>"), |
| "last": tokenizer.convert_tokens_to_ids(f"<|action_{self.num_latent_action_query-1}|>"), |
| } |
|
|
| def _ensure_im_end_id(self, tokenizer): |
| if self._im_end_id is None: |
| self._im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") |
|
|
| def _find_last_pos(self, seq_1d: torch.Tensor, token_id: int) -> int: |
| idx = (seq_1d == int(token_id)).nonzero(as_tuple=True)[0] |
| if idx.numel() == 0: |
| return -1 |
| return int(idx[-1].item()) |
|
|
| def _find_first_pos_after(self, seq_1d: torch.Tensor, token_id: int, start: int) -> int: |
| if start < 0: |
| start = 0 |
| sub = seq_1d[start:] |
| idx = (sub == int(token_id)).nonzero(as_tuple=True)[0] |
| if idx.numel() == 0: |
| return -1 |
| return int(start + idx[0].item()) |
|
|
| |
| |
| |
| def _get_action_block_start(self, input_ids_1d: torch.Tensor, tokenizer) -> int: |
| self._ensure_action_token_ids(tokenizer) |
| first_id = self.action_token_ids["first"] |
| last_id = self.action_token_ids["last"] |
|
|
| pos = (input_ids_1d == int(first_id)).nonzero(as_tuple=True)[0] |
| if pos.numel() == 0: |
| return -1 |
|
|
| start = int(pos[0].item()) |
| end = start + self.num_latent_action_query |
| if end > input_ids_1d.shape[0]: |
| return -1 |
| if int(input_ids_1d[end - 1].item()) != int(last_id): |
| return -1 |
| return start |
|
|
| def _extract_action_query_hidden_states( |
| self, |
| hidden_states: torch.Tensor, |
| input_ids: torch.Tensor, |
| tokenizer, |
| return_starts: bool = False, |
| ): |
| self._ensure_action_token_ids(tokenizer) |
|
|
| B = hidden_states.shape[0] |
| out = [] |
| starts = [] |
| for b in range(B): |
| start = self._get_action_block_start(input_ids[b], tokenizer) |
| assert start != -1, "No valid contiguous action token block found in the sequence." |
| end = start + self.num_latent_action_query |
| out.append(hidden_states[b, start:end, :]) |
| starts.append(start) |
|
|
| out = torch.stack(out, dim=0) |
| if return_starts: |
| return out, torch.tensor(starts, device=input_ids.device, dtype=torch.long) |
| return out |
|
|
| |
| |
| |
| def _token_nll_span( |
| self, |
| logits_1d: torch.Tensor, |
| input_ids_1d: torch.Tensor, |
| start: int, |
| end: int, |
| ignore_ids: Optional[Set[int]] = None, |
| ): |
| """ |
| Return (nll_vec, target_ids_vec) for tokens in [start,end), |
| using next-token alignment: |
| token at position j is scored by logits[j-1] (requires j>0). |
| """ |
| if end <= start: |
| return None, None |
| S = int(input_ids_1d.shape[0]) |
| start = max(0, int(start)) |
| end = min(S, int(end)) |
| if end <= start: |
| return None, None |
|
|
| j = torch.arange(start, end, device=input_ids_1d.device, dtype=torch.long) |
| j = j[j > 0] |
| if j.numel() == 0: |
| return None, None |
|
|
| targets = input_ids_1d[j].long() |
|
|
| if ignore_ids is not None and len(ignore_ids) > 0: |
| keep = torch.ones_like(targets, dtype=torch.bool) |
| for tid in ignore_ids: |
| keep &= (targets != int(tid)) |
| j = j[keep] |
| if j.numel() == 0: |
| return None, None |
| targets = input_ids_1d[j].long() |
|
|
| pred_pos = j - 1 |
| pred_logits = logits_1d[pred_pos].float() |
| nll = F.cross_entropy(pred_logits, targets, reduction="none") |
| return nll, targets |
|
|
| |
| |
| |
| |
| |
| |
| def _compute_language_llr_from_boundaries( |
| self, |
| priori_logits: torch.Tensor, |
| posteriori_logits: torch.Tensor, |
| priori_input_ids: torch.Tensor, |
| posteriori_input_ids: torch.Tensor, |
| priori_action_starts: torch.Tensor, |
| posteriori_action_starts: torch.Tensor, |
| ) -> torch.Tensor: |
| tokenizer = self.qwen_vl_interface.processor.tokenizer |
| self._ensure_im_end_id(tokenizer) |
|
|
| pad_id = tokenizer.pad_token_id |
| ignore_ids: Set[int] = set() |
| if pad_id is not None: |
| ignore_ids.add(int(pad_id)) |
| ignore_ids.add(int(IMAGE_TOKEN_INDEX)) |
| ignore_ids.add(int(VIDEO_TOKEN_INDEX)) |
| ignore_ids.add(int(VISION_START_TOKEN_INDEX)) |
| ignore_ids.add(int(VISION_END_TOKEN_INDEX)) |
| ignore_ids.add(int(IM_START_TOKEN_INDEX)) |
| ignore_ids.add(int(IM_END_TOKEN_INDEX)) |
|
|
| B = int(priori_input_ids.shape[0]) |
| K = self.num_latent_action_query |
|
|
| llr_vals = [] |
| post_nll_means = [] |
|
|
| for b in range(B): |
| ids_prior = priori_input_ids[b] |
| ids_post = posteriori_input_ids[b] |
|
|
| a_start_prior = int(priori_action_starts[b].item()) |
| a_start_post = int(posteriori_action_starts[b].item()) |
|
|
| |
| lang_start_prior = a_start_prior + K |
| if lang_start_prior >= ids_prior.shape[0]: |
| continue |
| im_end = self._find_first_pos_after(ids_prior, self._im_end_id, lang_start_prior) |
| lang_end_prior = im_end if im_end != -1 else int(ids_prior.shape[0]) |
| if lang_end_prior <= lang_start_prior: |
| continue |
|
|
| |
| v_end_post = self._find_last_pos(ids_post, VISION_END_TOKEN_INDEX) |
| if v_end_post == -1: |
| continue |
| lang_start_post = v_end_post + 1 |
| lang_end_post = a_start_post |
| if lang_end_post <= lang_start_post: |
| continue |
|
|
| |
| if self.training and self.assert_lang_span_match: |
| prior_span_ids = ids_prior[lang_start_prior:lang_end_prior] |
| post_span_ids = ids_post[lang_start_post:lang_end_post] |
|
|
| if (prior_span_ids.numel() != post_span_ids.numel()) or (not torch.equal(prior_span_ids, post_span_ids)): |
| |
| prior_text = tokenizer.decode(prior_span_ids.tolist()) |
| post_text = tokenizer.decode(post_span_ids.tolist()) |
|
|
| raise AssertionError( |
| "\n[LangForceV5] Language span mismatch detected!\n" |
| f"Sample b={b}\n" |
| f"PRIOR span idx: [{lang_start_prior}:{lang_end_prior}] (len={prior_span_ids.numel()})\n" |
| f"POST span idx: [{lang_start_post}:{lang_end_post}] (len={post_span_ids.numel()})\n" |
| f"PRIOR span: {repr(prior_text)}\n" |
| f"POST span: {repr(post_text)}\n" |
| f"PRIOR token ids (first 50): {prior_span_ids[:50].tolist()}\n" |
| f"POST token ids (first 50): {post_span_ids[:50].tolist()}\n" |
| "This indicates your boundary-based language extraction is inconsistent (likely prompt/template issue)." |
| ) |
|
|
| |
| nll_prior, tok_prior = self._token_nll_span( |
| logits_1d=priori_logits[b], |
| input_ids_1d=ids_prior, |
| start=lang_start_prior, |
| end=lang_end_prior, |
| ignore_ids=ignore_ids, |
| ) |
| nll_post, tok_post = self._token_nll_span( |
| logits_1d=posteriori_logits[b], |
| input_ids_1d=ids_post, |
| start=lang_start_post, |
| end=lang_end_post, |
| ignore_ids=ignore_ids, |
| ) |
| if nll_prior is None or nll_post is None: |
| continue |
|
|
| |
| post_nll_mean = nll_post.mean().detach() |
| post_nll_means.append(post_nll_mean) |
|
|
| |
| if self.use_hard_token_llr: |
| |
| if tok_prior is None or tok_post is None or tok_prior.shape != tok_post.shape or (not torch.equal(tok_prior, tok_post)): |
| |
| llr = (nll_post.mean() - nll_prior.mean()) |
| else: |
| k = min(self.hard_token_k, int(nll_post.numel())) |
| if k <= 0: |
| continue |
| idx = torch.topk(nll_post.detach(), k=k, largest=True).indices |
| llr = (nll_post[idx] - nll_prior[idx]).mean() |
| else: |
| llr = (nll_post.mean() - nll_prior.mean()) |
|
|
| llr_vals.append(llr) |
|
|
| if len(llr_vals) == 0: |
| return torch.tensor(0.0, device=priori_logits.device, dtype=torch.float32) |
|
|
| llr_vals_t = torch.stack(llr_vals).float() |
| post_nll_means_t = torch.stack(post_nll_means).float() |
|
|
| |
| if self.use_kl_gate and self.training: |
| batch_mean = post_nll_means_t.mean().detach() |
| with torch.no_grad(): |
| if int(self.post_nll_ema_inited.item()) == 0: |
| self.post_nll_ema.copy_(batch_mean) |
| self.post_nll_ema_inited.fill_(1) |
| else: |
| m = self.kl_gate_momentum |
| self.post_nll_ema.copy_(m * self.post_nll_ema + (1.0 - m) * batch_mean) |
|
|
| |
| if self.use_kl_gate: |
| tau = (self.post_nll_ema.detach() * float(self.kl_gate_tau_scale)) |
| temp = max(float(self.kl_gate_temp), 1e-6) |
| |
| g = torch.sigmoid((tau - post_nll_means_t) / temp) |
| |
| if self.kl_gate_min != 0.0 or self.kl_gate_max != 1.0: |
| g = float(self.kl_gate_min) + (float(self.kl_gate_max) - float(self.kl_gate_min)) * g |
| else: |
| g = torch.ones_like(post_nll_means_t) |
|
|
| |
| return (g * llr_vals_t).mean() |
|
|
| |
| |
| |
| def forward( |
| self, |
| examples: List[dict] = None, |
| **kwargs, |
| ) -> dict: |
| batch_images = [example["image"] for example in examples] |
| instructions_priori = [self.latent_action_query + example["lang"] for example in examples] |
| instructions_posteriori = [example["lang"] + self.latent_action_query for example in examples] |
|
|
| actions = [example["action"] for example in examples] |
| state = [example["state"] for example in examples] if "state" in examples[0] else None |
|
|
| |
| qwen_inputs_priori = self.qwen_vl_interface.build_qwenvl_inputs( |
| images=batch_images, |
| instructions=instructions_priori |
| ) |
|
|
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| qwenvl_outputs_priori = self.qwen_vl_interface( |
| **qwen_inputs_priori, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| use_cache=False, |
| ) |
| priori_last_hidden = qwenvl_outputs_priori.hidden_states[-1] |
| priori_action_hidden, priori_action_starts = self._extract_action_query_hidden_states( |
| priori_last_hidden, |
| qwen_inputs_priori["input_ids"], |
| self.qwen_vl_interface.processor.tokenizer, |
| return_starts=True |
| ) |
| priori_logits = qwenvl_outputs_priori.logits |
|
|
| |
| qwen_inputs_posteriori = self.qwen_vl_interface.build_qwenvl_inputs( |
| images=batch_images, |
| instructions=instructions_posteriori |
| ) |
|
|
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| qwenvl_outputs_posteriori = self.qwen_vl_interface( |
| **qwen_inputs_posteriori, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| use_cache=False, |
| ) |
| posteriori_last_hidden = qwenvl_outputs_posteriori.hidden_states[-1] |
| posteriori_action_hidden, posteriori_action_starts = self._extract_action_query_hidden_states( |
| posteriori_last_hidden, |
| qwen_inputs_posteriori["input_ids"], |
| self.qwen_vl_interface.processor.tokenizer, |
| return_starts=True |
| ) |
|
|
| |
| posteriori_logits = qwenvl_outputs_posteriori.logits.detach() |
|
|
| |
| kl_loss = self._compute_language_llr_from_boundaries( |
| priori_logits=priori_logits, |
| posteriori_logits=posteriori_logits, |
| priori_input_ids=qwen_inputs_priori["input_ids"], |
| posteriori_input_ids=qwen_inputs_posteriori["input_ids"], |
| priori_action_starts=priori_action_starts, |
| posteriori_action_starts=posteriori_action_starts, |
| ) |
|
|
| |
| with torch.autocast("cuda", dtype=torch.float32): |
| actions_t = torch.tensor( |
| np.array(actions), device=priori_action_hidden.device, dtype=priori_action_hidden.dtype |
| ) |
| actions_target = actions_t[:, -(self.future_action_window_size + 1):, :] |
|
|
| repeated_diffusion_steps = ( |
| self.config.trainer.get("repeated_diffusion_steps", 4) if self.config and self.config.trainer else 4 |
| ) |
|
|
| state_tensor = None |
| if state is not None: |
| state_tensor = torch.tensor( |
| np.array(state), device=priori_action_hidden.device, dtype=priori_action_hidden.dtype |
| ) |
|
|
| actions_target_repeated = actions_target.repeat(repeated_diffusion_steps, 1, 1) |
|
|
| |
| if self.detach_prior_cond: |
| priori_cond_base = priori_action_hidden.detach() |
| else: |
| priori_cond_base = priori_action_hidden |
|
|
| priori_cond = priori_cond_base.repeat(repeated_diffusion_steps, 1, 1).float() |
| posteriori_cond = posteriori_action_hidden.repeat(repeated_diffusion_steps, 1, 1).float() |
| state_repeated = state_tensor.repeat(repeated_diffusion_steps, 1, 1) if state_tensor is not None else None |
|
|
| prior_loss = self.action_model(priori_cond, actions_target_repeated, state_repeated) |
| main_loss = self.action_model(posteriori_cond, actions_target_repeated, state_repeated) |
|
|
| |
| total_loss = ( |
| (1.0 - self.prior_loss_weight) * main_loss |
| + self.prior_loss_weight * prior_loss |
| - self.kl_weight * kl_loss |
| ) |
|
|
| return { |
| "action_loss": total_loss, |
| |
| "main_loss": main_loss.detach(), |
| "prior_loss": prior_loss.detach(), |
| "kl_loss": kl_loss.detach(), |
| } |
|
|
| |
| |
| |
| @torch.inference_mode() |
| def predict_action( |
| self, |
| examples: List[dict], |
| **kwargs: str, |
| ) -> dict: |
| """ |
| Inference uses Posteriori branch: (V + L + action_query) |
| """ |
| if type(examples) is not list: |
| examples = [examples] |
|
|
| |
| batch_images = [] |
| for ex in examples: |
| imgs = ex["image"] |
| if isinstance(imgs, list): |
| batch_images.append([to_pil_preserve(im) for im in imgs]) |
| else: |
| batch_images.append([to_pil_preserve(imgs)]) |
|
|
| instructions_posteriori = [ex["lang"] + self.latent_action_query for ex in examples] |
| state = [ex["state"] for ex in examples] if "state" in examples[0] else None |
|
|
| train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None) |
| if train_obs_image_size: |
| batch_images = resize_images(batch_images, target_size=train_obs_image_size) |
|
|
| qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs( |
| images=batch_images, |
| instructions=instructions_posteriori |
| ) |
|
|
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| qwenvl_outputs = self.qwen_vl_interface( |
| **qwen_inputs, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| use_cache=False, |
| ) |
|
|
| last_hidden = qwenvl_outputs.hidden_states[-1] |
| action_hidden = self._extract_action_query_hidden_states( |
| last_hidden, |
| qwen_inputs["input_ids"], |
| self.qwen_vl_interface.processor.tokenizer, |
| return_starts=False |
| ) |
|
|
| state_tensor = None |
| if state is not None: |
| state_tensor = torch.from_numpy(np.array(state)).to(action_hidden.device, dtype=action_hidden.dtype) |
|
|
| with torch.autocast("cuda", dtype=torch.float32): |
| pred_actions = self.action_model.predict_action(action_hidden, state_tensor) |
|
|
| return {"normalized_actions": pred_actions.detach().cpu().numpy()} |
|
|
|
|
| if __name__ == "__main__": |
| from omegaconf import OmegaConf |
| import debugpy |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config_yaml", type=str, default="./examples/Robotwin/train_files/starvla_cotrain_robotwin.yaml") |
| args, clipargs = parser.parse_known_args() |
|
|
| debugpy.listen(("0.0.0.0", 10092)) |
| print("🔍 Rank 0 waiting for debugger attach on port 10092...") |
| debugpy.wait_for_client() |
|
|
| args.config_yaml = "examples/MultiRobot/train_files/starvla_cotrain_multiRobot.yaml" |
| cfg = OmegaConf.load(args.config_yaml) |
|
|
| model: LangForce = LangForce(cfg) |
| print(model) |
|
|
| image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) |
| sample = { |
| "action": np.random.uniform(-1, 1, size=(16, 7)).astype(np.float16), |
| "image": [image], |
| "lang": "Put all the toys in the child's room ... inside the toy box.", |
| } |
| sample2 = { |
| "action": np.random.uniform(-1, 1, size=(16, 7)).astype(np.float16), |
| "image": [image], |
| "lang": "Put all the toys in the child's room ... inside the toy box.", |
| } |
|
|
| batch = [sample, sample2] |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = model.to(device) |
|
|
| out = model(batch) |
| print("Action Loss:", out["action_loss"].item(), "KL Loss:", out["kl_loss"].item()) |
|
|
| pred = model.predict_action([sample]) |
| print("Pred shape:", pred["normalized_actions"].shape) |
|
|
| |
| vla_dataset_cfg = cfg.datasets.vla_data |
| from torch.utils.data import DataLoader |
| from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn |
|
|
| cfg.datasets.vla_data.include_state = "False" |
| dataset = get_vla_dataset(data_cfg=vla_dataset_cfg) |
|
|
| train_dataloader = DataLoader( |
| dataset, |
| batch_size=2, |
| num_workers=1, |
| collate_fn=collate_fn, |
| ) |
|
|
| for batch in tqdm(train_dataloader, desc="Processing Batches"): |
| model(batch) |
| break |
|
|
| print("Finished") |
|
|