| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | from __future__ import annotations
|
| |
|
| | import os
|
| | import glob
|
| | import json
|
| | import argparse
|
| | import importlib
|
| | from typing import Any, Dict, List, Optional, Tuple
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from torch.utils.data import Dataset, DataLoader
|
| |
|
| | from dotenv import load_dotenv
|
| | from accelerate import Accelerator
|
| | from accelerate.utils import set_seed
|
| |
|
| | try:
|
| | from huggingface_hub import snapshot_download
|
| | except Exception:
|
| | snapshot_download = None
|
| |
|
| | from vision_sigma_vla import TelepathyVisionModule, VisionConfig
|
| | from language_sigma_vla import TelepathyLanguageModule, LanguageConfig
|
| | from action_sigma_vla import TelepathyActionModule, ActionConfig
|
| | from sigma_telepathy_adapter import SigmaTelepathyAdapter, SigmaTelepathyAdapterConfig
|
| |
|
| |
|
| | def ensure_sigma_artifacts_from_hf(
|
| | repo_id: str,
|
| | hf_token: Optional[str],
|
| | local_cache_root: str,
|
| | ) -> Dict[str, str]:
|
| | """
|
| | Download Sigma artifacts from HF repo into a local cache folder.
|
| | Returns local paths for shard_dir and telepathy_heads_path.
|
| |
|
| | We only pull:
|
| | storage/sigma_pickplace/**
|
| | storage/sigma_lora_out/**
|
| | """
|
| | if snapshot_download is None:
|
| | raise ImportError(
|
| | "huggingface_hub is not available but auto-download was requested. "
|
| | "Please `pip install huggingface_hub` or download artifacts manually."
|
| | )
|
| |
|
| | os.makedirs(local_cache_root, exist_ok=True)
|
| | local_dir = snapshot_download(
|
| | repo_id=repo_id,
|
| | token=hf_token,
|
| | local_dir=os.path.join(local_cache_root, repo_id.replace("/", "__")),
|
| | local_dir_use_symlinks=False,
|
| | allow_patterns=[
|
| | "storage/sigma_pickplace/**",
|
| | "storage/sigma_lora_out/**",
|
| | ],
|
| | )
|
| |
|
| | shard_dir = os.path.join(local_dir, "storage", "sigma_pickplace")
|
| | telepathy_heads_path = os.path.join(
|
| | local_dir, "storage", "sigma_lora_out", "sigma_telepathy_heads.pt"
|
| | )
|
| |
|
| | return {
|
| | "local_repo_dir": local_dir,
|
| | "shard_dir": shard_dir,
|
| | "telepathy_heads_path": telepathy_heads_path,
|
| | }
|
| |
|
| |
|
| | def load_pi05_policy(
|
| | repo_id: str,
|
| | hf_token: Optional[str],
|
| | device: torch.device,
|
| | strict_load: bool = True,
|
| | ):
|
| | """
|
| | Load PI05Policy from LeRobot. We try a few import paths to be robust across versions.
|
| | If the LeRobot PI05Policy.from_pretrained supports strict loading, we enable it.
|
| | """
|
| | policy_cls = None
|
| | import_errors = []
|
| |
|
| | candidate_paths = [
|
| | ("lerobot.policies.pi05.modeling_pi05", "PI05Policy"),
|
| | ("lerobot.policies.pi05", "PI05Policy"),
|
| | ]
|
| |
|
| | for mod_name, cls_name in candidate_paths:
|
| | try:
|
| | mod = importlib.import_module(mod_name)
|
| | policy_cls = getattr(mod, cls_name)
|
| | break
|
| | except Exception as e:
|
| | import_errors.append(f"{mod_name}.{cls_name}: {type(e).__name__}: {e}")
|
| |
|
| | if policy_cls is None:
|
| | raise ImportError(
|
| | "Failed to import PI05Policy from LeRobot. Tried:\n - "
|
| | + "\n - ".join(import_errors)
|
| | )
|
| |
|
| | policy = None
|
| | tried = []
|
| | if strict_load:
|
| | try:
|
| | policy = policy_cls.from_pretrained(repo_id, token=hf_token, strict=True)
|
| | tried.append("from_pretrained(..., strict=True)")
|
| | except TypeError:
|
| | tried.append("strict=True not supported")
|
| | except Exception as e:
|
| | tried.append(f"strict=True failed: {type(e).__name__}: {e}")
|
| |
|
| | if policy is None:
|
| | try:
|
| | policy = policy_cls.from_pretrained(repo_id, token=hf_token)
|
| | tried.append("from_pretrained(repo_id, token=...)")
|
| | except TypeError:
|
| | policy = policy_cls.from_pretrained(pretrained_name_or_path=repo_id, token=hf_token)
|
| | tried.append("from_pretrained(pretrained_name_or_path=..., token=...)")
|
| |
|
| | if policy is None:
|
| | raise RuntimeError("PI05Policy loading returned None. Tried: " + "; ".join(tried))
|
| |
|
| | policy = policy.to(device)
|
| | policy.eval()
|
| | return policy
|
| |
|
| |
|
| | def get_policy_tokenizer(
|
| | policy,
|
| | repo_id: str,
|
| | hf_token: Optional[str],
|
| | forced_tokenizer_id: str = "",
|
| | ):
|
| | """
|
| | Robust tokenizer getter for PI05Policy.
|
| |
|
| | IMPORTANT:
|
| | - Never call AutoTokenizer.from_pretrained(repo_id) because repo_id is a policy repo.
|
| | - If --tokenizer_id is provided and points to a LOCAL folder, load locally.
|
| | - Otherwise load from HF id.
|
| | - If still missing, recursively search for tokenizer/processor inside policy.
|
| | """
|
| | from transformers import AutoTokenizer
|
| |
|
| | if forced_tokenizer_id:
|
| | if os.path.exists(forced_tokenizer_id):
|
| | tok = AutoTokenizer.from_pretrained(
|
| | forced_tokenizer_id,
|
| | local_files_only=True,
|
| | trust_remote_code=True,
|
| | )
|
| | else:
|
| | tok = AutoTokenizer.from_pretrained(
|
| | forced_tokenizer_id,
|
| | token=hf_token,
|
| | trust_remote_code=True,
|
| | )
|
| | if tok.pad_token is None:
|
| | tok.pad_token = tok.eos_token
|
| | return tok
|
| |
|
| | def _recursive_find_tokenizer(obj, max_depth: int = 4):
|
| | if obj is None or max_depth <= 0:
|
| | return None
|
| |
|
| | for key in ["tokenizer", "processor", "text_tokenizer", "language_tokenizer"]:
|
| | if hasattr(obj, key):
|
| | v = getattr(obj, key)
|
| | if v is None:
|
| | continue
|
| | if key == "processor" and hasattr(v, "tokenizer") and v.tokenizer is not None:
|
| | return v.tokenizer
|
| | if hasattr(v, "__call__"):
|
| | return v
|
| |
|
| | nested_names = [
|
| | "paligemma_with_expert",
|
| | "paligemma",
|
| | "gemma_expert",
|
| | "language_model",
|
| | "text_model",
|
| | "model",
|
| | "policy",
|
| | ]
|
| | for name in nested_names:
|
| | if hasattr(obj, name):
|
| | found = _recursive_find_tokenizer(
|
| | getattr(obj, name), max_depth=max_depth - 1
|
| | )
|
| | if found is not None:
|
| | return found
|
| | return None
|
| |
|
| | tok = _recursive_find_tokenizer(policy)
|
| | if tok is not None:
|
| | if getattr(tok, "pad_token", None) is None and getattr(tok, "eos_token", None) is not None:
|
| | tok.pad_token = tok.eos_token
|
| | return tok
|
| |
|
| | backbone_name = None
|
| | config_candidates = []
|
| | for attr in ["config", "model", "paligemma_with_expert", "paligemma"]:
|
| | if hasattr(policy, attr):
|
| | config_candidates.append(getattr(policy, attr))
|
| |
|
| | def _try_get_name(cfg_obj):
|
| | if cfg_obj is None:
|
| | return None
|
| | for k in [
|
| | "_name_or_path",
|
| | "text_backbone_id",
|
| | "text_model_id",
|
| | "language_model_id",
|
| | "processor_name_or_path",
|
| | "tokenizer_name_or_path",
|
| | ]:
|
| | if hasattr(cfg_obj, k):
|
| | v = getattr(cfg_obj, k)
|
| | if isinstance(v, str) and v:
|
| | return v
|
| | if hasattr(cfg_obj, "config"):
|
| | c = getattr(cfg_obj, "config")
|
| | if hasattr(c, "_name_or_path") and isinstance(c._name_or_path, str) and c._name_or_path:
|
| | return c._name_or_path
|
| | return None
|
| |
|
| | for c in config_candidates:
|
| | backbone_name = _try_get_name(c)
|
| | if backbone_name:
|
| | break
|
| |
|
| | if backbone_name:
|
| | tok = AutoTokenizer.from_pretrained(
|
| | backbone_name, token=hf_token, trust_remote_code=True
|
| | )
|
| | if tok.pad_token is None:
|
| | tok.pad_token = tok.eos_token
|
| | return tok
|
| |
|
| | raise ValueError(
|
| | f"Cannot obtain tokenizer from PI05Policy for repo '{repo_id}'. "
|
| | "Your lerobot PI05 port does not expose tokenizer/processor nor backbone name. "
|
| | "Please pass --tokenizer_id explicitly."
|
| | )
|
| |
|
| |
|
| | def get_policy_text_embedding_layer(policy):
|
| | """
|
| | Locate the text embedding layer inside PI05Policy robustly.
|
| | """
|
| | def _recursive_find(obj, depth: int = 6):
|
| | if obj is None or depth <= 0:
|
| | return None
|
| |
|
| | if hasattr(obj, "get_input_embeddings"):
|
| | try:
|
| | emb = obj.get_input_embeddings()
|
| | if emb is not None:
|
| | return emb
|
| | except Exception:
|
| | pass
|
| |
|
| | for key in ["embed_tokens", "embeddings", "token_embedding"]:
|
| | if hasattr(obj, key):
|
| | v = getattr(obj, key)
|
| | if isinstance(v, nn.Module):
|
| | return v
|
| |
|
| | nested_names = [
|
| | "model",
|
| | "paligemma_with_expert",
|
| | "paligemma",
|
| | "language_model",
|
| | "gemma_expert",
|
| | "text_model",
|
| | "policy",
|
| | ]
|
| | for name in nested_names:
|
| | if hasattr(obj, name):
|
| | found = _recursive_find(getattr(obj, name), depth=depth - 1)
|
| | if found is not None:
|
| | return found
|
| |
|
| | return None
|
| |
|
| | emb = _recursive_find(policy)
|
| | if emb is None:
|
| | raise AttributeError(
|
| | "Cannot locate PI05 text embedding layer via recursive search. "
|
| | "Your PI05Policy likely changed internal naming. "
|
| | "Please inspect policy.model.* to confirm embed_tokens location."
|
| | )
|
| | return emb
|
| |
|
| |
|
| | def verify_tokenizer_embedding_compat(
|
| | tokenizer,
|
| | text_embed_layer: nn.Module,
|
| | allow_mismatch: bool = False,
|
| | ):
|
| | """
|
| | Ensure tokenizer vocab/special ids are consistent with PI05 text embedding table.
|
| | This directly prevents the 'embed_tokens.weight missing or misaligned' baseline issue.
|
| | """
|
| | emb_vocab = None
|
| | if isinstance(text_embed_layer, nn.Embedding):
|
| | emb_vocab = int(text_embed_layer.num_embeddings)
|
| | elif hasattr(text_embed_layer, "weight") and text_embed_layer.weight is not None:
|
| | emb_vocab = int(text_embed_layer.weight.size(0))
|
| |
|
| | tok_vocab = getattr(tokenizer, "vocab_size", None)
|
| | if tok_vocab is None:
|
| | try:
|
| | tok_vocab = len(tokenizer)
|
| | except Exception:
|
| | tok_vocab = None
|
| |
|
| | if emb_vocab is None or tok_vocab is None:
|
| | print("[WARN] Cannot infer tokenizer/embedding vocab sizes. Skipping compatibility check.")
|
| | return
|
| |
|
| | if emb_vocab != tok_vocab:
|
| | msg = (
|
| | f"[ERROR] Tokenizer vocab size ({tok_vocab}) != PI05 embedding table size ({emb_vocab}). "
|
| | "This will corrupt text embeddings and invalidate baseline. "
|
| | "Fix by passing --tokenizer_id matching the PI05 text backbone "
|
| | "(e.g., the original openpi/PI05 tokenizer) or re-exporting policy with aligned vocab."
|
| | )
|
| | if allow_mismatch:
|
| | print(msg.replace("[ERROR]", "[WARN]") + " Proceeding due to --allow_tokenizer_mismatch.")
|
| | else:
|
| | raise ValueError(msg)
|
| |
|
| | for name in ["pad_token_id", "eos_token_id", "bos_token_id", "unk_token_id"]:
|
| | tid = getattr(tokenizer, name, None)
|
| | if tid is None:
|
| | continue
|
| | if not (0 <= int(tid) < emb_vocab):
|
| | msg = (
|
| | f"[ERROR] Tokenizer {name}={tid} out of embedding range [0, {emb_vocab-1}]. "
|
| | "Your tokenizer does not belong to this PI05 backbone."
|
| | )
|
| | if allow_mismatch:
|
| | print(msg.replace("[ERROR]", "[WARN]") + " Proceeding due to --allow_tokenizer_mismatch.")
|
| | else:
|
| | raise ValueError(msg)
|
| |
|
| |
|
| | class TelepathyVLA(nn.Module):
|
| | """
|
| | Full model matching your final arrows.
|
| | """
|
| | def __init__(
|
| | self,
|
| | v_cfg: VisionConfig,
|
| | l_cfg: LanguageConfig,
|
| | a_cfg: ActionConfig,
|
| | disable_telepathy: bool = False,
|
| | ):
|
| | super().__init__()
|
| | self.vision = TelepathyVisionModule(v_cfg)
|
| | self.language = TelepathyLanguageModule(l_cfg)
|
| | self.action = TelepathyActionModule(a_cfg)
|
| | self.disable_telepathy = disable_telepathy
|
| | self.register_buffer("_m_prev", None, persistent=False)
|
| |
|
| | self._proj_inited = False
|
| | self.text_proj: Optional[nn.Module] = None
|
| | self.vision_proj: Optional[nn.Module] = None
|
| | self.state_proj: Optional[nn.Module] = None
|
| |
|
| | def reset_memory(self):
|
| | self._m_prev = None
|
| |
|
| | @torch.no_grad()
|
| | def forward_once(
|
| | self,
|
| | vis_obs: torch.Tensor,
|
| | robot_state: torch.Tensor,
|
| | text_tokens: torch.Tensor,
|
| | depth_obs: Optional[torch.Tensor] = None,
|
| | audio_obs: Optional[torch.Tensor] = None,
|
| | attn_mask: Optional[torch.Tensor] = None,
|
| | return_intermediate: bool = False,
|
| | ) -> Dict[str, torch.Tensor]:
|
| |
|
| | vis0 = self.vision(
|
| | vis_obs=vis_obs,
|
| | robot_state=robot_state,
|
| | depth_obs=depth_obs,
|
| | audio_obs=audio_obs,
|
| | telepathy_factors=None,
|
| | return_intermediate=return_intermediate,
|
| | )
|
| |
|
| | vis_d = vis0["vision_tokens"].size(-1)
|
| | state_d = vis0["state_tokens"].size(-1)
|
| | target_d = vis_d
|
| |
|
| | if not self._proj_inited:
|
| | self.text_proj = nn.Linear(text_tokens.size(-1), target_d, bias=False) \
|
| | if text_tokens.size(-1) != target_d else nn.Identity()
|
| | self.vision_proj = nn.Identity() if vis_d == target_d else nn.Linear(vis_d, target_d, bias=False)
|
| | self.state_proj = nn.Identity() if state_d == target_d else nn.Linear(state_d, target_d, bias=False)
|
| |
|
| | self.text_proj = self.text_proj.to(device=text_tokens.device, dtype=text_tokens.dtype)
|
| | self.vision_proj = self.vision_proj.to(device=text_tokens.device, dtype=text_tokens.dtype)
|
| | self.state_proj = self.state_proj.to(device=text_tokens.device, dtype=text_tokens.dtype)
|
| | self._proj_inited = True
|
| |
|
| | assert self.text_proj is not None and self.vision_proj is not None and self.state_proj is not None
|
| |
|
| | text_tokens = self.text_proj(text_tokens)
|
| | vision_tokens = self.vision_proj(vis0["vision_tokens"])
|
| | state_tokens = self.state_proj(vis0["state_tokens"])
|
| |
|
| | lang_out = self.language(
|
| | text_tokens=text_tokens,
|
| | vision_tokens=vision_tokens,
|
| | state_tokens=state_tokens,
|
| | m_prev=self._m_prev,
|
| | attn_mask=attn_mask,
|
| | return_intermediate=return_intermediate,
|
| | )
|
| |
|
| | raw_tau = lang_out.get("telepathy_factors", None)
|
| | self._m_prev = lang_out.get("m_t", None)
|
| |
|
| | telepathy_scale = float(getattr(self, "telepathy_scale", 1.0))
|
| |
|
| | if self.disable_telepathy:
|
| | tau = None
|
| | vis_out = vis0
|
| | else:
|
| | tau = raw_tau
|
| | if tau is not None:
|
| | tau = tau * telepathy_scale
|
| | vis_out = self.vision(
|
| | vis_obs=vis_obs,
|
| | robot_state=robot_state,
|
| | depth_obs=depth_obs,
|
| | audio_obs=audio_obs,
|
| | telepathy_factors=tau,
|
| | return_intermediate=return_intermediate,
|
| | )
|
| |
|
| | high_level_rep = lang_out.get("high_level_rep", None)
|
| | if high_level_rep is None:
|
| | raise KeyError("language output missing 'high_level_rep'.")
|
| |
|
| | if high_level_rep.dim() == 2:
|
| | high_level_rep = high_level_rep.unsqueeze(1)
|
| |
|
| | if tau is None:
|
| | B, L, _ = high_level_rep.shape
|
| | tau_dim = getattr(self.language, "tau_dim", 128)
|
| | tau = torch.zeros(B, L, tau_dim, device=high_level_rep.device, dtype=high_level_rep.dtype)
|
| | else:
|
| | if tau.dim() == 2:
|
| | tau = tau.unsqueeze(1)
|
| | if tau.size(1) != high_level_rep.size(1):
|
| | L = high_level_rep.size(1)
|
| | if tau.size(1) == 1:
|
| | tau = tau.expand(-1, L, -1)
|
| | else:
|
| | tau = tau[:, :L, :]
|
| |
|
| | expected_in = None
|
| | acp = getattr(self.action, "action_condition_proj", None)
|
| | if acp is not None:
|
| | if hasattr(acp, "in_features"):
|
| | expected_in = int(acp.in_features)
|
| | elif hasattr(acp, "net") and len(acp.net) > 0 and hasattr(acp.net[0], "in_features"):
|
| | expected_in = int(acp.net[0].in_features)
|
| |
|
| | if expected_in is not None:
|
| | d_high = high_level_rep.size(-1)
|
| | target_tau = expected_in - d_high
|
| |
|
| | if target_tau <= 0:
|
| | pass
|
| | else:
|
| | if tau.size(-1) < target_tau:
|
| | tau = F.pad(tau, (0, target_tau - tau.size(-1)))
|
| | elif tau.size(-1) > target_tau:
|
| | tau = tau[..., :target_tau]
|
| |
|
| | state_for_action = vis_out["state_tokens"]
|
| | if state_for_action.dim() == 2:
|
| | state_for_action = state_for_action.unsqueeze(1)
|
| | elif state_for_action.dim() > 3:
|
| | state_for_action = state_for_action.view(
|
| | state_for_action.size(0), -1, state_for_action.size(-1)
|
| | )
|
| |
|
| | lang_d = high_level_rep.size(-1)
|
| |
|
| | def _pad_or_trim_to(x: torch.Tensor, d: int) -> torch.Tensor:
|
| | cur_d = x.size(-1)
|
| | if cur_d == d:
|
| | return x
|
| | if cur_d < d:
|
| | return F.pad(x, (0, d - cur_d))
|
| | return x[..., :d]
|
| |
|
| | state_for_action = _pad_or_trim_to(state_for_action, lang_d)
|
| |
|
| | act_out = self.action(
|
| | high_level_rep=high_level_rep,
|
| | telepathy_factors=tau,
|
| | state_tokens=state_for_action,
|
| | return_intermediate=return_intermediate,
|
| | )
|
| |
|
| | out: Dict[str, torch.Tensor] = {}
|
| | out.update(vis_out)
|
| | out.update(lang_out)
|
| | out.update(act_out)
|
| | return out
|
| |
|
| |
|
| | class SigmaShardDataset(Dataset):
|
| | """
|
| | Loads .pt shards produced by dataset_preprocess_sigma_vla.py.
|
| | Each shard is a list of dict samples OR a dict containing a list (samples/data).
|
| | """
|
| | def __init__(self, shard_dir: str):
|
| | super().__init__()
|
| | if not os.path.isdir(shard_dir):
|
| | raise FileNotFoundError(
|
| | f"shard_dir does not exist: {shard_dir}. Double-check the path."
|
| | )
|
| |
|
| | patterns = [
|
| | os.path.join(shard_dir, "sigma_vla_shard_*.pt"),
|
| | os.path.join(shard_dir, "*.pt"),
|
| | os.path.join(shard_dir, "**", "*.pt"),
|
| | ]
|
| | paths: List[str] = []
|
| | for p in patterns:
|
| | paths.extend(glob.glob(p, recursive=True))
|
| |
|
| | self.shard_paths = sorted(list(set(paths)))
|
| | if len(self.shard_paths) == 0:
|
| | raise FileNotFoundError(
|
| | f"No .pt shards found under {shard_dir}. "
|
| | "Your HF cache is empty or shards are not tracked by LFS."
|
| | )
|
| |
|
| | print(f"[INFO] Found {len(self.shard_paths)} shard files. Example: {self.shard_paths[:3]}")
|
| |
|
| | self.index_map: List[Tuple[int, int]] = []
|
| | self._shard_cache: Dict[int, List[Dict[str, Any]]] = {}
|
| |
|
| | for sid, p in enumerate(self.shard_paths):
|
| | shard = torch.load(p, map_location="cpu")
|
| | shard_list = self._normalize_shard(shard, p)
|
| | for lid in range(len(shard_list)):
|
| | self.index_map.append((sid, lid))
|
| |
|
| | self.total = len(self.index_map)
|
| |
|
| | def __len__(self):
|
| | return self.total
|
| |
|
| | def _normalize_shard(self, shard_obj: Any, path: str) -> List[Dict[str, Any]]:
|
| | if isinstance(shard_obj, (list, tuple)):
|
| | return list(shard_obj)
|
| |
|
| | if isinstance(shard_obj, dict):
|
| | for k in ["samples", "data", "items"]:
|
| | if k in shard_obj and isinstance(shard_obj[k], (list, tuple)):
|
| | return list(shard_obj[k])
|
| |
|
| | raise TypeError(
|
| | f"Unsupported shard format in {path}. "
|
| | f"Expected list/tuple of samples or dict{{samples/data}}. "
|
| | f"Got type: {type(shard_obj).__name__}"
|
| | )
|
| |
|
| | def _get_shard(self, sid: int) -> List[Dict[str, Any]]:
|
| | if sid not in self._shard_cache:
|
| | raw = torch.load(self.shard_paths[sid], map_location="cpu")
|
| | self._shard_cache[sid] = self._normalize_shard(raw, self.shard_paths[sid])
|
| | return self._shard_cache[sid]
|
| |
|
| | def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| | sid, lid = self.index_map[idx]
|
| | shard = self._get_shard(sid)
|
| | return shard[lid]
|
| |
|
| |
|
| | def collate_sigma(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| | """
|
| | Robust collate for Sigma shards.
|
| | """
|
| | s0 = batch_list[0]
|
| |
|
| | def pick_key(sample: Dict[str, Any], candidates: List[str], field_name: str):
|
| | for k in candidates:
|
| | if k in sample:
|
| | return k
|
| | raise KeyError(
|
| | f"Shard sample missing required field '{field_name}'. "
|
| | f"Tried keys: {candidates}. "
|
| | f"Available keys: {list(sample.keys())}"
|
| | )
|
| |
|
| | if "vision" in s0:
|
| | vis_k = "vision"
|
| | else:
|
| | vis_k = pick_key(s0, ["vis_obs", "rgb_obs", "image", "images", "obs"], "vision/vis_obs")
|
| |
|
| | vis_obs = torch.stack([b[vis_k] for b in batch_list], dim=0).float()
|
| | if vis_obs.dim() == 5:
|
| | vis_obs = vis_obs[:, -1]
|
| |
|
| | depth_obs = None
|
| | if "depth" in s0:
|
| | depth_obs = torch.stack([b["depth"] for b in batch_list], dim=0).float()
|
| | elif any(k in s0 for k in ["depth_obs", "depths"]):
|
| | dk = pick_key(s0, ["depth_obs", "depths"], "depth")
|
| | depth_obs = torch.stack([b[dk] for b in batch_list], dim=0).float()
|
| |
|
| | audio_obs = None
|
| | if "audio" in s0:
|
| | audio_obs = torch.stack([b["audio"] for b in batch_list], dim=0).float()
|
| | elif any(k in s0 for k in ["audio_obs", "audios"]):
|
| | ak = pick_key(s0, ["audio_obs", "audios"], "audio")
|
| | audio_obs = torch.stack([b[ak] for b in batch_list], dim=0).float()
|
| |
|
| | if "state" in s0:
|
| | state_k = "state"
|
| | else:
|
| | state_k = pick_key(s0, ["robot_state", "proprio", "proprio_obs"], "state/robot_state")
|
| |
|
| | robot_state = torch.stack([b[state_k] for b in batch_list], dim=0).float()
|
| |
|
| | if "text" in s0:
|
| | texts = [b.get("text", "") for b in batch_list]
|
| | else:
|
| | text_k = pick_key(s0, ["text", "prompt", "instruction"], "text")
|
| | texts = [b.get(text_k, "") for b in batch_list]
|
| |
|
| | if "action" in s0:
|
| | a0 = s0["action"]
|
| | if isinstance(a0, dict):
|
| | def pick_action_key(d, candidates, name):
|
| | for k in candidates:
|
| | if k in d:
|
| | return k
|
| | raise KeyError(
|
| | f"Action dict missing '{name}'. Tried {candidates}. "
|
| | f"Available action keys: {list(d.keys())}"
|
| | )
|
| |
|
| | vec_k = pick_action_key(a0, ["gt_action_vector", "action_vector", "vector", "vec"], "gt_action_vector")
|
| | chk_k = pick_action_key(a0, ["gt_action_chunk", "action_chunk", "chunk", "chk"], "gt_action_chunk")
|
| | trj_k = pick_action_key(a0, ["gt_action_trajectory", "action_trajectory", "trajectory", "traj"], "gt_action_trajectory")
|
| |
|
| | gt_action_vector = torch.stack([b["action"][vec_k] for b in batch_list], dim=0).float()
|
| | gt_action_chunk = torch.stack([b["action"][chk_k] for b in batch_list], dim=0).float()
|
| | gt_action_trajectory = torch.stack([b["action"][trj_k] for b in batch_list], dim=0).float()
|
| | else:
|
| | act = torch.stack([b["action"] for b in batch_list], dim=0).float()
|
| | gt_action_vector = act
|
| | gt_action_chunk = act
|
| | gt_action_trajectory = act
|
| | else:
|
| | gt_vec_k = pick_key(s0, ["gt_action_vector", "action_vector", "gt_vec"], "gt_action_vector")
|
| | gt_chk_k = pick_key(s0, ["gt_action_chunk", "action_chunk", "gt_chunk"], "gt_action_chunk")
|
| | gt_trj_k = pick_key(s0, ["gt_action_trajectory", "action_trajectory", "gt_traj"], "gt_action_trajectory")
|
| |
|
| | gt_action_vector = torch.stack([b[gt_vec_k] for b in batch_list], dim=0).float()
|
| | gt_action_chunk = torch.stack([b[gt_chk_k] for b in batch_list], dim=0).float()
|
| | gt_action_trajectory = torch.stack([b[gt_trj_k] for b in batch_list], dim=0).float()
|
| |
|
| |
|
| | base_action_vector = None
|
| | base_action_chunk = None
|
| | base_action_trajectory = None
|
| |
|
| | has_base_top = any(
|
| | k in s0
|
| | for k in ["base_action_vector", "base_action_chunk", "base_action_trajectory"]
|
| | )
|
| | has_base_in_action = "action" in s0 and isinstance(s0["action"], dict) and any(
|
| | k in s0["action"]
|
| | for k in ["base_action_vector", "base_action_chunk", "base_action_trajectory"]
|
| | )
|
| |
|
| | if has_base_top:
|
| | if "base_action_vector" in s0:
|
| | base_action_vector = torch.stack([b["base_action_vector"] for b in batch_list], dim=0).float()
|
| | if "base_action_chunk" in s0:
|
| | base_action_chunk = torch.stack([b["base_action_chunk"] for b in batch_list], dim=0).float()
|
| | if "base_action_trajectory" in s0:
|
| | base_action_trajectory = torch.stack([b["base_action_trajectory"] for b in batch_list], dim=0).float()
|
| | elif has_base_in_action:
|
| | a0 = s0["action"]
|
| | def pick_base_key(d, candidates):
|
| | for k in candidates:
|
| | if k in d:
|
| | return k
|
| | return None
|
| |
|
| | vec_bk = pick_base_key(a0, ["base_action_vector", "base_vec"])
|
| | chk_bk = pick_base_key(a0, ["base_action_chunk", "base_chunk"])
|
| | trj_bk = pick_base_key(a0, ["base_action_trajectory", "base_traj"])
|
| |
|
| | if vec_bk is not None:
|
| | base_action_vector = torch.stack([b["action"][vec_bk] for b in batch_list], dim=0).float()
|
| | if chk_bk is not None:
|
| | base_action_chunk = torch.stack([b["action"][chk_bk] for b in batch_list], dim=0).float()
|
| | if trj_bk is not None:
|
| | base_action_trajectory = torch.stack([b["action"][trj_bk] for b in batch_list], dim=0).float()
|
| |
|
| | batch: Dict[str, Any] = {
|
| | "vis_obs": vis_obs,
|
| | "depth_obs": depth_obs,
|
| | "audio_obs": audio_obs,
|
| | "robot_state": robot_state,
|
| | "texts": texts,
|
| | "gt_action_vector": gt_action_vector,
|
| | "gt_action_chunk": gt_action_chunk,
|
| | "gt_action_trajectory": gt_action_trajectory,
|
| | }
|
| |
|
| | if base_action_vector is not None:
|
| | batch["base_action_vector"] = base_action_vector
|
| | if base_action_chunk is not None:
|
| | batch["base_action_chunk"] = base_action_chunk
|
| | if base_action_trajectory is not None:
|
| | batch["base_action_trajectory"] = base_action_trajectory
|
| |
|
| | return batch
|
| |
|
| |
|
| | def _align_target(pred_t: torch.Tensor, gt_t: torch.Tensor) -> torch.Tensor:
|
| | """
|
| | Align GT to prediction for MSE:
|
| | - handle 2D vs 3D mismatches by collapsing or expanding time dimension.
|
| | - then align last-dim by pad/trim.
|
| | """
|
| | if gt_t.dim() == 3 and pred_t.dim() == 2:
|
| | gt_t = gt_t[:, -1, :]
|
| |
|
| | if pred_t.dim() == 3 and gt_t.dim() == 2:
|
| | gt_t = gt_t.unsqueeze(1)
|
| | if gt_t.size(1) != pred_t.size(1):
|
| | gt_t = gt_t.expand(-1, pred_t.size(1), -1)
|
| |
|
| | if pred_t.dim() == 3 and gt_t.dim() == 3:
|
| | Tp = pred_t.size(1)
|
| | Tg = gt_t.size(1)
|
| | if Tg < Tp:
|
| | pad = torch.zeros(
|
| | gt_t.size(0), Tp - Tg, gt_t.size(2),
|
| | device=gt_t.device, dtype=gt_t.dtype
|
| | )
|
| | gt_t = torch.cat([gt_t, pad], dim=1)
|
| | elif Tg > Tp:
|
| | gt_t = gt_t[:, :Tp, :]
|
| |
|
| | pd = pred_t.size(-1)
|
| | gd = gt_t.size(-1)
|
| | if gd < pd:
|
| | gt_t = F.pad(gt_t, (0, pd - gd))
|
| | elif gd > pd:
|
| | gt_t = gt_t[..., :pd]
|
| |
|
| | return gt_t
|
| |
|
| |
|
| | def _pred_action(pred: Dict[str, torch.Tensor], key: str) -> torch.Tensor:
|
| | if key in pred:
|
| | return pred[key]
|
| | if "action" in pred:
|
| | return pred["action"]
|
| | raise KeyError(
|
| | f"Pred dict missing action key '{key}' and fallback 'action'. "
|
| | f"Available pred keys: {list(pred.keys())}"
|
| | )
|
| |
|
| |
|
| | @torch.no_grad()
|
| | def compute_branch_mse(pred: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, float]:
|
| | vec_pred = _pred_action(pred, "action_vector")
|
| | chk_pred = _pred_action(pred, "action_chunk")
|
| | trj_pred = _pred_action(pred, "action_trajectory")
|
| |
|
| | device = vec_pred.device
|
| |
|
| | gt_vec = _align_target(vec_pred, batch["gt_action_vector"].to(device))
|
| | gt_chk = _align_target(chk_pred, batch["gt_action_chunk"].to(device))
|
| | gt_trj = _align_target(trj_pred, batch["gt_action_trajectory"].to(device))
|
| |
|
| | mse_vec = F.mse_loss(vec_pred, gt_vec).item()
|
| | mse_chk = F.mse_loss(chk_pred, gt_chk).item()
|
| | mse_trj = F.mse_loss(trj_pred, gt_trj).item()
|
| | return {"mse_vector": mse_vec, "mse_chunk": mse_chk, "mse_traj": mse_trj}
|
| |
|
| |
|
| | @torch.no_grad()
|
| | def compute_success_proxy(
|
| | pred: Dict[str, torch.Tensor],
|
| | batch: Dict[str, Any],
|
| | thr_vec: float,
|
| | thr_chk: float,
|
| | thr_trj: float,
|
| | ) -> Tuple[int, int, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| | """
|
| | Returns:
|
| | num_success, num_total, mse_vec_per_sample, mse_chk_per_sample, mse_trj_per_sample
|
| | where per-sample MSE is averaged over all non-batch dims.
|
| | """
|
| | vec_pred = _pred_action(pred, "action_vector")
|
| | chk_pred = _pred_action(pred, "action_chunk")
|
| | trj_pred = _pred_action(pred, "action_trajectory")
|
| |
|
| | device = vec_pred.device
|
| |
|
| | gt_vec = _align_target(vec_pred, batch["gt_action_vector"].to(device))
|
| | gt_chk = _align_target(chk_pred, batch["gt_action_chunk"].to(device))
|
| | gt_trj = _align_target(trj_pred, batch["gt_action_trajectory"].to(device))
|
| |
|
| | reduce_dims_vec = list(range(1, vec_pred.dim()))
|
| | reduce_dims_chk = list(range(1, chk_pred.dim()))
|
| | reduce_dims_trj = list(range(1, trj_pred.dim()))
|
| |
|
| | mse_vec_s = ((vec_pred - gt_vec) ** 2).mean(dim=reduce_dims_vec)
|
| | mse_chk_s = ((chk_pred - gt_chk) ** 2).mean(dim=reduce_dims_chk)
|
| | mse_trj_s = ((trj_pred - gt_trj) ** 2).mean(dim=reduce_dims_trj)
|
| |
|
| | success_mask = (mse_vec_s < thr_vec) & (mse_chk_s < thr_chk) & (mse_trj_s < thr_trj)
|
| | num_success = int(success_mask.sum().item())
|
| | num_total = int(success_mask.numel())
|
| |
|
| | return num_success, num_total, mse_vec_s, mse_chk_s, mse_trj_s
|
| |
|
| |
|
| | @torch.no_grad()
|
| | def compute_telepathy_stability(pred: Dict[str, torch.Tensor]) -> float:
|
| | tau = pred.get("telepathy_factors", None)
|
| | if tau is None:
|
| | return float("nan")
|
| | return float((tau ** 2).mean().item())
|
| |
|
| |
|
| | @torch.no_grad()
|
| | def cosine_alignment(a: torch.Tensor, b: torch.Tensor) -> float:
|
| | """
|
| | Cosine alignment that is robust to hidden-size mismatch.
|
| | Accepts [B, D] or [B, T, D]. Pools time if present.
|
| | If dims differ, crops both to min(Da, Db) for a fair cosine check.
|
| | """
|
| | if a.dim() == 3:
|
| | a = a.mean(dim=1)
|
| | if b.dim() == 3:
|
| | b = b.mean(dim=1)
|
| |
|
| | if a.numel() == 0 or b.numel() == 0:
|
| | return float("nan")
|
| |
|
| | da, db = a.size(-1), b.size(-1)
|
| | if da != db:
|
| | d = min(da, db)
|
| | a = a[..., :d]
|
| | b = b[..., :d]
|
| |
|
| | a = F.normalize(a, dim=-1)
|
| | b = F.normalize(b, dim=-1)
|
| | return float((a * b).sum(dim=-1).mean().item())
|
| |
|
| |
|
| | @torch.no_grad()
|
| | def build_text_tokens_from_policy(
|
| | tokenizer,
|
| | text_embed_layer: nn.Module,
|
| | texts: List[str],
|
| | device: torch.device,
|
| | target_dtype: torch.dtype,
|
| | max_text_len: int = 0,
|
| | ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| | """
|
| | Tokenize prompts and map to embeddings using PI05 internal embedding layer.
|
| | Returns (text_tokens, attn_mask).
|
| | """
|
| | if max_text_len and max_text_len > 0:
|
| | tok = tokenizer(
|
| | texts,
|
| | padding=True,
|
| | truncation=True,
|
| | max_length=max_text_len,
|
| | return_tensors="pt",
|
| | )
|
| | else:
|
| | tok = tokenizer(
|
| | texts,
|
| | padding=True,
|
| | truncation=False,
|
| | return_tensors="pt",
|
| | )
|
| |
|
| | if hasattr(tok, "input_ids"):
|
| | input_ids = tok.input_ids
|
| | attn_mask = tok.attention_mask
|
| | else:
|
| | input_ids = tok["input_ids"]
|
| | attn_mask = tok.get("attention_mask", None)
|
| | if attn_mask is None:
|
| | attn_mask = torch.ones_like(input_ids)
|
| |
|
| | input_ids = input_ids.to(device)
|
| | attn_mask = attn_mask.to(device)
|
| |
|
| | text_tokens = text_embed_layer(input_ids).to(dtype=target_dtype)
|
| | return text_tokens, attn_mask
|
| |
|
| |
|
| | def main():
|
| | parser = argparse.ArgumentParser()
|
| |
|
| | parser.add_argument("--sigma_env", type=str, default="sigma.env")
|
| | parser.add_argument("--shard_dir", type=str, default="")
|
| | parser.add_argument("--output_dir", type=str, default="./sigma_eval_out")
|
| |
|
| | parser.add_argument(
|
| | "--base_model_id",
|
| | type=str,
|
| | required=True,
|
| | help="LeRobot/OpenPI policy repo, e.g., lerobot/pi05_base or your Sigma policy repo.",
|
| | )
|
| | parser.add_argument(
|
| | "--telepathy_heads_path",
|
| | type=str,
|
| | default="",
|
| | help="Path to sigma_telepathy_heads.pt. If empty, auto-fetch may fill it.",
|
| | )
|
| | parser.add_argument(
|
| | "--disable_telepathy",
|
| | action="store_true",
|
| | help="Disable telepathy injection (control run).",
|
| | )
|
| | parser.add_argument(
|
| | "--tokenizer_id",
|
| | type=str,
|
| | default="",
|
| | help="Explicit HF tokenizer id OR local tokenizer folder path.",
|
| | )
|
| |
|
| | parser.add_argument("--max_text_len", type=int, default=0)
|
| |
|
| | parser.add_argument(
|
| | "--artifacts_repo_id",
|
| | type=str,
|
| | default="",
|
| | help="HF repo containing storage/sigma_pickplace and storage/sigma_lora_out.",
|
| | )
|
| | parser.add_argument(
|
| | "--hf_cache_root",
|
| | type=str,
|
| | default="/workspace/.hf_sigma_cache",
|
| | )
|
| |
|
| | parser.add_argument("--load_in_4bit", action="store_true")
|
| | parser.add_argument("--dtype", type=str, default="bf16")
|
| |
|
| | parser.add_argument("--batch_size", type=int, default=4)
|
| | parser.add_argument("--num_workers", type=int, default=2)
|
| | parser.add_argument("--max_batches", type=int, default=-1)
|
| | parser.add_argument("--seed", type=int, default=42)
|
| | parser.add_argument(
|
| | "--shuffle",
|
| | action="store_true",
|
| | help="Shuffle dataset order to enable different random subsets per seed.",
|
| | )
|
| | parser.add_argument(
|
| | "--telepathy_scale",
|
| | type=float,
|
| | default=1.0,
|
| | help="Multiply telepathy_factors (tau) to control injection strength.",
|
| | )
|
| |
|
| | parser.add_argument("--succ_thr_vec", type=float, default=0.05)
|
| | parser.add_argument("--succ_thr_chk", type=float, default=0.10)
|
| | parser.add_argument("--succ_thr_trj", type=float, default=0.10)
|
| |
|
| |
|
| | parser.add_argument(
|
| | "--hard_thr_vec",
|
| | type=float,
|
| | default=-1.0,
|
| | help="Per-sample MSE threshold for the 'hard' set on vector branch; <=0 means 2x succ_thr_vec.",
|
| | )
|
| | parser.add_argument(
|
| | "--hard_thr_chk",
|
| | type=float,
|
| | default=-1.0,
|
| | help="Per-sample MSE threshold for the 'hard' set on chunk branch; <=0 means 2x succ_thr_chk.",
|
| | )
|
| | parser.add_argument(
|
| | "--hard_thr_trj",
|
| | type=float,
|
| | default=-1.0,
|
| | help="Per-sample MSE threshold for the 'hard' set on trajectory branch; <=0 means 2x succ_thr_trj.",
|
| | )
|
| |
|
| | parser.add_argument(
|
| | "--strict_pi05_load",
|
| | action="store_true",
|
| | help="Try strict PI05Policy loading if supported by LeRobot.",
|
| | )
|
| | parser.add_argument(
|
| | "--allow_tokenizer_mismatch",
|
| | action="store_true",
|
| | help="Do not fail on tokenizer/embedding mismatch (NOT recommended for baseline).",
|
| | )
|
| |
|
| |
|
| | parser.add_argument(
|
| | "--use_telepathy_adapter",
|
| | action="store_true",
|
| | help="If set and telepathy is enabled, apply sigma_telepathy_adapter to actions in eval.",
|
| | )
|
| |
|
| | args = parser.parse_args()
|
| |
|
| | if os.path.exists(args.sigma_env):
|
| | load_dotenv(args.sigma_env)
|
| | hf_token = os.getenv("HF_TOKEN", None)
|
| |
|
| | accelerator = Accelerator(mixed_precision=args.dtype if args.dtype != "fp32" else "no")
|
| | set_seed(args.seed)
|
| | device = accelerator.device
|
| |
|
| | if args.load_in_4bit:
|
| | print("[WARN] --load_in_4bit is ignored for PI05Policy evaluator.")
|
| |
|
| | artifacts_repo = args.artifacts_repo_id.strip()
|
| | if not artifacts_repo and args.base_model_id.startswith("Veltraxor/"):
|
| | artifacts_repo = args.base_model_id
|
| |
|
| | need_shards = (not args.shard_dir) or (not os.path.isdir(args.shard_dir))
|
| | need_heads = (not args.telepathy_heads_path) or (not os.path.isfile(args.telepathy_heads_path))
|
| |
|
| | if artifacts_repo and (need_shards or need_heads):
|
| | paths = ensure_sigma_artifacts_from_hf(
|
| | repo_id=artifacts_repo,
|
| | hf_token=hf_token,
|
| | local_cache_root=args.hf_cache_root,
|
| | )
|
| | if need_shards:
|
| | args.shard_dir = paths["shard_dir"]
|
| | print(f"[INFO] Using cached shard_dir: {args.shard_dir}")
|
| | if need_heads:
|
| | args.telepathy_heads_path = paths["telepathy_heads_path"]
|
| | print(f"[INFO] Using cached telepathy_heads_path: {args.telepathy_heads_path}")
|
| |
|
| | if not args.shard_dir or not os.path.isdir(args.shard_dir):
|
| | raise FileNotFoundError(
|
| | f"shard_dir not found locally: {args.shard_dir}. "
|
| | "Either provide a valid local path or an artifacts_repo_id for auto-download."
|
| | )
|
| |
|
| | if not args.telepathy_heads_path or not os.path.isfile(args.telepathy_heads_path):
|
| | raise FileNotFoundError(
|
| | f"telepathy_heads_path not found locally: {args.telepathy_heads_path}. "
|
| | "Either provide a valid local path or store it under storage/sigma_lora_out/ "
|
| | "in artifacts_repo_id for auto-download."
|
| | )
|
| |
|
| | policy = load_pi05_policy(
|
| | args.base_model_id,
|
| | hf_token,
|
| | device=device,
|
| | strict_load=args.strict_pi05_load,
|
| | )
|
| |
|
| | tokenizer = get_policy_tokenizer(
|
| | policy,
|
| | args.base_model_id,
|
| | hf_token,
|
| | forced_tokenizer_id=args.tokenizer_id,
|
| | )
|
| | text_embed_layer = get_policy_text_embedding_layer(policy)
|
| |
|
| | verify_tokenizer_embedding_compat(
|
| | tokenizer=tokenizer,
|
| | text_embed_layer=text_embed_layer,
|
| | allow_mismatch=args.allow_tokenizer_mismatch,
|
| | )
|
| |
|
| | v_cfg = VisionConfig()
|
| | l_cfg = LanguageConfig()
|
| | a_cfg = ActionConfig()
|
| | telepathy_vla = TelepathyVLA(v_cfg, l_cfg, a_cfg, disable_telepathy=args.disable_telepathy)
|
| | telepathy_vla.telepathy_scale = args.telepathy_scale
|
| |
|
| |
|
| | adapter_cfg = SigmaTelepathyAdapterConfig()
|
| | telepathy_adapter = SigmaTelepathyAdapter(adapter_cfg).to(device)
|
| |
|
| | if accelerator.is_main_process:
|
| | file_size_mb = os.path.getsize(args.telepathy_heads_path) / (1024 * 1024)
|
| | print(f"[CHECK-A] disable_telepathy={args.disable_telepathy}")
|
| | print(f"[CHECK-A] telepathy_heads_path={args.telepathy_heads_path} size={file_size_mb:.2f}MB")
|
| |
|
| | sd = torch.load(args.telepathy_heads_path, map_location="cpu")
|
| |
|
| | tensor_list = [v.detach().float().reshape(-1) for v in sd.values() if torch.is_tensor(v)]
|
| | if accelerator.is_main_process and len(tensor_list) > 0:
|
| | capped = [t[:100000] for t in tensor_list]
|
| | flat = torch.cat(capped, dim=0)
|
| | rms = torch.sqrt((flat ** 2).mean()).item()
|
| | print(f"[CHECK-A] heads_tensors={len(tensor_list)} mean={flat.mean().item():.6f} std={flat.std().item():.6f} rms={rms:.6f}")
|
| |
|
| | missing, unexpected = telepathy_vla.load_state_dict(sd, strict=False)
|
| | if accelerator.is_main_process:
|
| | if len(missing) > 0 or len(unexpected) > 0:
|
| | print(f"[CHECK-A] loaded with strict=False. Missing={len(missing)} Unexpected={len(unexpected)}")
|
| | print(f"[CHECK-A] Missing keys (first 20): {missing[:20]}")
|
| | print(f"[CHECK-A] Unexpected keys (first 20): {unexpected[:20]}")
|
| | else:
|
| | print("[CHECK-A] heads fully matched (no missing/unexpected).")
|
| |
|
| | telepathy_vla.eval()
|
| |
|
| | ds = SigmaShardDataset(args.shard_dir)
|
| | dl = DataLoader(
|
| | ds,
|
| | batch_size=args.batch_size,
|
| | shuffle=args.shuffle,
|
| | num_workers=args.num_workers,
|
| | collate_fn=collate_sigma,
|
| | drop_last=False,
|
| | pin_memory=torch.cuda.is_available(),
|
| | )
|
| |
|
| | telepathy_vla, dl = accelerator.prepare(telepathy_vla, dl)
|
| | target_dtype = next(telepathy_vla.parameters()).dtype
|
| |
|
| | sum_mse_vec = 0.0
|
| | sum_mse_chk = 0.0
|
| | sum_mse_trj = 0.0
|
| | sum_tau_l2 = 0.0
|
| | sum_sem_align = 0.0
|
| |
|
| |
|
| | hard_thr_vec = args.hard_thr_vec if args.hard_thr_vec > 0.0 else 2.0 * args.succ_thr_vec
|
| | hard_thr_chk = args.hard_thr_chk if args.hard_thr_chk > 0.0 else 2.0 * args.succ_thr_chk
|
| | hard_thr_trj = args.hard_thr_trj if args.hard_thr_trj > 0.0 else 2.0 * args.succ_thr_trj
|
| |
|
| | sum_hard_mse_vec = 0.0
|
| | sum_hard_mse_chk = 0.0
|
| | sum_hard_mse_trj = 0.0
|
| | total_hard_samples = 0
|
| |
|
| | n_batches = 0
|
| | n_samples = 0
|
| |
|
| | os.makedirs(args.output_dir, exist_ok=True)
|
| |
|
| | for bidx, batch in enumerate(dl):
|
| | if args.max_batches > 0 and bidx >= args.max_batches:
|
| | break
|
| |
|
| | telepathy_vla.reset_memory()
|
| |
|
| | B = batch["vis_obs"].size(0)
|
| | n_samples += B
|
| |
|
| | text_tokens, attn_mask = build_text_tokens_from_policy(
|
| | tokenizer=tokenizer,
|
| | text_embed_layer=text_embed_layer,
|
| | texts=batch["texts"],
|
| | device=device,
|
| | target_dtype=target_dtype,
|
| | max_text_len=args.max_text_len,
|
| | )
|
| |
|
| | robot_state = batch["robot_state"].to(device)
|
| | if robot_state.dim() == 3:
|
| | robot_state = robot_state[:, -1]
|
| |
|
| |
|
| | if "base_action_vector" in batch:
|
| | batch["base_action_vector"] = batch["base_action_vector"].to(device)
|
| | if "base_action_chunk" in batch:
|
| | batch["base_action_chunk"] = batch["base_action_chunk"].to(device)
|
| | if "base_action_trajectory" in batch:
|
| | batch["base_action_trajectory"] = batch["base_action_trajectory"].to(device)
|
| |
|
| | try:
|
| | expected_d = telepathy_vla.vision.state_encoder.mlp[0].in_features
|
| | except Exception:
|
| | expected_d = robot_state.size(-1)
|
| |
|
| | cur_d = robot_state.size(-1)
|
| | if cur_d < expected_d:
|
| | robot_state = F.pad(robot_state, (0, expected_d - cur_d))
|
| | elif cur_d > expected_d:
|
| | robot_state = robot_state[..., :expected_d]
|
| |
|
| | pred = telepathy_vla.forward_once(
|
| | vis_obs=batch["vis_obs"].to(device),
|
| | robot_state=robot_state,
|
| | depth_obs=batch["depth_obs"].to(device) if batch["depth_obs"] is not None else None,
|
| | audio_obs=batch["audio_obs"].to(device) if batch["audio_obs"] is not None else None,
|
| | text_tokens=text_tokens,
|
| | attn_mask=attn_mask,
|
| | return_intermediate=True,
|
| | )
|
| |
|
| | if accelerator.is_main_process and bidx == 0:
|
| | model_ref = telepathy_vla.module if hasattr(telepathy_vla, "module") else telepathy_vla
|
| | model_ref.reset_memory()
|
| | prev_flag = bool(model_ref.disable_telepathy)
|
| | model_ref.disable_telepathy = True
|
| | pred_ctrl = model_ref.forward_once(
|
| | vis_obs=batch["vis_obs"].to(device),
|
| | robot_state=robot_state,
|
| | depth_obs=batch["depth_obs"].to(device) if batch["depth_obs"] is not None else None,
|
| | audio_obs=batch["audio_obs"].to(device) if batch["audio_obs"] is not None else None,
|
| | text_tokens=text_tokens,
|
| | attn_mask=attn_mask,
|
| | return_intermediate=False,
|
| | )
|
| | model_ref.disable_telepathy = prev_flag
|
| |
|
| | try:
|
| | act_exp = _pred_action(pred, "action_vector")
|
| | act_ctl = _pred_action(pred_ctrl, "action_vector")
|
| | diff = (act_exp - act_ctl).abs().mean().item()
|
| | print(f"[CHECK-B] telepathy_effect_mean_abs_diff(action_vector)={diff:.6f}")
|
| | except Exception as e:
|
| | print(f"[CHECK-B] action diff check failed: {type(e).__name__}: {e}")
|
| |
|
| |
|
| | if (not args.disable_telepathy) and args.use_telepathy_adapter:
|
| | pred = telepathy_adapter(pred, batch)
|
| |
|
| | mse = compute_branch_mse(pred, batch)
|
| | tau_l2 = compute_telepathy_stability(pred)
|
| |
|
| | (
|
| | _,
|
| | _,
|
| | mse_vec_s,
|
| | mse_chk_s,
|
| | mse_trj_s,
|
| | ) = compute_success_proxy(
|
| | pred,
|
| | batch,
|
| | thr_vec=args.succ_thr_vec,
|
| | thr_chk=args.succ_thr_chk,
|
| | thr_trj=args.succ_thr_trj,
|
| | )
|
| |
|
| |
|
| | hard_mask = (mse_vec_s > hard_thr_vec) | (mse_chk_s > hard_thr_chk) | (mse_trj_s > hard_thr_trj)
|
| | hard_count = int(hard_mask.sum().item())
|
| | if hard_count > 0:
|
| | sum_hard_mse_vec += mse_vec_s[hard_mask].sum().item()
|
| | sum_hard_mse_chk += mse_chk_s[hard_mask].sum().item()
|
| | sum_hard_mse_trj += mse_trj_s[hard_mask].sum().item()
|
| | total_hard_samples += hard_count
|
| |
|
| | sem_factors = pred.get("semantic_factors", None)
|
| | if sem_factors is not None:
|
| | if sem_factors.dim() == 3:
|
| | sem_pool = sem_factors.mean(dim=1)
|
| | elif sem_factors.dim() == 2:
|
| | sem_pool = sem_factors
|
| | else:
|
| | sem_pool = sem_factors.view(sem_factors.size(0), -1)
|
| |
|
| | txt_pool = text_tokens.mean(dim=1)
|
| | sem_align = cosine_alignment(sem_pool, txt_pool)
|
| | else:
|
| | sem_align = float("nan")
|
| |
|
| | sum_mse_vec += mse["mse_vector"]
|
| | sum_mse_chk += mse["mse_chunk"]
|
| | sum_mse_trj += mse["mse_traj"]
|
| | if not (tau_l2 != tau_l2):
|
| | sum_tau_l2 += tau_l2
|
| | if not (sem_align != sem_align):
|
| | sum_sem_align += sem_align
|
| |
|
| | n_batches += 1
|
| |
|
| | if accelerator.is_main_process and bidx % 20 == 0:
|
| | print(
|
| | f"batch={bidx} "
|
| | f"mse_vec={mse['mse_vector']:.4f} mse_chk={mse['mse_chunk']:.4f} mse_trj={mse['mse_traj']:.4f} "
|
| | f"tau_l2={tau_l2:.4f} sem_align={sem_align:.4f}"
|
| | )
|
| |
|
| | if accelerator.is_main_process:
|
| | avg_mse_vec = sum_mse_vec / max(1, n_batches)
|
| | avg_mse_chk = sum_mse_chk / max(1, n_batches)
|
| | avg_mse_trj = sum_mse_trj / max(1, n_batches)
|
| |
|
| | avg_tau_l2 = sum_tau_l2 / max(1, n_batches)
|
| | avg_sem_align = sum_sem_align / max(1, n_batches)
|
| |
|
| | if total_hard_samples > 0:
|
| | avg_hard_mse_vec = sum_hard_mse_vec / float(total_hard_samples)
|
| | avg_hard_mse_chk = sum_hard_mse_chk / float(total_hard_samples)
|
| | avg_hard_mse_trj = sum_hard_mse_trj / float(total_hard_samples)
|
| | else:
|
| | avg_hard_mse_vec = float("nan")
|
| | avg_hard_mse_chk = float("nan")
|
| | avg_hard_mse_trj = float("nan")
|
| |
|
| | hard_fraction = float(total_hard_samples / max(1, n_samples))
|
| |
|
| | report = {
|
| | "num_samples": n_samples,
|
| | "num_batches": n_batches,
|
| | "avg_mse_vector": avg_mse_vec,
|
| | "avg_mse_chunk": avg_mse_chk,
|
| | "avg_mse_traj": avg_mse_trj,
|
| | "avg_tau_l2": avg_tau_l2,
|
| | "avg_semantic_text_alignment": avg_sem_align,
|
| | "hard_thresholds": {
|
| | "vec": hard_thr_vec,
|
| | "chk": hard_thr_chk,
|
| | "trj": hard_thr_trj,
|
| | },
|
| | "avg_hard_mse_vector": avg_hard_mse_vec,
|
| | "avg_hard_mse_chunk": avg_hard_mse_chk,
|
| | "avg_hard_mse_traj": avg_hard_mse_trj,
|
| | "hard_sample_fraction": hard_fraction,
|
| | "total_hard_samples": int(total_hard_samples),
|
| | }
|
| |
|
| | with open(
|
| | os.path.join(args.output_dir, "sigma_eval_report.json"),
|
| | "w",
|
| | encoding="utf-8",
|
| | ) as f:
|
| | json.dump(report, f, indent=2)
|
| | print("[DONE] Saved report:", report)
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|