Sigma / eval_sigma_vla_rollout.py
ConorWang's picture
Upload 10 files
03426f9 verified
# eval_sigma_vla_rollout.py
# Offline closed-loop evaluation for Telepathy-augmented VLA on top of PI05 policy backbone.
#
# Key design:
# - base_model_id is a LeRobot/OpenPI policy repo (e.g., lerobot/pi05_base or your fine-tuned Sigma repo).
# - We load PI05Policy via LeRobot, NOT AutoModelForCausalLM.
# - Text embeddings are taken from the PI05 internal text backbone so that TelepathyLanguageModule
# receives the same type of inputs used during training.
#
# Hardened in this revision:
# - Robust recursive shard discovery under any naming & subfolders.
# - Shard content structure normalization (list-of-samples, or dict{samples/data}).
# - Collate auto-adapts to real schema: vision/state/action/text, with time-dim collapse for vision.
# - Action GT supports dict-style branches or a single tensor.
# - Metrics tolerate missing multi-branch outputs (fallback to "action").
# - Text tokens dtype/device aligned to model dtype for mixed precision safety.
# - Robot state time-dim collapse + pad/trim to state encoder expected dim.
# - Dynamic projection to align vision/state token hidden size to vision backbone dim (768),
# and project text to the same dim BEFORE feeding language module.
# - Optional max_text_len to avoid tokenizer truncation warnings.
# - action input contract hardening:
# * high_level_rep 2D -> 3D
# * tau None/2D -> 3D
# * tau length aligned to high_level_rep length
# * tau last-dim auto pad/trim so concat(high_level_rep, tau) matches action_condition_proj in_features
# - tokenizer_id can be a LOCAL path; when it exists locally we load with local_files_only
# - _align_target handles 2D<->3D mismatches (fixes MSE crashes)
# - remove duplicated "high_level_rep/tau re-normalization" that overwrote the hardening
#
# NEW in this patch:
# - cosine_alignment auto-aligns hidden sizes (fixes 256 vs 2048 crash).
# - semantic pooling guard supports 2D/3D factors safely.
# - alignment metric ignores zero-length cases robustly.
#
# EXTRA HARDENING (this patch for your baseline issue):
# - Try strict load for PI05Policy if the LeRobot version supports it.
# - Verify tokenizer vocab size and special-token ids match PI05 text embedding table.
# - Fail fast with a clear message if mismatch is detected (unless explicitly overridden).
#
# NEW in this hard-set patch:
# - Per-sample MSE is exposed from success proxy.
# - A "hard set" is defined as samples whose branch-wise MSE exceeds hard thresholds.
# - Hard-set averages (MSE and fraction of samples) are reported alongside global metrics.
#
# NEW in this adapter patch:
# - sigma_telepathy_adapter is applied at eval time (when telepathy is enabled) to gate
# Telepathy residuals based on their magnitude and tau strength, optionally using
# offline base_action_* if present in the shards.
from __future__ import annotations
import os
import glob
import json
import argparse
import importlib
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from dotenv import load_dotenv
from accelerate import Accelerator
from accelerate.utils import set_seed
try:
from huggingface_hub import snapshot_download
except Exception:
snapshot_download = None # type: ignore
from vision_sigma_vla import TelepathyVisionModule, VisionConfig
from language_sigma_vla import TelepathyLanguageModule, LanguageConfig
from action_sigma_vla import TelepathyActionModule, ActionConfig
from sigma_telepathy_adapter import SigmaTelepathyAdapter, SigmaTelepathyAdapterConfig
def ensure_sigma_artifacts_from_hf(
repo_id: str,
hf_token: Optional[str],
local_cache_root: str,
) -> Dict[str, str]:
"""
Download Sigma artifacts from HF repo into a local cache folder.
Returns local paths for shard_dir and telepathy_heads_path.
We only pull:
storage/sigma_pickplace/**
storage/sigma_lora_out/**
"""
if snapshot_download is None:
raise ImportError(
"huggingface_hub is not available but auto-download was requested. "
"Please `pip install huggingface_hub` or download artifacts manually."
)
os.makedirs(local_cache_root, exist_ok=True)
local_dir = snapshot_download(
repo_id=repo_id,
token=hf_token,
local_dir=os.path.join(local_cache_root, repo_id.replace("/", "__")),
local_dir_use_symlinks=False,
allow_patterns=[
"storage/sigma_pickplace/**",
"storage/sigma_lora_out/**",
],
)
shard_dir = os.path.join(local_dir, "storage", "sigma_pickplace")
telepathy_heads_path = os.path.join(
local_dir, "storage", "sigma_lora_out", "sigma_telepathy_heads.pt"
)
return {
"local_repo_dir": local_dir,
"shard_dir": shard_dir,
"telepathy_heads_path": telepathy_heads_path,
}
def load_pi05_policy(
repo_id: str,
hf_token: Optional[str],
device: torch.device,
strict_load: bool = True,
):
"""
Load PI05Policy from LeRobot. We try a few import paths to be robust across versions.
If the LeRobot PI05Policy.from_pretrained supports strict loading, we enable it.
"""
policy_cls = None
import_errors = []
candidate_paths = [
("lerobot.policies.pi05.modeling_pi05", "PI05Policy"),
("lerobot.policies.pi05", "PI05Policy"),
]
for mod_name, cls_name in candidate_paths:
try:
mod = importlib.import_module(mod_name)
policy_cls = getattr(mod, cls_name)
break
except Exception as e:
import_errors.append(f"{mod_name}.{cls_name}: {type(e).__name__}: {e}")
if policy_cls is None:
raise ImportError(
"Failed to import PI05Policy from LeRobot. Tried:\n - "
+ "\n - ".join(import_errors)
)
policy = None
tried = []
if strict_load:
try:
policy = policy_cls.from_pretrained(repo_id, token=hf_token, strict=True)
tried.append("from_pretrained(..., strict=True)")
except TypeError:
tried.append("strict=True not supported")
except Exception as e:
tried.append(f"strict=True failed: {type(e).__name__}: {e}")
if policy is None:
try:
policy = policy_cls.from_pretrained(repo_id, token=hf_token)
tried.append("from_pretrained(repo_id, token=...)")
except TypeError:
policy = policy_cls.from_pretrained(pretrained_name_or_path=repo_id, token=hf_token)
tried.append("from_pretrained(pretrained_name_or_path=..., token=...)")
if policy is None:
raise RuntimeError("PI05Policy loading returned None. Tried: " + "; ".join(tried))
policy = policy.to(device)
policy.eval()
return policy
def get_policy_tokenizer(
policy,
repo_id: str,
hf_token: Optional[str],
forced_tokenizer_id: str = "",
):
"""
Robust tokenizer getter for PI05Policy.
IMPORTANT:
- Never call AutoTokenizer.from_pretrained(repo_id) because repo_id is a policy repo.
- If --tokenizer_id is provided and points to a LOCAL folder, load locally.
- Otherwise load from HF id.
- If still missing, recursively search for tokenizer/processor inside policy.
"""
from transformers import AutoTokenizer
if forced_tokenizer_id:
if os.path.exists(forced_tokenizer_id):
tok = AutoTokenizer.from_pretrained(
forced_tokenizer_id,
local_files_only=True,
trust_remote_code=True,
)
else:
tok = AutoTokenizer.from_pretrained(
forced_tokenizer_id,
token=hf_token,
trust_remote_code=True,
)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
return tok
def _recursive_find_tokenizer(obj, max_depth: int = 4):
if obj is None or max_depth <= 0:
return None
for key in ["tokenizer", "processor", "text_tokenizer", "language_tokenizer"]:
if hasattr(obj, key):
v = getattr(obj, key)
if v is None:
continue
if key == "processor" and hasattr(v, "tokenizer") and v.tokenizer is not None:
return v.tokenizer
if hasattr(v, "__call__"):
return v
nested_names = [
"paligemma_with_expert",
"paligemma",
"gemma_expert",
"language_model",
"text_model",
"model",
"policy",
]
for name in nested_names:
if hasattr(obj, name):
found = _recursive_find_tokenizer(
getattr(obj, name), max_depth=max_depth - 1
)
if found is not None:
return found
return None
tok = _recursive_find_tokenizer(policy)
if tok is not None:
if getattr(tok, "pad_token", None) is None and getattr(tok, "eos_token", None) is not None:
tok.pad_token = tok.eos_token
return tok
backbone_name = None
config_candidates = []
for attr in ["config", "model", "paligemma_with_expert", "paligemma"]:
if hasattr(policy, attr):
config_candidates.append(getattr(policy, attr))
def _try_get_name(cfg_obj):
if cfg_obj is None:
return None
for k in [
"_name_or_path",
"text_backbone_id",
"text_model_id",
"language_model_id",
"processor_name_or_path",
"tokenizer_name_or_path",
]:
if hasattr(cfg_obj, k):
v = getattr(cfg_obj, k)
if isinstance(v, str) and v:
return v
if hasattr(cfg_obj, "config"):
c = getattr(cfg_obj, "config")
if hasattr(c, "_name_or_path") and isinstance(c._name_or_path, str) and c._name_or_path:
return c._name_or_path
return None
for c in config_candidates:
backbone_name = _try_get_name(c)
if backbone_name:
break
if backbone_name:
tok = AutoTokenizer.from_pretrained(
backbone_name, token=hf_token, trust_remote_code=True
)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
return tok
raise ValueError(
f"Cannot obtain tokenizer from PI05Policy for repo '{repo_id}'. "
"Your lerobot PI05 port does not expose tokenizer/processor nor backbone name. "
"Please pass --tokenizer_id explicitly."
)
def get_policy_text_embedding_layer(policy):
"""
Locate the text embedding layer inside PI05Policy robustly.
"""
def _recursive_find(obj, depth: int = 6):
if obj is None or depth <= 0:
return None
if hasattr(obj, "get_input_embeddings"):
try:
emb = obj.get_input_embeddings()
if emb is not None:
return emb
except Exception:
pass
for key in ["embed_tokens", "embeddings", "token_embedding"]:
if hasattr(obj, key):
v = getattr(obj, key)
if isinstance(v, nn.Module):
return v
nested_names = [
"model",
"paligemma_with_expert",
"paligemma",
"language_model",
"gemma_expert",
"text_model",
"policy",
]
for name in nested_names:
if hasattr(obj, name):
found = _recursive_find(getattr(obj, name), depth=depth - 1)
if found is not None:
return found
return None
emb = _recursive_find(policy)
if emb is None:
raise AttributeError(
"Cannot locate PI05 text embedding layer via recursive search. "
"Your PI05Policy likely changed internal naming. "
"Please inspect policy.model.* to confirm embed_tokens location."
)
return emb
def verify_tokenizer_embedding_compat(
tokenizer,
text_embed_layer: nn.Module,
allow_mismatch: bool = False,
):
"""
Ensure tokenizer vocab/special ids are consistent with PI05 text embedding table.
This directly prevents the 'embed_tokens.weight missing or misaligned' baseline issue.
"""
emb_vocab = None
if isinstance(text_embed_layer, nn.Embedding):
emb_vocab = int(text_embed_layer.num_embeddings)
elif hasattr(text_embed_layer, "weight") and text_embed_layer.weight is not None:
emb_vocab = int(text_embed_layer.weight.size(0))
tok_vocab = getattr(tokenizer, "vocab_size", None)
if tok_vocab is None:
try:
tok_vocab = len(tokenizer)
except Exception:
tok_vocab = None
if emb_vocab is None or tok_vocab is None:
print("[WARN] Cannot infer tokenizer/embedding vocab sizes. Skipping compatibility check.")
return
if emb_vocab != tok_vocab:
msg = (
f"[ERROR] Tokenizer vocab size ({tok_vocab}) != PI05 embedding table size ({emb_vocab}). "
"This will corrupt text embeddings and invalidate baseline. "
"Fix by passing --tokenizer_id matching the PI05 text backbone "
"(e.g., the original openpi/PI05 tokenizer) or re-exporting policy with aligned vocab."
)
if allow_mismatch:
print(msg.replace("[ERROR]", "[WARN]") + " Proceeding due to --allow_tokenizer_mismatch.")
else:
raise ValueError(msg)
for name in ["pad_token_id", "eos_token_id", "bos_token_id", "unk_token_id"]:
tid = getattr(tokenizer, name, None)
if tid is None:
continue
if not (0 <= int(tid) < emb_vocab):
msg = (
f"[ERROR] Tokenizer {name}={tid} out of embedding range [0, {emb_vocab-1}]. "
"Your tokenizer does not belong to this PI05 backbone."
)
if allow_mismatch:
print(msg.replace("[ERROR]", "[WARN]") + " Proceeding due to --allow_tokenizer_mismatch.")
else:
raise ValueError(msg)
class TelepathyVLA(nn.Module):
"""
Full model matching your final arrows.
"""
def __init__(
self,
v_cfg: VisionConfig,
l_cfg: LanguageConfig,
a_cfg: ActionConfig,
disable_telepathy: bool = False,
):
super().__init__()
self.vision = TelepathyVisionModule(v_cfg)
self.language = TelepathyLanguageModule(l_cfg)
self.action = TelepathyActionModule(a_cfg)
self.disable_telepathy = disable_telepathy
self.register_buffer("_m_prev", None, persistent=False)
self._proj_inited = False
self.text_proj: Optional[nn.Module] = None
self.vision_proj: Optional[nn.Module] = None
self.state_proj: Optional[nn.Module] = None
def reset_memory(self):
self._m_prev = None
@torch.no_grad()
def forward_once(
self,
vis_obs: torch.Tensor,
robot_state: torch.Tensor,
text_tokens: torch.Tensor,
depth_obs: Optional[torch.Tensor] = None,
audio_obs: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
return_intermediate: bool = False,
) -> Dict[str, torch.Tensor]:
vis0 = self.vision(
vis_obs=vis_obs,
robot_state=robot_state,
depth_obs=depth_obs,
audio_obs=audio_obs,
telepathy_factors=None,
return_intermediate=return_intermediate,
)
vis_d = vis0["vision_tokens"].size(-1)
state_d = vis0["state_tokens"].size(-1)
target_d = vis_d
if not self._proj_inited:
self.text_proj = nn.Linear(text_tokens.size(-1), target_d, bias=False) \
if text_tokens.size(-1) != target_d else nn.Identity()
self.vision_proj = nn.Identity() if vis_d == target_d else nn.Linear(vis_d, target_d, bias=False)
self.state_proj = nn.Identity() if state_d == target_d else nn.Linear(state_d, target_d, bias=False)
self.text_proj = self.text_proj.to(device=text_tokens.device, dtype=text_tokens.dtype)
self.vision_proj = self.vision_proj.to(device=text_tokens.device, dtype=text_tokens.dtype)
self.state_proj = self.state_proj.to(device=text_tokens.device, dtype=text_tokens.dtype)
self._proj_inited = True
assert self.text_proj is not None and self.vision_proj is not None and self.state_proj is not None
text_tokens = self.text_proj(text_tokens)
vision_tokens = self.vision_proj(vis0["vision_tokens"])
state_tokens = self.state_proj(vis0["state_tokens"])
lang_out = self.language(
text_tokens=text_tokens,
vision_tokens=vision_tokens,
state_tokens=state_tokens,
m_prev=self._m_prev,
attn_mask=attn_mask,
return_intermediate=return_intermediate,
)
raw_tau = lang_out.get("telepathy_factors", None)
self._m_prev = lang_out.get("m_t", None)
telepathy_scale = float(getattr(self, "telepathy_scale", 1.0))
if self.disable_telepathy:
tau = None
vis_out = vis0
else:
tau = raw_tau
if tau is not None:
tau = tau * telepathy_scale
vis_out = self.vision(
vis_obs=vis_obs,
robot_state=robot_state,
depth_obs=depth_obs,
audio_obs=audio_obs,
telepathy_factors=tau,
return_intermediate=return_intermediate,
)
high_level_rep = lang_out.get("high_level_rep", None)
if high_level_rep is None:
raise KeyError("language output missing 'high_level_rep'.")
if high_level_rep.dim() == 2:
high_level_rep = high_level_rep.unsqueeze(1)
if tau is None:
B, L, _ = high_level_rep.shape
tau_dim = getattr(self.language, "tau_dim", 128)
tau = torch.zeros(B, L, tau_dim, device=high_level_rep.device, dtype=high_level_rep.dtype)
else:
if tau.dim() == 2:
tau = tau.unsqueeze(1)
if tau.size(1) != high_level_rep.size(1):
L = high_level_rep.size(1)
if tau.size(1) == 1:
tau = tau.expand(-1, L, -1)
else:
tau = tau[:, :L, :]
expected_in = None
acp = getattr(self.action, "action_condition_proj", None)
if acp is not None:
if hasattr(acp, "in_features"):
expected_in = int(acp.in_features)
elif hasattr(acp, "net") and len(acp.net) > 0 and hasattr(acp.net[0], "in_features"):
expected_in = int(acp.net[0].in_features)
if expected_in is not None:
d_high = high_level_rep.size(-1)
target_tau = expected_in - d_high
if target_tau <= 0:
pass
else:
if tau.size(-1) < target_tau:
tau = F.pad(tau, (0, target_tau - tau.size(-1)))
elif tau.size(-1) > target_tau:
tau = tau[..., :target_tau]
state_for_action = vis_out["state_tokens"]
if state_for_action.dim() == 2:
state_for_action = state_for_action.unsqueeze(1)
elif state_for_action.dim() > 3:
state_for_action = state_for_action.view(
state_for_action.size(0), -1, state_for_action.size(-1)
)
lang_d = high_level_rep.size(-1)
def _pad_or_trim_to(x: torch.Tensor, d: int) -> torch.Tensor:
cur_d = x.size(-1)
if cur_d == d:
return x
if cur_d < d:
return F.pad(x, (0, d - cur_d))
return x[..., :d]
state_for_action = _pad_or_trim_to(state_for_action, lang_d)
act_out = self.action(
high_level_rep=high_level_rep,
telepathy_factors=tau,
state_tokens=state_for_action,
return_intermediate=return_intermediate,
)
out: Dict[str, torch.Tensor] = {}
out.update(vis_out)
out.update(lang_out)
out.update(act_out)
return out
class SigmaShardDataset(Dataset):
"""
Loads .pt shards produced by dataset_preprocess_sigma_vla.py.
Each shard is a list of dict samples OR a dict containing a list (samples/data).
"""
def __init__(self, shard_dir: str):
super().__init__()
if not os.path.isdir(shard_dir):
raise FileNotFoundError(
f"shard_dir does not exist: {shard_dir}. Double-check the path."
)
patterns = [
os.path.join(shard_dir, "sigma_vla_shard_*.pt"),
os.path.join(shard_dir, "*.pt"),
os.path.join(shard_dir, "**", "*.pt"),
]
paths: List[str] = []
for p in patterns:
paths.extend(glob.glob(p, recursive=True))
self.shard_paths = sorted(list(set(paths)))
if len(self.shard_paths) == 0:
raise FileNotFoundError(
f"No .pt shards found under {shard_dir}. "
"Your HF cache is empty or shards are not tracked by LFS."
)
print(f"[INFO] Found {len(self.shard_paths)} shard files. Example: {self.shard_paths[:3]}")
self.index_map: List[Tuple[int, int]] = []
self._shard_cache: Dict[int, List[Dict[str, Any]]] = {}
for sid, p in enumerate(self.shard_paths):
shard = torch.load(p, map_location="cpu")
shard_list = self._normalize_shard(shard, p)
for lid in range(len(shard_list)):
self.index_map.append((sid, lid))
self.total = len(self.index_map)
def __len__(self):
return self.total
def _normalize_shard(self, shard_obj: Any, path: str) -> List[Dict[str, Any]]:
if isinstance(shard_obj, (list, tuple)):
return list(shard_obj)
if isinstance(shard_obj, dict):
for k in ["samples", "data", "items"]:
if k in shard_obj and isinstance(shard_obj[k], (list, tuple)):
return list(shard_obj[k])
raise TypeError(
f"Unsupported shard format in {path}. "
f"Expected list/tuple of samples or dict{{samples/data}}. "
f"Got type: {type(shard_obj).__name__}"
)
def _get_shard(self, sid: int) -> List[Dict[str, Any]]:
if sid not in self._shard_cache:
raw = torch.load(self.shard_paths[sid], map_location="cpu")
self._shard_cache[sid] = self._normalize_shard(raw, self.shard_paths[sid])
return self._shard_cache[sid]
def __getitem__(self, idx: int) -> Dict[str, Any]:
sid, lid = self.index_map[idx]
shard = self._get_shard(sid)
return shard[lid]
def collate_sigma(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Robust collate for Sigma shards.
"""
s0 = batch_list[0]
def pick_key(sample: Dict[str, Any], candidates: List[str], field_name: str):
for k in candidates:
if k in sample:
return k
raise KeyError(
f"Shard sample missing required field '{field_name}'. "
f"Tried keys: {candidates}. "
f"Available keys: {list(sample.keys())}"
)
if "vision" in s0:
vis_k = "vision"
else:
vis_k = pick_key(s0, ["vis_obs", "rgb_obs", "image", "images", "obs"], "vision/vis_obs")
vis_obs = torch.stack([b[vis_k] for b in batch_list], dim=0).float()
if vis_obs.dim() == 5:
vis_obs = vis_obs[:, -1]
depth_obs = None
if "depth" in s0:
depth_obs = torch.stack([b["depth"] for b in batch_list], dim=0).float()
elif any(k in s0 for k in ["depth_obs", "depths"]):
dk = pick_key(s0, ["depth_obs", "depths"], "depth")
depth_obs = torch.stack([b[dk] for b in batch_list], dim=0).float()
audio_obs = None
if "audio" in s0:
audio_obs = torch.stack([b["audio"] for b in batch_list], dim=0).float()
elif any(k in s0 for k in ["audio_obs", "audios"]):
ak = pick_key(s0, ["audio_obs", "audios"], "audio")
audio_obs = torch.stack([b[ak] for b in batch_list], dim=0).float()
if "state" in s0:
state_k = "state"
else:
state_k = pick_key(s0, ["robot_state", "proprio", "proprio_obs"], "state/robot_state")
robot_state = torch.stack([b[state_k] for b in batch_list], dim=0).float()
if "text" in s0:
texts = [b.get("text", "") for b in batch_list]
else:
text_k = pick_key(s0, ["text", "prompt", "instruction"], "text")
texts = [b.get(text_k, "") for b in batch_list]
if "action" in s0:
a0 = s0["action"]
if isinstance(a0, dict):
def pick_action_key(d, candidates, name):
for k in candidates:
if k in d:
return k
raise KeyError(
f"Action dict missing '{name}'. Tried {candidates}. "
f"Available action keys: {list(d.keys())}"
)
vec_k = pick_action_key(a0, ["gt_action_vector", "action_vector", "vector", "vec"], "gt_action_vector")
chk_k = pick_action_key(a0, ["gt_action_chunk", "action_chunk", "chunk", "chk"], "gt_action_chunk")
trj_k = pick_action_key(a0, ["gt_action_trajectory", "action_trajectory", "trajectory", "traj"], "gt_action_trajectory")
gt_action_vector = torch.stack([b["action"][vec_k] for b in batch_list], dim=0).float()
gt_action_chunk = torch.stack([b["action"][chk_k] for b in batch_list], dim=0).float()
gt_action_trajectory = torch.stack([b["action"][trj_k] for b in batch_list], dim=0).float()
else:
act = torch.stack([b["action"] for b in batch_list], dim=0).float()
gt_action_vector = act
gt_action_chunk = act
gt_action_trajectory = act
else:
gt_vec_k = pick_key(s0, ["gt_action_vector", "action_vector", "gt_vec"], "gt_action_vector")
gt_chk_k = pick_key(s0, ["gt_action_chunk", "action_chunk", "gt_chunk"], "gt_action_chunk")
gt_trj_k = pick_key(s0, ["gt_action_trajectory", "action_trajectory", "gt_traj"], "gt_action_trajectory")
gt_action_vector = torch.stack([b[gt_vec_k] for b in batch_list], dim=0).float()
gt_action_chunk = torch.stack([b[gt_chk_k] for b in batch_list], dim=0).float()
gt_action_trajectory = torch.stack([b[gt_trj_k] for b in batch_list], dim=0).float()
# Optional offline base actions for adapter; if missing, we simply do not include them.
base_action_vector = None
base_action_chunk = None
base_action_trajectory = None
has_base_top = any(
k in s0
for k in ["base_action_vector", "base_action_chunk", "base_action_trajectory"]
)
has_base_in_action = "action" in s0 and isinstance(s0["action"], dict) and any(
k in s0["action"]
for k in ["base_action_vector", "base_action_chunk", "base_action_trajectory"]
)
if has_base_top:
if "base_action_vector" in s0:
base_action_vector = torch.stack([b["base_action_vector"] for b in batch_list], dim=0).float()
if "base_action_chunk" in s0:
base_action_chunk = torch.stack([b["base_action_chunk"] for b in batch_list], dim=0).float()
if "base_action_trajectory" in s0:
base_action_trajectory = torch.stack([b["base_action_trajectory"] for b in batch_list], dim=0).float()
elif has_base_in_action:
a0 = s0["action"]
def pick_base_key(d, candidates):
for k in candidates:
if k in d:
return k
return None
vec_bk = pick_base_key(a0, ["base_action_vector", "base_vec"])
chk_bk = pick_base_key(a0, ["base_action_chunk", "base_chunk"])
trj_bk = pick_base_key(a0, ["base_action_trajectory", "base_traj"])
if vec_bk is not None:
base_action_vector = torch.stack([b["action"][vec_bk] for b in batch_list], dim=0).float()
if chk_bk is not None:
base_action_chunk = torch.stack([b["action"][chk_bk] for b in batch_list], dim=0).float()
if trj_bk is not None:
base_action_trajectory = torch.stack([b["action"][trj_bk] for b in batch_list], dim=0).float()
batch: Dict[str, Any] = {
"vis_obs": vis_obs,
"depth_obs": depth_obs,
"audio_obs": audio_obs,
"robot_state": robot_state,
"texts": texts,
"gt_action_vector": gt_action_vector,
"gt_action_chunk": gt_action_chunk,
"gt_action_trajectory": gt_action_trajectory,
}
if base_action_vector is not None:
batch["base_action_vector"] = base_action_vector
if base_action_chunk is not None:
batch["base_action_chunk"] = base_action_chunk
if base_action_trajectory is not None:
batch["base_action_trajectory"] = base_action_trajectory
return batch
def _align_target(pred_t: torch.Tensor, gt_t: torch.Tensor) -> torch.Tensor:
"""
Align GT to prediction for MSE:
- handle 2D vs 3D mismatches by collapsing or expanding time dimension.
- then align last-dim by pad/trim.
"""
if gt_t.dim() == 3 and pred_t.dim() == 2:
gt_t = gt_t[:, -1, :]
if pred_t.dim() == 3 and gt_t.dim() == 2:
gt_t = gt_t.unsqueeze(1)
if gt_t.size(1) != pred_t.size(1):
gt_t = gt_t.expand(-1, pred_t.size(1), -1)
if pred_t.dim() == 3 and gt_t.dim() == 3:
Tp = pred_t.size(1)
Tg = gt_t.size(1)
if Tg < Tp:
pad = torch.zeros(
gt_t.size(0), Tp - Tg, gt_t.size(2),
device=gt_t.device, dtype=gt_t.dtype
)
gt_t = torch.cat([gt_t, pad], dim=1)
elif Tg > Tp:
gt_t = gt_t[:, :Tp, :]
pd = pred_t.size(-1)
gd = gt_t.size(-1)
if gd < pd:
gt_t = F.pad(gt_t, (0, pd - gd))
elif gd > pd:
gt_t = gt_t[..., :pd]
return gt_t
def _pred_action(pred: Dict[str, torch.Tensor], key: str) -> torch.Tensor:
if key in pred:
return pred[key]
if "action" in pred:
return pred["action"]
raise KeyError(
f"Pred dict missing action key '{key}' and fallback 'action'. "
f"Available pred keys: {list(pred.keys())}"
)
@torch.no_grad()
def compute_branch_mse(pred: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, float]:
vec_pred = _pred_action(pred, "action_vector")
chk_pred = _pred_action(pred, "action_chunk")
trj_pred = _pred_action(pred, "action_trajectory")
device = vec_pred.device
gt_vec = _align_target(vec_pred, batch["gt_action_vector"].to(device))
gt_chk = _align_target(chk_pred, batch["gt_action_chunk"].to(device))
gt_trj = _align_target(trj_pred, batch["gt_action_trajectory"].to(device))
mse_vec = F.mse_loss(vec_pred, gt_vec).item()
mse_chk = F.mse_loss(chk_pred, gt_chk).item()
mse_trj = F.mse_loss(trj_pred, gt_trj).item()
return {"mse_vector": mse_vec, "mse_chunk": mse_chk, "mse_traj": mse_trj}
@torch.no_grad()
def compute_success_proxy(
pred: Dict[str, torch.Tensor],
batch: Dict[str, Any],
thr_vec: float,
thr_chk: float,
thr_trj: float,
) -> Tuple[int, int, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
num_success, num_total, mse_vec_per_sample, mse_chk_per_sample, mse_trj_per_sample
where per-sample MSE is averaged over all non-batch dims.
"""
vec_pred = _pred_action(pred, "action_vector")
chk_pred = _pred_action(pred, "action_chunk")
trj_pred = _pred_action(pred, "action_trajectory")
device = vec_pred.device
gt_vec = _align_target(vec_pred, batch["gt_action_vector"].to(device))
gt_chk = _align_target(chk_pred, batch["gt_action_chunk"].to(device))
gt_trj = _align_target(trj_pred, batch["gt_action_trajectory"].to(device))
reduce_dims_vec = list(range(1, vec_pred.dim()))
reduce_dims_chk = list(range(1, chk_pred.dim()))
reduce_dims_trj = list(range(1, trj_pred.dim()))
mse_vec_s = ((vec_pred - gt_vec) ** 2).mean(dim=reduce_dims_vec)
mse_chk_s = ((chk_pred - gt_chk) ** 2).mean(dim=reduce_dims_chk)
mse_trj_s = ((trj_pred - gt_trj) ** 2).mean(dim=reduce_dims_trj)
success_mask = (mse_vec_s < thr_vec) & (mse_chk_s < thr_chk) & (mse_trj_s < thr_trj)
num_success = int(success_mask.sum().item())
num_total = int(success_mask.numel())
return num_success, num_total, mse_vec_s, mse_chk_s, mse_trj_s
@torch.no_grad()
def compute_telepathy_stability(pred: Dict[str, torch.Tensor]) -> float:
tau = pred.get("telepathy_factors", None)
if tau is None:
return float("nan")
return float((tau ** 2).mean().item())
@torch.no_grad()
def cosine_alignment(a: torch.Tensor, b: torch.Tensor) -> float:
"""
Cosine alignment that is robust to hidden-size mismatch.
Accepts [B, D] or [B, T, D]. Pools time if present.
If dims differ, crops both to min(Da, Db) for a fair cosine check.
"""
if a.dim() == 3:
a = a.mean(dim=1)
if b.dim() == 3:
b = b.mean(dim=1)
if a.numel() == 0 or b.numel() == 0:
return float("nan")
da, db = a.size(-1), b.size(-1)
if da != db:
d = min(da, db)
a = a[..., :d]
b = b[..., :d]
a = F.normalize(a, dim=-1)
b = F.normalize(b, dim=-1)
return float((a * b).sum(dim=-1).mean().item())
@torch.no_grad()
def build_text_tokens_from_policy(
tokenizer,
text_embed_layer: nn.Module,
texts: List[str],
device: torch.device,
target_dtype: torch.dtype,
max_text_len: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Tokenize prompts and map to embeddings using PI05 internal embedding layer.
Returns (text_tokens, attn_mask).
"""
if max_text_len and max_text_len > 0:
tok = tokenizer(
texts,
padding=True,
truncation=True,
max_length=max_text_len,
return_tensors="pt",
)
else:
tok = tokenizer(
texts,
padding=True,
truncation=False,
return_tensors="pt",
)
if hasattr(tok, "input_ids"):
input_ids = tok.input_ids
attn_mask = tok.attention_mask
else:
input_ids = tok["input_ids"]
attn_mask = tok.get("attention_mask", None)
if attn_mask is None:
attn_mask = torch.ones_like(input_ids)
input_ids = input_ids.to(device)
attn_mask = attn_mask.to(device)
text_tokens = text_embed_layer(input_ids).to(dtype=target_dtype)
return text_tokens, attn_mask
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--sigma_env", type=str, default="sigma.env")
parser.add_argument("--shard_dir", type=str, default="")
parser.add_argument("--output_dir", type=str, default="./sigma_eval_out")
parser.add_argument(
"--base_model_id",
type=str,
required=True,
help="LeRobot/OpenPI policy repo, e.g., lerobot/pi05_base or your Sigma policy repo.",
)
parser.add_argument(
"--telepathy_heads_path",
type=str,
default="",
help="Path to sigma_telepathy_heads.pt. If empty, auto-fetch may fill it.",
)
parser.add_argument(
"--disable_telepathy",
action="store_true",
help="Disable telepathy injection (control run).",
)
parser.add_argument(
"--tokenizer_id",
type=str,
default="",
help="Explicit HF tokenizer id OR local tokenizer folder path.",
)
parser.add_argument("--max_text_len", type=int, default=0)
parser.add_argument(
"--artifacts_repo_id",
type=str,
default="",
help="HF repo containing storage/sigma_pickplace and storage/sigma_lora_out.",
)
parser.add_argument(
"--hf_cache_root",
type=str,
default="/workspace/.hf_sigma_cache",
)
parser.add_argument("--load_in_4bit", action="store_true")
parser.add_argument("--dtype", type=str, default="bf16")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--num_workers", type=int, default=2)
parser.add_argument("--max_batches", type=int, default=-1)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--shuffle",
action="store_true",
help="Shuffle dataset order to enable different random subsets per seed.",
)
parser.add_argument(
"--telepathy_scale",
type=float,
default=1.0,
help="Multiply telepathy_factors (tau) to control injection strength.",
)
parser.add_argument("--succ_thr_vec", type=float, default=0.05)
parser.add_argument("--succ_thr_chk", type=float, default=0.10)
parser.add_argument("--succ_thr_trj", type=float, default=0.10)
# Hard-set thresholds: if <=0, they default to 2x the success thresholds.
parser.add_argument(
"--hard_thr_vec",
type=float,
default=-1.0,
help="Per-sample MSE threshold for the 'hard' set on vector branch; <=0 means 2x succ_thr_vec.",
)
parser.add_argument(
"--hard_thr_chk",
type=float,
default=-1.0,
help="Per-sample MSE threshold for the 'hard' set on chunk branch; <=0 means 2x succ_thr_chk.",
)
parser.add_argument(
"--hard_thr_trj",
type=float,
default=-1.0,
help="Per-sample MSE threshold for the 'hard' set on trajectory branch; <=0 means 2x succ_thr_trj.",
)
parser.add_argument(
"--strict_pi05_load",
action="store_true",
help="Try strict PI05Policy loading if supported by LeRobot.",
)
parser.add_argument(
"--allow_tokenizer_mismatch",
action="store_true",
help="Do not fail on tokenizer/embedding mismatch (NOT recommended for baseline).",
)
# Simple flag to enable/disable the adapter without touching telepathy itself.
parser.add_argument(
"--use_telepathy_adapter",
action="store_true",
help="If set and telepathy is enabled, apply sigma_telepathy_adapter to actions in eval.",
)
args = parser.parse_args()
if os.path.exists(args.sigma_env):
load_dotenv(args.sigma_env)
hf_token = os.getenv("HF_TOKEN", None)
accelerator = Accelerator(mixed_precision=args.dtype if args.dtype != "fp32" else "no")
set_seed(args.seed)
device = accelerator.device
if args.load_in_4bit:
print("[WARN] --load_in_4bit is ignored for PI05Policy evaluator.")
artifacts_repo = args.artifacts_repo_id.strip()
if not artifacts_repo and args.base_model_id.startswith("Veltraxor/"):
artifacts_repo = args.base_model_id
need_shards = (not args.shard_dir) or (not os.path.isdir(args.shard_dir))
need_heads = (not args.telepathy_heads_path) or (not os.path.isfile(args.telepathy_heads_path))
if artifacts_repo and (need_shards or need_heads):
paths = ensure_sigma_artifacts_from_hf(
repo_id=artifacts_repo,
hf_token=hf_token,
local_cache_root=args.hf_cache_root,
)
if need_shards:
args.shard_dir = paths["shard_dir"]
print(f"[INFO] Using cached shard_dir: {args.shard_dir}")
if need_heads:
args.telepathy_heads_path = paths["telepathy_heads_path"]
print(f"[INFO] Using cached telepathy_heads_path: {args.telepathy_heads_path}")
if not args.shard_dir or not os.path.isdir(args.shard_dir):
raise FileNotFoundError(
f"shard_dir not found locally: {args.shard_dir}. "
"Either provide a valid local path or an artifacts_repo_id for auto-download."
)
if not args.telepathy_heads_path or not os.path.isfile(args.telepathy_heads_path):
raise FileNotFoundError(
f"telepathy_heads_path not found locally: {args.telepathy_heads_path}. "
"Either provide a valid local path or store it under storage/sigma_lora_out/ "
"in artifacts_repo_id for auto-download."
)
policy = load_pi05_policy(
args.base_model_id,
hf_token,
device=device,
strict_load=args.strict_pi05_load,
)
tokenizer = get_policy_tokenizer(
policy,
args.base_model_id,
hf_token,
forced_tokenizer_id=args.tokenizer_id,
)
text_embed_layer = get_policy_text_embedding_layer(policy)
verify_tokenizer_embedding_compat(
tokenizer=tokenizer,
text_embed_layer=text_embed_layer,
allow_mismatch=args.allow_tokenizer_mismatch,
)
v_cfg = VisionConfig()
l_cfg = LanguageConfig()
a_cfg = ActionConfig()
telepathy_vla = TelepathyVLA(v_cfg, l_cfg, a_cfg, disable_telepathy=args.disable_telepathy)
telepathy_vla.telepathy_scale = args.telepathy_scale
# Instantiate Telepathy adapter (used only when telepathy is enabled and flag is set).
adapter_cfg = SigmaTelepathyAdapterConfig()
telepathy_adapter = SigmaTelepathyAdapter(adapter_cfg).to(device)
if accelerator.is_main_process:
file_size_mb = os.path.getsize(args.telepathy_heads_path) / (1024 * 1024)
print(f"[CHECK-A] disable_telepathy={args.disable_telepathy}")
print(f"[CHECK-A] telepathy_heads_path={args.telepathy_heads_path} size={file_size_mb:.2f}MB")
sd = torch.load(args.telepathy_heads_path, map_location="cpu")
tensor_list = [v.detach().float().reshape(-1) for v in sd.values() if torch.is_tensor(v)]
if accelerator.is_main_process and len(tensor_list) > 0:
capped = [t[:100000] for t in tensor_list]
flat = torch.cat(capped, dim=0)
rms = torch.sqrt((flat ** 2).mean()).item()
print(f"[CHECK-A] heads_tensors={len(tensor_list)} mean={flat.mean().item():.6f} std={flat.std().item():.6f} rms={rms:.6f}")
missing, unexpected = telepathy_vla.load_state_dict(sd, strict=False)
if accelerator.is_main_process:
if len(missing) > 0 or len(unexpected) > 0:
print(f"[CHECK-A] loaded with strict=False. Missing={len(missing)} Unexpected={len(unexpected)}")
print(f"[CHECK-A] Missing keys (first 20): {missing[:20]}")
print(f"[CHECK-A] Unexpected keys (first 20): {unexpected[:20]}")
else:
print("[CHECK-A] heads fully matched (no missing/unexpected).")
telepathy_vla.eval()
ds = SigmaShardDataset(args.shard_dir)
dl = DataLoader(
ds,
batch_size=args.batch_size,
shuffle=args.shuffle,
num_workers=args.num_workers,
collate_fn=collate_sigma,
drop_last=False,
pin_memory=torch.cuda.is_available(),
)
telepathy_vla, dl = accelerator.prepare(telepathy_vla, dl)
target_dtype = next(telepathy_vla.parameters()).dtype
sum_mse_vec = 0.0
sum_mse_chk = 0.0
sum_mse_trj = 0.0
sum_tau_l2 = 0.0
sum_sem_align = 0.0
# Hard-set aggregators
hard_thr_vec = args.hard_thr_vec if args.hard_thr_vec > 0.0 else 2.0 * args.succ_thr_vec
hard_thr_chk = args.hard_thr_chk if args.hard_thr_chk > 0.0 else 2.0 * args.succ_thr_chk
hard_thr_trj = args.hard_thr_trj if args.hard_thr_trj > 0.0 else 2.0 * args.succ_thr_trj
sum_hard_mse_vec = 0.0
sum_hard_mse_chk = 0.0
sum_hard_mse_trj = 0.0
total_hard_samples = 0
n_batches = 0
n_samples = 0
os.makedirs(args.output_dir, exist_ok=True)
for bidx, batch in enumerate(dl):
if args.max_batches > 0 and bidx >= args.max_batches:
break
telepathy_vla.reset_memory()
B = batch["vis_obs"].size(0)
n_samples += B
text_tokens, attn_mask = build_text_tokens_from_policy(
tokenizer=tokenizer,
text_embed_layer=text_embed_layer,
texts=batch["texts"],
device=device,
target_dtype=target_dtype,
max_text_len=args.max_text_len,
)
robot_state = batch["robot_state"].to(device)
if robot_state.dim() == 3:
robot_state = robot_state[:, -1]
# Move optional base actions to device for the adapter.
if "base_action_vector" in batch:
batch["base_action_vector"] = batch["base_action_vector"].to(device)
if "base_action_chunk" in batch:
batch["base_action_chunk"] = batch["base_action_chunk"].to(device)
if "base_action_trajectory" in batch:
batch["base_action_trajectory"] = batch["base_action_trajectory"].to(device)
try:
expected_d = telepathy_vla.vision.state_encoder.mlp[0].in_features
except Exception:
expected_d = robot_state.size(-1)
cur_d = robot_state.size(-1)
if cur_d < expected_d:
robot_state = F.pad(robot_state, (0, expected_d - cur_d))
elif cur_d > expected_d:
robot_state = robot_state[..., :expected_d]
pred = telepathy_vla.forward_once(
vis_obs=batch["vis_obs"].to(device),
robot_state=robot_state,
depth_obs=batch["depth_obs"].to(device) if batch["depth_obs"] is not None else None,
audio_obs=batch["audio_obs"].to(device) if batch["audio_obs"] is not None else None,
text_tokens=text_tokens,
attn_mask=attn_mask,
return_intermediate=True,
)
if accelerator.is_main_process and bidx == 0:
model_ref = telepathy_vla.module if hasattr(telepathy_vla, "module") else telepathy_vla
model_ref.reset_memory()
prev_flag = bool(model_ref.disable_telepathy)
model_ref.disable_telepathy = True
pred_ctrl = model_ref.forward_once(
vis_obs=batch["vis_obs"].to(device),
robot_state=robot_state,
depth_obs=batch["depth_obs"].to(device) if batch["depth_obs"] is not None else None,
audio_obs=batch["audio_obs"].to(device) if batch["audio_obs"] is not None else None,
text_tokens=text_tokens,
attn_mask=attn_mask,
return_intermediate=False,
)
model_ref.disable_telepathy = prev_flag
try:
act_exp = _pred_action(pred, "action_vector")
act_ctl = _pred_action(pred_ctrl, "action_vector")
diff = (act_exp - act_ctl).abs().mean().item()
print(f"[CHECK-B] telepathy_effect_mean_abs_diff(action_vector)={diff:.6f}")
except Exception as e:
print(f"[CHECK-B] action diff check failed: {type(e).__name__}: {e}")
# Apply Telepathy adapter only when telepathy is enabled and the flag is set.
if (not args.disable_telepathy) and args.use_telepathy_adapter:
pred = telepathy_adapter(pred, batch)
mse = compute_branch_mse(pred, batch)
tau_l2 = compute_telepathy_stability(pred)
(
_,
_,
mse_vec_s,
mse_chk_s,
mse_trj_s,
) = compute_success_proxy(
pred,
batch,
thr_vec=args.succ_thr_vec,
thr_chk=args.succ_thr_chk,
thr_trj=args.succ_thr_trj,
)
# Hard-set accumulation: samples where any branch MSE exceeds hard thresholds
hard_mask = (mse_vec_s > hard_thr_vec) | (mse_chk_s > hard_thr_chk) | (mse_trj_s > hard_thr_trj)
hard_count = int(hard_mask.sum().item())
if hard_count > 0:
sum_hard_mse_vec += mse_vec_s[hard_mask].sum().item()
sum_hard_mse_chk += mse_chk_s[hard_mask].sum().item()
sum_hard_mse_trj += mse_trj_s[hard_mask].sum().item()
total_hard_samples += hard_count
sem_factors = pred.get("semantic_factors", None)
if sem_factors is not None:
if sem_factors.dim() == 3:
sem_pool = sem_factors.mean(dim=1)
elif sem_factors.dim() == 2:
sem_pool = sem_factors
else:
sem_pool = sem_factors.view(sem_factors.size(0), -1)
txt_pool = text_tokens.mean(dim=1)
sem_align = cosine_alignment(sem_pool, txt_pool)
else:
sem_align = float("nan")
sum_mse_vec += mse["mse_vector"]
sum_mse_chk += mse["mse_chunk"]
sum_mse_trj += mse["mse_traj"]
if not (tau_l2 != tau_l2):
sum_tau_l2 += tau_l2
if not (sem_align != sem_align):
sum_sem_align += sem_align
n_batches += 1
if accelerator.is_main_process and bidx % 20 == 0:
print(
f"batch={bidx} "
f"mse_vec={mse['mse_vector']:.4f} mse_chk={mse['mse_chunk']:.4f} mse_trj={mse['mse_traj']:.4f} "
f"tau_l2={tau_l2:.4f} sem_align={sem_align:.4f}"
)
if accelerator.is_main_process:
avg_mse_vec = sum_mse_vec / max(1, n_batches)
avg_mse_chk = sum_mse_chk / max(1, n_batches)
avg_mse_trj = sum_mse_trj / max(1, n_batches)
avg_tau_l2 = sum_tau_l2 / max(1, n_batches)
avg_sem_align = sum_sem_align / max(1, n_batches)
if total_hard_samples > 0:
avg_hard_mse_vec = sum_hard_mse_vec / float(total_hard_samples)
avg_hard_mse_chk = sum_hard_mse_chk / float(total_hard_samples)
avg_hard_mse_trj = sum_hard_mse_trj / float(total_hard_samples)
else:
avg_hard_mse_vec = float("nan")
avg_hard_mse_chk = float("nan")
avg_hard_mse_trj = float("nan")
hard_fraction = float(total_hard_samples / max(1, n_samples))
report = {
"num_samples": n_samples,
"num_batches": n_batches,
"avg_mse_vector": avg_mse_vec,
"avg_mse_chunk": avg_mse_chk,
"avg_mse_traj": avg_mse_trj,
"avg_tau_l2": avg_tau_l2,
"avg_semantic_text_alignment": avg_sem_align,
"hard_thresholds": {
"vec": hard_thr_vec,
"chk": hard_thr_chk,
"trj": hard_thr_trj,
},
"avg_hard_mse_vector": avg_hard_mse_vec,
"avg_hard_mse_chunk": avg_hard_mse_chk,
"avg_hard_mse_traj": avg_hard_mse_trj,
"hard_sample_fraction": hard_fraction,
"total_hard_samples": int(total_hard_samples),
}
with open(
os.path.join(args.output_dir, "sigma_eval_report.json"),
"w",
encoding="utf-8",
) as f:
json.dump(report, f, indent=2)
print("[DONE] Saved report:", report)
if __name__ == "__main__":
main()