# eval_sigma_vla_rollout.py # Offline closed-loop evaluation for Telepathy-augmented VLA on top of PI05 policy backbone. # # Key design: # - base_model_id is a LeRobot/OpenPI policy repo (e.g., lerobot/pi05_base or your fine-tuned Sigma repo). # - We load PI05Policy via LeRobot, NOT AutoModelForCausalLM. # - Text embeddings are taken from the PI05 internal text backbone so that TelepathyLanguageModule # receives the same type of inputs used during training. # # Hardened in this revision: # - Robust recursive shard discovery under any naming & subfolders. # - Shard content structure normalization (list-of-samples, or dict{samples/data}). # - Collate auto-adapts to real schema: vision/state/action/text, with time-dim collapse for vision. # - Action GT supports dict-style branches or a single tensor. # - Metrics tolerate missing multi-branch outputs (fallback to "action"). # - Text tokens dtype/device aligned to model dtype for mixed precision safety. # - Robot state time-dim collapse + pad/trim to state encoder expected dim. # - Dynamic projection to align vision/state token hidden size to vision backbone dim (768), # and project text to the same dim BEFORE feeding language module. # - Optional max_text_len to avoid tokenizer truncation warnings. # - action input contract hardening: # * high_level_rep 2D -> 3D # * tau None/2D -> 3D # * tau length aligned to high_level_rep length # * tau last-dim auto pad/trim so concat(high_level_rep, tau) matches action_condition_proj in_features # - tokenizer_id can be a LOCAL path; when it exists locally we load with local_files_only # - _align_target handles 2D<->3D mismatches (fixes MSE crashes) # - remove duplicated "high_level_rep/tau re-normalization" that overwrote the hardening # # NEW in this patch: # - cosine_alignment auto-aligns hidden sizes (fixes 256 vs 2048 crash). # - semantic pooling guard supports 2D/3D factors safely. # - alignment metric ignores zero-length cases robustly. # # EXTRA HARDENING (this patch for your baseline issue): # - Try strict load for PI05Policy if the LeRobot version supports it. # - Verify tokenizer vocab size and special-token ids match PI05 text embedding table. # - Fail fast with a clear message if mismatch is detected (unless explicitly overridden). # # NEW in this hard-set patch: # - Per-sample MSE is exposed from success proxy. # - A "hard set" is defined as samples whose branch-wise MSE exceeds hard thresholds. # - Hard-set averages (MSE and fraction of samples) are reported alongside global metrics. # # NEW in this adapter patch: # - sigma_telepathy_adapter is applied at eval time (when telepathy is enabled) to gate # Telepathy residuals based on their magnitude and tau strength, optionally using # offline base_action_* if present in the shards. from __future__ import annotations import os import glob import json import argparse import importlib from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from dotenv import load_dotenv from accelerate import Accelerator from accelerate.utils import set_seed try: from huggingface_hub import snapshot_download except Exception: snapshot_download = None # type: ignore from vision_sigma_vla import TelepathyVisionModule, VisionConfig from language_sigma_vla import TelepathyLanguageModule, LanguageConfig from action_sigma_vla import TelepathyActionModule, ActionConfig from sigma_telepathy_adapter import SigmaTelepathyAdapter, SigmaTelepathyAdapterConfig def ensure_sigma_artifacts_from_hf( repo_id: str, hf_token: Optional[str], local_cache_root: str, ) -> Dict[str, str]: """ Download Sigma artifacts from HF repo into a local cache folder. Returns local paths for shard_dir and telepathy_heads_path. We only pull: storage/sigma_pickplace/** storage/sigma_lora_out/** """ if snapshot_download is None: raise ImportError( "huggingface_hub is not available but auto-download was requested. " "Please `pip install huggingface_hub` or download artifacts manually." ) os.makedirs(local_cache_root, exist_ok=True) local_dir = snapshot_download( repo_id=repo_id, token=hf_token, local_dir=os.path.join(local_cache_root, repo_id.replace("/", "__")), local_dir_use_symlinks=False, allow_patterns=[ "storage/sigma_pickplace/**", "storage/sigma_lora_out/**", ], ) shard_dir = os.path.join(local_dir, "storage", "sigma_pickplace") telepathy_heads_path = os.path.join( local_dir, "storage", "sigma_lora_out", "sigma_telepathy_heads.pt" ) return { "local_repo_dir": local_dir, "shard_dir": shard_dir, "telepathy_heads_path": telepathy_heads_path, } def load_pi05_policy( repo_id: str, hf_token: Optional[str], device: torch.device, strict_load: bool = True, ): """ Load PI05Policy from LeRobot. We try a few import paths to be robust across versions. If the LeRobot PI05Policy.from_pretrained supports strict loading, we enable it. """ policy_cls = None import_errors = [] candidate_paths = [ ("lerobot.policies.pi05.modeling_pi05", "PI05Policy"), ("lerobot.policies.pi05", "PI05Policy"), ] for mod_name, cls_name in candidate_paths: try: mod = importlib.import_module(mod_name) policy_cls = getattr(mod, cls_name) break except Exception as e: import_errors.append(f"{mod_name}.{cls_name}: {type(e).__name__}: {e}") if policy_cls is None: raise ImportError( "Failed to import PI05Policy from LeRobot. Tried:\n - " + "\n - ".join(import_errors) ) policy = None tried = [] if strict_load: try: policy = policy_cls.from_pretrained(repo_id, token=hf_token, strict=True) tried.append("from_pretrained(..., strict=True)") except TypeError: tried.append("strict=True not supported") except Exception as e: tried.append(f"strict=True failed: {type(e).__name__}: {e}") if policy is None: try: policy = policy_cls.from_pretrained(repo_id, token=hf_token) tried.append("from_pretrained(repo_id, token=...)") except TypeError: policy = policy_cls.from_pretrained(pretrained_name_or_path=repo_id, token=hf_token) tried.append("from_pretrained(pretrained_name_or_path=..., token=...)") if policy is None: raise RuntimeError("PI05Policy loading returned None. Tried: " + "; ".join(tried)) policy = policy.to(device) policy.eval() return policy def get_policy_tokenizer( policy, repo_id: str, hf_token: Optional[str], forced_tokenizer_id: str = "", ): """ Robust tokenizer getter for PI05Policy. IMPORTANT: - Never call AutoTokenizer.from_pretrained(repo_id) because repo_id is a policy repo. - If --tokenizer_id is provided and points to a LOCAL folder, load locally. - Otherwise load from HF id. - If still missing, recursively search for tokenizer/processor inside policy. """ from transformers import AutoTokenizer if forced_tokenizer_id: if os.path.exists(forced_tokenizer_id): tok = AutoTokenizer.from_pretrained( forced_tokenizer_id, local_files_only=True, trust_remote_code=True, ) else: tok = AutoTokenizer.from_pretrained( forced_tokenizer_id, token=hf_token, trust_remote_code=True, ) if tok.pad_token is None: tok.pad_token = tok.eos_token return tok def _recursive_find_tokenizer(obj, max_depth: int = 4): if obj is None or max_depth <= 0: return None for key in ["tokenizer", "processor", "text_tokenizer", "language_tokenizer"]: if hasattr(obj, key): v = getattr(obj, key) if v is None: continue if key == "processor" and hasattr(v, "tokenizer") and v.tokenizer is not None: return v.tokenizer if hasattr(v, "__call__"): return v nested_names = [ "paligemma_with_expert", "paligemma", "gemma_expert", "language_model", "text_model", "model", "policy", ] for name in nested_names: if hasattr(obj, name): found = _recursive_find_tokenizer( getattr(obj, name), max_depth=max_depth - 1 ) if found is not None: return found return None tok = _recursive_find_tokenizer(policy) if tok is not None: if getattr(tok, "pad_token", None) is None and getattr(tok, "eos_token", None) is not None: tok.pad_token = tok.eos_token return tok backbone_name = None config_candidates = [] for attr in ["config", "model", "paligemma_with_expert", "paligemma"]: if hasattr(policy, attr): config_candidates.append(getattr(policy, attr)) def _try_get_name(cfg_obj): if cfg_obj is None: return None for k in [ "_name_or_path", "text_backbone_id", "text_model_id", "language_model_id", "processor_name_or_path", "tokenizer_name_or_path", ]: if hasattr(cfg_obj, k): v = getattr(cfg_obj, k) if isinstance(v, str) and v: return v if hasattr(cfg_obj, "config"): c = getattr(cfg_obj, "config") if hasattr(c, "_name_or_path") and isinstance(c._name_or_path, str) and c._name_or_path: return c._name_or_path return None for c in config_candidates: backbone_name = _try_get_name(c) if backbone_name: break if backbone_name: tok = AutoTokenizer.from_pretrained( backbone_name, token=hf_token, trust_remote_code=True ) if tok.pad_token is None: tok.pad_token = tok.eos_token return tok raise ValueError( f"Cannot obtain tokenizer from PI05Policy for repo '{repo_id}'. " "Your lerobot PI05 port does not expose tokenizer/processor nor backbone name. " "Please pass --tokenizer_id explicitly." ) def get_policy_text_embedding_layer(policy): """ Locate the text embedding layer inside PI05Policy robustly. """ def _recursive_find(obj, depth: int = 6): if obj is None or depth <= 0: return None if hasattr(obj, "get_input_embeddings"): try: emb = obj.get_input_embeddings() if emb is not None: return emb except Exception: pass for key in ["embed_tokens", "embeddings", "token_embedding"]: if hasattr(obj, key): v = getattr(obj, key) if isinstance(v, nn.Module): return v nested_names = [ "model", "paligemma_with_expert", "paligemma", "language_model", "gemma_expert", "text_model", "policy", ] for name in nested_names: if hasattr(obj, name): found = _recursive_find(getattr(obj, name), depth=depth - 1) if found is not None: return found return None emb = _recursive_find(policy) if emb is None: raise AttributeError( "Cannot locate PI05 text embedding layer via recursive search. " "Your PI05Policy likely changed internal naming. " "Please inspect policy.model.* to confirm embed_tokens location." ) return emb def verify_tokenizer_embedding_compat( tokenizer, text_embed_layer: nn.Module, allow_mismatch: bool = False, ): """ Ensure tokenizer vocab/special ids are consistent with PI05 text embedding table. This directly prevents the 'embed_tokens.weight missing or misaligned' baseline issue. """ emb_vocab = None if isinstance(text_embed_layer, nn.Embedding): emb_vocab = int(text_embed_layer.num_embeddings) elif hasattr(text_embed_layer, "weight") and text_embed_layer.weight is not None: emb_vocab = int(text_embed_layer.weight.size(0)) tok_vocab = getattr(tokenizer, "vocab_size", None) if tok_vocab is None: try: tok_vocab = len(tokenizer) except Exception: tok_vocab = None if emb_vocab is None or tok_vocab is None: print("[WARN] Cannot infer tokenizer/embedding vocab sizes. Skipping compatibility check.") return if emb_vocab != tok_vocab: msg = ( f"[ERROR] Tokenizer vocab size ({tok_vocab}) != PI05 embedding table size ({emb_vocab}). " "This will corrupt text embeddings and invalidate baseline. " "Fix by passing --tokenizer_id matching the PI05 text backbone " "(e.g., the original openpi/PI05 tokenizer) or re-exporting policy with aligned vocab." ) if allow_mismatch: print(msg.replace("[ERROR]", "[WARN]") + " Proceeding due to --allow_tokenizer_mismatch.") else: raise ValueError(msg) for name in ["pad_token_id", "eos_token_id", "bos_token_id", "unk_token_id"]: tid = getattr(tokenizer, name, None) if tid is None: continue if not (0 <= int(tid) < emb_vocab): msg = ( f"[ERROR] Tokenizer {name}={tid} out of embedding range [0, {emb_vocab-1}]. " "Your tokenizer does not belong to this PI05 backbone." ) if allow_mismatch: print(msg.replace("[ERROR]", "[WARN]") + " Proceeding due to --allow_tokenizer_mismatch.") else: raise ValueError(msg) class TelepathyVLA(nn.Module): """ Full model matching your final arrows. """ def __init__( self, v_cfg: VisionConfig, l_cfg: LanguageConfig, a_cfg: ActionConfig, disable_telepathy: bool = False, ): super().__init__() self.vision = TelepathyVisionModule(v_cfg) self.language = TelepathyLanguageModule(l_cfg) self.action = TelepathyActionModule(a_cfg) self.disable_telepathy = disable_telepathy self.register_buffer("_m_prev", None, persistent=False) self._proj_inited = False self.text_proj: Optional[nn.Module] = None self.vision_proj: Optional[nn.Module] = None self.state_proj: Optional[nn.Module] = None def reset_memory(self): self._m_prev = None @torch.no_grad() def forward_once( self, vis_obs: torch.Tensor, robot_state: torch.Tensor, text_tokens: torch.Tensor, depth_obs: Optional[torch.Tensor] = None, audio_obs: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, return_intermediate: bool = False, ) -> Dict[str, torch.Tensor]: vis0 = self.vision( vis_obs=vis_obs, robot_state=robot_state, depth_obs=depth_obs, audio_obs=audio_obs, telepathy_factors=None, return_intermediate=return_intermediate, ) vis_d = vis0["vision_tokens"].size(-1) state_d = vis0["state_tokens"].size(-1) target_d = vis_d if not self._proj_inited: self.text_proj = nn.Linear(text_tokens.size(-1), target_d, bias=False) \ if text_tokens.size(-1) != target_d else nn.Identity() self.vision_proj = nn.Identity() if vis_d == target_d else nn.Linear(vis_d, target_d, bias=False) self.state_proj = nn.Identity() if state_d == target_d else nn.Linear(state_d, target_d, bias=False) self.text_proj = self.text_proj.to(device=text_tokens.device, dtype=text_tokens.dtype) self.vision_proj = self.vision_proj.to(device=text_tokens.device, dtype=text_tokens.dtype) self.state_proj = self.state_proj.to(device=text_tokens.device, dtype=text_tokens.dtype) self._proj_inited = True assert self.text_proj is not None and self.vision_proj is not None and self.state_proj is not None text_tokens = self.text_proj(text_tokens) vision_tokens = self.vision_proj(vis0["vision_tokens"]) state_tokens = self.state_proj(vis0["state_tokens"]) lang_out = self.language( text_tokens=text_tokens, vision_tokens=vision_tokens, state_tokens=state_tokens, m_prev=self._m_prev, attn_mask=attn_mask, return_intermediate=return_intermediate, ) raw_tau = lang_out.get("telepathy_factors", None) self._m_prev = lang_out.get("m_t", None) telepathy_scale = float(getattr(self, "telepathy_scale", 1.0)) if self.disable_telepathy: tau = None vis_out = vis0 else: tau = raw_tau if tau is not None: tau = tau * telepathy_scale vis_out = self.vision( vis_obs=vis_obs, robot_state=robot_state, depth_obs=depth_obs, audio_obs=audio_obs, telepathy_factors=tau, return_intermediate=return_intermediate, ) high_level_rep = lang_out.get("high_level_rep", None) if high_level_rep is None: raise KeyError("language output missing 'high_level_rep'.") if high_level_rep.dim() == 2: high_level_rep = high_level_rep.unsqueeze(1) if tau is None: B, L, _ = high_level_rep.shape tau_dim = getattr(self.language, "tau_dim", 128) tau = torch.zeros(B, L, tau_dim, device=high_level_rep.device, dtype=high_level_rep.dtype) else: if tau.dim() == 2: tau = tau.unsqueeze(1) if tau.size(1) != high_level_rep.size(1): L = high_level_rep.size(1) if tau.size(1) == 1: tau = tau.expand(-1, L, -1) else: tau = tau[:, :L, :] expected_in = None acp = getattr(self.action, "action_condition_proj", None) if acp is not None: if hasattr(acp, "in_features"): expected_in = int(acp.in_features) elif hasattr(acp, "net") and len(acp.net) > 0 and hasattr(acp.net[0], "in_features"): expected_in = int(acp.net[0].in_features) if expected_in is not None: d_high = high_level_rep.size(-1) target_tau = expected_in - d_high if target_tau <= 0: pass else: if tau.size(-1) < target_tau: tau = F.pad(tau, (0, target_tau - tau.size(-1))) elif tau.size(-1) > target_tau: tau = tau[..., :target_tau] state_for_action = vis_out["state_tokens"] if state_for_action.dim() == 2: state_for_action = state_for_action.unsqueeze(1) elif state_for_action.dim() > 3: state_for_action = state_for_action.view( state_for_action.size(0), -1, state_for_action.size(-1) ) lang_d = high_level_rep.size(-1) def _pad_or_trim_to(x: torch.Tensor, d: int) -> torch.Tensor: cur_d = x.size(-1) if cur_d == d: return x if cur_d < d: return F.pad(x, (0, d - cur_d)) return x[..., :d] state_for_action = _pad_or_trim_to(state_for_action, lang_d) act_out = self.action( high_level_rep=high_level_rep, telepathy_factors=tau, state_tokens=state_for_action, return_intermediate=return_intermediate, ) out: Dict[str, torch.Tensor] = {} out.update(vis_out) out.update(lang_out) out.update(act_out) return out class SigmaShardDataset(Dataset): """ Loads .pt shards produced by dataset_preprocess_sigma_vla.py. Each shard is a list of dict samples OR a dict containing a list (samples/data). """ def __init__(self, shard_dir: str): super().__init__() if not os.path.isdir(shard_dir): raise FileNotFoundError( f"shard_dir does not exist: {shard_dir}. Double-check the path." ) patterns = [ os.path.join(shard_dir, "sigma_vla_shard_*.pt"), os.path.join(shard_dir, "*.pt"), os.path.join(shard_dir, "**", "*.pt"), ] paths: List[str] = [] for p in patterns: paths.extend(glob.glob(p, recursive=True)) self.shard_paths = sorted(list(set(paths))) if len(self.shard_paths) == 0: raise FileNotFoundError( f"No .pt shards found under {shard_dir}. " "Your HF cache is empty or shards are not tracked by LFS." ) print(f"[INFO] Found {len(self.shard_paths)} shard files. Example: {self.shard_paths[:3]}") self.index_map: List[Tuple[int, int]] = [] self._shard_cache: Dict[int, List[Dict[str, Any]]] = {} for sid, p in enumerate(self.shard_paths): shard = torch.load(p, map_location="cpu") shard_list = self._normalize_shard(shard, p) for lid in range(len(shard_list)): self.index_map.append((sid, lid)) self.total = len(self.index_map) def __len__(self): return self.total def _normalize_shard(self, shard_obj: Any, path: str) -> List[Dict[str, Any]]: if isinstance(shard_obj, (list, tuple)): return list(shard_obj) if isinstance(shard_obj, dict): for k in ["samples", "data", "items"]: if k in shard_obj and isinstance(shard_obj[k], (list, tuple)): return list(shard_obj[k]) raise TypeError( f"Unsupported shard format in {path}. " f"Expected list/tuple of samples or dict{{samples/data}}. " f"Got type: {type(shard_obj).__name__}" ) def _get_shard(self, sid: int) -> List[Dict[str, Any]]: if sid not in self._shard_cache: raw = torch.load(self.shard_paths[sid], map_location="cpu") self._shard_cache[sid] = self._normalize_shard(raw, self.shard_paths[sid]) return self._shard_cache[sid] def __getitem__(self, idx: int) -> Dict[str, Any]: sid, lid = self.index_map[idx] shard = self._get_shard(sid) return shard[lid] def collate_sigma(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]: """ Robust collate for Sigma shards. """ s0 = batch_list[0] def pick_key(sample: Dict[str, Any], candidates: List[str], field_name: str): for k in candidates: if k in sample: return k raise KeyError( f"Shard sample missing required field '{field_name}'. " f"Tried keys: {candidates}. " f"Available keys: {list(sample.keys())}" ) if "vision" in s0: vis_k = "vision" else: vis_k = pick_key(s0, ["vis_obs", "rgb_obs", "image", "images", "obs"], "vision/vis_obs") vis_obs = torch.stack([b[vis_k] for b in batch_list], dim=0).float() if vis_obs.dim() == 5: vis_obs = vis_obs[:, -1] depth_obs = None if "depth" in s0: depth_obs = torch.stack([b["depth"] for b in batch_list], dim=0).float() elif any(k in s0 for k in ["depth_obs", "depths"]): dk = pick_key(s0, ["depth_obs", "depths"], "depth") depth_obs = torch.stack([b[dk] for b in batch_list], dim=0).float() audio_obs = None if "audio" in s0: audio_obs = torch.stack([b["audio"] for b in batch_list], dim=0).float() elif any(k in s0 for k in ["audio_obs", "audios"]): ak = pick_key(s0, ["audio_obs", "audios"], "audio") audio_obs = torch.stack([b[ak] for b in batch_list], dim=0).float() if "state" in s0: state_k = "state" else: state_k = pick_key(s0, ["robot_state", "proprio", "proprio_obs"], "state/robot_state") robot_state = torch.stack([b[state_k] for b in batch_list], dim=0).float() if "text" in s0: texts = [b.get("text", "") for b in batch_list] else: text_k = pick_key(s0, ["text", "prompt", "instruction"], "text") texts = [b.get(text_k, "") for b in batch_list] if "action" in s0: a0 = s0["action"] if isinstance(a0, dict): def pick_action_key(d, candidates, name): for k in candidates: if k in d: return k raise KeyError( f"Action dict missing '{name}'. Tried {candidates}. " f"Available action keys: {list(d.keys())}" ) vec_k = pick_action_key(a0, ["gt_action_vector", "action_vector", "vector", "vec"], "gt_action_vector") chk_k = pick_action_key(a0, ["gt_action_chunk", "action_chunk", "chunk", "chk"], "gt_action_chunk") trj_k = pick_action_key(a0, ["gt_action_trajectory", "action_trajectory", "trajectory", "traj"], "gt_action_trajectory") gt_action_vector = torch.stack([b["action"][vec_k] for b in batch_list], dim=0).float() gt_action_chunk = torch.stack([b["action"][chk_k] for b in batch_list], dim=0).float() gt_action_trajectory = torch.stack([b["action"][trj_k] for b in batch_list], dim=0).float() else: act = torch.stack([b["action"] for b in batch_list], dim=0).float() gt_action_vector = act gt_action_chunk = act gt_action_trajectory = act else: gt_vec_k = pick_key(s0, ["gt_action_vector", "action_vector", "gt_vec"], "gt_action_vector") gt_chk_k = pick_key(s0, ["gt_action_chunk", "action_chunk", "gt_chunk"], "gt_action_chunk") gt_trj_k = pick_key(s0, ["gt_action_trajectory", "action_trajectory", "gt_traj"], "gt_action_trajectory") gt_action_vector = torch.stack([b[gt_vec_k] for b in batch_list], dim=0).float() gt_action_chunk = torch.stack([b[gt_chk_k] for b in batch_list], dim=0).float() gt_action_trajectory = torch.stack([b[gt_trj_k] for b in batch_list], dim=0).float() # Optional offline base actions for adapter; if missing, we simply do not include them. base_action_vector = None base_action_chunk = None base_action_trajectory = None has_base_top = any( k in s0 for k in ["base_action_vector", "base_action_chunk", "base_action_trajectory"] ) has_base_in_action = "action" in s0 and isinstance(s0["action"], dict) and any( k in s0["action"] for k in ["base_action_vector", "base_action_chunk", "base_action_trajectory"] ) if has_base_top: if "base_action_vector" in s0: base_action_vector = torch.stack([b["base_action_vector"] for b in batch_list], dim=0).float() if "base_action_chunk" in s0: base_action_chunk = torch.stack([b["base_action_chunk"] for b in batch_list], dim=0).float() if "base_action_trajectory" in s0: base_action_trajectory = torch.stack([b["base_action_trajectory"] for b in batch_list], dim=0).float() elif has_base_in_action: a0 = s0["action"] def pick_base_key(d, candidates): for k in candidates: if k in d: return k return None vec_bk = pick_base_key(a0, ["base_action_vector", "base_vec"]) chk_bk = pick_base_key(a0, ["base_action_chunk", "base_chunk"]) trj_bk = pick_base_key(a0, ["base_action_trajectory", "base_traj"]) if vec_bk is not None: base_action_vector = torch.stack([b["action"][vec_bk] for b in batch_list], dim=0).float() if chk_bk is not None: base_action_chunk = torch.stack([b["action"][chk_bk] for b in batch_list], dim=0).float() if trj_bk is not None: base_action_trajectory = torch.stack([b["action"][trj_bk] for b in batch_list], dim=0).float() batch: Dict[str, Any] = { "vis_obs": vis_obs, "depth_obs": depth_obs, "audio_obs": audio_obs, "robot_state": robot_state, "texts": texts, "gt_action_vector": gt_action_vector, "gt_action_chunk": gt_action_chunk, "gt_action_trajectory": gt_action_trajectory, } if base_action_vector is not None: batch["base_action_vector"] = base_action_vector if base_action_chunk is not None: batch["base_action_chunk"] = base_action_chunk if base_action_trajectory is not None: batch["base_action_trajectory"] = base_action_trajectory return batch def _align_target(pred_t: torch.Tensor, gt_t: torch.Tensor) -> torch.Tensor: """ Align GT to prediction for MSE: - handle 2D vs 3D mismatches by collapsing or expanding time dimension. - then align last-dim by pad/trim. """ if gt_t.dim() == 3 and pred_t.dim() == 2: gt_t = gt_t[:, -1, :] if pred_t.dim() == 3 and gt_t.dim() == 2: gt_t = gt_t.unsqueeze(1) if gt_t.size(1) != pred_t.size(1): gt_t = gt_t.expand(-1, pred_t.size(1), -1) if pred_t.dim() == 3 and gt_t.dim() == 3: Tp = pred_t.size(1) Tg = gt_t.size(1) if Tg < Tp: pad = torch.zeros( gt_t.size(0), Tp - Tg, gt_t.size(2), device=gt_t.device, dtype=gt_t.dtype ) gt_t = torch.cat([gt_t, pad], dim=1) elif Tg > Tp: gt_t = gt_t[:, :Tp, :] pd = pred_t.size(-1) gd = gt_t.size(-1) if gd < pd: gt_t = F.pad(gt_t, (0, pd - gd)) elif gd > pd: gt_t = gt_t[..., :pd] return gt_t def _pred_action(pred: Dict[str, torch.Tensor], key: str) -> torch.Tensor: if key in pred: return pred[key] if "action" in pred: return pred["action"] raise KeyError( f"Pred dict missing action key '{key}' and fallback 'action'. " f"Available pred keys: {list(pred.keys())}" ) @torch.no_grad() def compute_branch_mse(pred: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, float]: vec_pred = _pred_action(pred, "action_vector") chk_pred = _pred_action(pred, "action_chunk") trj_pred = _pred_action(pred, "action_trajectory") device = vec_pred.device gt_vec = _align_target(vec_pred, batch["gt_action_vector"].to(device)) gt_chk = _align_target(chk_pred, batch["gt_action_chunk"].to(device)) gt_trj = _align_target(trj_pred, batch["gt_action_trajectory"].to(device)) mse_vec = F.mse_loss(vec_pred, gt_vec).item() mse_chk = F.mse_loss(chk_pred, gt_chk).item() mse_trj = F.mse_loss(trj_pred, gt_trj).item() return {"mse_vector": mse_vec, "mse_chunk": mse_chk, "mse_traj": mse_trj} @torch.no_grad() def compute_success_proxy( pred: Dict[str, torch.Tensor], batch: Dict[str, Any], thr_vec: float, thr_chk: float, thr_trj: float, ) -> Tuple[int, int, torch.Tensor, torch.Tensor, torch.Tensor]: """ Returns: num_success, num_total, mse_vec_per_sample, mse_chk_per_sample, mse_trj_per_sample where per-sample MSE is averaged over all non-batch dims. """ vec_pred = _pred_action(pred, "action_vector") chk_pred = _pred_action(pred, "action_chunk") trj_pred = _pred_action(pred, "action_trajectory") device = vec_pred.device gt_vec = _align_target(vec_pred, batch["gt_action_vector"].to(device)) gt_chk = _align_target(chk_pred, batch["gt_action_chunk"].to(device)) gt_trj = _align_target(trj_pred, batch["gt_action_trajectory"].to(device)) reduce_dims_vec = list(range(1, vec_pred.dim())) reduce_dims_chk = list(range(1, chk_pred.dim())) reduce_dims_trj = list(range(1, trj_pred.dim())) mse_vec_s = ((vec_pred - gt_vec) ** 2).mean(dim=reduce_dims_vec) mse_chk_s = ((chk_pred - gt_chk) ** 2).mean(dim=reduce_dims_chk) mse_trj_s = ((trj_pred - gt_trj) ** 2).mean(dim=reduce_dims_trj) success_mask = (mse_vec_s < thr_vec) & (mse_chk_s < thr_chk) & (mse_trj_s < thr_trj) num_success = int(success_mask.sum().item()) num_total = int(success_mask.numel()) return num_success, num_total, mse_vec_s, mse_chk_s, mse_trj_s @torch.no_grad() def compute_telepathy_stability(pred: Dict[str, torch.Tensor]) -> float: tau = pred.get("telepathy_factors", None) if tau is None: return float("nan") return float((tau ** 2).mean().item()) @torch.no_grad() def cosine_alignment(a: torch.Tensor, b: torch.Tensor) -> float: """ Cosine alignment that is robust to hidden-size mismatch. Accepts [B, D] or [B, T, D]. Pools time if present. If dims differ, crops both to min(Da, Db) for a fair cosine check. """ if a.dim() == 3: a = a.mean(dim=1) if b.dim() == 3: b = b.mean(dim=1) if a.numel() == 0 or b.numel() == 0: return float("nan") da, db = a.size(-1), b.size(-1) if da != db: d = min(da, db) a = a[..., :d] b = b[..., :d] a = F.normalize(a, dim=-1) b = F.normalize(b, dim=-1) return float((a * b).sum(dim=-1).mean().item()) @torch.no_grad() def build_text_tokens_from_policy( tokenizer, text_embed_layer: nn.Module, texts: List[str], device: torch.device, target_dtype: torch.dtype, max_text_len: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Tokenize prompts and map to embeddings using PI05 internal embedding layer. Returns (text_tokens, attn_mask). """ if max_text_len and max_text_len > 0: tok = tokenizer( texts, padding=True, truncation=True, max_length=max_text_len, return_tensors="pt", ) else: tok = tokenizer( texts, padding=True, truncation=False, return_tensors="pt", ) if hasattr(tok, "input_ids"): input_ids = tok.input_ids attn_mask = tok.attention_mask else: input_ids = tok["input_ids"] attn_mask = tok.get("attention_mask", None) if attn_mask is None: attn_mask = torch.ones_like(input_ids) input_ids = input_ids.to(device) attn_mask = attn_mask.to(device) text_tokens = text_embed_layer(input_ids).to(dtype=target_dtype) return text_tokens, attn_mask def main(): parser = argparse.ArgumentParser() parser.add_argument("--sigma_env", type=str, default="sigma.env") parser.add_argument("--shard_dir", type=str, default="") parser.add_argument("--output_dir", type=str, default="./sigma_eval_out") parser.add_argument( "--base_model_id", type=str, required=True, help="LeRobot/OpenPI policy repo, e.g., lerobot/pi05_base or your Sigma policy repo.", ) parser.add_argument( "--telepathy_heads_path", type=str, default="", help="Path to sigma_telepathy_heads.pt. If empty, auto-fetch may fill it.", ) parser.add_argument( "--disable_telepathy", action="store_true", help="Disable telepathy injection (control run).", ) parser.add_argument( "--tokenizer_id", type=str, default="", help="Explicit HF tokenizer id OR local tokenizer folder path.", ) parser.add_argument("--max_text_len", type=int, default=0) parser.add_argument( "--artifacts_repo_id", type=str, default="", help="HF repo containing storage/sigma_pickplace and storage/sigma_lora_out.", ) parser.add_argument( "--hf_cache_root", type=str, default="/workspace/.hf_sigma_cache", ) parser.add_argument("--load_in_4bit", action="store_true") parser.add_argument("--dtype", type=str, default="bf16") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--num_workers", type=int, default=2) parser.add_argument("--max_batches", type=int, default=-1) parser.add_argument("--seed", type=int, default=42) parser.add_argument( "--shuffle", action="store_true", help="Shuffle dataset order to enable different random subsets per seed.", ) parser.add_argument( "--telepathy_scale", type=float, default=1.0, help="Multiply telepathy_factors (tau) to control injection strength.", ) parser.add_argument("--succ_thr_vec", type=float, default=0.05) parser.add_argument("--succ_thr_chk", type=float, default=0.10) parser.add_argument("--succ_thr_trj", type=float, default=0.10) # Hard-set thresholds: if <=0, they default to 2x the success thresholds. parser.add_argument( "--hard_thr_vec", type=float, default=-1.0, help="Per-sample MSE threshold for the 'hard' set on vector branch; <=0 means 2x succ_thr_vec.", ) parser.add_argument( "--hard_thr_chk", type=float, default=-1.0, help="Per-sample MSE threshold for the 'hard' set on chunk branch; <=0 means 2x succ_thr_chk.", ) parser.add_argument( "--hard_thr_trj", type=float, default=-1.0, help="Per-sample MSE threshold for the 'hard' set on trajectory branch; <=0 means 2x succ_thr_trj.", ) parser.add_argument( "--strict_pi05_load", action="store_true", help="Try strict PI05Policy loading if supported by LeRobot.", ) parser.add_argument( "--allow_tokenizer_mismatch", action="store_true", help="Do not fail on tokenizer/embedding mismatch (NOT recommended for baseline).", ) # Simple flag to enable/disable the adapter without touching telepathy itself. parser.add_argument( "--use_telepathy_adapter", action="store_true", help="If set and telepathy is enabled, apply sigma_telepathy_adapter to actions in eval.", ) args = parser.parse_args() if os.path.exists(args.sigma_env): load_dotenv(args.sigma_env) hf_token = os.getenv("HF_TOKEN", None) accelerator = Accelerator(mixed_precision=args.dtype if args.dtype != "fp32" else "no") set_seed(args.seed) device = accelerator.device if args.load_in_4bit: print("[WARN] --load_in_4bit is ignored for PI05Policy evaluator.") artifacts_repo = args.artifacts_repo_id.strip() if not artifacts_repo and args.base_model_id.startswith("Veltraxor/"): artifacts_repo = args.base_model_id need_shards = (not args.shard_dir) or (not os.path.isdir(args.shard_dir)) need_heads = (not args.telepathy_heads_path) or (not os.path.isfile(args.telepathy_heads_path)) if artifacts_repo and (need_shards or need_heads): paths = ensure_sigma_artifacts_from_hf( repo_id=artifacts_repo, hf_token=hf_token, local_cache_root=args.hf_cache_root, ) if need_shards: args.shard_dir = paths["shard_dir"] print(f"[INFO] Using cached shard_dir: {args.shard_dir}") if need_heads: args.telepathy_heads_path = paths["telepathy_heads_path"] print(f"[INFO] Using cached telepathy_heads_path: {args.telepathy_heads_path}") if not args.shard_dir or not os.path.isdir(args.shard_dir): raise FileNotFoundError( f"shard_dir not found locally: {args.shard_dir}. " "Either provide a valid local path or an artifacts_repo_id for auto-download." ) if not args.telepathy_heads_path or not os.path.isfile(args.telepathy_heads_path): raise FileNotFoundError( f"telepathy_heads_path not found locally: {args.telepathy_heads_path}. " "Either provide a valid local path or store it under storage/sigma_lora_out/ " "in artifacts_repo_id for auto-download." ) policy = load_pi05_policy( args.base_model_id, hf_token, device=device, strict_load=args.strict_pi05_load, ) tokenizer = get_policy_tokenizer( policy, args.base_model_id, hf_token, forced_tokenizer_id=args.tokenizer_id, ) text_embed_layer = get_policy_text_embedding_layer(policy) verify_tokenizer_embedding_compat( tokenizer=tokenizer, text_embed_layer=text_embed_layer, allow_mismatch=args.allow_tokenizer_mismatch, ) v_cfg = VisionConfig() l_cfg = LanguageConfig() a_cfg = ActionConfig() telepathy_vla = TelepathyVLA(v_cfg, l_cfg, a_cfg, disable_telepathy=args.disable_telepathy) telepathy_vla.telepathy_scale = args.telepathy_scale # Instantiate Telepathy adapter (used only when telepathy is enabled and flag is set). adapter_cfg = SigmaTelepathyAdapterConfig() telepathy_adapter = SigmaTelepathyAdapter(adapter_cfg).to(device) if accelerator.is_main_process: file_size_mb = os.path.getsize(args.telepathy_heads_path) / (1024 * 1024) print(f"[CHECK-A] disable_telepathy={args.disable_telepathy}") print(f"[CHECK-A] telepathy_heads_path={args.telepathy_heads_path} size={file_size_mb:.2f}MB") sd = torch.load(args.telepathy_heads_path, map_location="cpu") tensor_list = [v.detach().float().reshape(-1) for v in sd.values() if torch.is_tensor(v)] if accelerator.is_main_process and len(tensor_list) > 0: capped = [t[:100000] for t in tensor_list] flat = torch.cat(capped, dim=0) rms = torch.sqrt((flat ** 2).mean()).item() print(f"[CHECK-A] heads_tensors={len(tensor_list)} mean={flat.mean().item():.6f} std={flat.std().item():.6f} rms={rms:.6f}") missing, unexpected = telepathy_vla.load_state_dict(sd, strict=False) if accelerator.is_main_process: if len(missing) > 0 or len(unexpected) > 0: print(f"[CHECK-A] loaded with strict=False. Missing={len(missing)} Unexpected={len(unexpected)}") print(f"[CHECK-A] Missing keys (first 20): {missing[:20]}") print(f"[CHECK-A] Unexpected keys (first 20): {unexpected[:20]}") else: print("[CHECK-A] heads fully matched (no missing/unexpected).") telepathy_vla.eval() ds = SigmaShardDataset(args.shard_dir) dl = DataLoader( ds, batch_size=args.batch_size, shuffle=args.shuffle, num_workers=args.num_workers, collate_fn=collate_sigma, drop_last=False, pin_memory=torch.cuda.is_available(), ) telepathy_vla, dl = accelerator.prepare(telepathy_vla, dl) target_dtype = next(telepathy_vla.parameters()).dtype sum_mse_vec = 0.0 sum_mse_chk = 0.0 sum_mse_trj = 0.0 sum_tau_l2 = 0.0 sum_sem_align = 0.0 # Hard-set aggregators hard_thr_vec = args.hard_thr_vec if args.hard_thr_vec > 0.0 else 2.0 * args.succ_thr_vec hard_thr_chk = args.hard_thr_chk if args.hard_thr_chk > 0.0 else 2.0 * args.succ_thr_chk hard_thr_trj = args.hard_thr_trj if args.hard_thr_trj > 0.0 else 2.0 * args.succ_thr_trj sum_hard_mse_vec = 0.0 sum_hard_mse_chk = 0.0 sum_hard_mse_trj = 0.0 total_hard_samples = 0 n_batches = 0 n_samples = 0 os.makedirs(args.output_dir, exist_ok=True) for bidx, batch in enumerate(dl): if args.max_batches > 0 and bidx >= args.max_batches: break telepathy_vla.reset_memory() B = batch["vis_obs"].size(0) n_samples += B text_tokens, attn_mask = build_text_tokens_from_policy( tokenizer=tokenizer, text_embed_layer=text_embed_layer, texts=batch["texts"], device=device, target_dtype=target_dtype, max_text_len=args.max_text_len, ) robot_state = batch["robot_state"].to(device) if robot_state.dim() == 3: robot_state = robot_state[:, -1] # Move optional base actions to device for the adapter. if "base_action_vector" in batch: batch["base_action_vector"] = batch["base_action_vector"].to(device) if "base_action_chunk" in batch: batch["base_action_chunk"] = batch["base_action_chunk"].to(device) if "base_action_trajectory" in batch: batch["base_action_trajectory"] = batch["base_action_trajectory"].to(device) try: expected_d = telepathy_vla.vision.state_encoder.mlp[0].in_features except Exception: expected_d = robot_state.size(-1) cur_d = robot_state.size(-1) if cur_d < expected_d: robot_state = F.pad(robot_state, (0, expected_d - cur_d)) elif cur_d > expected_d: robot_state = robot_state[..., :expected_d] pred = telepathy_vla.forward_once( vis_obs=batch["vis_obs"].to(device), robot_state=robot_state, depth_obs=batch["depth_obs"].to(device) if batch["depth_obs"] is not None else None, audio_obs=batch["audio_obs"].to(device) if batch["audio_obs"] is not None else None, text_tokens=text_tokens, attn_mask=attn_mask, return_intermediate=True, ) if accelerator.is_main_process and bidx == 0: model_ref = telepathy_vla.module if hasattr(telepathy_vla, "module") else telepathy_vla model_ref.reset_memory() prev_flag = bool(model_ref.disable_telepathy) model_ref.disable_telepathy = True pred_ctrl = model_ref.forward_once( vis_obs=batch["vis_obs"].to(device), robot_state=robot_state, depth_obs=batch["depth_obs"].to(device) if batch["depth_obs"] is not None else None, audio_obs=batch["audio_obs"].to(device) if batch["audio_obs"] is not None else None, text_tokens=text_tokens, attn_mask=attn_mask, return_intermediate=False, ) model_ref.disable_telepathy = prev_flag try: act_exp = _pred_action(pred, "action_vector") act_ctl = _pred_action(pred_ctrl, "action_vector") diff = (act_exp - act_ctl).abs().mean().item() print(f"[CHECK-B] telepathy_effect_mean_abs_diff(action_vector)={diff:.6f}") except Exception as e: print(f"[CHECK-B] action diff check failed: {type(e).__name__}: {e}") # Apply Telepathy adapter only when telepathy is enabled and the flag is set. if (not args.disable_telepathy) and args.use_telepathy_adapter: pred = telepathy_adapter(pred, batch) mse = compute_branch_mse(pred, batch) tau_l2 = compute_telepathy_stability(pred) ( _, _, mse_vec_s, mse_chk_s, mse_trj_s, ) = compute_success_proxy( pred, batch, thr_vec=args.succ_thr_vec, thr_chk=args.succ_thr_chk, thr_trj=args.succ_thr_trj, ) # Hard-set accumulation: samples where any branch MSE exceeds hard thresholds hard_mask = (mse_vec_s > hard_thr_vec) | (mse_chk_s > hard_thr_chk) | (mse_trj_s > hard_thr_trj) hard_count = int(hard_mask.sum().item()) if hard_count > 0: sum_hard_mse_vec += mse_vec_s[hard_mask].sum().item() sum_hard_mse_chk += mse_chk_s[hard_mask].sum().item() sum_hard_mse_trj += mse_trj_s[hard_mask].sum().item() total_hard_samples += hard_count sem_factors = pred.get("semantic_factors", None) if sem_factors is not None: if sem_factors.dim() == 3: sem_pool = sem_factors.mean(dim=1) elif sem_factors.dim() == 2: sem_pool = sem_factors else: sem_pool = sem_factors.view(sem_factors.size(0), -1) txt_pool = text_tokens.mean(dim=1) sem_align = cosine_alignment(sem_pool, txt_pool) else: sem_align = float("nan") sum_mse_vec += mse["mse_vector"] sum_mse_chk += mse["mse_chunk"] sum_mse_trj += mse["mse_traj"] if not (tau_l2 != tau_l2): sum_tau_l2 += tau_l2 if not (sem_align != sem_align): sum_sem_align += sem_align n_batches += 1 if accelerator.is_main_process and bidx % 20 == 0: print( f"batch={bidx} " f"mse_vec={mse['mse_vector']:.4f} mse_chk={mse['mse_chunk']:.4f} mse_trj={mse['mse_traj']:.4f} " f"tau_l2={tau_l2:.4f} sem_align={sem_align:.4f}" ) if accelerator.is_main_process: avg_mse_vec = sum_mse_vec / max(1, n_batches) avg_mse_chk = sum_mse_chk / max(1, n_batches) avg_mse_trj = sum_mse_trj / max(1, n_batches) avg_tau_l2 = sum_tau_l2 / max(1, n_batches) avg_sem_align = sum_sem_align / max(1, n_batches) if total_hard_samples > 0: avg_hard_mse_vec = sum_hard_mse_vec / float(total_hard_samples) avg_hard_mse_chk = sum_hard_mse_chk / float(total_hard_samples) avg_hard_mse_trj = sum_hard_mse_trj / float(total_hard_samples) else: avg_hard_mse_vec = float("nan") avg_hard_mse_chk = float("nan") avg_hard_mse_trj = float("nan") hard_fraction = float(total_hard_samples / max(1, n_samples)) report = { "num_samples": n_samples, "num_batches": n_batches, "avg_mse_vector": avg_mse_vec, "avg_mse_chunk": avg_mse_chk, "avg_mse_traj": avg_mse_trj, "avg_tau_l2": avg_tau_l2, "avg_semantic_text_alignment": avg_sem_align, "hard_thresholds": { "vec": hard_thr_vec, "chk": hard_thr_chk, "trj": hard_thr_trj, }, "avg_hard_mse_vector": avg_hard_mse_vec, "avg_hard_mse_chunk": avg_hard_mse_chk, "avg_hard_mse_traj": avg_hard_mse_trj, "hard_sample_fraction": hard_fraction, "total_hard_samples": int(total_hard_samples), } with open( os.path.join(args.output_dir, "sigma_eval_report.json"), "w", encoding="utf-8", ) as f: json.dump(report, f, indent=2) print("[DONE] Saved report:", report) if __name__ == "__main__": main()