ConorWang commited on
Commit
03426f9
·
verified ·
1 Parent(s): 6140022

Upload 10 files

Browse files
eval_sigma_vla_rollout.py ADDED
@@ -0,0 +1,1402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # eval_sigma_vla_rollout.py
2
+ # Offline closed-loop evaluation for Telepathy-augmented VLA on top of PI05 policy backbone.
3
+ #
4
+ # Key design:
5
+ # - base_model_id is a LeRobot/OpenPI policy repo (e.g., lerobot/pi05_base or your fine-tuned Sigma repo).
6
+ # - We load PI05Policy via LeRobot, NOT AutoModelForCausalLM.
7
+ # - Text embeddings are taken from the PI05 internal text backbone so that TelepathyLanguageModule
8
+ # receives the same type of inputs used during training.
9
+ #
10
+ # Hardened in this revision:
11
+ # - Robust recursive shard discovery under any naming & subfolders.
12
+ # - Shard content structure normalization (list-of-samples, or dict{samples/data}).
13
+ # - Collate auto-adapts to real schema: vision/state/action/text, with time-dim collapse for vision.
14
+ # - Action GT supports dict-style branches or a single tensor.
15
+ # - Metrics tolerate missing multi-branch outputs (fallback to "action").
16
+ # - Text tokens dtype/device aligned to model dtype for mixed precision safety.
17
+ # - Robot state time-dim collapse + pad/trim to state encoder expected dim.
18
+ # - Dynamic projection to align vision/state token hidden size to vision backbone dim (768),
19
+ # and project text to the same dim BEFORE feeding language module.
20
+ # - Optional max_text_len to avoid tokenizer truncation warnings.
21
+ # - action input contract hardening:
22
+ # * high_level_rep 2D -> 3D
23
+ # * tau None/2D -> 3D
24
+ # * tau length aligned to high_level_rep length
25
+ # * tau last-dim auto pad/trim so concat(high_level_rep, tau) matches action_condition_proj in_features
26
+ # - tokenizer_id can be a LOCAL path; when it exists locally we load with local_files_only
27
+ # - _align_target handles 2D<->3D mismatches (fixes MSE crashes)
28
+ # - remove duplicated "high_level_rep/tau re-normalization" that overwrote the hardening
29
+ #
30
+ # NEW in this patch:
31
+ # - cosine_alignment auto-aligns hidden sizes (fixes 256 vs 2048 crash).
32
+ # - semantic pooling guard supports 2D/3D factors safely.
33
+ # - alignment metric ignores zero-length cases robustly.
34
+ #
35
+ # EXTRA HARDENING (this patch for your baseline issue):
36
+ # - Try strict load for PI05Policy if the LeRobot version supports it.
37
+ # - Verify tokenizer vocab size and special-token ids match PI05 text embedding table.
38
+ # - Fail fast with a clear message if mismatch is detected (unless explicitly overridden).
39
+ #
40
+ # NEW in this hard-set patch:
41
+ # - Per-sample MSE is exposed from success proxy.
42
+ # - A "hard set" is defined as samples whose branch-wise MSE exceeds hard thresholds.
43
+ # - Hard-set averages (MSE and fraction of samples) are reported alongside global metrics.
44
+ #
45
+ # NEW in this adapter patch:
46
+ # - sigma_telepathy_adapter is applied at eval time (when telepathy is enabled) to gate
47
+ # Telepathy residuals based on their magnitude and tau strength, optionally using
48
+ # offline base_action_* if present in the shards.
49
+
50
+ from __future__ import annotations
51
+
52
+ import os
53
+ import glob
54
+ import json
55
+ import argparse
56
+ import importlib
57
+ from typing import Any, Dict, List, Optional, Tuple
58
+
59
+ import torch
60
+ import torch.nn as nn
61
+ import torch.nn.functional as F
62
+ from torch.utils.data import Dataset, DataLoader
63
+
64
+ from dotenv import load_dotenv
65
+ from accelerate import Accelerator
66
+ from accelerate.utils import set_seed
67
+
68
+ try:
69
+ from huggingface_hub import snapshot_download
70
+ except Exception:
71
+ snapshot_download = None # type: ignore
72
+
73
+ from vision_sigma_vla import TelepathyVisionModule, VisionConfig
74
+ from language_sigma_vla import TelepathyLanguageModule, LanguageConfig
75
+ from action_sigma_vla import TelepathyActionModule, ActionConfig
76
+ from sigma_telepathy_adapter import SigmaTelepathyAdapter, SigmaTelepathyAdapterConfig
77
+
78
+
79
+ def ensure_sigma_artifacts_from_hf(
80
+ repo_id: str,
81
+ hf_token: Optional[str],
82
+ local_cache_root: str,
83
+ ) -> Dict[str, str]:
84
+ """
85
+ Download Sigma artifacts from HF repo into a local cache folder.
86
+ Returns local paths for shard_dir and telepathy_heads_path.
87
+
88
+ We only pull:
89
+ storage/sigma_pickplace/**
90
+ storage/sigma_lora_out/**
91
+ """
92
+ if snapshot_download is None:
93
+ raise ImportError(
94
+ "huggingface_hub is not available but auto-download was requested. "
95
+ "Please `pip install huggingface_hub` or download artifacts manually."
96
+ )
97
+
98
+ os.makedirs(local_cache_root, exist_ok=True)
99
+ local_dir = snapshot_download(
100
+ repo_id=repo_id,
101
+ token=hf_token,
102
+ local_dir=os.path.join(local_cache_root, repo_id.replace("/", "__")),
103
+ local_dir_use_symlinks=False,
104
+ allow_patterns=[
105
+ "storage/sigma_pickplace/**",
106
+ "storage/sigma_lora_out/**",
107
+ ],
108
+ )
109
+
110
+ shard_dir = os.path.join(local_dir, "storage", "sigma_pickplace")
111
+ telepathy_heads_path = os.path.join(
112
+ local_dir, "storage", "sigma_lora_out", "sigma_telepathy_heads.pt"
113
+ )
114
+
115
+ return {
116
+ "local_repo_dir": local_dir,
117
+ "shard_dir": shard_dir,
118
+ "telepathy_heads_path": telepathy_heads_path,
119
+ }
120
+
121
+
122
+ def load_pi05_policy(
123
+ repo_id: str,
124
+ hf_token: Optional[str],
125
+ device: torch.device,
126
+ strict_load: bool = True,
127
+ ):
128
+ """
129
+ Load PI05Policy from LeRobot. We try a few import paths to be robust across versions.
130
+ If the LeRobot PI05Policy.from_pretrained supports strict loading, we enable it.
131
+ """
132
+ policy_cls = None
133
+ import_errors = []
134
+
135
+ candidate_paths = [
136
+ ("lerobot.policies.pi05.modeling_pi05", "PI05Policy"),
137
+ ("lerobot.policies.pi05", "PI05Policy"),
138
+ ]
139
+
140
+ for mod_name, cls_name in candidate_paths:
141
+ try:
142
+ mod = importlib.import_module(mod_name)
143
+ policy_cls = getattr(mod, cls_name)
144
+ break
145
+ except Exception as e:
146
+ import_errors.append(f"{mod_name}.{cls_name}: {type(e).__name__}: {e}")
147
+
148
+ if policy_cls is None:
149
+ raise ImportError(
150
+ "Failed to import PI05Policy from LeRobot. Tried:\n - "
151
+ + "\n - ".join(import_errors)
152
+ )
153
+
154
+ policy = None
155
+ tried = []
156
+ if strict_load:
157
+ try:
158
+ policy = policy_cls.from_pretrained(repo_id, token=hf_token, strict=True)
159
+ tried.append("from_pretrained(..., strict=True)")
160
+ except TypeError:
161
+ tried.append("strict=True not supported")
162
+ except Exception as e:
163
+ tried.append(f"strict=True failed: {type(e).__name__}: {e}")
164
+
165
+ if policy is None:
166
+ try:
167
+ policy = policy_cls.from_pretrained(repo_id, token=hf_token)
168
+ tried.append("from_pretrained(repo_id, token=...)")
169
+ except TypeError:
170
+ policy = policy_cls.from_pretrained(pretrained_name_or_path=repo_id, token=hf_token)
171
+ tried.append("from_pretrained(pretrained_name_or_path=..., token=...)")
172
+
173
+ if policy is None:
174
+ raise RuntimeError("PI05Policy loading returned None. Tried: " + "; ".join(tried))
175
+
176
+ policy = policy.to(device)
177
+ policy.eval()
178
+ return policy
179
+
180
+
181
+ def get_policy_tokenizer(
182
+ policy,
183
+ repo_id: str,
184
+ hf_token: Optional[str],
185
+ forced_tokenizer_id: str = "",
186
+ ):
187
+ """
188
+ Robust tokenizer getter for PI05Policy.
189
+
190
+ IMPORTANT:
191
+ - Never call AutoTokenizer.from_pretrained(repo_id) because repo_id is a policy repo.
192
+ - If --tokenizer_id is provided and points to a LOCAL folder, load locally.
193
+ - Otherwise load from HF id.
194
+ - If still missing, recursively search for tokenizer/processor inside policy.
195
+ """
196
+ from transformers import AutoTokenizer
197
+
198
+ if forced_tokenizer_id:
199
+ if os.path.exists(forced_tokenizer_id):
200
+ tok = AutoTokenizer.from_pretrained(
201
+ forced_tokenizer_id,
202
+ local_files_only=True,
203
+ trust_remote_code=True,
204
+ )
205
+ else:
206
+ tok = AutoTokenizer.from_pretrained(
207
+ forced_tokenizer_id,
208
+ token=hf_token,
209
+ trust_remote_code=True,
210
+ )
211
+ if tok.pad_token is None:
212
+ tok.pad_token = tok.eos_token
213
+ return tok
214
+
215
+ def _recursive_find_tokenizer(obj, max_depth: int = 4):
216
+ if obj is None or max_depth <= 0:
217
+ return None
218
+
219
+ for key in ["tokenizer", "processor", "text_tokenizer", "language_tokenizer"]:
220
+ if hasattr(obj, key):
221
+ v = getattr(obj, key)
222
+ if v is None:
223
+ continue
224
+ if key == "processor" and hasattr(v, "tokenizer") and v.tokenizer is not None:
225
+ return v.tokenizer
226
+ if hasattr(v, "__call__"):
227
+ return v
228
+
229
+ nested_names = [
230
+ "paligemma_with_expert",
231
+ "paligemma",
232
+ "gemma_expert",
233
+ "language_model",
234
+ "text_model",
235
+ "model",
236
+ "policy",
237
+ ]
238
+ for name in nested_names:
239
+ if hasattr(obj, name):
240
+ found = _recursive_find_tokenizer(
241
+ getattr(obj, name), max_depth=max_depth - 1
242
+ )
243
+ if found is not None:
244
+ return found
245
+ return None
246
+
247
+ tok = _recursive_find_tokenizer(policy)
248
+ if tok is not None:
249
+ if getattr(tok, "pad_token", None) is None and getattr(tok, "eos_token", None) is not None:
250
+ tok.pad_token = tok.eos_token
251
+ return tok
252
+
253
+ backbone_name = None
254
+ config_candidates = []
255
+ for attr in ["config", "model", "paligemma_with_expert", "paligemma"]:
256
+ if hasattr(policy, attr):
257
+ config_candidates.append(getattr(policy, attr))
258
+
259
+ def _try_get_name(cfg_obj):
260
+ if cfg_obj is None:
261
+ return None
262
+ for k in [
263
+ "_name_or_path",
264
+ "text_backbone_id",
265
+ "text_model_id",
266
+ "language_model_id",
267
+ "processor_name_or_path",
268
+ "tokenizer_name_or_path",
269
+ ]:
270
+ if hasattr(cfg_obj, k):
271
+ v = getattr(cfg_obj, k)
272
+ if isinstance(v, str) and v:
273
+ return v
274
+ if hasattr(cfg_obj, "config"):
275
+ c = getattr(cfg_obj, "config")
276
+ if hasattr(c, "_name_or_path") and isinstance(c._name_or_path, str) and c._name_or_path:
277
+ return c._name_or_path
278
+ return None
279
+
280
+ for c in config_candidates:
281
+ backbone_name = _try_get_name(c)
282
+ if backbone_name:
283
+ break
284
+
285
+ if backbone_name:
286
+ tok = AutoTokenizer.from_pretrained(
287
+ backbone_name, token=hf_token, trust_remote_code=True
288
+ )
289
+ if tok.pad_token is None:
290
+ tok.pad_token = tok.eos_token
291
+ return tok
292
+
293
+ raise ValueError(
294
+ f"Cannot obtain tokenizer from PI05Policy for repo '{repo_id}'. "
295
+ "Your lerobot PI05 port does not expose tokenizer/processor nor backbone name. "
296
+ "Please pass --tokenizer_id explicitly."
297
+ )
298
+
299
+
300
+ def get_policy_text_embedding_layer(policy):
301
+ """
302
+ Locate the text embedding layer inside PI05Policy robustly.
303
+ """
304
+ def _recursive_find(obj, depth: int = 6):
305
+ if obj is None or depth <= 0:
306
+ return None
307
+
308
+ if hasattr(obj, "get_input_embeddings"):
309
+ try:
310
+ emb = obj.get_input_embeddings()
311
+ if emb is not None:
312
+ return emb
313
+ except Exception:
314
+ pass
315
+
316
+ for key in ["embed_tokens", "embeddings", "token_embedding"]:
317
+ if hasattr(obj, key):
318
+ v = getattr(obj, key)
319
+ if isinstance(v, nn.Module):
320
+ return v
321
+
322
+ nested_names = [
323
+ "model",
324
+ "paligemma_with_expert",
325
+ "paligemma",
326
+ "language_model",
327
+ "gemma_expert",
328
+ "text_model",
329
+ "policy",
330
+ ]
331
+ for name in nested_names:
332
+ if hasattr(obj, name):
333
+ found = _recursive_find(getattr(obj, name), depth=depth - 1)
334
+ if found is not None:
335
+ return found
336
+
337
+ return None
338
+
339
+ emb = _recursive_find(policy)
340
+ if emb is None:
341
+ raise AttributeError(
342
+ "Cannot locate PI05 text embedding layer via recursive search. "
343
+ "Your PI05Policy likely changed internal naming. "
344
+ "Please inspect policy.model.* to confirm embed_tokens location."
345
+ )
346
+ return emb
347
+
348
+
349
+ def verify_tokenizer_embedding_compat(
350
+ tokenizer,
351
+ text_embed_layer: nn.Module,
352
+ allow_mismatch: bool = False,
353
+ ):
354
+ """
355
+ Ensure tokenizer vocab/special ids are consistent with PI05 text embedding table.
356
+ This directly prevents the 'embed_tokens.weight missing or misaligned' baseline issue.
357
+ """
358
+ emb_vocab = None
359
+ if isinstance(text_embed_layer, nn.Embedding):
360
+ emb_vocab = int(text_embed_layer.num_embeddings)
361
+ elif hasattr(text_embed_layer, "weight") and text_embed_layer.weight is not None:
362
+ emb_vocab = int(text_embed_layer.weight.size(0))
363
+
364
+ tok_vocab = getattr(tokenizer, "vocab_size", None)
365
+ if tok_vocab is None:
366
+ try:
367
+ tok_vocab = len(tokenizer)
368
+ except Exception:
369
+ tok_vocab = None
370
+
371
+ if emb_vocab is None or tok_vocab is None:
372
+ print("[WARN] Cannot infer tokenizer/embedding vocab sizes. Skipping compatibility check.")
373
+ return
374
+
375
+ if emb_vocab != tok_vocab:
376
+ msg = (
377
+ f"[ERROR] Tokenizer vocab size ({tok_vocab}) != PI05 embedding table size ({emb_vocab}). "
378
+ "This will corrupt text embeddings and invalidate baseline. "
379
+ "Fix by passing --tokenizer_id matching the PI05 text backbone "
380
+ "(e.g., the original openpi/PI05 tokenizer) or re-exporting policy with aligned vocab."
381
+ )
382
+ if allow_mismatch:
383
+ print(msg.replace("[ERROR]", "[WARN]") + " Proceeding due to --allow_tokenizer_mismatch.")
384
+ else:
385
+ raise ValueError(msg)
386
+
387
+ for name in ["pad_token_id", "eos_token_id", "bos_token_id", "unk_token_id"]:
388
+ tid = getattr(tokenizer, name, None)
389
+ if tid is None:
390
+ continue
391
+ if not (0 <= int(tid) < emb_vocab):
392
+ msg = (
393
+ f"[ERROR] Tokenizer {name}={tid} out of embedding range [0, {emb_vocab-1}]. "
394
+ "Your tokenizer does not belong to this PI05 backbone."
395
+ )
396
+ if allow_mismatch:
397
+ print(msg.replace("[ERROR]", "[WARN]") + " Proceeding due to --allow_tokenizer_mismatch.")
398
+ else:
399
+ raise ValueError(msg)
400
+
401
+
402
+ class TelepathyVLA(nn.Module):
403
+ """
404
+ Full model matching your final arrows.
405
+ """
406
+ def __init__(
407
+ self,
408
+ v_cfg: VisionConfig,
409
+ l_cfg: LanguageConfig,
410
+ a_cfg: ActionConfig,
411
+ disable_telepathy: bool = False,
412
+ ):
413
+ super().__init__()
414
+ self.vision = TelepathyVisionModule(v_cfg)
415
+ self.language = TelepathyLanguageModule(l_cfg)
416
+ self.action = TelepathyActionModule(a_cfg)
417
+ self.disable_telepathy = disable_telepathy
418
+ self.register_buffer("_m_prev", None, persistent=False)
419
+
420
+ self._proj_inited = False
421
+ self.text_proj: Optional[nn.Module] = None
422
+ self.vision_proj: Optional[nn.Module] = None
423
+ self.state_proj: Optional[nn.Module] = None
424
+
425
+ def reset_memory(self):
426
+ self._m_prev = None
427
+
428
+ @torch.no_grad()
429
+ def forward_once(
430
+ self,
431
+ vis_obs: torch.Tensor,
432
+ robot_state: torch.Tensor,
433
+ text_tokens: torch.Tensor,
434
+ depth_obs: Optional[torch.Tensor] = None,
435
+ audio_obs: Optional[torch.Tensor] = None,
436
+ attn_mask: Optional[torch.Tensor] = None,
437
+ return_intermediate: bool = False,
438
+ ) -> Dict[str, torch.Tensor]:
439
+
440
+ vis0 = self.vision(
441
+ vis_obs=vis_obs,
442
+ robot_state=robot_state,
443
+ depth_obs=depth_obs,
444
+ audio_obs=audio_obs,
445
+ telepathy_factors=None,
446
+ return_intermediate=return_intermediate,
447
+ )
448
+
449
+ vis_d = vis0["vision_tokens"].size(-1)
450
+ state_d = vis0["state_tokens"].size(-1)
451
+ target_d = vis_d
452
+
453
+ if not self._proj_inited:
454
+ self.text_proj = nn.Linear(text_tokens.size(-1), target_d, bias=False) \
455
+ if text_tokens.size(-1) != target_d else nn.Identity()
456
+ self.vision_proj = nn.Identity() if vis_d == target_d else nn.Linear(vis_d, target_d, bias=False)
457
+ self.state_proj = nn.Identity() if state_d == target_d else nn.Linear(state_d, target_d, bias=False)
458
+
459
+ self.text_proj = self.text_proj.to(device=text_tokens.device, dtype=text_tokens.dtype)
460
+ self.vision_proj = self.vision_proj.to(device=text_tokens.device, dtype=text_tokens.dtype)
461
+ self.state_proj = self.state_proj.to(device=text_tokens.device, dtype=text_tokens.dtype)
462
+ self._proj_inited = True
463
+
464
+ assert self.text_proj is not None and self.vision_proj is not None and self.state_proj is not None
465
+
466
+ text_tokens = self.text_proj(text_tokens)
467
+ vision_tokens = self.vision_proj(vis0["vision_tokens"])
468
+ state_tokens = self.state_proj(vis0["state_tokens"])
469
+
470
+ lang_out = self.language(
471
+ text_tokens=text_tokens,
472
+ vision_tokens=vision_tokens,
473
+ state_tokens=state_tokens,
474
+ m_prev=self._m_prev,
475
+ attn_mask=attn_mask,
476
+ return_intermediate=return_intermediate,
477
+ )
478
+
479
+ raw_tau = lang_out.get("telepathy_factors", None)
480
+ self._m_prev = lang_out.get("m_t", None)
481
+
482
+ telepathy_scale = float(getattr(self, "telepathy_scale", 1.0))
483
+
484
+ if self.disable_telepathy:
485
+ tau = None
486
+ vis_out = vis0
487
+ else:
488
+ tau = raw_tau
489
+ if tau is not None:
490
+ tau = tau * telepathy_scale
491
+ vis_out = self.vision(
492
+ vis_obs=vis_obs,
493
+ robot_state=robot_state,
494
+ depth_obs=depth_obs,
495
+ audio_obs=audio_obs,
496
+ telepathy_factors=tau,
497
+ return_intermediate=return_intermediate,
498
+ )
499
+
500
+ high_level_rep = lang_out.get("high_level_rep", None)
501
+ if high_level_rep is None:
502
+ raise KeyError("language output missing 'high_level_rep'.")
503
+
504
+ if high_level_rep.dim() == 2:
505
+ high_level_rep = high_level_rep.unsqueeze(1)
506
+
507
+ if tau is None:
508
+ B, L, _ = high_level_rep.shape
509
+ tau_dim = getattr(self.language, "tau_dim", 128)
510
+ tau = torch.zeros(B, L, tau_dim, device=high_level_rep.device, dtype=high_level_rep.dtype)
511
+ else:
512
+ if tau.dim() == 2:
513
+ tau = tau.unsqueeze(1)
514
+ if tau.size(1) != high_level_rep.size(1):
515
+ L = high_level_rep.size(1)
516
+ if tau.size(1) == 1:
517
+ tau = tau.expand(-1, L, -1)
518
+ else:
519
+ tau = tau[:, :L, :]
520
+
521
+ expected_in = None
522
+ acp = getattr(self.action, "action_condition_proj", None)
523
+ if acp is not None:
524
+ if hasattr(acp, "in_features"):
525
+ expected_in = int(acp.in_features)
526
+ elif hasattr(acp, "net") and len(acp.net) > 0 and hasattr(acp.net[0], "in_features"):
527
+ expected_in = int(acp.net[0].in_features)
528
+
529
+ if expected_in is not None:
530
+ d_high = high_level_rep.size(-1)
531
+ target_tau = expected_in - d_high
532
+
533
+ if target_tau <= 0:
534
+ pass
535
+ else:
536
+ if tau.size(-1) < target_tau:
537
+ tau = F.pad(tau, (0, target_tau - tau.size(-1)))
538
+ elif tau.size(-1) > target_tau:
539
+ tau = tau[..., :target_tau]
540
+
541
+ state_for_action = vis_out["state_tokens"]
542
+ if state_for_action.dim() == 2:
543
+ state_for_action = state_for_action.unsqueeze(1)
544
+ elif state_for_action.dim() > 3:
545
+ state_for_action = state_for_action.view(
546
+ state_for_action.size(0), -1, state_for_action.size(-1)
547
+ )
548
+
549
+ lang_d = high_level_rep.size(-1)
550
+
551
+ def _pad_or_trim_to(x: torch.Tensor, d: int) -> torch.Tensor:
552
+ cur_d = x.size(-1)
553
+ if cur_d == d:
554
+ return x
555
+ if cur_d < d:
556
+ return F.pad(x, (0, d - cur_d))
557
+ return x[..., :d]
558
+
559
+ state_for_action = _pad_or_trim_to(state_for_action, lang_d)
560
+
561
+ act_out = self.action(
562
+ high_level_rep=high_level_rep,
563
+ telepathy_factors=tau,
564
+ state_tokens=state_for_action,
565
+ return_intermediate=return_intermediate,
566
+ )
567
+
568
+ out: Dict[str, torch.Tensor] = {}
569
+ out.update(vis_out)
570
+ out.update(lang_out)
571
+ out.update(act_out)
572
+ return out
573
+
574
+
575
+ class SigmaShardDataset(Dataset):
576
+ """
577
+ Loads .pt shards produced by dataset_preprocess_sigma_vla.py.
578
+ Each shard is a list of dict samples OR a dict containing a list (samples/data).
579
+ """
580
+ def __init__(self, shard_dir: str):
581
+ super().__init__()
582
+ if not os.path.isdir(shard_dir):
583
+ raise FileNotFoundError(
584
+ f"shard_dir does not exist: {shard_dir}. Double-check the path."
585
+ )
586
+
587
+ patterns = [
588
+ os.path.join(shard_dir, "sigma_vla_shard_*.pt"),
589
+ os.path.join(shard_dir, "*.pt"),
590
+ os.path.join(shard_dir, "**", "*.pt"),
591
+ ]
592
+ paths: List[str] = []
593
+ for p in patterns:
594
+ paths.extend(glob.glob(p, recursive=True))
595
+
596
+ self.shard_paths = sorted(list(set(paths)))
597
+ if len(self.shard_paths) == 0:
598
+ raise FileNotFoundError(
599
+ f"No .pt shards found under {shard_dir}. "
600
+ "Your HF cache is empty or shards are not tracked by LFS."
601
+ )
602
+
603
+ print(f"[INFO] Found {len(self.shard_paths)} shard files. Example: {self.shard_paths[:3]}")
604
+
605
+ self.index_map: List[Tuple[int, int]] = []
606
+ self._shard_cache: Dict[int, List[Dict[str, Any]]] = {}
607
+
608
+ for sid, p in enumerate(self.shard_paths):
609
+ shard = torch.load(p, map_location="cpu")
610
+ shard_list = self._normalize_shard(shard, p)
611
+ for lid in range(len(shard_list)):
612
+ self.index_map.append((sid, lid))
613
+
614
+ self.total = len(self.index_map)
615
+
616
+ def __len__(self):
617
+ return self.total
618
+
619
+ def _normalize_shard(self, shard_obj: Any, path: str) -> List[Dict[str, Any]]:
620
+ if isinstance(shard_obj, (list, tuple)):
621
+ return list(shard_obj)
622
+
623
+ if isinstance(shard_obj, dict):
624
+ for k in ["samples", "data", "items"]:
625
+ if k in shard_obj and isinstance(shard_obj[k], (list, tuple)):
626
+ return list(shard_obj[k])
627
+
628
+ raise TypeError(
629
+ f"Unsupported shard format in {path}. "
630
+ f"Expected list/tuple of samples or dict{{samples/data}}. "
631
+ f"Got type: {type(shard_obj).__name__}"
632
+ )
633
+
634
+ def _get_shard(self, sid: int) -> List[Dict[str, Any]]:
635
+ if sid not in self._shard_cache:
636
+ raw = torch.load(self.shard_paths[sid], map_location="cpu")
637
+ self._shard_cache[sid] = self._normalize_shard(raw, self.shard_paths[sid])
638
+ return self._shard_cache[sid]
639
+
640
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
641
+ sid, lid = self.index_map[idx]
642
+ shard = self._get_shard(sid)
643
+ return shard[lid]
644
+
645
+
646
+ def collate_sigma(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]:
647
+ """
648
+ Robust collate for Sigma shards.
649
+ """
650
+ s0 = batch_list[0]
651
+
652
+ def pick_key(sample: Dict[str, Any], candidates: List[str], field_name: str):
653
+ for k in candidates:
654
+ if k in sample:
655
+ return k
656
+ raise KeyError(
657
+ f"Shard sample missing required field '{field_name}'. "
658
+ f"Tried keys: {candidates}. "
659
+ f"Available keys: {list(sample.keys())}"
660
+ )
661
+
662
+ if "vision" in s0:
663
+ vis_k = "vision"
664
+ else:
665
+ vis_k = pick_key(s0, ["vis_obs", "rgb_obs", "image", "images", "obs"], "vision/vis_obs")
666
+
667
+ vis_obs = torch.stack([b[vis_k] for b in batch_list], dim=0).float()
668
+ if vis_obs.dim() == 5:
669
+ vis_obs = vis_obs[:, -1]
670
+
671
+ depth_obs = None
672
+ if "depth" in s0:
673
+ depth_obs = torch.stack([b["depth"] for b in batch_list], dim=0).float()
674
+ elif any(k in s0 for k in ["depth_obs", "depths"]):
675
+ dk = pick_key(s0, ["depth_obs", "depths"], "depth")
676
+ depth_obs = torch.stack([b[dk] for b in batch_list], dim=0).float()
677
+
678
+ audio_obs = None
679
+ if "audio" in s0:
680
+ audio_obs = torch.stack([b["audio"] for b in batch_list], dim=0).float()
681
+ elif any(k in s0 for k in ["audio_obs", "audios"]):
682
+ ak = pick_key(s0, ["audio_obs", "audios"], "audio")
683
+ audio_obs = torch.stack([b[ak] for b in batch_list], dim=0).float()
684
+
685
+ if "state" in s0:
686
+ state_k = "state"
687
+ else:
688
+ state_k = pick_key(s0, ["robot_state", "proprio", "proprio_obs"], "state/robot_state")
689
+
690
+ robot_state = torch.stack([b[state_k] for b in batch_list], dim=0).float()
691
+
692
+ if "text" in s0:
693
+ texts = [b.get("text", "") for b in batch_list]
694
+ else:
695
+ text_k = pick_key(s0, ["text", "prompt", "instruction"], "text")
696
+ texts = [b.get(text_k, "") for b in batch_list]
697
+
698
+ if "action" in s0:
699
+ a0 = s0["action"]
700
+ if isinstance(a0, dict):
701
+ def pick_action_key(d, candidates, name):
702
+ for k in candidates:
703
+ if k in d:
704
+ return k
705
+ raise KeyError(
706
+ f"Action dict missing '{name}'. Tried {candidates}. "
707
+ f"Available action keys: {list(d.keys())}"
708
+ )
709
+
710
+ vec_k = pick_action_key(a0, ["gt_action_vector", "action_vector", "vector", "vec"], "gt_action_vector")
711
+ chk_k = pick_action_key(a0, ["gt_action_chunk", "action_chunk", "chunk", "chk"], "gt_action_chunk")
712
+ trj_k = pick_action_key(a0, ["gt_action_trajectory", "action_trajectory", "trajectory", "traj"], "gt_action_trajectory")
713
+
714
+ gt_action_vector = torch.stack([b["action"][vec_k] for b in batch_list], dim=0).float()
715
+ gt_action_chunk = torch.stack([b["action"][chk_k] for b in batch_list], dim=0).float()
716
+ gt_action_trajectory = torch.stack([b["action"][trj_k] for b in batch_list], dim=0).float()
717
+ else:
718
+ act = torch.stack([b["action"] for b in batch_list], dim=0).float()
719
+ gt_action_vector = act
720
+ gt_action_chunk = act
721
+ gt_action_trajectory = act
722
+ else:
723
+ gt_vec_k = pick_key(s0, ["gt_action_vector", "action_vector", "gt_vec"], "gt_action_vector")
724
+ gt_chk_k = pick_key(s0, ["gt_action_chunk", "action_chunk", "gt_chunk"], "gt_action_chunk")
725
+ gt_trj_k = pick_key(s0, ["gt_action_trajectory", "action_trajectory", "gt_traj"], "gt_action_trajectory")
726
+
727
+ gt_action_vector = torch.stack([b[gt_vec_k] for b in batch_list], dim=0).float()
728
+ gt_action_chunk = torch.stack([b[gt_chk_k] for b in batch_list], dim=0).float()
729
+ gt_action_trajectory = torch.stack([b[gt_trj_k] for b in batch_list], dim=0).float()
730
+
731
+ # Optional offline base actions for adapter; if missing, we simply do not include them.
732
+ base_action_vector = None
733
+ base_action_chunk = None
734
+ base_action_trajectory = None
735
+
736
+ has_base_top = any(
737
+ k in s0
738
+ for k in ["base_action_vector", "base_action_chunk", "base_action_trajectory"]
739
+ )
740
+ has_base_in_action = "action" in s0 and isinstance(s0["action"], dict) and any(
741
+ k in s0["action"]
742
+ for k in ["base_action_vector", "base_action_chunk", "base_action_trajectory"]
743
+ )
744
+
745
+ if has_base_top:
746
+ if "base_action_vector" in s0:
747
+ base_action_vector = torch.stack([b["base_action_vector"] for b in batch_list], dim=0).float()
748
+ if "base_action_chunk" in s0:
749
+ base_action_chunk = torch.stack([b["base_action_chunk"] for b in batch_list], dim=0).float()
750
+ if "base_action_trajectory" in s0:
751
+ base_action_trajectory = torch.stack([b["base_action_trajectory"] for b in batch_list], dim=0).float()
752
+ elif has_base_in_action:
753
+ a0 = s0["action"]
754
+ def pick_base_key(d, candidates):
755
+ for k in candidates:
756
+ if k in d:
757
+ return k
758
+ return None
759
+
760
+ vec_bk = pick_base_key(a0, ["base_action_vector", "base_vec"])
761
+ chk_bk = pick_base_key(a0, ["base_action_chunk", "base_chunk"])
762
+ trj_bk = pick_base_key(a0, ["base_action_trajectory", "base_traj"])
763
+
764
+ if vec_bk is not None:
765
+ base_action_vector = torch.stack([b["action"][vec_bk] for b in batch_list], dim=0).float()
766
+ if chk_bk is not None:
767
+ base_action_chunk = torch.stack([b["action"][chk_bk] for b in batch_list], dim=0).float()
768
+ if trj_bk is not None:
769
+ base_action_trajectory = torch.stack([b["action"][trj_bk] for b in batch_list], dim=0).float()
770
+
771
+ batch: Dict[str, Any] = {
772
+ "vis_obs": vis_obs,
773
+ "depth_obs": depth_obs,
774
+ "audio_obs": audio_obs,
775
+ "robot_state": robot_state,
776
+ "texts": texts,
777
+ "gt_action_vector": gt_action_vector,
778
+ "gt_action_chunk": gt_action_chunk,
779
+ "gt_action_trajectory": gt_action_trajectory,
780
+ }
781
+
782
+ if base_action_vector is not None:
783
+ batch["base_action_vector"] = base_action_vector
784
+ if base_action_chunk is not None:
785
+ batch["base_action_chunk"] = base_action_chunk
786
+ if base_action_trajectory is not None:
787
+ batch["base_action_trajectory"] = base_action_trajectory
788
+
789
+ return batch
790
+
791
+
792
+ def _align_target(pred_t: torch.Tensor, gt_t: torch.Tensor) -> torch.Tensor:
793
+ """
794
+ Align GT to prediction for MSE:
795
+ - handle 2D vs 3D mismatches by collapsing or expanding time dimension.
796
+ - then align last-dim by pad/trim.
797
+ """
798
+ if gt_t.dim() == 3 and pred_t.dim() == 2:
799
+ gt_t = gt_t[:, -1, :]
800
+
801
+ if pred_t.dim() == 3 and gt_t.dim() == 2:
802
+ gt_t = gt_t.unsqueeze(1)
803
+ if gt_t.size(1) != pred_t.size(1):
804
+ gt_t = gt_t.expand(-1, pred_t.size(1), -1)
805
+
806
+ if pred_t.dim() == 3 and gt_t.dim() == 3:
807
+ Tp = pred_t.size(1)
808
+ Tg = gt_t.size(1)
809
+ if Tg < Tp:
810
+ pad = torch.zeros(
811
+ gt_t.size(0), Tp - Tg, gt_t.size(2),
812
+ device=gt_t.device, dtype=gt_t.dtype
813
+ )
814
+ gt_t = torch.cat([gt_t, pad], dim=1)
815
+ elif Tg > Tp:
816
+ gt_t = gt_t[:, :Tp, :]
817
+
818
+ pd = pred_t.size(-1)
819
+ gd = gt_t.size(-1)
820
+ if gd < pd:
821
+ gt_t = F.pad(gt_t, (0, pd - gd))
822
+ elif gd > pd:
823
+ gt_t = gt_t[..., :pd]
824
+
825
+ return gt_t
826
+
827
+
828
+ def _pred_action(pred: Dict[str, torch.Tensor], key: str) -> torch.Tensor:
829
+ if key in pred:
830
+ return pred[key]
831
+ if "action" in pred:
832
+ return pred["action"]
833
+ raise KeyError(
834
+ f"Pred dict missing action key '{key}' and fallback 'action'. "
835
+ f"Available pred keys: {list(pred.keys())}"
836
+ )
837
+
838
+
839
+ @torch.no_grad()
840
+ def compute_branch_mse(pred: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, float]:
841
+ vec_pred = _pred_action(pred, "action_vector")
842
+ chk_pred = _pred_action(pred, "action_chunk")
843
+ trj_pred = _pred_action(pred, "action_trajectory")
844
+
845
+ device = vec_pred.device
846
+
847
+ gt_vec = _align_target(vec_pred, batch["gt_action_vector"].to(device))
848
+ gt_chk = _align_target(chk_pred, batch["gt_action_chunk"].to(device))
849
+ gt_trj = _align_target(trj_pred, batch["gt_action_trajectory"].to(device))
850
+
851
+ mse_vec = F.mse_loss(vec_pred, gt_vec).item()
852
+ mse_chk = F.mse_loss(chk_pred, gt_chk).item()
853
+ mse_trj = F.mse_loss(trj_pred, gt_trj).item()
854
+ return {"mse_vector": mse_vec, "mse_chunk": mse_chk, "mse_traj": mse_trj}
855
+
856
+
857
+ @torch.no_grad()
858
+ def compute_success_proxy(
859
+ pred: Dict[str, torch.Tensor],
860
+ batch: Dict[str, Any],
861
+ thr_vec: float,
862
+ thr_chk: float,
863
+ thr_trj: float,
864
+ ) -> Tuple[int, int, torch.Tensor, torch.Tensor, torch.Tensor]:
865
+ """
866
+ Returns:
867
+ num_success, num_total, mse_vec_per_sample, mse_chk_per_sample, mse_trj_per_sample
868
+ where per-sample MSE is averaged over all non-batch dims.
869
+ """
870
+ vec_pred = _pred_action(pred, "action_vector")
871
+ chk_pred = _pred_action(pred, "action_chunk")
872
+ trj_pred = _pred_action(pred, "action_trajectory")
873
+
874
+ device = vec_pred.device
875
+
876
+ gt_vec = _align_target(vec_pred, batch["gt_action_vector"].to(device))
877
+ gt_chk = _align_target(chk_pred, batch["gt_action_chunk"].to(device))
878
+ gt_trj = _align_target(trj_pred, batch["gt_action_trajectory"].to(device))
879
+
880
+ reduce_dims_vec = list(range(1, vec_pred.dim()))
881
+ reduce_dims_chk = list(range(1, chk_pred.dim()))
882
+ reduce_dims_trj = list(range(1, trj_pred.dim()))
883
+
884
+ mse_vec_s = ((vec_pred - gt_vec) ** 2).mean(dim=reduce_dims_vec)
885
+ mse_chk_s = ((chk_pred - gt_chk) ** 2).mean(dim=reduce_dims_chk)
886
+ mse_trj_s = ((trj_pred - gt_trj) ** 2).mean(dim=reduce_dims_trj)
887
+
888
+ success_mask = (mse_vec_s < thr_vec) & (mse_chk_s < thr_chk) & (mse_trj_s < thr_trj)
889
+ num_success = int(success_mask.sum().item())
890
+ num_total = int(success_mask.numel())
891
+
892
+ return num_success, num_total, mse_vec_s, mse_chk_s, mse_trj_s
893
+
894
+
895
+ @torch.no_grad()
896
+ def compute_telepathy_stability(pred: Dict[str, torch.Tensor]) -> float:
897
+ tau = pred.get("telepathy_factors", None)
898
+ if tau is None:
899
+ return float("nan")
900
+ return float((tau ** 2).mean().item())
901
+
902
+
903
+ @torch.no_grad()
904
+ def cosine_alignment(a: torch.Tensor, b: torch.Tensor) -> float:
905
+ """
906
+ Cosine alignment that is robust to hidden-size mismatch.
907
+ Accepts [B, D] or [B, T, D]. Pools time if present.
908
+ If dims differ, crops both to min(Da, Db) for a fair cosine check.
909
+ """
910
+ if a.dim() == 3:
911
+ a = a.mean(dim=1)
912
+ if b.dim() == 3:
913
+ b = b.mean(dim=1)
914
+
915
+ if a.numel() == 0 or b.numel() == 0:
916
+ return float("nan")
917
+
918
+ da, db = a.size(-1), b.size(-1)
919
+ if da != db:
920
+ d = min(da, db)
921
+ a = a[..., :d]
922
+ b = b[..., :d]
923
+
924
+ a = F.normalize(a, dim=-1)
925
+ b = F.normalize(b, dim=-1)
926
+ return float((a * b).sum(dim=-1).mean().item())
927
+
928
+
929
+ @torch.no_grad()
930
+ def build_text_tokens_from_policy(
931
+ tokenizer,
932
+ text_embed_layer: nn.Module,
933
+ texts: List[str],
934
+ device: torch.device,
935
+ target_dtype: torch.dtype,
936
+ max_text_len: int = 0,
937
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
938
+ """
939
+ Tokenize prompts and map to embeddings using PI05 internal embedding layer.
940
+ Returns (text_tokens, attn_mask).
941
+ """
942
+ if max_text_len and max_text_len > 0:
943
+ tok = tokenizer(
944
+ texts,
945
+ padding=True,
946
+ truncation=True,
947
+ max_length=max_text_len,
948
+ return_tensors="pt",
949
+ )
950
+ else:
951
+ tok = tokenizer(
952
+ texts,
953
+ padding=True,
954
+ truncation=False,
955
+ return_tensors="pt",
956
+ )
957
+
958
+ if hasattr(tok, "input_ids"):
959
+ input_ids = tok.input_ids
960
+ attn_mask = tok.attention_mask
961
+ else:
962
+ input_ids = tok["input_ids"]
963
+ attn_mask = tok.get("attention_mask", None)
964
+ if attn_mask is None:
965
+ attn_mask = torch.ones_like(input_ids)
966
+
967
+ input_ids = input_ids.to(device)
968
+ attn_mask = attn_mask.to(device)
969
+
970
+ text_tokens = text_embed_layer(input_ids).to(dtype=target_dtype)
971
+ return text_tokens, attn_mask
972
+
973
+
974
+ def main():
975
+ parser = argparse.ArgumentParser()
976
+
977
+ parser.add_argument("--sigma_env", type=str, default="sigma.env")
978
+ parser.add_argument("--shard_dir", type=str, default="")
979
+ parser.add_argument("--output_dir", type=str, default="./sigma_eval_out")
980
+
981
+ parser.add_argument(
982
+ "--base_model_id",
983
+ type=str,
984
+ required=True,
985
+ help="LeRobot/OpenPI policy repo, e.g., lerobot/pi05_base or your Sigma policy repo.",
986
+ )
987
+ parser.add_argument(
988
+ "--telepathy_heads_path",
989
+ type=str,
990
+ default="",
991
+ help="Path to sigma_telepathy_heads.pt. If empty, auto-fetch may fill it.",
992
+ )
993
+ parser.add_argument(
994
+ "--disable_telepathy",
995
+ action="store_true",
996
+ help="Disable telepathy injection (control run).",
997
+ )
998
+ parser.add_argument(
999
+ "--tokenizer_id",
1000
+ type=str,
1001
+ default="",
1002
+ help="Explicit HF tokenizer id OR local tokenizer folder path.",
1003
+ )
1004
+
1005
+ parser.add_argument("--max_text_len", type=int, default=0)
1006
+
1007
+ parser.add_argument(
1008
+ "--artifacts_repo_id",
1009
+ type=str,
1010
+ default="",
1011
+ help="HF repo containing storage/sigma_pickplace and storage/sigma_lora_out.",
1012
+ )
1013
+ parser.add_argument(
1014
+ "--hf_cache_root",
1015
+ type=str,
1016
+ default="/workspace/.hf_sigma_cache",
1017
+ )
1018
+
1019
+ parser.add_argument("--load_in_4bit", action="store_true")
1020
+ parser.add_argument("--dtype", type=str, default="bf16")
1021
+
1022
+ parser.add_argument("--batch_size", type=int, default=4)
1023
+ parser.add_argument("--num_workers", type=int, default=2)
1024
+ parser.add_argument("--max_batches", type=int, default=-1)
1025
+ parser.add_argument("--seed", type=int, default=42)
1026
+ parser.add_argument(
1027
+ "--shuffle",
1028
+ action="store_true",
1029
+ help="Shuffle dataset order to enable different random subsets per seed.",
1030
+ )
1031
+ parser.add_argument(
1032
+ "--telepathy_scale",
1033
+ type=float,
1034
+ default=1.0,
1035
+ help="Multiply telepathy_factors (tau) to control injection strength.",
1036
+ )
1037
+
1038
+ parser.add_argument("--succ_thr_vec", type=float, default=0.05)
1039
+ parser.add_argument("--succ_thr_chk", type=float, default=0.10)
1040
+ parser.add_argument("--succ_thr_trj", type=float, default=0.10)
1041
+
1042
+ # Hard-set thresholds: if <=0, they default to 2x the success thresholds.
1043
+ parser.add_argument(
1044
+ "--hard_thr_vec",
1045
+ type=float,
1046
+ default=-1.0,
1047
+ help="Per-sample MSE threshold for the 'hard' set on vector branch; <=0 means 2x succ_thr_vec.",
1048
+ )
1049
+ parser.add_argument(
1050
+ "--hard_thr_chk",
1051
+ type=float,
1052
+ default=-1.0,
1053
+ help="Per-sample MSE threshold for the 'hard' set on chunk branch; <=0 means 2x succ_thr_chk.",
1054
+ )
1055
+ parser.add_argument(
1056
+ "--hard_thr_trj",
1057
+ type=float,
1058
+ default=-1.0,
1059
+ help="Per-sample MSE threshold for the 'hard' set on trajectory branch; <=0 means 2x succ_thr_trj.",
1060
+ )
1061
+
1062
+ parser.add_argument(
1063
+ "--strict_pi05_load",
1064
+ action="store_true",
1065
+ help="Try strict PI05Policy loading if supported by LeRobot.",
1066
+ )
1067
+ parser.add_argument(
1068
+ "--allow_tokenizer_mismatch",
1069
+ action="store_true",
1070
+ help="Do not fail on tokenizer/embedding mismatch (NOT recommended for baseline).",
1071
+ )
1072
+
1073
+ # Simple flag to enable/disable the adapter without touching telepathy itself.
1074
+ parser.add_argument(
1075
+ "--use_telepathy_adapter",
1076
+ action="store_true",
1077
+ help="If set and telepathy is enabled, apply sigma_telepathy_adapter to actions in eval.",
1078
+ )
1079
+
1080
+ args = parser.parse_args()
1081
+
1082
+ if os.path.exists(args.sigma_env):
1083
+ load_dotenv(args.sigma_env)
1084
+ hf_token = os.getenv("HF_TOKEN", None)
1085
+
1086
+ accelerator = Accelerator(mixed_precision=args.dtype if args.dtype != "fp32" else "no")
1087
+ set_seed(args.seed)
1088
+ device = accelerator.device
1089
+
1090
+ if args.load_in_4bit:
1091
+ print("[WARN] --load_in_4bit is ignored for PI05Policy evaluator.")
1092
+
1093
+ artifacts_repo = args.artifacts_repo_id.strip()
1094
+ if not artifacts_repo and args.base_model_id.startswith("Veltraxor/"):
1095
+ artifacts_repo = args.base_model_id
1096
+
1097
+ need_shards = (not args.shard_dir) or (not os.path.isdir(args.shard_dir))
1098
+ need_heads = (not args.telepathy_heads_path) or (not os.path.isfile(args.telepathy_heads_path))
1099
+
1100
+ if artifacts_repo and (need_shards or need_heads):
1101
+ paths = ensure_sigma_artifacts_from_hf(
1102
+ repo_id=artifacts_repo,
1103
+ hf_token=hf_token,
1104
+ local_cache_root=args.hf_cache_root,
1105
+ )
1106
+ if need_shards:
1107
+ args.shard_dir = paths["shard_dir"]
1108
+ print(f"[INFO] Using cached shard_dir: {args.shard_dir}")
1109
+ if need_heads:
1110
+ args.telepathy_heads_path = paths["telepathy_heads_path"]
1111
+ print(f"[INFO] Using cached telepathy_heads_path: {args.telepathy_heads_path}")
1112
+
1113
+ if not args.shard_dir or not os.path.isdir(args.shard_dir):
1114
+ raise FileNotFoundError(
1115
+ f"shard_dir not found locally: {args.shard_dir}. "
1116
+ "Either provide a valid local path or an artifacts_repo_id for auto-download."
1117
+ )
1118
+
1119
+ if not args.telepathy_heads_path or not os.path.isfile(args.telepathy_heads_path):
1120
+ raise FileNotFoundError(
1121
+ f"telepathy_heads_path not found locally: {args.telepathy_heads_path}. "
1122
+ "Either provide a valid local path or store it under storage/sigma_lora_out/ "
1123
+ "in artifacts_repo_id for auto-download."
1124
+ )
1125
+
1126
+ policy = load_pi05_policy(
1127
+ args.base_model_id,
1128
+ hf_token,
1129
+ device=device,
1130
+ strict_load=args.strict_pi05_load,
1131
+ )
1132
+
1133
+ tokenizer = get_policy_tokenizer(
1134
+ policy,
1135
+ args.base_model_id,
1136
+ hf_token,
1137
+ forced_tokenizer_id=args.tokenizer_id,
1138
+ )
1139
+ text_embed_layer = get_policy_text_embedding_layer(policy)
1140
+
1141
+ verify_tokenizer_embedding_compat(
1142
+ tokenizer=tokenizer,
1143
+ text_embed_layer=text_embed_layer,
1144
+ allow_mismatch=args.allow_tokenizer_mismatch,
1145
+ )
1146
+
1147
+ v_cfg = VisionConfig()
1148
+ l_cfg = LanguageConfig()
1149
+ a_cfg = ActionConfig()
1150
+ telepathy_vla = TelepathyVLA(v_cfg, l_cfg, a_cfg, disable_telepathy=args.disable_telepathy)
1151
+ telepathy_vla.telepathy_scale = args.telepathy_scale
1152
+
1153
+ # Instantiate Telepathy adapter (used only when telepathy is enabled and flag is set).
1154
+ adapter_cfg = SigmaTelepathyAdapterConfig()
1155
+ telepathy_adapter = SigmaTelepathyAdapter(adapter_cfg).to(device)
1156
+
1157
+ if accelerator.is_main_process:
1158
+ file_size_mb = os.path.getsize(args.telepathy_heads_path) / (1024 * 1024)
1159
+ print(f"[CHECK-A] disable_telepathy={args.disable_telepathy}")
1160
+ print(f"[CHECK-A] telepathy_heads_path={args.telepathy_heads_path} size={file_size_mb:.2f}MB")
1161
+
1162
+ sd = torch.load(args.telepathy_heads_path, map_location="cpu")
1163
+
1164
+ tensor_list = [v.detach().float().reshape(-1) for v in sd.values() if torch.is_tensor(v)]
1165
+ if accelerator.is_main_process and len(tensor_list) > 0:
1166
+ capped = [t[:100000] for t in tensor_list]
1167
+ flat = torch.cat(capped, dim=0)
1168
+ rms = torch.sqrt((flat ** 2).mean()).item()
1169
+ print(f"[CHECK-A] heads_tensors={len(tensor_list)} mean={flat.mean().item():.6f} std={flat.std().item():.6f} rms={rms:.6f}")
1170
+
1171
+ missing, unexpected = telepathy_vla.load_state_dict(sd, strict=False)
1172
+ if accelerator.is_main_process:
1173
+ if len(missing) > 0 or len(unexpected) > 0:
1174
+ print(f"[CHECK-A] loaded with strict=False. Missing={len(missing)} Unexpected={len(unexpected)}")
1175
+ print(f"[CHECK-A] Missing keys (first 20): {missing[:20]}")
1176
+ print(f"[CHECK-A] Unexpected keys (first 20): {unexpected[:20]}")
1177
+ else:
1178
+ print("[CHECK-A] heads fully matched (no missing/unexpected).")
1179
+
1180
+ telepathy_vla.eval()
1181
+
1182
+ ds = SigmaShardDataset(args.shard_dir)
1183
+ dl = DataLoader(
1184
+ ds,
1185
+ batch_size=args.batch_size,
1186
+ shuffle=args.shuffle,
1187
+ num_workers=args.num_workers,
1188
+ collate_fn=collate_sigma,
1189
+ drop_last=False,
1190
+ pin_memory=torch.cuda.is_available(),
1191
+ )
1192
+
1193
+ telepathy_vla, dl = accelerator.prepare(telepathy_vla, dl)
1194
+ target_dtype = next(telepathy_vla.parameters()).dtype
1195
+
1196
+ sum_mse_vec = 0.0
1197
+ sum_mse_chk = 0.0
1198
+ sum_mse_trj = 0.0
1199
+ sum_tau_l2 = 0.0
1200
+ sum_sem_align = 0.0
1201
+
1202
+ # Hard-set aggregators
1203
+ hard_thr_vec = args.hard_thr_vec if args.hard_thr_vec > 0.0 else 2.0 * args.succ_thr_vec
1204
+ hard_thr_chk = args.hard_thr_chk if args.hard_thr_chk > 0.0 else 2.0 * args.succ_thr_chk
1205
+ hard_thr_trj = args.hard_thr_trj if args.hard_thr_trj > 0.0 else 2.0 * args.succ_thr_trj
1206
+
1207
+ sum_hard_mse_vec = 0.0
1208
+ sum_hard_mse_chk = 0.0
1209
+ sum_hard_mse_trj = 0.0
1210
+ total_hard_samples = 0
1211
+
1212
+ n_batches = 0
1213
+ n_samples = 0
1214
+
1215
+ os.makedirs(args.output_dir, exist_ok=True)
1216
+
1217
+ for bidx, batch in enumerate(dl):
1218
+ if args.max_batches > 0 and bidx >= args.max_batches:
1219
+ break
1220
+
1221
+ telepathy_vla.reset_memory()
1222
+
1223
+ B = batch["vis_obs"].size(0)
1224
+ n_samples += B
1225
+
1226
+ text_tokens, attn_mask = build_text_tokens_from_policy(
1227
+ tokenizer=tokenizer,
1228
+ text_embed_layer=text_embed_layer,
1229
+ texts=batch["texts"],
1230
+ device=device,
1231
+ target_dtype=target_dtype,
1232
+ max_text_len=args.max_text_len,
1233
+ )
1234
+
1235
+ robot_state = batch["robot_state"].to(device)
1236
+ if robot_state.dim() == 3:
1237
+ robot_state = robot_state[:, -1]
1238
+
1239
+ # Move optional base actions to device for the adapter.
1240
+ if "base_action_vector" in batch:
1241
+ batch["base_action_vector"] = batch["base_action_vector"].to(device)
1242
+ if "base_action_chunk" in batch:
1243
+ batch["base_action_chunk"] = batch["base_action_chunk"].to(device)
1244
+ if "base_action_trajectory" in batch:
1245
+ batch["base_action_trajectory"] = batch["base_action_trajectory"].to(device)
1246
+
1247
+ try:
1248
+ expected_d = telepathy_vla.vision.state_encoder.mlp[0].in_features
1249
+ except Exception:
1250
+ expected_d = robot_state.size(-1)
1251
+
1252
+ cur_d = robot_state.size(-1)
1253
+ if cur_d < expected_d:
1254
+ robot_state = F.pad(robot_state, (0, expected_d - cur_d))
1255
+ elif cur_d > expected_d:
1256
+ robot_state = robot_state[..., :expected_d]
1257
+
1258
+ pred = telepathy_vla.forward_once(
1259
+ vis_obs=batch["vis_obs"].to(device),
1260
+ robot_state=robot_state,
1261
+ depth_obs=batch["depth_obs"].to(device) if batch["depth_obs"] is not None else None,
1262
+ audio_obs=batch["audio_obs"].to(device) if batch["audio_obs"] is not None else None,
1263
+ text_tokens=text_tokens,
1264
+ attn_mask=attn_mask,
1265
+ return_intermediate=True,
1266
+ )
1267
+
1268
+ if accelerator.is_main_process and bidx == 0:
1269
+ model_ref = telepathy_vla.module if hasattr(telepathy_vla, "module") else telepathy_vla
1270
+ model_ref.reset_memory()
1271
+ prev_flag = bool(model_ref.disable_telepathy)
1272
+ model_ref.disable_telepathy = True
1273
+ pred_ctrl = model_ref.forward_once(
1274
+ vis_obs=batch["vis_obs"].to(device),
1275
+ robot_state=robot_state,
1276
+ depth_obs=batch["depth_obs"].to(device) if batch["depth_obs"] is not None else None,
1277
+ audio_obs=batch["audio_obs"].to(device) if batch["audio_obs"] is not None else None,
1278
+ text_tokens=text_tokens,
1279
+ attn_mask=attn_mask,
1280
+ return_intermediate=False,
1281
+ )
1282
+ model_ref.disable_telepathy = prev_flag
1283
+
1284
+ try:
1285
+ act_exp = _pred_action(pred, "action_vector")
1286
+ act_ctl = _pred_action(pred_ctrl, "action_vector")
1287
+ diff = (act_exp - act_ctl).abs().mean().item()
1288
+ print(f"[CHECK-B] telepathy_effect_mean_abs_diff(action_vector)={diff:.6f}")
1289
+ except Exception as e:
1290
+ print(f"[CHECK-B] action diff check failed: {type(e).__name__}: {e}")
1291
+
1292
+ # Apply Telepathy adapter only when telepathy is enabled and the flag is set.
1293
+ if (not args.disable_telepathy) and args.use_telepathy_adapter:
1294
+ pred = telepathy_adapter(pred, batch)
1295
+
1296
+ mse = compute_branch_mse(pred, batch)
1297
+ tau_l2 = compute_telepathy_stability(pred)
1298
+
1299
+ (
1300
+ _,
1301
+ _,
1302
+ mse_vec_s,
1303
+ mse_chk_s,
1304
+ mse_trj_s,
1305
+ ) = compute_success_proxy(
1306
+ pred,
1307
+ batch,
1308
+ thr_vec=args.succ_thr_vec,
1309
+ thr_chk=args.succ_thr_chk,
1310
+ thr_trj=args.succ_thr_trj,
1311
+ )
1312
+
1313
+ # Hard-set accumulation: samples where any branch MSE exceeds hard thresholds
1314
+ hard_mask = (mse_vec_s > hard_thr_vec) | (mse_chk_s > hard_thr_chk) | (mse_trj_s > hard_thr_trj)
1315
+ hard_count = int(hard_mask.sum().item())
1316
+ if hard_count > 0:
1317
+ sum_hard_mse_vec += mse_vec_s[hard_mask].sum().item()
1318
+ sum_hard_mse_chk += mse_chk_s[hard_mask].sum().item()
1319
+ sum_hard_mse_trj += mse_trj_s[hard_mask].sum().item()
1320
+ total_hard_samples += hard_count
1321
+
1322
+ sem_factors = pred.get("semantic_factors", None)
1323
+ if sem_factors is not None:
1324
+ if sem_factors.dim() == 3:
1325
+ sem_pool = sem_factors.mean(dim=1)
1326
+ elif sem_factors.dim() == 2:
1327
+ sem_pool = sem_factors
1328
+ else:
1329
+ sem_pool = sem_factors.view(sem_factors.size(0), -1)
1330
+
1331
+ txt_pool = text_tokens.mean(dim=1)
1332
+ sem_align = cosine_alignment(sem_pool, txt_pool)
1333
+ else:
1334
+ sem_align = float("nan")
1335
+
1336
+ sum_mse_vec += mse["mse_vector"]
1337
+ sum_mse_chk += mse["mse_chunk"]
1338
+ sum_mse_trj += mse["mse_traj"]
1339
+ if not (tau_l2 != tau_l2):
1340
+ sum_tau_l2 += tau_l2
1341
+ if not (sem_align != sem_align):
1342
+ sum_sem_align += sem_align
1343
+
1344
+ n_batches += 1
1345
+
1346
+ if accelerator.is_main_process and bidx % 20 == 0:
1347
+ print(
1348
+ f"batch={bidx} "
1349
+ f"mse_vec={mse['mse_vector']:.4f} mse_chk={mse['mse_chunk']:.4f} mse_trj={mse['mse_traj']:.4f} "
1350
+ f"tau_l2={tau_l2:.4f} sem_align={sem_align:.4f}"
1351
+ )
1352
+
1353
+ if accelerator.is_main_process:
1354
+ avg_mse_vec = sum_mse_vec / max(1, n_batches)
1355
+ avg_mse_chk = sum_mse_chk / max(1, n_batches)
1356
+ avg_mse_trj = sum_mse_trj / max(1, n_batches)
1357
+
1358
+ avg_tau_l2 = sum_tau_l2 / max(1, n_batches)
1359
+ avg_sem_align = sum_sem_align / max(1, n_batches)
1360
+
1361
+ if total_hard_samples > 0:
1362
+ avg_hard_mse_vec = sum_hard_mse_vec / float(total_hard_samples)
1363
+ avg_hard_mse_chk = sum_hard_mse_chk / float(total_hard_samples)
1364
+ avg_hard_mse_trj = sum_hard_mse_trj / float(total_hard_samples)
1365
+ else:
1366
+ avg_hard_mse_vec = float("nan")
1367
+ avg_hard_mse_chk = float("nan")
1368
+ avg_hard_mse_trj = float("nan")
1369
+
1370
+ hard_fraction = float(total_hard_samples / max(1, n_samples))
1371
+
1372
+ report = {
1373
+ "num_samples": n_samples,
1374
+ "num_batches": n_batches,
1375
+ "avg_mse_vector": avg_mse_vec,
1376
+ "avg_mse_chunk": avg_mse_chk,
1377
+ "avg_mse_traj": avg_mse_trj,
1378
+ "avg_tau_l2": avg_tau_l2,
1379
+ "avg_semantic_text_alignment": avg_sem_align,
1380
+ "hard_thresholds": {
1381
+ "vec": hard_thr_vec,
1382
+ "chk": hard_thr_chk,
1383
+ "trj": hard_thr_trj,
1384
+ },
1385
+ "avg_hard_mse_vector": avg_hard_mse_vec,
1386
+ "avg_hard_mse_chunk": avg_hard_mse_chk,
1387
+ "avg_hard_mse_traj": avg_hard_mse_trj,
1388
+ "hard_sample_fraction": hard_fraction,
1389
+ "total_hard_samples": int(total_hard_samples),
1390
+ }
1391
+
1392
+ with open(
1393
+ os.path.join(args.output_dir, "sigma_eval_report.json"),
1394
+ "w",
1395
+ encoding="utf-8",
1396
+ ) as f:
1397
+ json.dump(report, f, indent=2)
1398
+ print("[DONE] Saved report:", report)
1399
+
1400
+
1401
+ if __name__ == "__main__":
1402
+ main()
modeling_pi05.py ADDED
@@ -0,0 +1,1264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import builtins
18
+ import logging
19
+ import math
20
+ from collections import deque
21
+ from pathlib import Path
22
+ from typing import TYPE_CHECKING, Literal, TypedDict
23
+
24
+ import torch
25
+ import torch.nn.functional as F # noqa: N812
26
+ from torch import Tensor, nn
27
+ from typing_extensions import Unpack
28
+
29
+ from lerobot.utils.import_utils import _transformers_available
30
+
31
+ # Conditional import for type checking and lazy loading
32
+ if TYPE_CHECKING or _transformers_available:
33
+ from transformers.models.auto import CONFIG_MAPPING
34
+ from transformers.models.gemma import modeling_gemma
35
+ from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
36
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
37
+ else:
38
+ CONFIG_MAPPING = None
39
+ modeling_gemma = None
40
+ GemmaForCausalLM = None
41
+ PaliGemmaForConditionalGeneration = None
42
+
43
+ from lerobot.configs.policies import PreTrainedConfig
44
+ from lerobot.policies.pi05.configuration_pi05 import PI05Config
45
+ from lerobot.policies.pretrained import PreTrainedPolicy, T
46
+ from lerobot.policies.rtc.modeling_rtc import RTCProcessor
47
+ from lerobot.utils.constants import (
48
+ ACTION,
49
+ OBS_LANGUAGE_ATTENTION_MASK,
50
+ OBS_LANGUAGE_TOKENS,
51
+ OPENPI_ATTENTION_MASK_VALUE,
52
+ )
53
+
54
+
55
+ class ActionSelectKwargs(TypedDict, total=False):
56
+ inference_delay: int | None
57
+ prev_chunk_left_over: Tensor | None
58
+ execution_horizon: int | None
59
+
60
+
61
+ def get_safe_dtype(target_dtype, device_type):
62
+ """Get a safe dtype for the given device type."""
63
+ if device_type == "mps" and target_dtype == torch.float64:
64
+ return torch.float32
65
+ if device_type == "cpu":
66
+ # CPU doesn't support bfloat16, use float32 instead
67
+ if target_dtype == torch.bfloat16:
68
+ return torch.float32
69
+ if target_dtype == torch.float64:
70
+ return torch.float64
71
+ return target_dtype
72
+
73
+
74
+ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
75
+ time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
76
+ ) -> Tensor:
77
+ """Computes sine-cosine positional embedding vectors for scalar positions."""
78
+ if dimension % 2 != 0:
79
+ raise ValueError(f"dimension ({dimension}) must be divisible by 2")
80
+
81
+ if time.ndim != 1:
82
+ raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
83
+
84
+ dtype = get_safe_dtype(torch.float64, device.type)
85
+ fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
86
+ period = min_period * (max_period / min_period) ** fraction
87
+
88
+ # Compute the outer product
89
+ scaling_factor = 1.0 / period * 2 * math.pi
90
+ sin_input = scaling_factor[None, :] * time[:, None]
91
+ return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
92
+
93
+
94
+ def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
95
+ alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
96
+ beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
97
+ dist = torch.distributions.Beta(alpha_t, beta_t)
98
+ return dist.sample((bsize,))
99
+
100
+
101
+ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
102
+ """Copied from big_vision.
103
+
104
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
105
+ smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
106
+ setup several types of attention, for example:
107
+
108
+ [[1 1 1 1 1 1]]: pure causal attention.
109
+
110
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
111
+ themselves and the last 3 tokens have a causal attention. The first
112
+ entry could also be a 1 without changing behaviour.
113
+
114
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
115
+ block can attend all previous blocks and all tokens on the same block.
116
+
117
+ Args:
118
+ input_mask: bool[B, N] true if its part of the input, false if padding.
119
+ mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
120
+ it and 0 where it shares the same attention mask as the previous token.
121
+ """
122
+ if att_masks.ndim != 2:
123
+ raise ValueError(att_masks.ndim)
124
+ if pad_masks.ndim != 2:
125
+ raise ValueError(pad_masks.ndim)
126
+
127
+ cumsum = torch.cumsum(att_masks, dim=1)
128
+ att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
129
+ pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
130
+ return att_2d_masks & pad_2d_masks
131
+
132
+
133
+ def pad_vector(vector, new_dim):
134
+ """Pad the last dimension of a vector to new_dim with zeros.
135
+
136
+ Can be (batch_size x sequence_length x features_dimension)
137
+ or (batch_size x features_dimension)
138
+ """
139
+ if vector.shape[-1] >= new_dim:
140
+ return vector
141
+ return F.pad(vector, (0, new_dim - vector.shape[-1]))
142
+
143
+
144
+ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
145
+ images: torch.Tensor,
146
+ height: int,
147
+ width: int,
148
+ mode: str = "bilinear",
149
+ ) -> torch.Tensor:
150
+ """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
151
+ by padding with black. If the image is float32, it must be in the range [-1, 1].
152
+
153
+ Args:
154
+ images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
155
+ height: Target height
156
+ width: Target width
157
+ mode: Interpolation mode ('bilinear', 'nearest', etc.)
158
+
159
+ Returns:
160
+ Resized and padded tensor with same shape format as input
161
+ """
162
+ # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
163
+ if images.shape[-1] <= 4: # Assume channels-last format
164
+ channels_last = True
165
+ if images.dim() == 3:
166
+ images = images.unsqueeze(0) # Add batch dimension
167
+ images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
168
+ else:
169
+ channels_last = False
170
+ if images.dim() == 3:
171
+ images = images.unsqueeze(0) # Add batch dimension
172
+
173
+ batch_size, channels, cur_height, cur_width = images.shape
174
+
175
+ # Calculate resize ratio
176
+ ratio = max(cur_width / width, cur_height / height)
177
+ resized_height = int(cur_height / ratio)
178
+ resized_width = int(cur_width / ratio)
179
+
180
+ # Resize
181
+ resized_images = F.interpolate(
182
+ images,
183
+ size=(resized_height, resized_width),
184
+ mode=mode,
185
+ align_corners=False if mode == "bilinear" else None,
186
+ )
187
+
188
+ # Handle dtype-specific clipping
189
+ if images.dtype == torch.uint8:
190
+ resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
191
+ elif images.dtype == torch.float32:
192
+ resized_images = resized_images.clamp(-1.0, 1.0)
193
+ else:
194
+ raise ValueError(f"Unsupported image dtype: {images.dtype}")
195
+
196
+ # Calculate padding
197
+ pad_h0, remainder_h = divmod(height - resized_height, 2)
198
+ pad_h1 = pad_h0 + remainder_h
199
+ pad_w0, remainder_w = divmod(width - resized_width, 2)
200
+ pad_w1 = pad_w0 + remainder_w
201
+
202
+ # Pad
203
+ constant_value = 0 if images.dtype == torch.uint8 else -1.0
204
+ padded_images = F.pad(
205
+ resized_images,
206
+ (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
207
+ mode="constant",
208
+ value=constant_value,
209
+ )
210
+
211
+ # Convert back to original format if needed
212
+ if channels_last:
213
+ padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
214
+
215
+ return padded_images
216
+
217
+
218
+ # Define the complete layer computation function for gradient checkpointing
219
+ def compute_layer_complete(
220
+ layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
221
+ ):
222
+ models = [paligemma.language_model, gemma_expert.model]
223
+ query_states = []
224
+ key_states = []
225
+ value_states = []
226
+ gates = []
227
+ for i, hidden_states in enumerate(inputs_embeds):
228
+ layer = models[i].layers[layer_idx]
229
+ hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
230
+ gates.append(gate)
231
+ input_shape = hidden_states.shape[:-1]
232
+ hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
233
+ query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
234
+ key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
235
+ value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
236
+ query_states.append(query_state)
237
+ key_states.append(key_state)
238
+ value_states.append(value_state)
239
+ # Concatenate and process attention
240
+ query_states = torch.cat(query_states, dim=2)
241
+ key_states = torch.cat(key_states, dim=2)
242
+ value_states = torch.cat(value_states, dim=2)
243
+ dummy_tensor = torch.zeros(
244
+ query_states.shape[0],
245
+ query_states.shape[2],
246
+ query_states.shape[-1],
247
+ device=query_states.device,
248
+ dtype=query_states.dtype,
249
+ )
250
+ cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
251
+ query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
252
+ query_states, key_states, cos, sin, unsqueeze_dim=1
253
+ )
254
+ batch_size = query_states.shape[0]
255
+ scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
256
+ # Attention computation
257
+ att_output, _ = modeling_gemma.eager_attention_forward(
258
+ paligemma.language_model.layers[layer_idx].self_attn,
259
+ query_states,
260
+ key_states,
261
+ value_states,
262
+ attention_mask,
263
+ scaling,
264
+ )
265
+ # Get head_dim from the current layer, not from the model
266
+ head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
267
+ att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
268
+ # Process layer outputs
269
+ outputs_embeds = []
270
+ start_pos = 0
271
+ for i, hidden_states in enumerate(inputs_embeds):
272
+ layer = models[i].layers[layer_idx]
273
+ end_pos = start_pos + hidden_states.shape[1]
274
+ if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
275
+ att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
276
+ out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
277
+ # first residual
278
+ out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
279
+ after_first_residual = out_emb.clone()
280
+ out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
281
+ # Convert to bfloat16 if the next layer (mlp) uses bfloat16
282
+ if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
283
+ out_emb = out_emb.to(dtype=torch.bfloat16)
284
+ out_emb = layer.mlp(out_emb)
285
+ # second residual
286
+ out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
287
+ outputs_embeds.append(out_emb)
288
+ start_pos = end_pos
289
+ return outputs_embeds
290
+
291
+
292
+ class GemmaConfig: # see openpi `gemma.py: Config`
293
+ """Configuration for Gemma model variants."""
294
+
295
+ def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim):
296
+ self.width = width
297
+ self.depth = depth
298
+ self.mlp_dim = mlp_dim
299
+ self.num_heads = num_heads
300
+ self.num_kv_heads = num_kv_heads
301
+ self.head_dim = head_dim
302
+
303
+
304
+ def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config`
305
+ """Returns config for specified gemma variant."""
306
+ if variant == "gemma_300m":
307
+ return GemmaConfig(
308
+ width=1024,
309
+ depth=18,
310
+ mlp_dim=4096,
311
+ num_heads=8,
312
+ num_kv_heads=1,
313
+ head_dim=256,
314
+ )
315
+ elif variant == "gemma_2b":
316
+ return GemmaConfig(
317
+ width=2048,
318
+ depth=18,
319
+ mlp_dim=16_384,
320
+ num_heads=8,
321
+ num_kv_heads=1,
322
+ head_dim=256,
323
+ )
324
+ else:
325
+ raise ValueError(f"Unknown variant: {variant}")
326
+
327
+
328
+ class PaliGemmaWithExpertModel(
329
+ nn.Module
330
+ ): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi
331
+ """PaliGemma model with action expert for PI05."""
332
+
333
+ def __init__(
334
+ self,
335
+ vlm_config,
336
+ action_expert_config,
337
+ use_adarms=None,
338
+ precision: Literal["bfloat16", "float32"] = "bfloat16",
339
+ ):
340
+ if use_adarms is None:
341
+ use_adarms = [False, False]
342
+ super().__init__()
343
+
344
+ vlm_config_hf = CONFIG_MAPPING["paligemma"]()
345
+ vlm_config_hf._vocab_size = 257152 # noqa: SLF001
346
+ vlm_config_hf.image_token_index = 257152
347
+ vlm_config_hf.text_config.hidden_size = vlm_config.width
348
+ vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
349
+ vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
350
+ vlm_config_hf.text_config.head_dim = vlm_config.head_dim
351
+ vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
352
+ vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
353
+ vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
354
+ vlm_config_hf.text_config.torch_dtype = "float32"
355
+ vlm_config_hf.text_config.vocab_size = 257152
356
+ vlm_config_hf.text_config.use_adarms = use_adarms[0]
357
+ vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
358
+ vlm_config_hf.vision_config.intermediate_size = 4304
359
+ vlm_config_hf.vision_config.projection_dim = 2048
360
+ vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
361
+ vlm_config_hf.vision_config.torch_dtype = "float32"
362
+
363
+ action_expert_config_hf = CONFIG_MAPPING["gemma"](
364
+ head_dim=action_expert_config.head_dim,
365
+ hidden_size=action_expert_config.width,
366
+ intermediate_size=action_expert_config.mlp_dim,
367
+ num_attention_heads=action_expert_config.num_heads,
368
+ num_hidden_layers=action_expert_config.depth,
369
+ num_key_value_heads=action_expert_config.num_kv_heads,
370
+ vocab_size=257152,
371
+ hidden_activation="gelu_pytorch_tanh",
372
+ torch_dtype="float32",
373
+ use_adarms=use_adarms[1],
374
+ adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
375
+ )
376
+
377
+ self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
378
+ self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
379
+ self.gemma_expert.model.embed_tokens = None
380
+
381
+ self.to_bfloat16_for_selected_params(precision)
382
+
383
+ def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
384
+ if precision == "bfloat16":
385
+ self.to(dtype=torch.bfloat16)
386
+ elif precision == "float32":
387
+ self.to(dtype=torch.float32)
388
+ return
389
+ else:
390
+ raise ValueError(f"Invalid precision: {precision}")
391
+
392
+ params_to_keep_float32 = [
393
+ "vision_tower.vision_model.embeddings.patch_embedding.weight",
394
+ "vision_tower.vision_model.embeddings.patch_embedding.bias",
395
+ "vision_tower.vision_model.embeddings.position_embedding.weight",
396
+ "input_layernorm",
397
+ "post_attention_layernorm",
398
+ "model.norm",
399
+ ]
400
+
401
+ for name, param in self.named_parameters():
402
+ if any(selector in name for selector in params_to_keep_float32):
403
+ param.data = param.data.to(dtype=torch.float32)
404
+
405
+ def embed_image(self, image: torch.Tensor):
406
+ return self.paligemma.model.get_image_features(image)
407
+
408
+ def embed_language_tokens(self, tokens: torch.Tensor):
409
+ return self.paligemma.language_model.embed_tokens(tokens)
410
+
411
+ def forward(
412
+ self,
413
+ attention_mask: torch.Tensor | None = None,
414
+ position_ids: torch.LongTensor | None = None,
415
+ past_key_values: list[torch.FloatTensor] | None = None,
416
+ inputs_embeds: list[torch.FloatTensor] | None = None,
417
+ use_cache: bool | None = None,
418
+ adarms_cond: list[torch.Tensor] | None = None,
419
+ ):
420
+ if adarms_cond is None:
421
+ adarms_cond = [None, None]
422
+ if inputs_embeds[1] is None:
423
+ prefix_output = self.paligemma.language_model.forward(
424
+ inputs_embeds=inputs_embeds[0],
425
+ attention_mask=attention_mask,
426
+ position_ids=position_ids,
427
+ past_key_values=past_key_values,
428
+ use_cache=use_cache,
429
+ adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
430
+ )
431
+ prefix_past_key_values = prefix_output.past_key_values
432
+ prefix_output = prefix_output.last_hidden_state
433
+ suffix_output = None
434
+ elif inputs_embeds[0] is None:
435
+ suffix_output = self.gemma_expert.model.forward(
436
+ inputs_embeds=inputs_embeds[1],
437
+ attention_mask=attention_mask,
438
+ position_ids=position_ids,
439
+ past_key_values=past_key_values,
440
+ use_cache=use_cache,
441
+ adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
442
+ )
443
+ suffix_output = suffix_output.last_hidden_state
444
+ prefix_output = None
445
+ prefix_past_key_values = None
446
+ else:
447
+ models = [self.paligemma.language_model, self.gemma_expert.model]
448
+ num_layers = self.paligemma.config.text_config.num_hidden_layers
449
+
450
+ # Check if gradient checkpointing is enabled for any of the models
451
+ use_gradient_checkpointing = (
452
+ hasattr(self.gemma_expert.model, "gradient_checkpointing")
453
+ and self.gemma_expert.model.gradient_checkpointing
454
+ and self.training
455
+ ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
456
+
457
+ # Process all layers with gradient checkpointing if enabled
458
+ for layer_idx in range(num_layers):
459
+ if use_gradient_checkpointing:
460
+ inputs_embeds = torch.utils.checkpoint.checkpoint(
461
+ compute_layer_complete,
462
+ layer_idx,
463
+ inputs_embeds,
464
+ attention_mask,
465
+ position_ids,
466
+ adarms_cond,
467
+ use_reentrant=False,
468
+ preserve_rng_state=False,
469
+ paligemma=self.paligemma,
470
+ gemma_expert=self.gemma_expert,
471
+ )
472
+ else:
473
+ inputs_embeds = compute_layer_complete(
474
+ layer_idx,
475
+ inputs_embeds,
476
+ attention_mask,
477
+ position_ids,
478
+ adarms_cond,
479
+ paligemma=self.paligemma,
480
+ gemma_expert=self.gemma_expert,
481
+ )
482
+
483
+ # final norm
484
+ def compute_final_norms(inputs_embeds, adarms_cond):
485
+ outputs_embeds = []
486
+ for i, hidden_states in enumerate(inputs_embeds):
487
+ out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
488
+ outputs_embeds.append(out_emb)
489
+ return outputs_embeds
490
+
491
+ # Apply gradient checkpointing to final norm if enabled
492
+ if use_gradient_checkpointing:
493
+ outputs_embeds = torch.utils.checkpoint.checkpoint(
494
+ compute_final_norms,
495
+ inputs_embeds,
496
+ adarms_cond,
497
+ use_reentrant=False,
498
+ preserve_rng_state=False,
499
+ )
500
+ else:
501
+ outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
502
+
503
+ prefix_output = outputs_embeds[0]
504
+ suffix_output = outputs_embeds[1]
505
+ prefix_past_key_values = None
506
+
507
+ return [prefix_output, suffix_output], prefix_past_key_values
508
+
509
+
510
+ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
511
+ """Core PI05 PyTorch model."""
512
+
513
+ def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
514
+ super().__init__()
515
+ self.config = config
516
+ self.rtc_processor = rtc_processor
517
+
518
+ paligemma_config = get_gemma_config(config.paligemma_variant)
519
+ action_expert_config = get_gemma_config(config.action_expert_variant)
520
+
521
+ self.paligemma_with_expert = PaliGemmaWithExpertModel(
522
+ paligemma_config,
523
+ action_expert_config,
524
+ use_adarms=[False, True],
525
+ precision=config.dtype,
526
+ )
527
+
528
+ self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
529
+ self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim)
530
+
531
+ self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
532
+ self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
533
+
534
+ # Initialize gradient checkpointing flag
535
+ self.gradient_checkpointing_enabled = False
536
+
537
+ # Compile model if requested
538
+ if config.compile_model:
539
+ torch.set_float32_matmul_precision("high")
540
+ self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
541
+
542
+ msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
543
+
544
+ # PATCH: make transformers version guard non-fatal and robust across versions
545
+ try:
546
+ from transformers.models.siglip import check
547
+
548
+ if hasattr(check, "check_whether_transformers_replace_is_installed_correctly"):
549
+ ok = check.check_whether_transformers_replace_is_installed_correctly()
550
+ if not ok:
551
+ logging.warning("[pi05] %s", msg)
552
+ else:
553
+ logging.warning(
554
+ "[patch_pi05] SigLIP check helper missing; skipping strict transformers version guard."
555
+ )
556
+ except Exception as e: # noqa: BLE001
557
+ logging.warning(
558
+ "[patch_pi05] Could not run transformers version guard (%s). "
559
+ "Continuing without strict transformers check. %s",
560
+ msg,
561
+ e,
562
+ )
563
+
564
+ def gradient_checkpointing_enable(self):
565
+ """Enable gradient checkpointing for memory optimization."""
566
+ self.gradient_checkpointing_enabled = True
567
+ self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
568
+ self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
569
+ self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
570
+ logging.info("Enabled gradient checkpointing for PI05Pytorch model")
571
+
572
+ def gradient_checkpointing_disable(self):
573
+ """Disable gradient checkpointing."""
574
+ self.gradient_checkpointing_enabled = False
575
+ self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
576
+ self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
577
+ self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
578
+ logging.info("Disabled gradient checkpointing for PI05Pytorch model")
579
+
580
+ def _rtc_enabled(self):
581
+ return self.config.rtc_config is not None and self.config.rtc_config.enabled
582
+
583
+ def _apply_checkpoint(self, func, *args, **kwargs):
584
+ """Helper method to apply gradient checkpointing if enabled."""
585
+ if self.gradient_checkpointing_enabled and self.training:
586
+ return torch.utils.checkpoint.checkpoint(
587
+ func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
588
+ )
589
+ return func(*args, **kwargs)
590
+
591
+ def _prepare_attention_masks_4d(self, att_2d_masks):
592
+ """Helper method to prepare 4D attention masks for transformer."""
593
+ att_2d_masks_4d = att_2d_masks[:, None, :, :]
594
+ return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
595
+
596
+ def sample_noise(self, shape, device):
597
+ return torch.normal(
598
+ mean=0.0,
599
+ std=1.0,
600
+ size=shape,
601
+ dtype=torch.float32,
602
+ device=device,
603
+ )
604
+
605
+ def sample_time(self, bsize, device):
606
+ time_beta = sample_beta(
607
+ self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
608
+ )
609
+ time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
610
+ return time.to(dtype=torch.float32, device=device)
611
+
612
+ def embed_prefix(
613
+ self, images, img_masks, tokens, masks
614
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
615
+ """Embed images with SigLIP and language tokens with embedding layer."""
616
+ embs = []
617
+ pad_masks = []
618
+ att_masks = []
619
+
620
+ # Process images
621
+ for img, img_mask in zip(images, img_masks, strict=True):
622
+
623
+ def image_embed_func(img):
624
+ return self.paligemma_with_expert.embed_image(img)
625
+
626
+ img_emb = self._apply_checkpoint(image_embed_func, img)
627
+ bsize, num_img_embs = img_emb.shape[:2]
628
+
629
+ embs.append(img_emb)
630
+ pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
631
+ att_masks += [0] * num_img_embs
632
+
633
+ # Process language tokens
634
+ def lang_embed_func(tokens):
635
+ lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
636
+ lang_emb_dim = lang_emb.shape[-1]
637
+ return lang_emb * math.sqrt(lang_emb_dim)
638
+
639
+ lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
640
+ embs.append(lang_emb)
641
+ pad_masks.append(masks)
642
+
643
+ num_lang_embs = lang_emb.shape[1]
644
+ att_masks += [0] * num_lang_embs
645
+
646
+ embs = torch.cat(embs, dim=1)
647
+ pad_masks = torch.cat(pad_masks, dim=1)
648
+ att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
649
+
650
+ bsize = pad_masks.shape[0]
651
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
652
+
653
+ return embs, pad_masks, att_masks
654
+
655
+ def embed_suffix(self, noisy_actions, timestep):
656
+ """Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
657
+ embs = []
658
+ pad_masks = []
659
+ att_masks = []
660
+
661
+ # Embed timestep using sine-cosine positional encoding
662
+ time_emb = create_sinusoidal_pos_embedding(
663
+ timestep,
664
+ self.action_in_proj.out_features,
665
+ min_period=self.config.min_period,
666
+ max_period=self.config.max_period,
667
+ device=timestep.device,
668
+ )
669
+ time_emb = time_emb.type(dtype=timestep.dtype)
670
+
671
+ # Fuse timestep + action information using an MLP
672
+ def action_proj_func(noisy_actions):
673
+ return self.action_in_proj(noisy_actions)
674
+
675
+ action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
676
+
677
+ def time_mlp_func(time_emb):
678
+ x = self.time_mlp_in(time_emb)
679
+ x = F.silu(x)
680
+ x = self.time_mlp_out(x)
681
+ return F.silu(x)
682
+
683
+ time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
684
+ action_time_emb = action_emb
685
+ adarms_cond = time_emb
686
+
687
+ embs.append(action_time_emb)
688
+ bsize, action_time_dim = action_time_emb.shape[:2]
689
+ action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
690
+ pad_masks.append(action_time_mask)
691
+
692
+ # Set attention masks so that image, language and state inputs do not attend to action tokens
693
+ att_masks += [1] + ([0] * (self.config.chunk_size - 1))
694
+
695
+ embs = torch.cat(embs, dim=1)
696
+ pad_masks = torch.cat(pad_masks, dim=1)
697
+ att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
698
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
699
+
700
+ return embs, pad_masks, att_masks, adarms_cond
701
+
702
+ def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
703
+ """Do a full training forward pass and compute the loss."""
704
+ if noise is None:
705
+ noise = self.sample_noise(actions.shape, actions.device)
706
+
707
+ if time is None:
708
+ time = self.sample_time(actions.shape[0], actions.device)
709
+
710
+ time_expanded = time[:, None, None]
711
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
712
+ u_t = noise - actions
713
+
714
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
715
+ suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
716
+
717
+ if (
718
+ self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
719
+ == torch.bfloat16
720
+ ):
721
+ suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
722
+ prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
723
+
724
+ pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
725
+ att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
726
+
727
+ att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
728
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1
729
+
730
+ att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
731
+
732
+ def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
733
+ (_, suffix_out), _ = self.paligemma_with_expert.forward(
734
+ attention_mask=att_2d_masks_4d,
735
+ position_ids=position_ids,
736
+ past_key_values=None,
737
+ inputs_embeds=[prefix_embs, suffix_embs],
738
+ use_cache=False,
739
+ adarms_cond=[None, adarms_cond],
740
+ )
741
+ return suffix_out
742
+
743
+ suffix_out = self._apply_checkpoint(
744
+ forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
745
+ )
746
+
747
+ suffix_out = suffix_out[:, -self.config.chunk_size :]
748
+ suffix_out = suffix_out.to(dtype=torch.float32)
749
+
750
+ def action_out_proj_func(suffix_out):
751
+ return self.action_out_proj(suffix_out)
752
+
753
+ v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
754
+
755
+ return F.mse_loss(u_t, v_t, reduction="none")
756
+
757
+ @torch.no_grad() # see openpi `sample_actions` (slightly adapted)
758
+ def sample_actions(
759
+ self,
760
+ images,
761
+ img_masks,
762
+ tokens,
763
+ masks,
764
+ noise=None,
765
+ num_steps=None,
766
+ **kwargs: Unpack[ActionSelectKwargs],
767
+ ) -> Tensor:
768
+ """Do a full inference forward and compute the action."""
769
+ if num_steps is None:
770
+ num_steps = self.config.num_inference_steps
771
+
772
+ bsize = tokens.shape[0]
773
+ device = tokens.device
774
+
775
+ if noise is None:
776
+ # Sample noise with padded dimension as expected by action_in_proj
777
+ actions_shape = (
778
+ bsize,
779
+ self.config.chunk_size,
780
+ self.config.max_action_dim,
781
+ ) # Use config max_action_dim for internal processing
782
+ noise = self.sample_noise(actions_shape, device)
783
+
784
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
785
+ prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
786
+ prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
787
+
788
+ prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
789
+ self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
790
+
791
+ _, past_key_values = self.paligemma_with_expert.forward(
792
+ attention_mask=prefix_att_2d_masks_4d,
793
+ position_ids=prefix_position_ids,
794
+ past_key_values=None,
795
+ inputs_embeds=[prefix_embs, None],
796
+ use_cache=True,
797
+ )
798
+
799
+ dt = -1.0 / num_steps
800
+ dt = torch.tensor(dt, dtype=torch.float32, device=device)
801
+
802
+ x_t = noise
803
+ time = torch.tensor(1.0, dtype=torch.float32, device=device)
804
+ while time >= -dt / 2:
805
+ expanded_time = time.expand(bsize)
806
+
807
+ # Define a closure function to properly capture expanded_time
808
+ # This avoids the lambda expression (E731) and loop variable binding (B023) issues
809
+ def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
810
+ return self.denoise_step(
811
+ prefix_pad_masks=prefix_pad_masks,
812
+ past_key_values=past_key_values,
813
+ x_t=input_x_t,
814
+ timestep=current_timestep,
815
+ )
816
+
817
+ if self._rtc_enabled():
818
+ inference_delay = kwargs.get("inference_delay")
819
+ prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
820
+ execution_horizon = kwargs.get("execution_horizon")
821
+
822
+ v_t = self.rtc_processor.denoise_step(
823
+ x_t=x_t,
824
+ prev_chunk_left_over=prev_chunk_left_over,
825
+ inference_delay=inference_delay,
826
+ time=time,
827
+ original_denoise_step_partial=denoise_step_partial_call,
828
+ execution_horizon=execution_horizon,
829
+ )
830
+ else:
831
+ v_t = denoise_step_partial_call(x_t)
832
+
833
+ # Euler step
834
+ x_t += dt * v_t
835
+
836
+ # Record x_t and v_t after Euler step
837
+ if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
838
+ self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
839
+
840
+ time += dt
841
+
842
+ return x_t
843
+
844
+ def denoise_step(
845
+ self,
846
+ prefix_pad_masks,
847
+ past_key_values,
848
+ x_t,
849
+ timestep,
850
+ ):
851
+ """Apply one denoising step of the noise `x_t` at a given timestep."""
852
+ suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
853
+
854
+ suffix_len = suffix_pad_masks.shape[1]
855
+ batch_size = prefix_pad_masks.shape[0]
856
+ prefix_len = prefix_pad_masks.shape[1]
857
+
858
+ prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
859
+ suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
860
+ full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
861
+
862
+ prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
863
+ position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
864
+
865
+ full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
866
+ self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
867
+
868
+ outputs_embeds, _ = self.paligemma_with_expert.forward(
869
+ attention_mask=full_att_2d_masks_4d,
870
+ position_ids=position_ids,
871
+ past_key_values=past_key_values,
872
+ inputs_embeds=[None, suffix_embs],
873
+ use_cache=False,
874
+ adarms_cond=[None, adarms_cond],
875
+ )
876
+
877
+ suffix_out = outputs_embeds[1]
878
+ suffix_out = suffix_out[:, -self.config.chunk_size :]
879
+ suffix_out = suffix_out.to(dtype=torch.float32)
880
+ return self.action_out_proj(suffix_out)
881
+
882
+
883
+ class PI05Policy(PreTrainedPolicy):
884
+ """PI05 Policy for LeRobot."""
885
+
886
+ config_class = PI05Config
887
+ name = "pi05"
888
+
889
+ def __init__(
890
+ self,
891
+ config: PI05Config,
892
+ ):
893
+ """
894
+ Args:
895
+ config: Policy configuration class instance.
896
+ """
897
+ super().__init__(config)
898
+ config.validate_features()
899
+ self.config = config
900
+
901
+ # Initialize the core PI05 model
902
+ self.init_rtc_processor()
903
+ self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor)
904
+
905
+ # Enable gradient checkpointing if requested
906
+ if config.gradient_checkpointing:
907
+ self.model.gradient_checkpointing_enable()
908
+
909
+ self.model.to(config.device)
910
+
911
+ self.reset()
912
+
913
+ @classmethod
914
+ def from_pretrained(
915
+ cls: builtins.type[T],
916
+ pretrained_name_or_path: str | Path,
917
+ *,
918
+ config: PreTrainedConfig | None = None,
919
+ force_download: bool = False,
920
+ resume_download: bool | None = None,
921
+ proxies: dict | None = None,
922
+ token: str | bool | None = None,
923
+ cache_dir: str | Path | None = None,
924
+ local_files_only: bool = False,
925
+ revision: str | None = None,
926
+ strict: bool = True,
927
+ **kwargs,
928
+ ) -> T:
929
+ """Override the from_pretrained method to handle key remapping and display important disclaimer."""
930
+ print(
931
+ "The PI05 model is a direct port of the OpenPI implementation. \n"
932
+ "This implementation follows the original OpenPI structure for compatibility. \n"
933
+ "Original implementation: https://github.com/Physical-Intelligence/openpi"
934
+ )
935
+ if pretrained_name_or_path is None:
936
+ raise ValueError("pretrained_name_or_path is required")
937
+
938
+ # Use provided config if available, otherwise create default config
939
+ if config is None:
940
+ config = PreTrainedConfig.from_pretrained(
941
+ pretrained_name_or_path=pretrained_name_or_path,
942
+ force_download=force_download,
943
+ resume_download=resume_download,
944
+ proxies=proxies,
945
+ token=token,
946
+ cache_dir=cache_dir,
947
+ local_files_only=local_files_only,
948
+ revision=revision,
949
+ **kwargs,
950
+ )
951
+
952
+ # Initialize model without loading weights
953
+ # Check if dataset_stats were provided in kwargs
954
+ model = cls(config, **kwargs)
955
+
956
+ # Now manually load and remap the state dict
957
+ try:
958
+ # Try to load the pytorch_model.bin or model.safetensors file
959
+ print(f"Loading model from: {pretrained_name_or_path}")
960
+ try:
961
+ from transformers.utils import cached_file
962
+
963
+ # Try safetensors first
964
+ resolved_file = cached_file(
965
+ pretrained_name_or_path,
966
+ "model.safetensors",
967
+ cache_dir=kwargs.get("cache_dir"),
968
+ force_download=kwargs.get("force_download", False),
969
+ resume_download=kwargs.get("resume_download"),
970
+ proxies=kwargs.get("proxies"),
971
+ use_auth_token=kwargs.get("use_auth_token"),
972
+ revision=kwargs.get("revision"),
973
+ local_files_only=kwargs.get("local_files_only", False),
974
+ )
975
+ from safetensors.torch import load_file
976
+
977
+ original_state_dict = load_file(resolved_file)
978
+ print("✓ Loaded state dict from model.safetensors")
979
+ except Exception as e: # noqa: BLE001
980
+ print(f"Could not load state dict from remote files: {e}")
981
+ print("Returning model without loading pretrained weights")
982
+ return model
983
+
984
+ # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
985
+ fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
986
+
987
+ # Then add "model." prefix for all keys that don't already have it
988
+ remapped_state_dict = {}
989
+ remap_count = 0
990
+
991
+ for key, value in fixed_state_dict.items():
992
+ if not key.startswith("model."):
993
+ new_key = f"model.{key}"
994
+ remapped_state_dict[new_key] = value
995
+ remap_count += 1
996
+ if remap_count <= 10: # Only print first 10 to avoid spam
997
+ print(f"Remapped: {key} -> {new_key}")
998
+ else:
999
+ remapped_state_dict[key] = value
1000
+
1001
+ if remap_count > 0:
1002
+ print(f"Remapped {remap_count} state dict keys")
1003
+
1004
+ # Load the remapped state dict into the model
1005
+ missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
1006
+
1007
+ # --- PATCH: tie embed_tokens to lm_head if ckpt omitted embed_tokens ---
1008
+ if any("embed_tokens.weight" in k for k in missing_keys):
1009
+ try:
1010
+ with torch.no_grad():
1011
+ embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens
1012
+ lm_head = model.model.paligemma_with_expert.paligemma.lm_head
1013
+ if embed is not None and lm_head is not None:
1014
+ embed.weight = lm_head.weight
1015
+ except Exception as _e: # noqa: BLE001
1016
+ print("[patch_pi05] Could not tie embed_tokens to lm_head:", _e)
1017
+
1018
+ # --- FIX: tie embed_tokens to lm_head if embed_tokens missing in ckpt ---
1019
+ if any("embed_tokens.weight" in k for k in missing_keys):
1020
+ with torch.no_grad():
1021
+ embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens
1022
+ lm_head = model.model.paligemma_with_expert.paligemma.lm_head
1023
+ embed.weight = lm_head.weight
1024
+
1025
+ if missing_keys:
1026
+ print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
1027
+ if len(missing_keys) <= 5:
1028
+ for key in missing_keys:
1029
+ print(f" - {key}")
1030
+ else:
1031
+ for key in missing_keys[:5]:
1032
+ print(f" - {key}")
1033
+ print(f" ... and {len(missing_keys) - 5} more")
1034
+
1035
+ if unexpected_keys:
1036
+ print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
1037
+ if len(unexpected_keys) <= 5:
1038
+ for key in unexpected_keys:
1039
+ print(f" - {key}")
1040
+ else:
1041
+ for key in unexpected_keys[:5]:
1042
+ print(f" - {key}")
1043
+ print(f" ... and {len(unexpected_keys) - 5} more")
1044
+
1045
+ if not missing_keys and not unexpected_keys:
1046
+ print("All keys loaded successfully!")
1047
+
1048
+ except Exception as e: # noqa: BLE001
1049
+ print(f"Warning: Could not remap state dict keys: {e}")
1050
+
1051
+ return model
1052
+
1053
+ def _fix_pytorch_state_dict_keys(
1054
+ self, state_dict, model_config
1055
+ ): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys`
1056
+ """Fix state dict keys to match current model architecture."""
1057
+ import re
1058
+
1059
+ fixed_state_dict = {}
1060
+
1061
+ for key, value in state_dict.items():
1062
+ new_key = key
1063
+
1064
+ # Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias
1065
+ # For gemma expert layers
1066
+ if re.match(
1067
+ r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight",
1068
+ key,
1069
+ ):
1070
+ # Check if the model actually has adaRMS enabled for the expert
1071
+ expert_uses_adarms = getattr(
1072
+ self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
1073
+ )
1074
+ if expert_uses_adarms:
1075
+ logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}")
1076
+ continue
1077
+
1078
+ if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key):
1079
+ # Check if the model actually has adaRMS enabled for the expert
1080
+ expert_uses_adarms = getattr(
1081
+ self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
1082
+ )
1083
+ if expert_uses_adarms:
1084
+ logging.warning(f"Skipping norm key (adaRMS mismatch): {key}")
1085
+ continue
1086
+
1087
+ # Handle MLP naming changes for pi05
1088
+ # pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_*
1089
+ if key.startswith("action_time_mlp_in."):
1090
+ new_key = key.replace("action_time_mlp_in.", "time_mlp_in.")
1091
+ elif key.startswith("action_time_mlp_out."):
1092
+ new_key = key.replace("action_time_mlp_out.", "time_mlp_out.")
1093
+ # Also handle state_proj which shouldn't exist in pi05
1094
+ if key.startswith("state_proj."):
1095
+ logging.warning(f"Skipping state_proj key in pi05 mode: {key}")
1096
+ continue
1097
+
1098
+ # Handle vision tower embedding layer potential differences
1099
+ if "patch_embedding" in key:
1100
+ # Some checkpoints might have this, but current model expects different structure
1101
+ logging.warning(f"Vision embedding key might need handling: {key}")
1102
+
1103
+ fixed_state_dict[new_key] = value
1104
+
1105
+ return fixed_state_dict
1106
+
1107
+ def get_optim_params(self) -> dict:
1108
+ return self.parameters()
1109
+
1110
+ def reset(self):
1111
+ """Reset internal state - called when environment resets."""
1112
+ self._action_queue = deque(maxlen=self.config.n_action_steps)
1113
+ self._queues = {
1114
+ ACTION: deque(maxlen=self.config.n_action_steps),
1115
+ }
1116
+
1117
+ def init_rtc_processor(self):
1118
+ """Initialize RTC processor if RTC is enabled in config."""
1119
+ self.rtc_processor = None
1120
+
1121
+ # Create processor if config provided
1122
+ # If RTC is not enabled - we can still track the denoising data
1123
+ if self.config.rtc_config is not None:
1124
+ self.rtc_processor = RTCProcessor(self.config.rtc_config)
1125
+
1126
+ model_value = getattr(self, "model", None)
1127
+ if model_value is not None:
1128
+ model_value.rtc_processor = self.rtc_processor
1129
+
1130
+ def _rtc_enabled(self) -> bool:
1131
+ return self.config.rtc_config is not None and self.config.rtc_config.enabled
1132
+
1133
+ def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
1134
+ """Preprocess images for the model.
1135
+
1136
+ Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1].
1137
+ PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1].
1138
+ """
1139
+ images = []
1140
+ img_masks = []
1141
+
1142
+ # Get device from model parameters
1143
+ device = next(self.parameters()).device
1144
+
1145
+ present_img_keys = [key for key in self.config.image_features if key in batch]
1146
+ missing_img_keys = [key for key in self.config.image_features if key not in batch]
1147
+
1148
+ if len(present_img_keys) == 0:
1149
+ raise ValueError(
1150
+ f"All image features are missing from the batch. At least one expected. "
1151
+ f"(batch: {batch.keys()}) (image_features: {self.config.image_features})"
1152
+ )
1153
+
1154
+ # Preprocess image features present in the batch
1155
+ for key in present_img_keys:
1156
+ img = batch[key]
1157
+
1158
+ # Ensure tensor is on the same device as the model
1159
+ if img.device != device:
1160
+ img = img.to(device)
1161
+
1162
+ # Ensure float32 dtype for consistency
1163
+ if img.dtype != torch.float32:
1164
+ img = img.to(torch.float32)
1165
+
1166
+ # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats
1167
+ is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1
1168
+
1169
+ if is_channels_first:
1170
+ # Convert [B, C, H, W] to [B, H, W, C] for processing
1171
+ img = img.permute(0, 2, 3, 1)
1172
+
1173
+ # from openpi preprocess_observation_pytorch: Resize with padding if needed
1174
+ if img.shape[1:3] != self.config.image_resolution:
1175
+ img = resize_with_pad_torch(img, *self.config.image_resolution)
1176
+
1177
+ # Normalize from [0,1] to [-1,1] as expected by siglip
1178
+ img = img * 2.0 - 1.0
1179
+
1180
+ # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first
1181
+ if is_channels_first:
1182
+ img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
1183
+
1184
+ images.append(img)
1185
+ # Create mask (all ones for real images)
1186
+ bsize = img.shape[0]
1187
+ mask = torch.ones(bsize, dtype=torch.bool, device=device)
1188
+ img_masks.append(mask)
1189
+
1190
+ # Create image features not present in the batch as fully 0 padded images
1191
+ for _num_empty_cameras in range(len(missing_img_keys)):
1192
+ img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP
1193
+ mask = torch.zeros_like(mask) # Mask is zero for empty cameras
1194
+ images.append(img)
1195
+ img_masks.append(mask)
1196
+
1197
+ return images, img_masks
1198
+
1199
+ def prepare_action(self, batch):
1200
+ """Pad action"""
1201
+ actions = pad_vector(batch[ACTION], self.config.max_action_dim)
1202
+ return actions
1203
+
1204
+ @torch.no_grad()
1205
+ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
1206
+ """Select a single action given environment observations."""
1207
+ assert not self._rtc_enabled(), (
1208
+ "RTC is not supported for select_action, use it with predict_action_chunk"
1209
+ )
1210
+
1211
+ self.eval()
1212
+
1213
+ # Action queue logic for n_action_steps > 1
1214
+ if len(self._action_queue) == 0:
1215
+ actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
1216
+ # Transpose to get shape (n_action_steps, batch_size, action_dim)
1217
+ self._action_queue.extend(actions.transpose(0, 1))
1218
+
1219
+ return self._action_queue.popleft()
1220
+
1221
+ @torch.no_grad()
1222
+ def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
1223
+ """Predict a chunk of actions given environment observations."""
1224
+ self.eval()
1225
+
1226
+ # Prepare inputs
1227
+ images, img_masks = self._preprocess_images(batch)
1228
+ tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
1229
+
1230
+ # Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
1231
+ actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
1232
+
1233
+ # Unpad actions to actual action dimension
1234
+ original_action_dim = self.config.output_features[ACTION].shape[0]
1235
+ actions = actions[:, :, :original_action_dim]
1236
+
1237
+ return actions
1238
+
1239
+ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
1240
+ """Run the batch through the model and compute the loss for training."""
1241
+
1242
+ # Prepare inputs
1243
+ images, img_masks = self._preprocess_images(batch)
1244
+ tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
1245
+
1246
+ actions = self.prepare_action(batch)
1247
+
1248
+ # Compute loss (no separate state needed for PI05)
1249
+ losses = self.model.forward(images, img_masks, tokens, masks, actions)
1250
+
1251
+ # Truncate losses to actual action dimensions
1252
+ original_action_dim = self.config.output_features[ACTION].shape[0]
1253
+ losses = losses[:, :, :original_action_dim]
1254
+
1255
+ loss = losses.mean()
1256
+
1257
+ loss_dict = {
1258
+ "loss": loss.item(),
1259
+ "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
1260
+ }
1261
+
1262
+ return loss, loss_dict
1263
+
1264
+ # PATCH: downgrade transformer version guard
patch_sigma_env.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ patch_sigma_env.py
4
+
5
+ Idempotent patcher for Sigma VLA experiments.
6
+
7
+ Patch goals:
8
+ 1) LeRobot PI05Policy (modeling_pi05.py):
9
+ 1.1 If ckpt omits embed_tokens.weight, tie embed_tokens.weight to lm_head.weight
10
+ *after* load_state_dict runs.
11
+ 1.2 Ensure torch is imported if target file lacks it.
12
+ 1.3 Downgrade the "incorrect transformer version" hard guard
13
+ (ValueError) to a WARNING so new GPU environments don't crash.
14
+ IMPORTANT: preserve indentation and patch only the intended guard.
15
+
16
+ 2) LeRobot policies __init__ (lerobot/policies/__init__.py):
17
+ 2.1 Make ONLY Groot/Diffusers-related imports optional (wrapped in try/except),
18
+ leaving all other exports untouched.
19
+ This prevents errors like: No module named 'triton.ops'
20
+ or diffusers/peft chain issues on fresh GPUs.
21
+
22
+ 3) eval_sigma_vla_rollout.py (your /workspace eval script):
23
+ 3.1 Force strict=False for PI05Policy.from_pretrained calls:
24
+ - strict=True -> strict=False
25
+ - if a PI05Policy load call has no strict arg, add strict=False
26
+ 3.2 Ensure randomized subset evaluation is possible:
27
+ - add --shuffle arg if missing
28
+ - change DataLoader shuffle=False -> shuffle=getattr(args,"shuffle",False)
29
+
30
+ Safe to run multiple times; no-op if already patched.
31
+ """
32
+
33
+ import os
34
+ import re
35
+ import sys
36
+ import pathlib
37
+ from typing import Optional, Tuple, List
38
+
39
+
40
+ # -------------------------
41
+ # Utilities
42
+ # -------------------------
43
+
44
+ def _read_text(p: pathlib.Path) -> str:
45
+ return p.read_text(encoding="utf-8")
46
+
47
+ def _write_text(p: pathlib.Path, s: str) -> None:
48
+ p.write_text(s, encoding="utf-8")
49
+
50
+ def _search_file(
51
+ roots: List[os.PathLike],
52
+ filename: str,
53
+ must_contain: Optional[str] = None
54
+ ) -> Optional[pathlib.Path]:
55
+ for r in roots:
56
+ r = pathlib.Path(r)
57
+ if not r.exists():
58
+ continue
59
+ for p in r.rglob(filename):
60
+ if must_contain and must_contain not in str(p):
61
+ continue
62
+ return p
63
+ return None
64
+
65
+ def _default_roots():
66
+ return [
67
+ "/workspace/lerobot/src",
68
+ "/workspace/lerobot",
69
+ pathlib.Path(sys.prefix)
70
+ / "lib"
71
+ / f"python{sys.version_info.major}.{sys.version_info.minor}"
72
+ / "site-packages",
73
+ ]
74
+
75
+
76
+ # -------------------------
77
+ # Patch 1: PI05Policy (LeRobot)
78
+ # -------------------------
79
+
80
+ def find_pi05_file() -> pathlib.Path:
81
+ env = os.getenv("PI05_FILE")
82
+ if env:
83
+ p = pathlib.Path(env)
84
+ if p.exists():
85
+ return p
86
+
87
+ p = _search_file(_default_roots(), "modeling_pi05.py", must_contain="/pi05/")
88
+ if p and p.exists():
89
+ return p
90
+
91
+ raise FileNotFoundError("modeling_pi05.py not found. Set PI05_FILE env var to its path.")
92
+
93
+
94
+ def ensure_torch_import(s: str) -> str:
95
+ if re.search(r"(?m)^\s*import\s+torch\b", s) or re.search(r"(?m)^\s*from\s+torch\b", s):
96
+ return s
97
+
98
+ lines = s.splitlines(True)
99
+ insert_idx = 0
100
+
101
+ if lines and lines[0].startswith("#!"):
102
+ insert_idx = 1
103
+
104
+ # skip module docstring block if present
105
+ if insert_idx < len(lines) and lines[insert_idx].lstrip().startswith('"""'):
106
+ i = insert_idx + 1
107
+ while i < len(lines) and '"""' not in lines[i]:
108
+ i += 1
109
+ if i < len(lines):
110
+ insert_idx = i + 1
111
+
112
+ lines.insert(insert_idx, "import torch # PATCH: required for embed/lm_head tying\n")
113
+ return "".join(lines)
114
+
115
+
116
+ def patch_pi05_embed_tie(p: pathlib.Path) -> Tuple[bool, str]:
117
+ s = _read_text(p)
118
+ s = ensure_torch_import(s)
119
+
120
+ marker = "PATCH: tie embed_tokens to lm_head if ckpt omitted embed_tokens"
121
+ if marker in s:
122
+ _write_text(p, s)
123
+ return False, f"PI05 embed-tie patch already present: {p}"
124
+
125
+ pat = r"(?m)^(\s*)missing_keys,\s*unexpected_keys\s*=\s*model\.load_state_dict\(\s*remapped_state_dict\s*,\s*strict\s*=\s*strict\s*\)\s*$"
126
+ m = re.search(pat, s)
127
+ if not m:
128
+ _write_text(p, s)
129
+ return False, f"Could not find load_state_dict line to patch in PI05 file: {p}"
130
+
131
+ indent = m.group(1)
132
+ inject = (
133
+ f"\n{indent}# --- PATCH: tie embed_tokens to lm_head if ckpt omitted embed_tokens ---\n"
134
+ f"{indent}if any('embed_tokens.weight' in k for k in missing_keys):\n"
135
+ f"{indent} try:\n"
136
+ f"{indent} with torch.no_grad():\n"
137
+ f"{indent} embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens\n"
138
+ f"{indent} lm_head = model.model.paligemma_with_expert.paligemma.lm_head\n"
139
+ f"{indent} if embed is not None and lm_head is not None:\n"
140
+ f"{indent} embed.weight = lm_head.weight # {marker}\n"
141
+ f"{indent} except Exception as _e:\n"
142
+ f"{indent} print('[patch_pi05] Could not tie embed_tokens to lm_head:', _e)\n"
143
+ )
144
+
145
+ s2 = re.sub(pat, lambda mm: mm.group(0) + inject, s, count=1)
146
+ _write_text(p, s2)
147
+ return True, f"Patched PI05 embed-tie in: {p}"
148
+
149
+
150
+ def patch_pi05_transformers_guard(p: pathlib.Path) -> Tuple[bool, str]:
151
+ """
152
+ Downgrade ONLY the PI05 hard guard:
153
+ ValueError: An incorrect transformer version is used...
154
+ to WARNING print, preserving indentation.
155
+
156
+ Strategy:
157
+ - Find raise ValueError(msg) from None lines.
158
+ - Only patch the one whose nearby context contains
159
+ "incorrect transformer version".
160
+ """
161
+ s = _read_text(p)
162
+ marker = "PATCH: downgrade transformer version guard"
163
+ if marker in s:
164
+ return False, f"PI05 transformers-guard patch already present: {p}"
165
+
166
+ if "incorrect transformer version" not in s:
167
+ return False, f"No transformers guard message found to patch in: {p}"
168
+
169
+ lines = s.splitlines(True)
170
+ raise_pat = re.compile(r"^(\s*)raise\s+ValueError\(\s*msg\s*\)\s*from\s*None\s*$")
171
+
172
+ target_idx = None
173
+ target_indent = ""
174
+
175
+ for i, line in enumerate(lines):
176
+ m = raise_pat.match(line)
177
+ if not m:
178
+ continue
179
+ # look back a few lines for the specific guard text
180
+ window_start = max(0, i - 8)
181
+ window = "".join(lines[window_start:i+1]).lower()
182
+ if "incorrect transformer version" in window:
183
+ target_idx = i
184
+ target_indent = m.group(1)
185
+ break
186
+
187
+ if target_idx is None:
188
+ return False, f"Guard raise line with context not found in: {p}"
189
+
190
+ repl = (
191
+ f"{target_indent}# --- PATCH: downgrade transformer version guard ---\n"
192
+ f"{target_indent}print('[patch_pi05] WARNING:', msg) # {marker}\n"
193
+ f"{target_indent}# continues execution despite version mismatch\n"
194
+ )
195
+
196
+ lines[target_idx] = repl
197
+ s2 = "".join(lines)
198
+ _write_text(p, s2)
199
+ return True, f"Patched PI05 transformers guard (raise->warn) in: {p}"
200
+
201
+
202
+ # -------------------------
203
+ # Patch 2: LeRobot policies optional imports
204
+ # -------------------------
205
+
206
+ def find_policies_init() -> pathlib.Path:
207
+ env = os.getenv("POLICIES_INIT_FILE")
208
+ if env:
209
+ p = pathlib.Path(env)
210
+ if p.exists():
211
+ return p
212
+
213
+ p = _search_file(_default_roots(), "__init__.py", must_contain="/lerobot/policies/")
214
+ if p and p.exists():
215
+ return p
216
+
217
+ raise FileNotFoundError("lerobot/policies/__init__.py not found. Set POLICIES_INIT_FILE env var.")
218
+
219
+
220
+ def patch_policies_optional_imports(p: pathlib.Path) -> Tuple[bool, str]:
221
+ """
222
+ Make ONLY Groot/Diffusers imports optional.
223
+ This avoids wrapping unrelated exports/imports.
224
+ """
225
+ s = _read_text(p)
226
+ marker = "PATCH: optional Groot/Diffusers imports"
227
+ if marker in s:
228
+ return False, f"Policies optional-import patch already present: {p}"
229
+
230
+ lines = s.splitlines(True)
231
+
232
+ def is_groot_line(line: str) -> bool:
233
+ # strict filter: only lines that import groot submodule
234
+ return bool(re.search(r"^\s*from\s+\.\s*groot\b|^\s*from\s+\.groot\b|^\s*import\s+.*\bgroot\b", line))
235
+
236
+ idxs = [i for i, l in enumerate(lines) if is_groot_line(l)]
237
+ if not idxs:
238
+ return False, f"No Groot imports found to wrap in: {p}"
239
+
240
+ # group consecutive indices
241
+ groups = []
242
+ start = prev = idxs[0]
243
+ for i in idxs[1:]:
244
+ if i == prev + 1:
245
+ prev = i
246
+ else:
247
+ groups.append((start, prev))
248
+ start = prev = i
249
+ groups.append((start, prev))
250
+
251
+ new_lines = []
252
+ last_end = -1
253
+ for (a, b) in groups:
254
+ # copy lines before this group
255
+ new_lines.extend(lines[last_end + 1:a])
256
+
257
+ # wrap group
258
+ new_lines.append("# --- PATCH: optional Groot/Diffusers imports ---\n")
259
+ new_lines.append(f"try: # {marker}\n")
260
+ for j in range(a, b + 1):
261
+ new_lines.append(" " + lines[j].lstrip())
262
+ new_lines.append("except Exception as _e:\n")
263
+ new_lines.append(" print('[policies_init] WARNING: optional groot deps missing:', _e)\n")
264
+
265
+ last_end = b
266
+
267
+ # copy rest
268
+ new_lines.extend(lines[last_end + 1:])
269
+
270
+ s2 = "".join(new_lines)
271
+ if s2 == s:
272
+ return False, f"Policies file unchanged after optional-import attempt: {p}"
273
+
274
+ _write_text(p, s2)
275
+ return True, f"Patched policies __init__ optional imports in: {p}"
276
+
277
+
278
+ # -------------------------
279
+ # Patch 3: eval_sigma_vla_rollout.py
280
+ # -------------------------
281
+
282
+ def find_eval_file() -> pathlib.Path:
283
+ env = os.getenv("EVAL_FILE")
284
+ if env:
285
+ p = pathlib.Path(env)
286
+ if p.exists():
287
+ return p
288
+
289
+ p = pathlib.Path("/workspace/eval_sigma_vla_rollout.py")
290
+ if p.exists():
291
+ return p
292
+
293
+ pp = _search_file(["/workspace", "/workspace/lerobot"], "eval_sigma_vla_rollout.py")
294
+ if pp and pp.exists():
295
+ return pp
296
+
297
+ raise FileNotFoundError("eval_sigma_vla_rollout.py not found. Set EVAL_FILE env var.")
298
+
299
+
300
+ def patch_eval_force_strict_false(p: pathlib.Path) -> Tuple[bool, str]:
301
+ s = _read_text(p)
302
+ marker = "PATCH: force strict=False for PI05Policy"
303
+
304
+ # 1) strict=True -> strict=False in PI05 loads
305
+ pat_strict_true = r"(policy_cls\.from_pretrained\([^)]*strict\s*=\s*)True(\s*[^)]*\))"
306
+ s2, n_true = re.subn(pat_strict_true, r"\1False\2", s)
307
+
308
+ # 2) add strict=False if missing on PI05 loads
309
+ def _add_strict_false_call(match: re.Match) -> str:
310
+ call = match.group(0)
311
+ if "strict" in call:
312
+ return call
313
+ return call[:-1] + ", strict=False)"
314
+
315
+ pat_no_strict_1 = r"policy_cls\.from_pretrained\(\s*repo_id\s*,\s*token\s*=\s*hf_token\s*\)"
316
+ pat_no_strict_2 = r"policy_cls\.from_pretrained\(\s*pretrained_name_or_path\s*=\s*repo_id\s*,\s*token\s*=\s*hf_token\s*\)"
317
+
318
+ s3, n_add1 = re.subn(pat_no_strict_1, _add_strict_false_call, s2)
319
+ s4, n_add2 = re.subn(pat_no_strict_2, _add_strict_false_call, s3)
320
+
321
+ changed = (n_true + n_add1 + n_add2) > 0
322
+ if not changed:
323
+ if marker in s:
324
+ return False, f"Eval strict patch already present: {p}"
325
+ return False, f"Eval already strict=False or no PI05 strict targets found: {p}"
326
+
327
+ if marker not in s4:
328
+ # annotate the first strict=False we introduced / touched
329
+ s4 = s4.replace("strict=False)", f"strict=False) # {marker}", 1)
330
+
331
+ _write_text(p, s4)
332
+ return True, f"Patched eval PI05 strict=False in: {p}"
333
+
334
+
335
+ def patch_eval_shuffle_support(p: pathlib.Path) -> Tuple[bool, str]:
336
+ s = _read_text(p)
337
+ marker_arg = "PATCH: add --shuffle arg"
338
+ marker_dl = "PATCH: DataLoader shuffle uses args.shuffle"
339
+
340
+ changed = False
341
+
342
+ # 1) add CLI arg --shuffle if absent
343
+ if re.search(r'add_argument\(\s*["\']--shuffle["\']', s) is None:
344
+ # find last parser.add_argument(...) to insert after
345
+ arg_pat = re.compile(r"(?m)^\s*parser\.add_argument\(.+?\)\s*$")
346
+ matches = list(arg_pat.finditer(s))
347
+ if matches:
348
+ last = matches[-1]
349
+ insert_pos = last.end()
350
+ insert_text = (
351
+ "\nparser.add_argument("
352
+ "\"--shuffle\", action=\"store_true\", "
353
+ "help=\"Shuffle dataset order to sample different subsets per seed.\")"
354
+ f" # {marker_arg}\n"
355
+ )
356
+ s = s[:insert_pos] + insert_text + s[insert_pos:]
357
+ changed = True
358
+
359
+ # 2) DataLoader(... shuffle=False ...) -> args.shuffle
360
+ if marker_dl not in s:
361
+ def _dl_repl(m: re.Match) -> str:
362
+ prefix = m.group(1)
363
+ return prefix + f'getattr(args, "shuffle", False) # {marker_dl}'
364
+
365
+ # replace only literal shuffle=False
366
+ pat_dl = re.compile(r"(?s)(DataLoader\([\s\S]{0,1200}?shuffle\s*=\s*)False")
367
+ if pat_dl.search(s):
368
+ s = pat_dl.sub(_dl_repl, s, count=1)
369
+ changed = True
370
+
371
+ if changed:
372
+ _write_text(p, s)
373
+ return True, f"Patched eval shuffle support in: {p}"
374
+
375
+ return False, f"Eval shuffle support already present or no targets found: {p}"
376
+
377
+
378
+ # -------------------------
379
+ # Main
380
+ # -------------------------
381
+
382
+ def main():
383
+ changed_any = False
384
+
385
+ try:
386
+ pi05_file = find_pi05_file()
387
+ changed, msg = patch_pi05_embed_tie(pi05_file)
388
+ print(msg)
389
+ changed_any |= changed
390
+ except Exception as e:
391
+ print("[patch_sigma_env] PI05 embed-tie patch skipped:", e)
392
+
393
+ try:
394
+ pi05_file = find_pi05_file()
395
+ changed, msg = patch_pi05_transformers_guard(pi05_file)
396
+ print(msg)
397
+ changed_any |= changed
398
+ except Exception as e:
399
+ print("[patch_sigma_env] PI05 transformers-guard patch skipped:", e)
400
+
401
+ try:
402
+ policies_init = find_policies_init()
403
+ changed, msg = patch_policies_optional_imports(policies_init)
404
+ print(msg)
405
+ changed_any |= changed
406
+ except Exception as e:
407
+ print("[patch_sigma_env] policies __init__ patch skipped:", e)
408
+
409
+ try:
410
+ eval_file = find_eval_file()
411
+ changed, msg = patch_eval_force_strict_false(eval_file)
412
+ print(msg)
413
+ changed_any |= changed
414
+ except Exception as e:
415
+ print("[patch_sigma_env] Eval strict patch skipped:", e)
416
+
417
+ try:
418
+ eval_file = find_eval_file()
419
+ changed, msg = patch_eval_shuffle_support(eval_file)
420
+ print(msg)
421
+ changed_any |= changed
422
+ except Exception as e:
423
+ print("[patch_sigma_env] Eval shuffle patch skipped:", e)
424
+
425
+ if changed_any:
426
+ print("[patch_sigma_env] Done. Patches applied.")
427
+ else:
428
+ print("[patch_sigma_env] Done. Nothing to change (already patched).")
429
+
430
+
431
+ if __name__ == "__main__":
432
+ main()
pi05_embed_tie.patch ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py
2
+ index b017bbc5..d6290da6 100644
3
+ --- a/src/lerobot/policies/pi05/modeling_pi05.py
4
+ +++ b/src/lerobot/policies/pi05/modeling_pi05.py
5
+ @@ -989,6 +989,13 @@ class PI05Policy(PreTrainedPolicy):
6
+ if remap_count > 0:
7
+ print(f"Remapped {remap_count} state dict keys")
8
+
9
+ # Load the remapped state dict into the model
10
+ missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
11
+ +
12
+ + # --- FIX: tie embed_tokens to lm_head if embed_tokens missing in ckpt ---
13
+ + if any("embed_tokens.weight" in k for k in missing_keys):
14
+ + with torch.no_grad():
15
+ + embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens
16
+ + lm_head = model.model.paligemma_with_expert.paligemma.lm_head
17
+ + embed.weight = lm_head.weight
18
+
19
+ return model
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --index-url https://download.pytorch.org/whl/cu121
2
+ --extra-index-url https://pypi.org/simple
3
+
4
+ torch==2.5.1+cu121
5
+ torchvision==0.20.1+cu121
6
+
7
+ transformers==4.44.2
8
+ accelerate==1.10.1
9
+ peft==0.17.0
10
+ safetensors==0.4.5
11
+ huggingface_hub[cli,hf-transfer]==0.35.3
12
+ datasets==4.1.1
13
+ sentencepiece==0.2.0
14
+ einops==0.8.0
15
+ bitsandbytes==0.43.3
16
+
17
+ numpy==2.0.2
18
+ pandas==2.2.3
19
+ pyarrow==21.0.0
20
+ tqdm==4.66.5
21
+
22
+ opencv-python-headless==4.10.0.84
23
+ pillow==10.4.0
24
+ av==15.1.0
25
+ imageio==2.36.0
26
+ imageio-ffmpeg==0.5.1
27
+
28
+ hydra-core==1.3.2
29
+ omegaconf==2.3.0
30
+ pyyaml==6.0.2
31
+ packaging==24.2
32
+ python-dotenv==1.0.1
33
+ wandb==0.21.1