Upload 10 files
Browse files- eval_sigma_vla_rollout.py +1402 -0
- modeling_pi05.py +1264 -0
- patch_sigma_env.py +432 -0
- pi05_embed_tie.patch +19 -0
- requirements.txt +33 -0
eval_sigma_vla_rollout.py
ADDED
|
@@ -0,0 +1,1402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# eval_sigma_vla_rollout.py
|
| 2 |
+
# Offline closed-loop evaluation for Telepathy-augmented VLA on top of PI05 policy backbone.
|
| 3 |
+
#
|
| 4 |
+
# Key design:
|
| 5 |
+
# - base_model_id is a LeRobot/OpenPI policy repo (e.g., lerobot/pi05_base or your fine-tuned Sigma repo).
|
| 6 |
+
# - We load PI05Policy via LeRobot, NOT AutoModelForCausalLM.
|
| 7 |
+
# - Text embeddings are taken from the PI05 internal text backbone so that TelepathyLanguageModule
|
| 8 |
+
# receives the same type of inputs used during training.
|
| 9 |
+
#
|
| 10 |
+
# Hardened in this revision:
|
| 11 |
+
# - Robust recursive shard discovery under any naming & subfolders.
|
| 12 |
+
# - Shard content structure normalization (list-of-samples, or dict{samples/data}).
|
| 13 |
+
# - Collate auto-adapts to real schema: vision/state/action/text, with time-dim collapse for vision.
|
| 14 |
+
# - Action GT supports dict-style branches or a single tensor.
|
| 15 |
+
# - Metrics tolerate missing multi-branch outputs (fallback to "action").
|
| 16 |
+
# - Text tokens dtype/device aligned to model dtype for mixed precision safety.
|
| 17 |
+
# - Robot state time-dim collapse + pad/trim to state encoder expected dim.
|
| 18 |
+
# - Dynamic projection to align vision/state token hidden size to vision backbone dim (768),
|
| 19 |
+
# and project text to the same dim BEFORE feeding language module.
|
| 20 |
+
# - Optional max_text_len to avoid tokenizer truncation warnings.
|
| 21 |
+
# - action input contract hardening:
|
| 22 |
+
# * high_level_rep 2D -> 3D
|
| 23 |
+
# * tau None/2D -> 3D
|
| 24 |
+
# * tau length aligned to high_level_rep length
|
| 25 |
+
# * tau last-dim auto pad/trim so concat(high_level_rep, tau) matches action_condition_proj in_features
|
| 26 |
+
# - tokenizer_id can be a LOCAL path; when it exists locally we load with local_files_only
|
| 27 |
+
# - _align_target handles 2D<->3D mismatches (fixes MSE crashes)
|
| 28 |
+
# - remove duplicated "high_level_rep/tau re-normalization" that overwrote the hardening
|
| 29 |
+
#
|
| 30 |
+
# NEW in this patch:
|
| 31 |
+
# - cosine_alignment auto-aligns hidden sizes (fixes 256 vs 2048 crash).
|
| 32 |
+
# - semantic pooling guard supports 2D/3D factors safely.
|
| 33 |
+
# - alignment metric ignores zero-length cases robustly.
|
| 34 |
+
#
|
| 35 |
+
# EXTRA HARDENING (this patch for your baseline issue):
|
| 36 |
+
# - Try strict load for PI05Policy if the LeRobot version supports it.
|
| 37 |
+
# - Verify tokenizer vocab size and special-token ids match PI05 text embedding table.
|
| 38 |
+
# - Fail fast with a clear message if mismatch is detected (unless explicitly overridden).
|
| 39 |
+
#
|
| 40 |
+
# NEW in this hard-set patch:
|
| 41 |
+
# - Per-sample MSE is exposed from success proxy.
|
| 42 |
+
# - A "hard set" is defined as samples whose branch-wise MSE exceeds hard thresholds.
|
| 43 |
+
# - Hard-set averages (MSE and fraction of samples) are reported alongside global metrics.
|
| 44 |
+
#
|
| 45 |
+
# NEW in this adapter patch:
|
| 46 |
+
# - sigma_telepathy_adapter is applied at eval time (when telepathy is enabled) to gate
|
| 47 |
+
# Telepathy residuals based on their magnitude and tau strength, optionally using
|
| 48 |
+
# offline base_action_* if present in the shards.
|
| 49 |
+
|
| 50 |
+
from __future__ import annotations
|
| 51 |
+
|
| 52 |
+
import os
|
| 53 |
+
import glob
|
| 54 |
+
import json
|
| 55 |
+
import argparse
|
| 56 |
+
import importlib
|
| 57 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 58 |
+
|
| 59 |
+
import torch
|
| 60 |
+
import torch.nn as nn
|
| 61 |
+
import torch.nn.functional as F
|
| 62 |
+
from torch.utils.data import Dataset, DataLoader
|
| 63 |
+
|
| 64 |
+
from dotenv import load_dotenv
|
| 65 |
+
from accelerate import Accelerator
|
| 66 |
+
from accelerate.utils import set_seed
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
from huggingface_hub import snapshot_download
|
| 70 |
+
except Exception:
|
| 71 |
+
snapshot_download = None # type: ignore
|
| 72 |
+
|
| 73 |
+
from vision_sigma_vla import TelepathyVisionModule, VisionConfig
|
| 74 |
+
from language_sigma_vla import TelepathyLanguageModule, LanguageConfig
|
| 75 |
+
from action_sigma_vla import TelepathyActionModule, ActionConfig
|
| 76 |
+
from sigma_telepathy_adapter import SigmaTelepathyAdapter, SigmaTelepathyAdapterConfig
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def ensure_sigma_artifacts_from_hf(
|
| 80 |
+
repo_id: str,
|
| 81 |
+
hf_token: Optional[str],
|
| 82 |
+
local_cache_root: str,
|
| 83 |
+
) -> Dict[str, str]:
|
| 84 |
+
"""
|
| 85 |
+
Download Sigma artifacts from HF repo into a local cache folder.
|
| 86 |
+
Returns local paths for shard_dir and telepathy_heads_path.
|
| 87 |
+
|
| 88 |
+
We only pull:
|
| 89 |
+
storage/sigma_pickplace/**
|
| 90 |
+
storage/sigma_lora_out/**
|
| 91 |
+
"""
|
| 92 |
+
if snapshot_download is None:
|
| 93 |
+
raise ImportError(
|
| 94 |
+
"huggingface_hub is not available but auto-download was requested. "
|
| 95 |
+
"Please `pip install huggingface_hub` or download artifacts manually."
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
os.makedirs(local_cache_root, exist_ok=True)
|
| 99 |
+
local_dir = snapshot_download(
|
| 100 |
+
repo_id=repo_id,
|
| 101 |
+
token=hf_token,
|
| 102 |
+
local_dir=os.path.join(local_cache_root, repo_id.replace("/", "__")),
|
| 103 |
+
local_dir_use_symlinks=False,
|
| 104 |
+
allow_patterns=[
|
| 105 |
+
"storage/sigma_pickplace/**",
|
| 106 |
+
"storage/sigma_lora_out/**",
|
| 107 |
+
],
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
shard_dir = os.path.join(local_dir, "storage", "sigma_pickplace")
|
| 111 |
+
telepathy_heads_path = os.path.join(
|
| 112 |
+
local_dir, "storage", "sigma_lora_out", "sigma_telepathy_heads.pt"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
return {
|
| 116 |
+
"local_repo_dir": local_dir,
|
| 117 |
+
"shard_dir": shard_dir,
|
| 118 |
+
"telepathy_heads_path": telepathy_heads_path,
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def load_pi05_policy(
|
| 123 |
+
repo_id: str,
|
| 124 |
+
hf_token: Optional[str],
|
| 125 |
+
device: torch.device,
|
| 126 |
+
strict_load: bool = True,
|
| 127 |
+
):
|
| 128 |
+
"""
|
| 129 |
+
Load PI05Policy from LeRobot. We try a few import paths to be robust across versions.
|
| 130 |
+
If the LeRobot PI05Policy.from_pretrained supports strict loading, we enable it.
|
| 131 |
+
"""
|
| 132 |
+
policy_cls = None
|
| 133 |
+
import_errors = []
|
| 134 |
+
|
| 135 |
+
candidate_paths = [
|
| 136 |
+
("lerobot.policies.pi05.modeling_pi05", "PI05Policy"),
|
| 137 |
+
("lerobot.policies.pi05", "PI05Policy"),
|
| 138 |
+
]
|
| 139 |
+
|
| 140 |
+
for mod_name, cls_name in candidate_paths:
|
| 141 |
+
try:
|
| 142 |
+
mod = importlib.import_module(mod_name)
|
| 143 |
+
policy_cls = getattr(mod, cls_name)
|
| 144 |
+
break
|
| 145 |
+
except Exception as e:
|
| 146 |
+
import_errors.append(f"{mod_name}.{cls_name}: {type(e).__name__}: {e}")
|
| 147 |
+
|
| 148 |
+
if policy_cls is None:
|
| 149 |
+
raise ImportError(
|
| 150 |
+
"Failed to import PI05Policy from LeRobot. Tried:\n - "
|
| 151 |
+
+ "\n - ".join(import_errors)
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
policy = None
|
| 155 |
+
tried = []
|
| 156 |
+
if strict_load:
|
| 157 |
+
try:
|
| 158 |
+
policy = policy_cls.from_pretrained(repo_id, token=hf_token, strict=True)
|
| 159 |
+
tried.append("from_pretrained(..., strict=True)")
|
| 160 |
+
except TypeError:
|
| 161 |
+
tried.append("strict=True not supported")
|
| 162 |
+
except Exception as e:
|
| 163 |
+
tried.append(f"strict=True failed: {type(e).__name__}: {e}")
|
| 164 |
+
|
| 165 |
+
if policy is None:
|
| 166 |
+
try:
|
| 167 |
+
policy = policy_cls.from_pretrained(repo_id, token=hf_token)
|
| 168 |
+
tried.append("from_pretrained(repo_id, token=...)")
|
| 169 |
+
except TypeError:
|
| 170 |
+
policy = policy_cls.from_pretrained(pretrained_name_or_path=repo_id, token=hf_token)
|
| 171 |
+
tried.append("from_pretrained(pretrained_name_or_path=..., token=...)")
|
| 172 |
+
|
| 173 |
+
if policy is None:
|
| 174 |
+
raise RuntimeError("PI05Policy loading returned None. Tried: " + "; ".join(tried))
|
| 175 |
+
|
| 176 |
+
policy = policy.to(device)
|
| 177 |
+
policy.eval()
|
| 178 |
+
return policy
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def get_policy_tokenizer(
|
| 182 |
+
policy,
|
| 183 |
+
repo_id: str,
|
| 184 |
+
hf_token: Optional[str],
|
| 185 |
+
forced_tokenizer_id: str = "",
|
| 186 |
+
):
|
| 187 |
+
"""
|
| 188 |
+
Robust tokenizer getter for PI05Policy.
|
| 189 |
+
|
| 190 |
+
IMPORTANT:
|
| 191 |
+
- Never call AutoTokenizer.from_pretrained(repo_id) because repo_id is a policy repo.
|
| 192 |
+
- If --tokenizer_id is provided and points to a LOCAL folder, load locally.
|
| 193 |
+
- Otherwise load from HF id.
|
| 194 |
+
- If still missing, recursively search for tokenizer/processor inside policy.
|
| 195 |
+
"""
|
| 196 |
+
from transformers import AutoTokenizer
|
| 197 |
+
|
| 198 |
+
if forced_tokenizer_id:
|
| 199 |
+
if os.path.exists(forced_tokenizer_id):
|
| 200 |
+
tok = AutoTokenizer.from_pretrained(
|
| 201 |
+
forced_tokenizer_id,
|
| 202 |
+
local_files_only=True,
|
| 203 |
+
trust_remote_code=True,
|
| 204 |
+
)
|
| 205 |
+
else:
|
| 206 |
+
tok = AutoTokenizer.from_pretrained(
|
| 207 |
+
forced_tokenizer_id,
|
| 208 |
+
token=hf_token,
|
| 209 |
+
trust_remote_code=True,
|
| 210 |
+
)
|
| 211 |
+
if tok.pad_token is None:
|
| 212 |
+
tok.pad_token = tok.eos_token
|
| 213 |
+
return tok
|
| 214 |
+
|
| 215 |
+
def _recursive_find_tokenizer(obj, max_depth: int = 4):
|
| 216 |
+
if obj is None or max_depth <= 0:
|
| 217 |
+
return None
|
| 218 |
+
|
| 219 |
+
for key in ["tokenizer", "processor", "text_tokenizer", "language_tokenizer"]:
|
| 220 |
+
if hasattr(obj, key):
|
| 221 |
+
v = getattr(obj, key)
|
| 222 |
+
if v is None:
|
| 223 |
+
continue
|
| 224 |
+
if key == "processor" and hasattr(v, "tokenizer") and v.tokenizer is not None:
|
| 225 |
+
return v.tokenizer
|
| 226 |
+
if hasattr(v, "__call__"):
|
| 227 |
+
return v
|
| 228 |
+
|
| 229 |
+
nested_names = [
|
| 230 |
+
"paligemma_with_expert",
|
| 231 |
+
"paligemma",
|
| 232 |
+
"gemma_expert",
|
| 233 |
+
"language_model",
|
| 234 |
+
"text_model",
|
| 235 |
+
"model",
|
| 236 |
+
"policy",
|
| 237 |
+
]
|
| 238 |
+
for name in nested_names:
|
| 239 |
+
if hasattr(obj, name):
|
| 240 |
+
found = _recursive_find_tokenizer(
|
| 241 |
+
getattr(obj, name), max_depth=max_depth - 1
|
| 242 |
+
)
|
| 243 |
+
if found is not None:
|
| 244 |
+
return found
|
| 245 |
+
return None
|
| 246 |
+
|
| 247 |
+
tok = _recursive_find_tokenizer(policy)
|
| 248 |
+
if tok is not None:
|
| 249 |
+
if getattr(tok, "pad_token", None) is None and getattr(tok, "eos_token", None) is not None:
|
| 250 |
+
tok.pad_token = tok.eos_token
|
| 251 |
+
return tok
|
| 252 |
+
|
| 253 |
+
backbone_name = None
|
| 254 |
+
config_candidates = []
|
| 255 |
+
for attr in ["config", "model", "paligemma_with_expert", "paligemma"]:
|
| 256 |
+
if hasattr(policy, attr):
|
| 257 |
+
config_candidates.append(getattr(policy, attr))
|
| 258 |
+
|
| 259 |
+
def _try_get_name(cfg_obj):
|
| 260 |
+
if cfg_obj is None:
|
| 261 |
+
return None
|
| 262 |
+
for k in [
|
| 263 |
+
"_name_or_path",
|
| 264 |
+
"text_backbone_id",
|
| 265 |
+
"text_model_id",
|
| 266 |
+
"language_model_id",
|
| 267 |
+
"processor_name_or_path",
|
| 268 |
+
"tokenizer_name_or_path",
|
| 269 |
+
]:
|
| 270 |
+
if hasattr(cfg_obj, k):
|
| 271 |
+
v = getattr(cfg_obj, k)
|
| 272 |
+
if isinstance(v, str) and v:
|
| 273 |
+
return v
|
| 274 |
+
if hasattr(cfg_obj, "config"):
|
| 275 |
+
c = getattr(cfg_obj, "config")
|
| 276 |
+
if hasattr(c, "_name_or_path") and isinstance(c._name_or_path, str) and c._name_or_path:
|
| 277 |
+
return c._name_or_path
|
| 278 |
+
return None
|
| 279 |
+
|
| 280 |
+
for c in config_candidates:
|
| 281 |
+
backbone_name = _try_get_name(c)
|
| 282 |
+
if backbone_name:
|
| 283 |
+
break
|
| 284 |
+
|
| 285 |
+
if backbone_name:
|
| 286 |
+
tok = AutoTokenizer.from_pretrained(
|
| 287 |
+
backbone_name, token=hf_token, trust_remote_code=True
|
| 288 |
+
)
|
| 289 |
+
if tok.pad_token is None:
|
| 290 |
+
tok.pad_token = tok.eos_token
|
| 291 |
+
return tok
|
| 292 |
+
|
| 293 |
+
raise ValueError(
|
| 294 |
+
f"Cannot obtain tokenizer from PI05Policy for repo '{repo_id}'. "
|
| 295 |
+
"Your lerobot PI05 port does not expose tokenizer/processor nor backbone name. "
|
| 296 |
+
"Please pass --tokenizer_id explicitly."
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def get_policy_text_embedding_layer(policy):
|
| 301 |
+
"""
|
| 302 |
+
Locate the text embedding layer inside PI05Policy robustly.
|
| 303 |
+
"""
|
| 304 |
+
def _recursive_find(obj, depth: int = 6):
|
| 305 |
+
if obj is None or depth <= 0:
|
| 306 |
+
return None
|
| 307 |
+
|
| 308 |
+
if hasattr(obj, "get_input_embeddings"):
|
| 309 |
+
try:
|
| 310 |
+
emb = obj.get_input_embeddings()
|
| 311 |
+
if emb is not None:
|
| 312 |
+
return emb
|
| 313 |
+
except Exception:
|
| 314 |
+
pass
|
| 315 |
+
|
| 316 |
+
for key in ["embed_tokens", "embeddings", "token_embedding"]:
|
| 317 |
+
if hasattr(obj, key):
|
| 318 |
+
v = getattr(obj, key)
|
| 319 |
+
if isinstance(v, nn.Module):
|
| 320 |
+
return v
|
| 321 |
+
|
| 322 |
+
nested_names = [
|
| 323 |
+
"model",
|
| 324 |
+
"paligemma_with_expert",
|
| 325 |
+
"paligemma",
|
| 326 |
+
"language_model",
|
| 327 |
+
"gemma_expert",
|
| 328 |
+
"text_model",
|
| 329 |
+
"policy",
|
| 330 |
+
]
|
| 331 |
+
for name in nested_names:
|
| 332 |
+
if hasattr(obj, name):
|
| 333 |
+
found = _recursive_find(getattr(obj, name), depth=depth - 1)
|
| 334 |
+
if found is not None:
|
| 335 |
+
return found
|
| 336 |
+
|
| 337 |
+
return None
|
| 338 |
+
|
| 339 |
+
emb = _recursive_find(policy)
|
| 340 |
+
if emb is None:
|
| 341 |
+
raise AttributeError(
|
| 342 |
+
"Cannot locate PI05 text embedding layer via recursive search. "
|
| 343 |
+
"Your PI05Policy likely changed internal naming. "
|
| 344 |
+
"Please inspect policy.model.* to confirm embed_tokens location."
|
| 345 |
+
)
|
| 346 |
+
return emb
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def verify_tokenizer_embedding_compat(
|
| 350 |
+
tokenizer,
|
| 351 |
+
text_embed_layer: nn.Module,
|
| 352 |
+
allow_mismatch: bool = False,
|
| 353 |
+
):
|
| 354 |
+
"""
|
| 355 |
+
Ensure tokenizer vocab/special ids are consistent with PI05 text embedding table.
|
| 356 |
+
This directly prevents the 'embed_tokens.weight missing or misaligned' baseline issue.
|
| 357 |
+
"""
|
| 358 |
+
emb_vocab = None
|
| 359 |
+
if isinstance(text_embed_layer, nn.Embedding):
|
| 360 |
+
emb_vocab = int(text_embed_layer.num_embeddings)
|
| 361 |
+
elif hasattr(text_embed_layer, "weight") and text_embed_layer.weight is not None:
|
| 362 |
+
emb_vocab = int(text_embed_layer.weight.size(0))
|
| 363 |
+
|
| 364 |
+
tok_vocab = getattr(tokenizer, "vocab_size", None)
|
| 365 |
+
if tok_vocab is None:
|
| 366 |
+
try:
|
| 367 |
+
tok_vocab = len(tokenizer)
|
| 368 |
+
except Exception:
|
| 369 |
+
tok_vocab = None
|
| 370 |
+
|
| 371 |
+
if emb_vocab is None or tok_vocab is None:
|
| 372 |
+
print("[WARN] Cannot infer tokenizer/embedding vocab sizes. Skipping compatibility check.")
|
| 373 |
+
return
|
| 374 |
+
|
| 375 |
+
if emb_vocab != tok_vocab:
|
| 376 |
+
msg = (
|
| 377 |
+
f"[ERROR] Tokenizer vocab size ({tok_vocab}) != PI05 embedding table size ({emb_vocab}). "
|
| 378 |
+
"This will corrupt text embeddings and invalidate baseline. "
|
| 379 |
+
"Fix by passing --tokenizer_id matching the PI05 text backbone "
|
| 380 |
+
"(e.g., the original openpi/PI05 tokenizer) or re-exporting policy with aligned vocab."
|
| 381 |
+
)
|
| 382 |
+
if allow_mismatch:
|
| 383 |
+
print(msg.replace("[ERROR]", "[WARN]") + " Proceeding due to --allow_tokenizer_mismatch.")
|
| 384 |
+
else:
|
| 385 |
+
raise ValueError(msg)
|
| 386 |
+
|
| 387 |
+
for name in ["pad_token_id", "eos_token_id", "bos_token_id", "unk_token_id"]:
|
| 388 |
+
tid = getattr(tokenizer, name, None)
|
| 389 |
+
if tid is None:
|
| 390 |
+
continue
|
| 391 |
+
if not (0 <= int(tid) < emb_vocab):
|
| 392 |
+
msg = (
|
| 393 |
+
f"[ERROR] Tokenizer {name}={tid} out of embedding range [0, {emb_vocab-1}]. "
|
| 394 |
+
"Your tokenizer does not belong to this PI05 backbone."
|
| 395 |
+
)
|
| 396 |
+
if allow_mismatch:
|
| 397 |
+
print(msg.replace("[ERROR]", "[WARN]") + " Proceeding due to --allow_tokenizer_mismatch.")
|
| 398 |
+
else:
|
| 399 |
+
raise ValueError(msg)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class TelepathyVLA(nn.Module):
|
| 403 |
+
"""
|
| 404 |
+
Full model matching your final arrows.
|
| 405 |
+
"""
|
| 406 |
+
def __init__(
|
| 407 |
+
self,
|
| 408 |
+
v_cfg: VisionConfig,
|
| 409 |
+
l_cfg: LanguageConfig,
|
| 410 |
+
a_cfg: ActionConfig,
|
| 411 |
+
disable_telepathy: bool = False,
|
| 412 |
+
):
|
| 413 |
+
super().__init__()
|
| 414 |
+
self.vision = TelepathyVisionModule(v_cfg)
|
| 415 |
+
self.language = TelepathyLanguageModule(l_cfg)
|
| 416 |
+
self.action = TelepathyActionModule(a_cfg)
|
| 417 |
+
self.disable_telepathy = disable_telepathy
|
| 418 |
+
self.register_buffer("_m_prev", None, persistent=False)
|
| 419 |
+
|
| 420 |
+
self._proj_inited = False
|
| 421 |
+
self.text_proj: Optional[nn.Module] = None
|
| 422 |
+
self.vision_proj: Optional[nn.Module] = None
|
| 423 |
+
self.state_proj: Optional[nn.Module] = None
|
| 424 |
+
|
| 425 |
+
def reset_memory(self):
|
| 426 |
+
self._m_prev = None
|
| 427 |
+
|
| 428 |
+
@torch.no_grad()
|
| 429 |
+
def forward_once(
|
| 430 |
+
self,
|
| 431 |
+
vis_obs: torch.Tensor,
|
| 432 |
+
robot_state: torch.Tensor,
|
| 433 |
+
text_tokens: torch.Tensor,
|
| 434 |
+
depth_obs: Optional[torch.Tensor] = None,
|
| 435 |
+
audio_obs: Optional[torch.Tensor] = None,
|
| 436 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 437 |
+
return_intermediate: bool = False,
|
| 438 |
+
) -> Dict[str, torch.Tensor]:
|
| 439 |
+
|
| 440 |
+
vis0 = self.vision(
|
| 441 |
+
vis_obs=vis_obs,
|
| 442 |
+
robot_state=robot_state,
|
| 443 |
+
depth_obs=depth_obs,
|
| 444 |
+
audio_obs=audio_obs,
|
| 445 |
+
telepathy_factors=None,
|
| 446 |
+
return_intermediate=return_intermediate,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
vis_d = vis0["vision_tokens"].size(-1)
|
| 450 |
+
state_d = vis0["state_tokens"].size(-1)
|
| 451 |
+
target_d = vis_d
|
| 452 |
+
|
| 453 |
+
if not self._proj_inited:
|
| 454 |
+
self.text_proj = nn.Linear(text_tokens.size(-1), target_d, bias=False) \
|
| 455 |
+
if text_tokens.size(-1) != target_d else nn.Identity()
|
| 456 |
+
self.vision_proj = nn.Identity() if vis_d == target_d else nn.Linear(vis_d, target_d, bias=False)
|
| 457 |
+
self.state_proj = nn.Identity() if state_d == target_d else nn.Linear(state_d, target_d, bias=False)
|
| 458 |
+
|
| 459 |
+
self.text_proj = self.text_proj.to(device=text_tokens.device, dtype=text_tokens.dtype)
|
| 460 |
+
self.vision_proj = self.vision_proj.to(device=text_tokens.device, dtype=text_tokens.dtype)
|
| 461 |
+
self.state_proj = self.state_proj.to(device=text_tokens.device, dtype=text_tokens.dtype)
|
| 462 |
+
self._proj_inited = True
|
| 463 |
+
|
| 464 |
+
assert self.text_proj is not None and self.vision_proj is not None and self.state_proj is not None
|
| 465 |
+
|
| 466 |
+
text_tokens = self.text_proj(text_tokens)
|
| 467 |
+
vision_tokens = self.vision_proj(vis0["vision_tokens"])
|
| 468 |
+
state_tokens = self.state_proj(vis0["state_tokens"])
|
| 469 |
+
|
| 470 |
+
lang_out = self.language(
|
| 471 |
+
text_tokens=text_tokens,
|
| 472 |
+
vision_tokens=vision_tokens,
|
| 473 |
+
state_tokens=state_tokens,
|
| 474 |
+
m_prev=self._m_prev,
|
| 475 |
+
attn_mask=attn_mask,
|
| 476 |
+
return_intermediate=return_intermediate,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
raw_tau = lang_out.get("telepathy_factors", None)
|
| 480 |
+
self._m_prev = lang_out.get("m_t", None)
|
| 481 |
+
|
| 482 |
+
telepathy_scale = float(getattr(self, "telepathy_scale", 1.0))
|
| 483 |
+
|
| 484 |
+
if self.disable_telepathy:
|
| 485 |
+
tau = None
|
| 486 |
+
vis_out = vis0
|
| 487 |
+
else:
|
| 488 |
+
tau = raw_tau
|
| 489 |
+
if tau is not None:
|
| 490 |
+
tau = tau * telepathy_scale
|
| 491 |
+
vis_out = self.vision(
|
| 492 |
+
vis_obs=vis_obs,
|
| 493 |
+
robot_state=robot_state,
|
| 494 |
+
depth_obs=depth_obs,
|
| 495 |
+
audio_obs=audio_obs,
|
| 496 |
+
telepathy_factors=tau,
|
| 497 |
+
return_intermediate=return_intermediate,
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
high_level_rep = lang_out.get("high_level_rep", None)
|
| 501 |
+
if high_level_rep is None:
|
| 502 |
+
raise KeyError("language output missing 'high_level_rep'.")
|
| 503 |
+
|
| 504 |
+
if high_level_rep.dim() == 2:
|
| 505 |
+
high_level_rep = high_level_rep.unsqueeze(1)
|
| 506 |
+
|
| 507 |
+
if tau is None:
|
| 508 |
+
B, L, _ = high_level_rep.shape
|
| 509 |
+
tau_dim = getattr(self.language, "tau_dim", 128)
|
| 510 |
+
tau = torch.zeros(B, L, tau_dim, device=high_level_rep.device, dtype=high_level_rep.dtype)
|
| 511 |
+
else:
|
| 512 |
+
if tau.dim() == 2:
|
| 513 |
+
tau = tau.unsqueeze(1)
|
| 514 |
+
if tau.size(1) != high_level_rep.size(1):
|
| 515 |
+
L = high_level_rep.size(1)
|
| 516 |
+
if tau.size(1) == 1:
|
| 517 |
+
tau = tau.expand(-1, L, -1)
|
| 518 |
+
else:
|
| 519 |
+
tau = tau[:, :L, :]
|
| 520 |
+
|
| 521 |
+
expected_in = None
|
| 522 |
+
acp = getattr(self.action, "action_condition_proj", None)
|
| 523 |
+
if acp is not None:
|
| 524 |
+
if hasattr(acp, "in_features"):
|
| 525 |
+
expected_in = int(acp.in_features)
|
| 526 |
+
elif hasattr(acp, "net") and len(acp.net) > 0 and hasattr(acp.net[0], "in_features"):
|
| 527 |
+
expected_in = int(acp.net[0].in_features)
|
| 528 |
+
|
| 529 |
+
if expected_in is not None:
|
| 530 |
+
d_high = high_level_rep.size(-1)
|
| 531 |
+
target_tau = expected_in - d_high
|
| 532 |
+
|
| 533 |
+
if target_tau <= 0:
|
| 534 |
+
pass
|
| 535 |
+
else:
|
| 536 |
+
if tau.size(-1) < target_tau:
|
| 537 |
+
tau = F.pad(tau, (0, target_tau - tau.size(-1)))
|
| 538 |
+
elif tau.size(-1) > target_tau:
|
| 539 |
+
tau = tau[..., :target_tau]
|
| 540 |
+
|
| 541 |
+
state_for_action = vis_out["state_tokens"]
|
| 542 |
+
if state_for_action.dim() == 2:
|
| 543 |
+
state_for_action = state_for_action.unsqueeze(1)
|
| 544 |
+
elif state_for_action.dim() > 3:
|
| 545 |
+
state_for_action = state_for_action.view(
|
| 546 |
+
state_for_action.size(0), -1, state_for_action.size(-1)
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
lang_d = high_level_rep.size(-1)
|
| 550 |
+
|
| 551 |
+
def _pad_or_trim_to(x: torch.Tensor, d: int) -> torch.Tensor:
|
| 552 |
+
cur_d = x.size(-1)
|
| 553 |
+
if cur_d == d:
|
| 554 |
+
return x
|
| 555 |
+
if cur_d < d:
|
| 556 |
+
return F.pad(x, (0, d - cur_d))
|
| 557 |
+
return x[..., :d]
|
| 558 |
+
|
| 559 |
+
state_for_action = _pad_or_trim_to(state_for_action, lang_d)
|
| 560 |
+
|
| 561 |
+
act_out = self.action(
|
| 562 |
+
high_level_rep=high_level_rep,
|
| 563 |
+
telepathy_factors=tau,
|
| 564 |
+
state_tokens=state_for_action,
|
| 565 |
+
return_intermediate=return_intermediate,
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
out: Dict[str, torch.Tensor] = {}
|
| 569 |
+
out.update(vis_out)
|
| 570 |
+
out.update(lang_out)
|
| 571 |
+
out.update(act_out)
|
| 572 |
+
return out
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
class SigmaShardDataset(Dataset):
|
| 576 |
+
"""
|
| 577 |
+
Loads .pt shards produced by dataset_preprocess_sigma_vla.py.
|
| 578 |
+
Each shard is a list of dict samples OR a dict containing a list (samples/data).
|
| 579 |
+
"""
|
| 580 |
+
def __init__(self, shard_dir: str):
|
| 581 |
+
super().__init__()
|
| 582 |
+
if not os.path.isdir(shard_dir):
|
| 583 |
+
raise FileNotFoundError(
|
| 584 |
+
f"shard_dir does not exist: {shard_dir}. Double-check the path."
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
patterns = [
|
| 588 |
+
os.path.join(shard_dir, "sigma_vla_shard_*.pt"),
|
| 589 |
+
os.path.join(shard_dir, "*.pt"),
|
| 590 |
+
os.path.join(shard_dir, "**", "*.pt"),
|
| 591 |
+
]
|
| 592 |
+
paths: List[str] = []
|
| 593 |
+
for p in patterns:
|
| 594 |
+
paths.extend(glob.glob(p, recursive=True))
|
| 595 |
+
|
| 596 |
+
self.shard_paths = sorted(list(set(paths)))
|
| 597 |
+
if len(self.shard_paths) == 0:
|
| 598 |
+
raise FileNotFoundError(
|
| 599 |
+
f"No .pt shards found under {shard_dir}. "
|
| 600 |
+
"Your HF cache is empty or shards are not tracked by LFS."
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
print(f"[INFO] Found {len(self.shard_paths)} shard files. Example: {self.shard_paths[:3]}")
|
| 604 |
+
|
| 605 |
+
self.index_map: List[Tuple[int, int]] = []
|
| 606 |
+
self._shard_cache: Dict[int, List[Dict[str, Any]]] = {}
|
| 607 |
+
|
| 608 |
+
for sid, p in enumerate(self.shard_paths):
|
| 609 |
+
shard = torch.load(p, map_location="cpu")
|
| 610 |
+
shard_list = self._normalize_shard(shard, p)
|
| 611 |
+
for lid in range(len(shard_list)):
|
| 612 |
+
self.index_map.append((sid, lid))
|
| 613 |
+
|
| 614 |
+
self.total = len(self.index_map)
|
| 615 |
+
|
| 616 |
+
def __len__(self):
|
| 617 |
+
return self.total
|
| 618 |
+
|
| 619 |
+
def _normalize_shard(self, shard_obj: Any, path: str) -> List[Dict[str, Any]]:
|
| 620 |
+
if isinstance(shard_obj, (list, tuple)):
|
| 621 |
+
return list(shard_obj)
|
| 622 |
+
|
| 623 |
+
if isinstance(shard_obj, dict):
|
| 624 |
+
for k in ["samples", "data", "items"]:
|
| 625 |
+
if k in shard_obj and isinstance(shard_obj[k], (list, tuple)):
|
| 626 |
+
return list(shard_obj[k])
|
| 627 |
+
|
| 628 |
+
raise TypeError(
|
| 629 |
+
f"Unsupported shard format in {path}. "
|
| 630 |
+
f"Expected list/tuple of samples or dict{{samples/data}}. "
|
| 631 |
+
f"Got type: {type(shard_obj).__name__}"
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
def _get_shard(self, sid: int) -> List[Dict[str, Any]]:
|
| 635 |
+
if sid not in self._shard_cache:
|
| 636 |
+
raw = torch.load(self.shard_paths[sid], map_location="cpu")
|
| 637 |
+
self._shard_cache[sid] = self._normalize_shard(raw, self.shard_paths[sid])
|
| 638 |
+
return self._shard_cache[sid]
|
| 639 |
+
|
| 640 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 641 |
+
sid, lid = self.index_map[idx]
|
| 642 |
+
shard = self._get_shard(sid)
|
| 643 |
+
return shard[lid]
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
def collate_sigma(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 647 |
+
"""
|
| 648 |
+
Robust collate for Sigma shards.
|
| 649 |
+
"""
|
| 650 |
+
s0 = batch_list[0]
|
| 651 |
+
|
| 652 |
+
def pick_key(sample: Dict[str, Any], candidates: List[str], field_name: str):
|
| 653 |
+
for k in candidates:
|
| 654 |
+
if k in sample:
|
| 655 |
+
return k
|
| 656 |
+
raise KeyError(
|
| 657 |
+
f"Shard sample missing required field '{field_name}'. "
|
| 658 |
+
f"Tried keys: {candidates}. "
|
| 659 |
+
f"Available keys: {list(sample.keys())}"
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
if "vision" in s0:
|
| 663 |
+
vis_k = "vision"
|
| 664 |
+
else:
|
| 665 |
+
vis_k = pick_key(s0, ["vis_obs", "rgb_obs", "image", "images", "obs"], "vision/vis_obs")
|
| 666 |
+
|
| 667 |
+
vis_obs = torch.stack([b[vis_k] for b in batch_list], dim=0).float()
|
| 668 |
+
if vis_obs.dim() == 5:
|
| 669 |
+
vis_obs = vis_obs[:, -1]
|
| 670 |
+
|
| 671 |
+
depth_obs = None
|
| 672 |
+
if "depth" in s0:
|
| 673 |
+
depth_obs = torch.stack([b["depth"] for b in batch_list], dim=0).float()
|
| 674 |
+
elif any(k in s0 for k in ["depth_obs", "depths"]):
|
| 675 |
+
dk = pick_key(s0, ["depth_obs", "depths"], "depth")
|
| 676 |
+
depth_obs = torch.stack([b[dk] for b in batch_list], dim=0).float()
|
| 677 |
+
|
| 678 |
+
audio_obs = None
|
| 679 |
+
if "audio" in s0:
|
| 680 |
+
audio_obs = torch.stack([b["audio"] for b in batch_list], dim=0).float()
|
| 681 |
+
elif any(k in s0 for k in ["audio_obs", "audios"]):
|
| 682 |
+
ak = pick_key(s0, ["audio_obs", "audios"], "audio")
|
| 683 |
+
audio_obs = torch.stack([b[ak] for b in batch_list], dim=0).float()
|
| 684 |
+
|
| 685 |
+
if "state" in s0:
|
| 686 |
+
state_k = "state"
|
| 687 |
+
else:
|
| 688 |
+
state_k = pick_key(s0, ["robot_state", "proprio", "proprio_obs"], "state/robot_state")
|
| 689 |
+
|
| 690 |
+
robot_state = torch.stack([b[state_k] for b in batch_list], dim=0).float()
|
| 691 |
+
|
| 692 |
+
if "text" in s0:
|
| 693 |
+
texts = [b.get("text", "") for b in batch_list]
|
| 694 |
+
else:
|
| 695 |
+
text_k = pick_key(s0, ["text", "prompt", "instruction"], "text")
|
| 696 |
+
texts = [b.get(text_k, "") for b in batch_list]
|
| 697 |
+
|
| 698 |
+
if "action" in s0:
|
| 699 |
+
a0 = s0["action"]
|
| 700 |
+
if isinstance(a0, dict):
|
| 701 |
+
def pick_action_key(d, candidates, name):
|
| 702 |
+
for k in candidates:
|
| 703 |
+
if k in d:
|
| 704 |
+
return k
|
| 705 |
+
raise KeyError(
|
| 706 |
+
f"Action dict missing '{name}'. Tried {candidates}. "
|
| 707 |
+
f"Available action keys: {list(d.keys())}"
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
vec_k = pick_action_key(a0, ["gt_action_vector", "action_vector", "vector", "vec"], "gt_action_vector")
|
| 711 |
+
chk_k = pick_action_key(a0, ["gt_action_chunk", "action_chunk", "chunk", "chk"], "gt_action_chunk")
|
| 712 |
+
trj_k = pick_action_key(a0, ["gt_action_trajectory", "action_trajectory", "trajectory", "traj"], "gt_action_trajectory")
|
| 713 |
+
|
| 714 |
+
gt_action_vector = torch.stack([b["action"][vec_k] for b in batch_list], dim=0).float()
|
| 715 |
+
gt_action_chunk = torch.stack([b["action"][chk_k] for b in batch_list], dim=0).float()
|
| 716 |
+
gt_action_trajectory = torch.stack([b["action"][trj_k] for b in batch_list], dim=0).float()
|
| 717 |
+
else:
|
| 718 |
+
act = torch.stack([b["action"] for b in batch_list], dim=0).float()
|
| 719 |
+
gt_action_vector = act
|
| 720 |
+
gt_action_chunk = act
|
| 721 |
+
gt_action_trajectory = act
|
| 722 |
+
else:
|
| 723 |
+
gt_vec_k = pick_key(s0, ["gt_action_vector", "action_vector", "gt_vec"], "gt_action_vector")
|
| 724 |
+
gt_chk_k = pick_key(s0, ["gt_action_chunk", "action_chunk", "gt_chunk"], "gt_action_chunk")
|
| 725 |
+
gt_trj_k = pick_key(s0, ["gt_action_trajectory", "action_trajectory", "gt_traj"], "gt_action_trajectory")
|
| 726 |
+
|
| 727 |
+
gt_action_vector = torch.stack([b[gt_vec_k] for b in batch_list], dim=0).float()
|
| 728 |
+
gt_action_chunk = torch.stack([b[gt_chk_k] for b in batch_list], dim=0).float()
|
| 729 |
+
gt_action_trajectory = torch.stack([b[gt_trj_k] for b in batch_list], dim=0).float()
|
| 730 |
+
|
| 731 |
+
# Optional offline base actions for adapter; if missing, we simply do not include them.
|
| 732 |
+
base_action_vector = None
|
| 733 |
+
base_action_chunk = None
|
| 734 |
+
base_action_trajectory = None
|
| 735 |
+
|
| 736 |
+
has_base_top = any(
|
| 737 |
+
k in s0
|
| 738 |
+
for k in ["base_action_vector", "base_action_chunk", "base_action_trajectory"]
|
| 739 |
+
)
|
| 740 |
+
has_base_in_action = "action" in s0 and isinstance(s0["action"], dict) and any(
|
| 741 |
+
k in s0["action"]
|
| 742 |
+
for k in ["base_action_vector", "base_action_chunk", "base_action_trajectory"]
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
if has_base_top:
|
| 746 |
+
if "base_action_vector" in s0:
|
| 747 |
+
base_action_vector = torch.stack([b["base_action_vector"] for b in batch_list], dim=0).float()
|
| 748 |
+
if "base_action_chunk" in s0:
|
| 749 |
+
base_action_chunk = torch.stack([b["base_action_chunk"] for b in batch_list], dim=0).float()
|
| 750 |
+
if "base_action_trajectory" in s0:
|
| 751 |
+
base_action_trajectory = torch.stack([b["base_action_trajectory"] for b in batch_list], dim=0).float()
|
| 752 |
+
elif has_base_in_action:
|
| 753 |
+
a0 = s0["action"]
|
| 754 |
+
def pick_base_key(d, candidates):
|
| 755 |
+
for k in candidates:
|
| 756 |
+
if k in d:
|
| 757 |
+
return k
|
| 758 |
+
return None
|
| 759 |
+
|
| 760 |
+
vec_bk = pick_base_key(a0, ["base_action_vector", "base_vec"])
|
| 761 |
+
chk_bk = pick_base_key(a0, ["base_action_chunk", "base_chunk"])
|
| 762 |
+
trj_bk = pick_base_key(a0, ["base_action_trajectory", "base_traj"])
|
| 763 |
+
|
| 764 |
+
if vec_bk is not None:
|
| 765 |
+
base_action_vector = torch.stack([b["action"][vec_bk] for b in batch_list], dim=0).float()
|
| 766 |
+
if chk_bk is not None:
|
| 767 |
+
base_action_chunk = torch.stack([b["action"][chk_bk] for b in batch_list], dim=0).float()
|
| 768 |
+
if trj_bk is not None:
|
| 769 |
+
base_action_trajectory = torch.stack([b["action"][trj_bk] for b in batch_list], dim=0).float()
|
| 770 |
+
|
| 771 |
+
batch: Dict[str, Any] = {
|
| 772 |
+
"vis_obs": vis_obs,
|
| 773 |
+
"depth_obs": depth_obs,
|
| 774 |
+
"audio_obs": audio_obs,
|
| 775 |
+
"robot_state": robot_state,
|
| 776 |
+
"texts": texts,
|
| 777 |
+
"gt_action_vector": gt_action_vector,
|
| 778 |
+
"gt_action_chunk": gt_action_chunk,
|
| 779 |
+
"gt_action_trajectory": gt_action_trajectory,
|
| 780 |
+
}
|
| 781 |
+
|
| 782 |
+
if base_action_vector is not None:
|
| 783 |
+
batch["base_action_vector"] = base_action_vector
|
| 784 |
+
if base_action_chunk is not None:
|
| 785 |
+
batch["base_action_chunk"] = base_action_chunk
|
| 786 |
+
if base_action_trajectory is not None:
|
| 787 |
+
batch["base_action_trajectory"] = base_action_trajectory
|
| 788 |
+
|
| 789 |
+
return batch
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
def _align_target(pred_t: torch.Tensor, gt_t: torch.Tensor) -> torch.Tensor:
|
| 793 |
+
"""
|
| 794 |
+
Align GT to prediction for MSE:
|
| 795 |
+
- handle 2D vs 3D mismatches by collapsing or expanding time dimension.
|
| 796 |
+
- then align last-dim by pad/trim.
|
| 797 |
+
"""
|
| 798 |
+
if gt_t.dim() == 3 and pred_t.dim() == 2:
|
| 799 |
+
gt_t = gt_t[:, -1, :]
|
| 800 |
+
|
| 801 |
+
if pred_t.dim() == 3 and gt_t.dim() == 2:
|
| 802 |
+
gt_t = gt_t.unsqueeze(1)
|
| 803 |
+
if gt_t.size(1) != pred_t.size(1):
|
| 804 |
+
gt_t = gt_t.expand(-1, pred_t.size(1), -1)
|
| 805 |
+
|
| 806 |
+
if pred_t.dim() == 3 and gt_t.dim() == 3:
|
| 807 |
+
Tp = pred_t.size(1)
|
| 808 |
+
Tg = gt_t.size(1)
|
| 809 |
+
if Tg < Tp:
|
| 810 |
+
pad = torch.zeros(
|
| 811 |
+
gt_t.size(0), Tp - Tg, gt_t.size(2),
|
| 812 |
+
device=gt_t.device, dtype=gt_t.dtype
|
| 813 |
+
)
|
| 814 |
+
gt_t = torch.cat([gt_t, pad], dim=1)
|
| 815 |
+
elif Tg > Tp:
|
| 816 |
+
gt_t = gt_t[:, :Tp, :]
|
| 817 |
+
|
| 818 |
+
pd = pred_t.size(-1)
|
| 819 |
+
gd = gt_t.size(-1)
|
| 820 |
+
if gd < pd:
|
| 821 |
+
gt_t = F.pad(gt_t, (0, pd - gd))
|
| 822 |
+
elif gd > pd:
|
| 823 |
+
gt_t = gt_t[..., :pd]
|
| 824 |
+
|
| 825 |
+
return gt_t
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
def _pred_action(pred: Dict[str, torch.Tensor], key: str) -> torch.Tensor:
|
| 829 |
+
if key in pred:
|
| 830 |
+
return pred[key]
|
| 831 |
+
if "action" in pred:
|
| 832 |
+
return pred["action"]
|
| 833 |
+
raise KeyError(
|
| 834 |
+
f"Pred dict missing action key '{key}' and fallback 'action'. "
|
| 835 |
+
f"Available pred keys: {list(pred.keys())}"
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
@torch.no_grad()
|
| 840 |
+
def compute_branch_mse(pred: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, float]:
|
| 841 |
+
vec_pred = _pred_action(pred, "action_vector")
|
| 842 |
+
chk_pred = _pred_action(pred, "action_chunk")
|
| 843 |
+
trj_pred = _pred_action(pred, "action_trajectory")
|
| 844 |
+
|
| 845 |
+
device = vec_pred.device
|
| 846 |
+
|
| 847 |
+
gt_vec = _align_target(vec_pred, batch["gt_action_vector"].to(device))
|
| 848 |
+
gt_chk = _align_target(chk_pred, batch["gt_action_chunk"].to(device))
|
| 849 |
+
gt_trj = _align_target(trj_pred, batch["gt_action_trajectory"].to(device))
|
| 850 |
+
|
| 851 |
+
mse_vec = F.mse_loss(vec_pred, gt_vec).item()
|
| 852 |
+
mse_chk = F.mse_loss(chk_pred, gt_chk).item()
|
| 853 |
+
mse_trj = F.mse_loss(trj_pred, gt_trj).item()
|
| 854 |
+
return {"mse_vector": mse_vec, "mse_chunk": mse_chk, "mse_traj": mse_trj}
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
@torch.no_grad()
|
| 858 |
+
def compute_success_proxy(
|
| 859 |
+
pred: Dict[str, torch.Tensor],
|
| 860 |
+
batch: Dict[str, Any],
|
| 861 |
+
thr_vec: float,
|
| 862 |
+
thr_chk: float,
|
| 863 |
+
thr_trj: float,
|
| 864 |
+
) -> Tuple[int, int, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 865 |
+
"""
|
| 866 |
+
Returns:
|
| 867 |
+
num_success, num_total, mse_vec_per_sample, mse_chk_per_sample, mse_trj_per_sample
|
| 868 |
+
where per-sample MSE is averaged over all non-batch dims.
|
| 869 |
+
"""
|
| 870 |
+
vec_pred = _pred_action(pred, "action_vector")
|
| 871 |
+
chk_pred = _pred_action(pred, "action_chunk")
|
| 872 |
+
trj_pred = _pred_action(pred, "action_trajectory")
|
| 873 |
+
|
| 874 |
+
device = vec_pred.device
|
| 875 |
+
|
| 876 |
+
gt_vec = _align_target(vec_pred, batch["gt_action_vector"].to(device))
|
| 877 |
+
gt_chk = _align_target(chk_pred, batch["gt_action_chunk"].to(device))
|
| 878 |
+
gt_trj = _align_target(trj_pred, batch["gt_action_trajectory"].to(device))
|
| 879 |
+
|
| 880 |
+
reduce_dims_vec = list(range(1, vec_pred.dim()))
|
| 881 |
+
reduce_dims_chk = list(range(1, chk_pred.dim()))
|
| 882 |
+
reduce_dims_trj = list(range(1, trj_pred.dim()))
|
| 883 |
+
|
| 884 |
+
mse_vec_s = ((vec_pred - gt_vec) ** 2).mean(dim=reduce_dims_vec)
|
| 885 |
+
mse_chk_s = ((chk_pred - gt_chk) ** 2).mean(dim=reduce_dims_chk)
|
| 886 |
+
mse_trj_s = ((trj_pred - gt_trj) ** 2).mean(dim=reduce_dims_trj)
|
| 887 |
+
|
| 888 |
+
success_mask = (mse_vec_s < thr_vec) & (mse_chk_s < thr_chk) & (mse_trj_s < thr_trj)
|
| 889 |
+
num_success = int(success_mask.sum().item())
|
| 890 |
+
num_total = int(success_mask.numel())
|
| 891 |
+
|
| 892 |
+
return num_success, num_total, mse_vec_s, mse_chk_s, mse_trj_s
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
@torch.no_grad()
|
| 896 |
+
def compute_telepathy_stability(pred: Dict[str, torch.Tensor]) -> float:
|
| 897 |
+
tau = pred.get("telepathy_factors", None)
|
| 898 |
+
if tau is None:
|
| 899 |
+
return float("nan")
|
| 900 |
+
return float((tau ** 2).mean().item())
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
@torch.no_grad()
|
| 904 |
+
def cosine_alignment(a: torch.Tensor, b: torch.Tensor) -> float:
|
| 905 |
+
"""
|
| 906 |
+
Cosine alignment that is robust to hidden-size mismatch.
|
| 907 |
+
Accepts [B, D] or [B, T, D]. Pools time if present.
|
| 908 |
+
If dims differ, crops both to min(Da, Db) for a fair cosine check.
|
| 909 |
+
"""
|
| 910 |
+
if a.dim() == 3:
|
| 911 |
+
a = a.mean(dim=1)
|
| 912 |
+
if b.dim() == 3:
|
| 913 |
+
b = b.mean(dim=1)
|
| 914 |
+
|
| 915 |
+
if a.numel() == 0 or b.numel() == 0:
|
| 916 |
+
return float("nan")
|
| 917 |
+
|
| 918 |
+
da, db = a.size(-1), b.size(-1)
|
| 919 |
+
if da != db:
|
| 920 |
+
d = min(da, db)
|
| 921 |
+
a = a[..., :d]
|
| 922 |
+
b = b[..., :d]
|
| 923 |
+
|
| 924 |
+
a = F.normalize(a, dim=-1)
|
| 925 |
+
b = F.normalize(b, dim=-1)
|
| 926 |
+
return float((a * b).sum(dim=-1).mean().item())
|
| 927 |
+
|
| 928 |
+
|
| 929 |
+
@torch.no_grad()
|
| 930 |
+
def build_text_tokens_from_policy(
|
| 931 |
+
tokenizer,
|
| 932 |
+
text_embed_layer: nn.Module,
|
| 933 |
+
texts: List[str],
|
| 934 |
+
device: torch.device,
|
| 935 |
+
target_dtype: torch.dtype,
|
| 936 |
+
max_text_len: int = 0,
|
| 937 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 938 |
+
"""
|
| 939 |
+
Tokenize prompts and map to embeddings using PI05 internal embedding layer.
|
| 940 |
+
Returns (text_tokens, attn_mask).
|
| 941 |
+
"""
|
| 942 |
+
if max_text_len and max_text_len > 0:
|
| 943 |
+
tok = tokenizer(
|
| 944 |
+
texts,
|
| 945 |
+
padding=True,
|
| 946 |
+
truncation=True,
|
| 947 |
+
max_length=max_text_len,
|
| 948 |
+
return_tensors="pt",
|
| 949 |
+
)
|
| 950 |
+
else:
|
| 951 |
+
tok = tokenizer(
|
| 952 |
+
texts,
|
| 953 |
+
padding=True,
|
| 954 |
+
truncation=False,
|
| 955 |
+
return_tensors="pt",
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
if hasattr(tok, "input_ids"):
|
| 959 |
+
input_ids = tok.input_ids
|
| 960 |
+
attn_mask = tok.attention_mask
|
| 961 |
+
else:
|
| 962 |
+
input_ids = tok["input_ids"]
|
| 963 |
+
attn_mask = tok.get("attention_mask", None)
|
| 964 |
+
if attn_mask is None:
|
| 965 |
+
attn_mask = torch.ones_like(input_ids)
|
| 966 |
+
|
| 967 |
+
input_ids = input_ids.to(device)
|
| 968 |
+
attn_mask = attn_mask.to(device)
|
| 969 |
+
|
| 970 |
+
text_tokens = text_embed_layer(input_ids).to(dtype=target_dtype)
|
| 971 |
+
return text_tokens, attn_mask
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
def main():
|
| 975 |
+
parser = argparse.ArgumentParser()
|
| 976 |
+
|
| 977 |
+
parser.add_argument("--sigma_env", type=str, default="sigma.env")
|
| 978 |
+
parser.add_argument("--shard_dir", type=str, default="")
|
| 979 |
+
parser.add_argument("--output_dir", type=str, default="./sigma_eval_out")
|
| 980 |
+
|
| 981 |
+
parser.add_argument(
|
| 982 |
+
"--base_model_id",
|
| 983 |
+
type=str,
|
| 984 |
+
required=True,
|
| 985 |
+
help="LeRobot/OpenPI policy repo, e.g., lerobot/pi05_base or your Sigma policy repo.",
|
| 986 |
+
)
|
| 987 |
+
parser.add_argument(
|
| 988 |
+
"--telepathy_heads_path",
|
| 989 |
+
type=str,
|
| 990 |
+
default="",
|
| 991 |
+
help="Path to sigma_telepathy_heads.pt. If empty, auto-fetch may fill it.",
|
| 992 |
+
)
|
| 993 |
+
parser.add_argument(
|
| 994 |
+
"--disable_telepathy",
|
| 995 |
+
action="store_true",
|
| 996 |
+
help="Disable telepathy injection (control run).",
|
| 997 |
+
)
|
| 998 |
+
parser.add_argument(
|
| 999 |
+
"--tokenizer_id",
|
| 1000 |
+
type=str,
|
| 1001 |
+
default="",
|
| 1002 |
+
help="Explicit HF tokenizer id OR local tokenizer folder path.",
|
| 1003 |
+
)
|
| 1004 |
+
|
| 1005 |
+
parser.add_argument("--max_text_len", type=int, default=0)
|
| 1006 |
+
|
| 1007 |
+
parser.add_argument(
|
| 1008 |
+
"--artifacts_repo_id",
|
| 1009 |
+
type=str,
|
| 1010 |
+
default="",
|
| 1011 |
+
help="HF repo containing storage/sigma_pickplace and storage/sigma_lora_out.",
|
| 1012 |
+
)
|
| 1013 |
+
parser.add_argument(
|
| 1014 |
+
"--hf_cache_root",
|
| 1015 |
+
type=str,
|
| 1016 |
+
default="/workspace/.hf_sigma_cache",
|
| 1017 |
+
)
|
| 1018 |
+
|
| 1019 |
+
parser.add_argument("--load_in_4bit", action="store_true")
|
| 1020 |
+
parser.add_argument("--dtype", type=str, default="bf16")
|
| 1021 |
+
|
| 1022 |
+
parser.add_argument("--batch_size", type=int, default=4)
|
| 1023 |
+
parser.add_argument("--num_workers", type=int, default=2)
|
| 1024 |
+
parser.add_argument("--max_batches", type=int, default=-1)
|
| 1025 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 1026 |
+
parser.add_argument(
|
| 1027 |
+
"--shuffle",
|
| 1028 |
+
action="store_true",
|
| 1029 |
+
help="Shuffle dataset order to enable different random subsets per seed.",
|
| 1030 |
+
)
|
| 1031 |
+
parser.add_argument(
|
| 1032 |
+
"--telepathy_scale",
|
| 1033 |
+
type=float,
|
| 1034 |
+
default=1.0,
|
| 1035 |
+
help="Multiply telepathy_factors (tau) to control injection strength.",
|
| 1036 |
+
)
|
| 1037 |
+
|
| 1038 |
+
parser.add_argument("--succ_thr_vec", type=float, default=0.05)
|
| 1039 |
+
parser.add_argument("--succ_thr_chk", type=float, default=0.10)
|
| 1040 |
+
parser.add_argument("--succ_thr_trj", type=float, default=0.10)
|
| 1041 |
+
|
| 1042 |
+
# Hard-set thresholds: if <=0, they default to 2x the success thresholds.
|
| 1043 |
+
parser.add_argument(
|
| 1044 |
+
"--hard_thr_vec",
|
| 1045 |
+
type=float,
|
| 1046 |
+
default=-1.0,
|
| 1047 |
+
help="Per-sample MSE threshold for the 'hard' set on vector branch; <=0 means 2x succ_thr_vec.",
|
| 1048 |
+
)
|
| 1049 |
+
parser.add_argument(
|
| 1050 |
+
"--hard_thr_chk",
|
| 1051 |
+
type=float,
|
| 1052 |
+
default=-1.0,
|
| 1053 |
+
help="Per-sample MSE threshold for the 'hard' set on chunk branch; <=0 means 2x succ_thr_chk.",
|
| 1054 |
+
)
|
| 1055 |
+
parser.add_argument(
|
| 1056 |
+
"--hard_thr_trj",
|
| 1057 |
+
type=float,
|
| 1058 |
+
default=-1.0,
|
| 1059 |
+
help="Per-sample MSE threshold for the 'hard' set on trajectory branch; <=0 means 2x succ_thr_trj.",
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
parser.add_argument(
|
| 1063 |
+
"--strict_pi05_load",
|
| 1064 |
+
action="store_true",
|
| 1065 |
+
help="Try strict PI05Policy loading if supported by LeRobot.",
|
| 1066 |
+
)
|
| 1067 |
+
parser.add_argument(
|
| 1068 |
+
"--allow_tokenizer_mismatch",
|
| 1069 |
+
action="store_true",
|
| 1070 |
+
help="Do not fail on tokenizer/embedding mismatch (NOT recommended for baseline).",
|
| 1071 |
+
)
|
| 1072 |
+
|
| 1073 |
+
# Simple flag to enable/disable the adapter without touching telepathy itself.
|
| 1074 |
+
parser.add_argument(
|
| 1075 |
+
"--use_telepathy_adapter",
|
| 1076 |
+
action="store_true",
|
| 1077 |
+
help="If set and telepathy is enabled, apply sigma_telepathy_adapter to actions in eval.",
|
| 1078 |
+
)
|
| 1079 |
+
|
| 1080 |
+
args = parser.parse_args()
|
| 1081 |
+
|
| 1082 |
+
if os.path.exists(args.sigma_env):
|
| 1083 |
+
load_dotenv(args.sigma_env)
|
| 1084 |
+
hf_token = os.getenv("HF_TOKEN", None)
|
| 1085 |
+
|
| 1086 |
+
accelerator = Accelerator(mixed_precision=args.dtype if args.dtype != "fp32" else "no")
|
| 1087 |
+
set_seed(args.seed)
|
| 1088 |
+
device = accelerator.device
|
| 1089 |
+
|
| 1090 |
+
if args.load_in_4bit:
|
| 1091 |
+
print("[WARN] --load_in_4bit is ignored for PI05Policy evaluator.")
|
| 1092 |
+
|
| 1093 |
+
artifacts_repo = args.artifacts_repo_id.strip()
|
| 1094 |
+
if not artifacts_repo and args.base_model_id.startswith("Veltraxor/"):
|
| 1095 |
+
artifacts_repo = args.base_model_id
|
| 1096 |
+
|
| 1097 |
+
need_shards = (not args.shard_dir) or (not os.path.isdir(args.shard_dir))
|
| 1098 |
+
need_heads = (not args.telepathy_heads_path) or (not os.path.isfile(args.telepathy_heads_path))
|
| 1099 |
+
|
| 1100 |
+
if artifacts_repo and (need_shards or need_heads):
|
| 1101 |
+
paths = ensure_sigma_artifacts_from_hf(
|
| 1102 |
+
repo_id=artifacts_repo,
|
| 1103 |
+
hf_token=hf_token,
|
| 1104 |
+
local_cache_root=args.hf_cache_root,
|
| 1105 |
+
)
|
| 1106 |
+
if need_shards:
|
| 1107 |
+
args.shard_dir = paths["shard_dir"]
|
| 1108 |
+
print(f"[INFO] Using cached shard_dir: {args.shard_dir}")
|
| 1109 |
+
if need_heads:
|
| 1110 |
+
args.telepathy_heads_path = paths["telepathy_heads_path"]
|
| 1111 |
+
print(f"[INFO] Using cached telepathy_heads_path: {args.telepathy_heads_path}")
|
| 1112 |
+
|
| 1113 |
+
if not args.shard_dir or not os.path.isdir(args.shard_dir):
|
| 1114 |
+
raise FileNotFoundError(
|
| 1115 |
+
f"shard_dir not found locally: {args.shard_dir}. "
|
| 1116 |
+
"Either provide a valid local path or an artifacts_repo_id for auto-download."
|
| 1117 |
+
)
|
| 1118 |
+
|
| 1119 |
+
if not args.telepathy_heads_path or not os.path.isfile(args.telepathy_heads_path):
|
| 1120 |
+
raise FileNotFoundError(
|
| 1121 |
+
f"telepathy_heads_path not found locally: {args.telepathy_heads_path}. "
|
| 1122 |
+
"Either provide a valid local path or store it under storage/sigma_lora_out/ "
|
| 1123 |
+
"in artifacts_repo_id for auto-download."
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
policy = load_pi05_policy(
|
| 1127 |
+
args.base_model_id,
|
| 1128 |
+
hf_token,
|
| 1129 |
+
device=device,
|
| 1130 |
+
strict_load=args.strict_pi05_load,
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
tokenizer = get_policy_tokenizer(
|
| 1134 |
+
policy,
|
| 1135 |
+
args.base_model_id,
|
| 1136 |
+
hf_token,
|
| 1137 |
+
forced_tokenizer_id=args.tokenizer_id,
|
| 1138 |
+
)
|
| 1139 |
+
text_embed_layer = get_policy_text_embedding_layer(policy)
|
| 1140 |
+
|
| 1141 |
+
verify_tokenizer_embedding_compat(
|
| 1142 |
+
tokenizer=tokenizer,
|
| 1143 |
+
text_embed_layer=text_embed_layer,
|
| 1144 |
+
allow_mismatch=args.allow_tokenizer_mismatch,
|
| 1145 |
+
)
|
| 1146 |
+
|
| 1147 |
+
v_cfg = VisionConfig()
|
| 1148 |
+
l_cfg = LanguageConfig()
|
| 1149 |
+
a_cfg = ActionConfig()
|
| 1150 |
+
telepathy_vla = TelepathyVLA(v_cfg, l_cfg, a_cfg, disable_telepathy=args.disable_telepathy)
|
| 1151 |
+
telepathy_vla.telepathy_scale = args.telepathy_scale
|
| 1152 |
+
|
| 1153 |
+
# Instantiate Telepathy adapter (used only when telepathy is enabled and flag is set).
|
| 1154 |
+
adapter_cfg = SigmaTelepathyAdapterConfig()
|
| 1155 |
+
telepathy_adapter = SigmaTelepathyAdapter(adapter_cfg).to(device)
|
| 1156 |
+
|
| 1157 |
+
if accelerator.is_main_process:
|
| 1158 |
+
file_size_mb = os.path.getsize(args.telepathy_heads_path) / (1024 * 1024)
|
| 1159 |
+
print(f"[CHECK-A] disable_telepathy={args.disable_telepathy}")
|
| 1160 |
+
print(f"[CHECK-A] telepathy_heads_path={args.telepathy_heads_path} size={file_size_mb:.2f}MB")
|
| 1161 |
+
|
| 1162 |
+
sd = torch.load(args.telepathy_heads_path, map_location="cpu")
|
| 1163 |
+
|
| 1164 |
+
tensor_list = [v.detach().float().reshape(-1) for v in sd.values() if torch.is_tensor(v)]
|
| 1165 |
+
if accelerator.is_main_process and len(tensor_list) > 0:
|
| 1166 |
+
capped = [t[:100000] for t in tensor_list]
|
| 1167 |
+
flat = torch.cat(capped, dim=0)
|
| 1168 |
+
rms = torch.sqrt((flat ** 2).mean()).item()
|
| 1169 |
+
print(f"[CHECK-A] heads_tensors={len(tensor_list)} mean={flat.mean().item():.6f} std={flat.std().item():.6f} rms={rms:.6f}")
|
| 1170 |
+
|
| 1171 |
+
missing, unexpected = telepathy_vla.load_state_dict(sd, strict=False)
|
| 1172 |
+
if accelerator.is_main_process:
|
| 1173 |
+
if len(missing) > 0 or len(unexpected) > 0:
|
| 1174 |
+
print(f"[CHECK-A] loaded with strict=False. Missing={len(missing)} Unexpected={len(unexpected)}")
|
| 1175 |
+
print(f"[CHECK-A] Missing keys (first 20): {missing[:20]}")
|
| 1176 |
+
print(f"[CHECK-A] Unexpected keys (first 20): {unexpected[:20]}")
|
| 1177 |
+
else:
|
| 1178 |
+
print("[CHECK-A] heads fully matched (no missing/unexpected).")
|
| 1179 |
+
|
| 1180 |
+
telepathy_vla.eval()
|
| 1181 |
+
|
| 1182 |
+
ds = SigmaShardDataset(args.shard_dir)
|
| 1183 |
+
dl = DataLoader(
|
| 1184 |
+
ds,
|
| 1185 |
+
batch_size=args.batch_size,
|
| 1186 |
+
shuffle=args.shuffle,
|
| 1187 |
+
num_workers=args.num_workers,
|
| 1188 |
+
collate_fn=collate_sigma,
|
| 1189 |
+
drop_last=False,
|
| 1190 |
+
pin_memory=torch.cuda.is_available(),
|
| 1191 |
+
)
|
| 1192 |
+
|
| 1193 |
+
telepathy_vla, dl = accelerator.prepare(telepathy_vla, dl)
|
| 1194 |
+
target_dtype = next(telepathy_vla.parameters()).dtype
|
| 1195 |
+
|
| 1196 |
+
sum_mse_vec = 0.0
|
| 1197 |
+
sum_mse_chk = 0.0
|
| 1198 |
+
sum_mse_trj = 0.0
|
| 1199 |
+
sum_tau_l2 = 0.0
|
| 1200 |
+
sum_sem_align = 0.0
|
| 1201 |
+
|
| 1202 |
+
# Hard-set aggregators
|
| 1203 |
+
hard_thr_vec = args.hard_thr_vec if args.hard_thr_vec > 0.0 else 2.0 * args.succ_thr_vec
|
| 1204 |
+
hard_thr_chk = args.hard_thr_chk if args.hard_thr_chk > 0.0 else 2.0 * args.succ_thr_chk
|
| 1205 |
+
hard_thr_trj = args.hard_thr_trj if args.hard_thr_trj > 0.0 else 2.0 * args.succ_thr_trj
|
| 1206 |
+
|
| 1207 |
+
sum_hard_mse_vec = 0.0
|
| 1208 |
+
sum_hard_mse_chk = 0.0
|
| 1209 |
+
sum_hard_mse_trj = 0.0
|
| 1210 |
+
total_hard_samples = 0
|
| 1211 |
+
|
| 1212 |
+
n_batches = 0
|
| 1213 |
+
n_samples = 0
|
| 1214 |
+
|
| 1215 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 1216 |
+
|
| 1217 |
+
for bidx, batch in enumerate(dl):
|
| 1218 |
+
if args.max_batches > 0 and bidx >= args.max_batches:
|
| 1219 |
+
break
|
| 1220 |
+
|
| 1221 |
+
telepathy_vla.reset_memory()
|
| 1222 |
+
|
| 1223 |
+
B = batch["vis_obs"].size(0)
|
| 1224 |
+
n_samples += B
|
| 1225 |
+
|
| 1226 |
+
text_tokens, attn_mask = build_text_tokens_from_policy(
|
| 1227 |
+
tokenizer=tokenizer,
|
| 1228 |
+
text_embed_layer=text_embed_layer,
|
| 1229 |
+
texts=batch["texts"],
|
| 1230 |
+
device=device,
|
| 1231 |
+
target_dtype=target_dtype,
|
| 1232 |
+
max_text_len=args.max_text_len,
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
robot_state = batch["robot_state"].to(device)
|
| 1236 |
+
if robot_state.dim() == 3:
|
| 1237 |
+
robot_state = robot_state[:, -1]
|
| 1238 |
+
|
| 1239 |
+
# Move optional base actions to device for the adapter.
|
| 1240 |
+
if "base_action_vector" in batch:
|
| 1241 |
+
batch["base_action_vector"] = batch["base_action_vector"].to(device)
|
| 1242 |
+
if "base_action_chunk" in batch:
|
| 1243 |
+
batch["base_action_chunk"] = batch["base_action_chunk"].to(device)
|
| 1244 |
+
if "base_action_trajectory" in batch:
|
| 1245 |
+
batch["base_action_trajectory"] = batch["base_action_trajectory"].to(device)
|
| 1246 |
+
|
| 1247 |
+
try:
|
| 1248 |
+
expected_d = telepathy_vla.vision.state_encoder.mlp[0].in_features
|
| 1249 |
+
except Exception:
|
| 1250 |
+
expected_d = robot_state.size(-1)
|
| 1251 |
+
|
| 1252 |
+
cur_d = robot_state.size(-1)
|
| 1253 |
+
if cur_d < expected_d:
|
| 1254 |
+
robot_state = F.pad(robot_state, (0, expected_d - cur_d))
|
| 1255 |
+
elif cur_d > expected_d:
|
| 1256 |
+
robot_state = robot_state[..., :expected_d]
|
| 1257 |
+
|
| 1258 |
+
pred = telepathy_vla.forward_once(
|
| 1259 |
+
vis_obs=batch["vis_obs"].to(device),
|
| 1260 |
+
robot_state=robot_state,
|
| 1261 |
+
depth_obs=batch["depth_obs"].to(device) if batch["depth_obs"] is not None else None,
|
| 1262 |
+
audio_obs=batch["audio_obs"].to(device) if batch["audio_obs"] is not None else None,
|
| 1263 |
+
text_tokens=text_tokens,
|
| 1264 |
+
attn_mask=attn_mask,
|
| 1265 |
+
return_intermediate=True,
|
| 1266 |
+
)
|
| 1267 |
+
|
| 1268 |
+
if accelerator.is_main_process and bidx == 0:
|
| 1269 |
+
model_ref = telepathy_vla.module if hasattr(telepathy_vla, "module") else telepathy_vla
|
| 1270 |
+
model_ref.reset_memory()
|
| 1271 |
+
prev_flag = bool(model_ref.disable_telepathy)
|
| 1272 |
+
model_ref.disable_telepathy = True
|
| 1273 |
+
pred_ctrl = model_ref.forward_once(
|
| 1274 |
+
vis_obs=batch["vis_obs"].to(device),
|
| 1275 |
+
robot_state=robot_state,
|
| 1276 |
+
depth_obs=batch["depth_obs"].to(device) if batch["depth_obs"] is not None else None,
|
| 1277 |
+
audio_obs=batch["audio_obs"].to(device) if batch["audio_obs"] is not None else None,
|
| 1278 |
+
text_tokens=text_tokens,
|
| 1279 |
+
attn_mask=attn_mask,
|
| 1280 |
+
return_intermediate=False,
|
| 1281 |
+
)
|
| 1282 |
+
model_ref.disable_telepathy = prev_flag
|
| 1283 |
+
|
| 1284 |
+
try:
|
| 1285 |
+
act_exp = _pred_action(pred, "action_vector")
|
| 1286 |
+
act_ctl = _pred_action(pred_ctrl, "action_vector")
|
| 1287 |
+
diff = (act_exp - act_ctl).abs().mean().item()
|
| 1288 |
+
print(f"[CHECK-B] telepathy_effect_mean_abs_diff(action_vector)={diff:.6f}")
|
| 1289 |
+
except Exception as e:
|
| 1290 |
+
print(f"[CHECK-B] action diff check failed: {type(e).__name__}: {e}")
|
| 1291 |
+
|
| 1292 |
+
# Apply Telepathy adapter only when telepathy is enabled and the flag is set.
|
| 1293 |
+
if (not args.disable_telepathy) and args.use_telepathy_adapter:
|
| 1294 |
+
pred = telepathy_adapter(pred, batch)
|
| 1295 |
+
|
| 1296 |
+
mse = compute_branch_mse(pred, batch)
|
| 1297 |
+
tau_l2 = compute_telepathy_stability(pred)
|
| 1298 |
+
|
| 1299 |
+
(
|
| 1300 |
+
_,
|
| 1301 |
+
_,
|
| 1302 |
+
mse_vec_s,
|
| 1303 |
+
mse_chk_s,
|
| 1304 |
+
mse_trj_s,
|
| 1305 |
+
) = compute_success_proxy(
|
| 1306 |
+
pred,
|
| 1307 |
+
batch,
|
| 1308 |
+
thr_vec=args.succ_thr_vec,
|
| 1309 |
+
thr_chk=args.succ_thr_chk,
|
| 1310 |
+
thr_trj=args.succ_thr_trj,
|
| 1311 |
+
)
|
| 1312 |
+
|
| 1313 |
+
# Hard-set accumulation: samples where any branch MSE exceeds hard thresholds
|
| 1314 |
+
hard_mask = (mse_vec_s > hard_thr_vec) | (mse_chk_s > hard_thr_chk) | (mse_trj_s > hard_thr_trj)
|
| 1315 |
+
hard_count = int(hard_mask.sum().item())
|
| 1316 |
+
if hard_count > 0:
|
| 1317 |
+
sum_hard_mse_vec += mse_vec_s[hard_mask].sum().item()
|
| 1318 |
+
sum_hard_mse_chk += mse_chk_s[hard_mask].sum().item()
|
| 1319 |
+
sum_hard_mse_trj += mse_trj_s[hard_mask].sum().item()
|
| 1320 |
+
total_hard_samples += hard_count
|
| 1321 |
+
|
| 1322 |
+
sem_factors = pred.get("semantic_factors", None)
|
| 1323 |
+
if sem_factors is not None:
|
| 1324 |
+
if sem_factors.dim() == 3:
|
| 1325 |
+
sem_pool = sem_factors.mean(dim=1)
|
| 1326 |
+
elif sem_factors.dim() == 2:
|
| 1327 |
+
sem_pool = sem_factors
|
| 1328 |
+
else:
|
| 1329 |
+
sem_pool = sem_factors.view(sem_factors.size(0), -1)
|
| 1330 |
+
|
| 1331 |
+
txt_pool = text_tokens.mean(dim=1)
|
| 1332 |
+
sem_align = cosine_alignment(sem_pool, txt_pool)
|
| 1333 |
+
else:
|
| 1334 |
+
sem_align = float("nan")
|
| 1335 |
+
|
| 1336 |
+
sum_mse_vec += mse["mse_vector"]
|
| 1337 |
+
sum_mse_chk += mse["mse_chunk"]
|
| 1338 |
+
sum_mse_trj += mse["mse_traj"]
|
| 1339 |
+
if not (tau_l2 != tau_l2):
|
| 1340 |
+
sum_tau_l2 += tau_l2
|
| 1341 |
+
if not (sem_align != sem_align):
|
| 1342 |
+
sum_sem_align += sem_align
|
| 1343 |
+
|
| 1344 |
+
n_batches += 1
|
| 1345 |
+
|
| 1346 |
+
if accelerator.is_main_process and bidx % 20 == 0:
|
| 1347 |
+
print(
|
| 1348 |
+
f"batch={bidx} "
|
| 1349 |
+
f"mse_vec={mse['mse_vector']:.4f} mse_chk={mse['mse_chunk']:.4f} mse_trj={mse['mse_traj']:.4f} "
|
| 1350 |
+
f"tau_l2={tau_l2:.4f} sem_align={sem_align:.4f}"
|
| 1351 |
+
)
|
| 1352 |
+
|
| 1353 |
+
if accelerator.is_main_process:
|
| 1354 |
+
avg_mse_vec = sum_mse_vec / max(1, n_batches)
|
| 1355 |
+
avg_mse_chk = sum_mse_chk / max(1, n_batches)
|
| 1356 |
+
avg_mse_trj = sum_mse_trj / max(1, n_batches)
|
| 1357 |
+
|
| 1358 |
+
avg_tau_l2 = sum_tau_l2 / max(1, n_batches)
|
| 1359 |
+
avg_sem_align = sum_sem_align / max(1, n_batches)
|
| 1360 |
+
|
| 1361 |
+
if total_hard_samples > 0:
|
| 1362 |
+
avg_hard_mse_vec = sum_hard_mse_vec / float(total_hard_samples)
|
| 1363 |
+
avg_hard_mse_chk = sum_hard_mse_chk / float(total_hard_samples)
|
| 1364 |
+
avg_hard_mse_trj = sum_hard_mse_trj / float(total_hard_samples)
|
| 1365 |
+
else:
|
| 1366 |
+
avg_hard_mse_vec = float("nan")
|
| 1367 |
+
avg_hard_mse_chk = float("nan")
|
| 1368 |
+
avg_hard_mse_trj = float("nan")
|
| 1369 |
+
|
| 1370 |
+
hard_fraction = float(total_hard_samples / max(1, n_samples))
|
| 1371 |
+
|
| 1372 |
+
report = {
|
| 1373 |
+
"num_samples": n_samples,
|
| 1374 |
+
"num_batches": n_batches,
|
| 1375 |
+
"avg_mse_vector": avg_mse_vec,
|
| 1376 |
+
"avg_mse_chunk": avg_mse_chk,
|
| 1377 |
+
"avg_mse_traj": avg_mse_trj,
|
| 1378 |
+
"avg_tau_l2": avg_tau_l2,
|
| 1379 |
+
"avg_semantic_text_alignment": avg_sem_align,
|
| 1380 |
+
"hard_thresholds": {
|
| 1381 |
+
"vec": hard_thr_vec,
|
| 1382 |
+
"chk": hard_thr_chk,
|
| 1383 |
+
"trj": hard_thr_trj,
|
| 1384 |
+
},
|
| 1385 |
+
"avg_hard_mse_vector": avg_hard_mse_vec,
|
| 1386 |
+
"avg_hard_mse_chunk": avg_hard_mse_chk,
|
| 1387 |
+
"avg_hard_mse_traj": avg_hard_mse_trj,
|
| 1388 |
+
"hard_sample_fraction": hard_fraction,
|
| 1389 |
+
"total_hard_samples": int(total_hard_samples),
|
| 1390 |
+
}
|
| 1391 |
+
|
| 1392 |
+
with open(
|
| 1393 |
+
os.path.join(args.output_dir, "sigma_eval_report.json"),
|
| 1394 |
+
"w",
|
| 1395 |
+
encoding="utf-8",
|
| 1396 |
+
) as f:
|
| 1397 |
+
json.dump(report, f, indent=2)
|
| 1398 |
+
print("[DONE] Saved report:", report)
|
| 1399 |
+
|
| 1400 |
+
|
| 1401 |
+
if __name__ == "__main__":
|
| 1402 |
+
main()
|
modeling_pi05.py
ADDED
|
@@ -0,0 +1,1264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import builtins
|
| 18 |
+
import logging
|
| 19 |
+
import math
|
| 20 |
+
from collections import deque
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import TYPE_CHECKING, Literal, TypedDict
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn.functional as F # noqa: N812
|
| 26 |
+
from torch import Tensor, nn
|
| 27 |
+
from typing_extensions import Unpack
|
| 28 |
+
|
| 29 |
+
from lerobot.utils.import_utils import _transformers_available
|
| 30 |
+
|
| 31 |
+
# Conditional import for type checking and lazy loading
|
| 32 |
+
if TYPE_CHECKING or _transformers_available:
|
| 33 |
+
from transformers.models.auto import CONFIG_MAPPING
|
| 34 |
+
from transformers.models.gemma import modeling_gemma
|
| 35 |
+
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
| 36 |
+
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
| 37 |
+
else:
|
| 38 |
+
CONFIG_MAPPING = None
|
| 39 |
+
modeling_gemma = None
|
| 40 |
+
GemmaForCausalLM = None
|
| 41 |
+
PaliGemmaForConditionalGeneration = None
|
| 42 |
+
|
| 43 |
+
from lerobot.configs.policies import PreTrainedConfig
|
| 44 |
+
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
| 45 |
+
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
| 46 |
+
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
| 47 |
+
from lerobot.utils.constants import (
|
| 48 |
+
ACTION,
|
| 49 |
+
OBS_LANGUAGE_ATTENTION_MASK,
|
| 50 |
+
OBS_LANGUAGE_TOKENS,
|
| 51 |
+
OPENPI_ATTENTION_MASK_VALUE,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class ActionSelectKwargs(TypedDict, total=False):
|
| 56 |
+
inference_delay: int | None
|
| 57 |
+
prev_chunk_left_over: Tensor | None
|
| 58 |
+
execution_horizon: int | None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_safe_dtype(target_dtype, device_type):
|
| 62 |
+
"""Get a safe dtype for the given device type."""
|
| 63 |
+
if device_type == "mps" and target_dtype == torch.float64:
|
| 64 |
+
return torch.float32
|
| 65 |
+
if device_type == "cpu":
|
| 66 |
+
# CPU doesn't support bfloat16, use float32 instead
|
| 67 |
+
if target_dtype == torch.bfloat16:
|
| 68 |
+
return torch.float32
|
| 69 |
+
if target_dtype == torch.float64:
|
| 70 |
+
return torch.float64
|
| 71 |
+
return target_dtype
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
|
| 75 |
+
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
| 76 |
+
) -> Tensor:
|
| 77 |
+
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
| 78 |
+
if dimension % 2 != 0:
|
| 79 |
+
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
| 80 |
+
|
| 81 |
+
if time.ndim != 1:
|
| 82 |
+
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
| 83 |
+
|
| 84 |
+
dtype = get_safe_dtype(torch.float64, device.type)
|
| 85 |
+
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
| 86 |
+
period = min_period * (max_period / min_period) ** fraction
|
| 87 |
+
|
| 88 |
+
# Compute the outer product
|
| 89 |
+
scaling_factor = 1.0 / period * 2 * math.pi
|
| 90 |
+
sin_input = scaling_factor[None, :] * time[:, None]
|
| 91 |
+
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
| 95 |
+
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
| 96 |
+
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
| 97 |
+
dist = torch.distributions.Beta(alpha_t, beta_t)
|
| 98 |
+
return dist.sample((bsize,))
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
|
| 102 |
+
"""Copied from big_vision.
|
| 103 |
+
|
| 104 |
+
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
| 105 |
+
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
| 106 |
+
setup several types of attention, for example:
|
| 107 |
+
|
| 108 |
+
[[1 1 1 1 1 1]]: pure causal attention.
|
| 109 |
+
|
| 110 |
+
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
| 111 |
+
themselves and the last 3 tokens have a causal attention. The first
|
| 112 |
+
entry could also be a 1 without changing behaviour.
|
| 113 |
+
|
| 114 |
+
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
| 115 |
+
block can attend all previous blocks and all tokens on the same block.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
input_mask: bool[B, N] true if its part of the input, false if padding.
|
| 119 |
+
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
| 120 |
+
it and 0 where it shares the same attention mask as the previous token.
|
| 121 |
+
"""
|
| 122 |
+
if att_masks.ndim != 2:
|
| 123 |
+
raise ValueError(att_masks.ndim)
|
| 124 |
+
if pad_masks.ndim != 2:
|
| 125 |
+
raise ValueError(pad_masks.ndim)
|
| 126 |
+
|
| 127 |
+
cumsum = torch.cumsum(att_masks, dim=1)
|
| 128 |
+
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
| 129 |
+
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
| 130 |
+
return att_2d_masks & pad_2d_masks
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def pad_vector(vector, new_dim):
|
| 134 |
+
"""Pad the last dimension of a vector to new_dim with zeros.
|
| 135 |
+
|
| 136 |
+
Can be (batch_size x sequence_length x features_dimension)
|
| 137 |
+
or (batch_size x features_dimension)
|
| 138 |
+
"""
|
| 139 |
+
if vector.shape[-1] >= new_dim:
|
| 140 |
+
return vector
|
| 141 |
+
return F.pad(vector, (0, new_dim - vector.shape[-1]))
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
| 145 |
+
images: torch.Tensor,
|
| 146 |
+
height: int,
|
| 147 |
+
width: int,
|
| 148 |
+
mode: str = "bilinear",
|
| 149 |
+
) -> torch.Tensor:
|
| 150 |
+
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
|
| 151 |
+
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
|
| 155 |
+
height: Target height
|
| 156 |
+
width: Target width
|
| 157 |
+
mode: Interpolation mode ('bilinear', 'nearest', etc.)
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
Resized and padded tensor with same shape format as input
|
| 161 |
+
"""
|
| 162 |
+
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
|
| 163 |
+
if images.shape[-1] <= 4: # Assume channels-last format
|
| 164 |
+
channels_last = True
|
| 165 |
+
if images.dim() == 3:
|
| 166 |
+
images = images.unsqueeze(0) # Add batch dimension
|
| 167 |
+
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
|
| 168 |
+
else:
|
| 169 |
+
channels_last = False
|
| 170 |
+
if images.dim() == 3:
|
| 171 |
+
images = images.unsqueeze(0) # Add batch dimension
|
| 172 |
+
|
| 173 |
+
batch_size, channels, cur_height, cur_width = images.shape
|
| 174 |
+
|
| 175 |
+
# Calculate resize ratio
|
| 176 |
+
ratio = max(cur_width / width, cur_height / height)
|
| 177 |
+
resized_height = int(cur_height / ratio)
|
| 178 |
+
resized_width = int(cur_width / ratio)
|
| 179 |
+
|
| 180 |
+
# Resize
|
| 181 |
+
resized_images = F.interpolate(
|
| 182 |
+
images,
|
| 183 |
+
size=(resized_height, resized_width),
|
| 184 |
+
mode=mode,
|
| 185 |
+
align_corners=False if mode == "bilinear" else None,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Handle dtype-specific clipping
|
| 189 |
+
if images.dtype == torch.uint8:
|
| 190 |
+
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
| 191 |
+
elif images.dtype == torch.float32:
|
| 192 |
+
resized_images = resized_images.clamp(-1.0, 1.0)
|
| 193 |
+
else:
|
| 194 |
+
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
| 195 |
+
|
| 196 |
+
# Calculate padding
|
| 197 |
+
pad_h0, remainder_h = divmod(height - resized_height, 2)
|
| 198 |
+
pad_h1 = pad_h0 + remainder_h
|
| 199 |
+
pad_w0, remainder_w = divmod(width - resized_width, 2)
|
| 200 |
+
pad_w1 = pad_w0 + remainder_w
|
| 201 |
+
|
| 202 |
+
# Pad
|
| 203 |
+
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
| 204 |
+
padded_images = F.pad(
|
| 205 |
+
resized_images,
|
| 206 |
+
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
| 207 |
+
mode="constant",
|
| 208 |
+
value=constant_value,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Convert back to original format if needed
|
| 212 |
+
if channels_last:
|
| 213 |
+
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
| 214 |
+
|
| 215 |
+
return padded_images
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# Define the complete layer computation function for gradient checkpointing
|
| 219 |
+
def compute_layer_complete(
|
| 220 |
+
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
| 221 |
+
):
|
| 222 |
+
models = [paligemma.language_model, gemma_expert.model]
|
| 223 |
+
query_states = []
|
| 224 |
+
key_states = []
|
| 225 |
+
value_states = []
|
| 226 |
+
gates = []
|
| 227 |
+
for i, hidden_states in enumerate(inputs_embeds):
|
| 228 |
+
layer = models[i].layers[layer_idx]
|
| 229 |
+
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
| 230 |
+
gates.append(gate)
|
| 231 |
+
input_shape = hidden_states.shape[:-1]
|
| 232 |
+
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
| 233 |
+
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 234 |
+
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 235 |
+
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 236 |
+
query_states.append(query_state)
|
| 237 |
+
key_states.append(key_state)
|
| 238 |
+
value_states.append(value_state)
|
| 239 |
+
# Concatenate and process attention
|
| 240 |
+
query_states = torch.cat(query_states, dim=2)
|
| 241 |
+
key_states = torch.cat(key_states, dim=2)
|
| 242 |
+
value_states = torch.cat(value_states, dim=2)
|
| 243 |
+
dummy_tensor = torch.zeros(
|
| 244 |
+
query_states.shape[0],
|
| 245 |
+
query_states.shape[2],
|
| 246 |
+
query_states.shape[-1],
|
| 247 |
+
device=query_states.device,
|
| 248 |
+
dtype=query_states.dtype,
|
| 249 |
+
)
|
| 250 |
+
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
| 251 |
+
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
| 252 |
+
query_states, key_states, cos, sin, unsqueeze_dim=1
|
| 253 |
+
)
|
| 254 |
+
batch_size = query_states.shape[0]
|
| 255 |
+
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
| 256 |
+
# Attention computation
|
| 257 |
+
att_output, _ = modeling_gemma.eager_attention_forward(
|
| 258 |
+
paligemma.language_model.layers[layer_idx].self_attn,
|
| 259 |
+
query_states,
|
| 260 |
+
key_states,
|
| 261 |
+
value_states,
|
| 262 |
+
attention_mask,
|
| 263 |
+
scaling,
|
| 264 |
+
)
|
| 265 |
+
# Get head_dim from the current layer, not from the model
|
| 266 |
+
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
| 267 |
+
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
| 268 |
+
# Process layer outputs
|
| 269 |
+
outputs_embeds = []
|
| 270 |
+
start_pos = 0
|
| 271 |
+
for i, hidden_states in enumerate(inputs_embeds):
|
| 272 |
+
layer = models[i].layers[layer_idx]
|
| 273 |
+
end_pos = start_pos + hidden_states.shape[1]
|
| 274 |
+
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
| 275 |
+
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
| 276 |
+
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
| 277 |
+
# first residual
|
| 278 |
+
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
| 279 |
+
after_first_residual = out_emb.clone()
|
| 280 |
+
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
| 281 |
+
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
| 282 |
+
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
| 283 |
+
out_emb = out_emb.to(dtype=torch.bfloat16)
|
| 284 |
+
out_emb = layer.mlp(out_emb)
|
| 285 |
+
# second residual
|
| 286 |
+
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
| 287 |
+
outputs_embeds.append(out_emb)
|
| 288 |
+
start_pos = end_pos
|
| 289 |
+
return outputs_embeds
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class GemmaConfig: # see openpi `gemma.py: Config`
|
| 293 |
+
"""Configuration for Gemma model variants."""
|
| 294 |
+
|
| 295 |
+
def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim):
|
| 296 |
+
self.width = width
|
| 297 |
+
self.depth = depth
|
| 298 |
+
self.mlp_dim = mlp_dim
|
| 299 |
+
self.num_heads = num_heads
|
| 300 |
+
self.num_kv_heads = num_kv_heads
|
| 301 |
+
self.head_dim = head_dim
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config`
|
| 305 |
+
"""Returns config for specified gemma variant."""
|
| 306 |
+
if variant == "gemma_300m":
|
| 307 |
+
return GemmaConfig(
|
| 308 |
+
width=1024,
|
| 309 |
+
depth=18,
|
| 310 |
+
mlp_dim=4096,
|
| 311 |
+
num_heads=8,
|
| 312 |
+
num_kv_heads=1,
|
| 313 |
+
head_dim=256,
|
| 314 |
+
)
|
| 315 |
+
elif variant == "gemma_2b":
|
| 316 |
+
return GemmaConfig(
|
| 317 |
+
width=2048,
|
| 318 |
+
depth=18,
|
| 319 |
+
mlp_dim=16_384,
|
| 320 |
+
num_heads=8,
|
| 321 |
+
num_kv_heads=1,
|
| 322 |
+
head_dim=256,
|
| 323 |
+
)
|
| 324 |
+
else:
|
| 325 |
+
raise ValueError(f"Unknown variant: {variant}")
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class PaliGemmaWithExpertModel(
|
| 329 |
+
nn.Module
|
| 330 |
+
): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi
|
| 331 |
+
"""PaliGemma model with action expert for PI05."""
|
| 332 |
+
|
| 333 |
+
def __init__(
|
| 334 |
+
self,
|
| 335 |
+
vlm_config,
|
| 336 |
+
action_expert_config,
|
| 337 |
+
use_adarms=None,
|
| 338 |
+
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
| 339 |
+
):
|
| 340 |
+
if use_adarms is None:
|
| 341 |
+
use_adarms = [False, False]
|
| 342 |
+
super().__init__()
|
| 343 |
+
|
| 344 |
+
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
| 345 |
+
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
| 346 |
+
vlm_config_hf.image_token_index = 257152
|
| 347 |
+
vlm_config_hf.text_config.hidden_size = vlm_config.width
|
| 348 |
+
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
|
| 349 |
+
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
|
| 350 |
+
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
|
| 351 |
+
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
| 352 |
+
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
| 353 |
+
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
| 354 |
+
vlm_config_hf.text_config.torch_dtype = "float32"
|
| 355 |
+
vlm_config_hf.text_config.vocab_size = 257152
|
| 356 |
+
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
| 357 |
+
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
| 358 |
+
vlm_config_hf.vision_config.intermediate_size = 4304
|
| 359 |
+
vlm_config_hf.vision_config.projection_dim = 2048
|
| 360 |
+
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
| 361 |
+
vlm_config_hf.vision_config.torch_dtype = "float32"
|
| 362 |
+
|
| 363 |
+
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
| 364 |
+
head_dim=action_expert_config.head_dim,
|
| 365 |
+
hidden_size=action_expert_config.width,
|
| 366 |
+
intermediate_size=action_expert_config.mlp_dim,
|
| 367 |
+
num_attention_heads=action_expert_config.num_heads,
|
| 368 |
+
num_hidden_layers=action_expert_config.depth,
|
| 369 |
+
num_key_value_heads=action_expert_config.num_kv_heads,
|
| 370 |
+
vocab_size=257152,
|
| 371 |
+
hidden_activation="gelu_pytorch_tanh",
|
| 372 |
+
torch_dtype="float32",
|
| 373 |
+
use_adarms=use_adarms[1],
|
| 374 |
+
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
| 378 |
+
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
| 379 |
+
self.gemma_expert.model.embed_tokens = None
|
| 380 |
+
|
| 381 |
+
self.to_bfloat16_for_selected_params(precision)
|
| 382 |
+
|
| 383 |
+
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
| 384 |
+
if precision == "bfloat16":
|
| 385 |
+
self.to(dtype=torch.bfloat16)
|
| 386 |
+
elif precision == "float32":
|
| 387 |
+
self.to(dtype=torch.float32)
|
| 388 |
+
return
|
| 389 |
+
else:
|
| 390 |
+
raise ValueError(f"Invalid precision: {precision}")
|
| 391 |
+
|
| 392 |
+
params_to_keep_float32 = [
|
| 393 |
+
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
| 394 |
+
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
| 395 |
+
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
| 396 |
+
"input_layernorm",
|
| 397 |
+
"post_attention_layernorm",
|
| 398 |
+
"model.norm",
|
| 399 |
+
]
|
| 400 |
+
|
| 401 |
+
for name, param in self.named_parameters():
|
| 402 |
+
if any(selector in name for selector in params_to_keep_float32):
|
| 403 |
+
param.data = param.data.to(dtype=torch.float32)
|
| 404 |
+
|
| 405 |
+
def embed_image(self, image: torch.Tensor):
|
| 406 |
+
return self.paligemma.model.get_image_features(image)
|
| 407 |
+
|
| 408 |
+
def embed_language_tokens(self, tokens: torch.Tensor):
|
| 409 |
+
return self.paligemma.language_model.embed_tokens(tokens)
|
| 410 |
+
|
| 411 |
+
def forward(
|
| 412 |
+
self,
|
| 413 |
+
attention_mask: torch.Tensor | None = None,
|
| 414 |
+
position_ids: torch.LongTensor | None = None,
|
| 415 |
+
past_key_values: list[torch.FloatTensor] | None = None,
|
| 416 |
+
inputs_embeds: list[torch.FloatTensor] | None = None,
|
| 417 |
+
use_cache: bool | None = None,
|
| 418 |
+
adarms_cond: list[torch.Tensor] | None = None,
|
| 419 |
+
):
|
| 420 |
+
if adarms_cond is None:
|
| 421 |
+
adarms_cond = [None, None]
|
| 422 |
+
if inputs_embeds[1] is None:
|
| 423 |
+
prefix_output = self.paligemma.language_model.forward(
|
| 424 |
+
inputs_embeds=inputs_embeds[0],
|
| 425 |
+
attention_mask=attention_mask,
|
| 426 |
+
position_ids=position_ids,
|
| 427 |
+
past_key_values=past_key_values,
|
| 428 |
+
use_cache=use_cache,
|
| 429 |
+
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
| 430 |
+
)
|
| 431 |
+
prefix_past_key_values = prefix_output.past_key_values
|
| 432 |
+
prefix_output = prefix_output.last_hidden_state
|
| 433 |
+
suffix_output = None
|
| 434 |
+
elif inputs_embeds[0] is None:
|
| 435 |
+
suffix_output = self.gemma_expert.model.forward(
|
| 436 |
+
inputs_embeds=inputs_embeds[1],
|
| 437 |
+
attention_mask=attention_mask,
|
| 438 |
+
position_ids=position_ids,
|
| 439 |
+
past_key_values=past_key_values,
|
| 440 |
+
use_cache=use_cache,
|
| 441 |
+
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
|
| 442 |
+
)
|
| 443 |
+
suffix_output = suffix_output.last_hidden_state
|
| 444 |
+
prefix_output = None
|
| 445 |
+
prefix_past_key_values = None
|
| 446 |
+
else:
|
| 447 |
+
models = [self.paligemma.language_model, self.gemma_expert.model]
|
| 448 |
+
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
| 449 |
+
|
| 450 |
+
# Check if gradient checkpointing is enabled for any of the models
|
| 451 |
+
use_gradient_checkpointing = (
|
| 452 |
+
hasattr(self.gemma_expert.model, "gradient_checkpointing")
|
| 453 |
+
and self.gemma_expert.model.gradient_checkpointing
|
| 454 |
+
and self.training
|
| 455 |
+
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
| 456 |
+
|
| 457 |
+
# Process all layers with gradient checkpointing if enabled
|
| 458 |
+
for layer_idx in range(num_layers):
|
| 459 |
+
if use_gradient_checkpointing:
|
| 460 |
+
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
| 461 |
+
compute_layer_complete,
|
| 462 |
+
layer_idx,
|
| 463 |
+
inputs_embeds,
|
| 464 |
+
attention_mask,
|
| 465 |
+
position_ids,
|
| 466 |
+
adarms_cond,
|
| 467 |
+
use_reentrant=False,
|
| 468 |
+
preserve_rng_state=False,
|
| 469 |
+
paligemma=self.paligemma,
|
| 470 |
+
gemma_expert=self.gemma_expert,
|
| 471 |
+
)
|
| 472 |
+
else:
|
| 473 |
+
inputs_embeds = compute_layer_complete(
|
| 474 |
+
layer_idx,
|
| 475 |
+
inputs_embeds,
|
| 476 |
+
attention_mask,
|
| 477 |
+
position_ids,
|
| 478 |
+
adarms_cond,
|
| 479 |
+
paligemma=self.paligemma,
|
| 480 |
+
gemma_expert=self.gemma_expert,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# final norm
|
| 484 |
+
def compute_final_norms(inputs_embeds, adarms_cond):
|
| 485 |
+
outputs_embeds = []
|
| 486 |
+
for i, hidden_states in enumerate(inputs_embeds):
|
| 487 |
+
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
| 488 |
+
outputs_embeds.append(out_emb)
|
| 489 |
+
return outputs_embeds
|
| 490 |
+
|
| 491 |
+
# Apply gradient checkpointing to final norm if enabled
|
| 492 |
+
if use_gradient_checkpointing:
|
| 493 |
+
outputs_embeds = torch.utils.checkpoint.checkpoint(
|
| 494 |
+
compute_final_norms,
|
| 495 |
+
inputs_embeds,
|
| 496 |
+
adarms_cond,
|
| 497 |
+
use_reentrant=False,
|
| 498 |
+
preserve_rng_state=False,
|
| 499 |
+
)
|
| 500 |
+
else:
|
| 501 |
+
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
|
| 502 |
+
|
| 503 |
+
prefix_output = outputs_embeds[0]
|
| 504 |
+
suffix_output = outputs_embeds[1]
|
| 505 |
+
prefix_past_key_values = None
|
| 506 |
+
|
| 507 |
+
return [prefix_output, suffix_output], prefix_past_key_values
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
| 511 |
+
"""Core PI05 PyTorch model."""
|
| 512 |
+
|
| 513 |
+
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
|
| 514 |
+
super().__init__()
|
| 515 |
+
self.config = config
|
| 516 |
+
self.rtc_processor = rtc_processor
|
| 517 |
+
|
| 518 |
+
paligemma_config = get_gemma_config(config.paligemma_variant)
|
| 519 |
+
action_expert_config = get_gemma_config(config.action_expert_variant)
|
| 520 |
+
|
| 521 |
+
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
| 522 |
+
paligemma_config,
|
| 523 |
+
action_expert_config,
|
| 524 |
+
use_adarms=[False, True],
|
| 525 |
+
precision=config.dtype,
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
| 529 |
+
self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim)
|
| 530 |
+
|
| 531 |
+
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
| 532 |
+
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
| 533 |
+
|
| 534 |
+
# Initialize gradient checkpointing flag
|
| 535 |
+
self.gradient_checkpointing_enabled = False
|
| 536 |
+
|
| 537 |
+
# Compile model if requested
|
| 538 |
+
if config.compile_model:
|
| 539 |
+
torch.set_float32_matmul_precision("high")
|
| 540 |
+
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
| 541 |
+
|
| 542 |
+
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
| 543 |
+
|
| 544 |
+
# PATCH: make transformers version guard non-fatal and robust across versions
|
| 545 |
+
try:
|
| 546 |
+
from transformers.models.siglip import check
|
| 547 |
+
|
| 548 |
+
if hasattr(check, "check_whether_transformers_replace_is_installed_correctly"):
|
| 549 |
+
ok = check.check_whether_transformers_replace_is_installed_correctly()
|
| 550 |
+
if not ok:
|
| 551 |
+
logging.warning("[pi05] %s", msg)
|
| 552 |
+
else:
|
| 553 |
+
logging.warning(
|
| 554 |
+
"[patch_pi05] SigLIP check helper missing; skipping strict transformers version guard."
|
| 555 |
+
)
|
| 556 |
+
except Exception as e: # noqa: BLE001
|
| 557 |
+
logging.warning(
|
| 558 |
+
"[patch_pi05] Could not run transformers version guard (%s). "
|
| 559 |
+
"Continuing without strict transformers check. %s",
|
| 560 |
+
msg,
|
| 561 |
+
e,
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
def gradient_checkpointing_enable(self):
|
| 565 |
+
"""Enable gradient checkpointing for memory optimization."""
|
| 566 |
+
self.gradient_checkpointing_enabled = True
|
| 567 |
+
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
| 568 |
+
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
| 569 |
+
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
| 570 |
+
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
|
| 571 |
+
|
| 572 |
+
def gradient_checkpointing_disable(self):
|
| 573 |
+
"""Disable gradient checkpointing."""
|
| 574 |
+
self.gradient_checkpointing_enabled = False
|
| 575 |
+
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
| 576 |
+
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
| 577 |
+
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
| 578 |
+
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
| 579 |
+
|
| 580 |
+
def _rtc_enabled(self):
|
| 581 |
+
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
| 582 |
+
|
| 583 |
+
def _apply_checkpoint(self, func, *args, **kwargs):
|
| 584 |
+
"""Helper method to apply gradient checkpointing if enabled."""
|
| 585 |
+
if self.gradient_checkpointing_enabled and self.training:
|
| 586 |
+
return torch.utils.checkpoint.checkpoint(
|
| 587 |
+
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
| 588 |
+
)
|
| 589 |
+
return func(*args, **kwargs)
|
| 590 |
+
|
| 591 |
+
def _prepare_attention_masks_4d(self, att_2d_masks):
|
| 592 |
+
"""Helper method to prepare 4D attention masks for transformer."""
|
| 593 |
+
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
| 594 |
+
return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
| 595 |
+
|
| 596 |
+
def sample_noise(self, shape, device):
|
| 597 |
+
return torch.normal(
|
| 598 |
+
mean=0.0,
|
| 599 |
+
std=1.0,
|
| 600 |
+
size=shape,
|
| 601 |
+
dtype=torch.float32,
|
| 602 |
+
device=device,
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
def sample_time(self, bsize, device):
|
| 606 |
+
time_beta = sample_beta(
|
| 607 |
+
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
|
| 608 |
+
)
|
| 609 |
+
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
| 610 |
+
return time.to(dtype=torch.float32, device=device)
|
| 611 |
+
|
| 612 |
+
def embed_prefix(
|
| 613 |
+
self, images, img_masks, tokens, masks
|
| 614 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 615 |
+
"""Embed images with SigLIP and language tokens with embedding layer."""
|
| 616 |
+
embs = []
|
| 617 |
+
pad_masks = []
|
| 618 |
+
att_masks = []
|
| 619 |
+
|
| 620 |
+
# Process images
|
| 621 |
+
for img, img_mask in zip(images, img_masks, strict=True):
|
| 622 |
+
|
| 623 |
+
def image_embed_func(img):
|
| 624 |
+
return self.paligemma_with_expert.embed_image(img)
|
| 625 |
+
|
| 626 |
+
img_emb = self._apply_checkpoint(image_embed_func, img)
|
| 627 |
+
bsize, num_img_embs = img_emb.shape[:2]
|
| 628 |
+
|
| 629 |
+
embs.append(img_emb)
|
| 630 |
+
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
| 631 |
+
att_masks += [0] * num_img_embs
|
| 632 |
+
|
| 633 |
+
# Process language tokens
|
| 634 |
+
def lang_embed_func(tokens):
|
| 635 |
+
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
| 636 |
+
lang_emb_dim = lang_emb.shape[-1]
|
| 637 |
+
return lang_emb * math.sqrt(lang_emb_dim)
|
| 638 |
+
|
| 639 |
+
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
| 640 |
+
embs.append(lang_emb)
|
| 641 |
+
pad_masks.append(masks)
|
| 642 |
+
|
| 643 |
+
num_lang_embs = lang_emb.shape[1]
|
| 644 |
+
att_masks += [0] * num_lang_embs
|
| 645 |
+
|
| 646 |
+
embs = torch.cat(embs, dim=1)
|
| 647 |
+
pad_masks = torch.cat(pad_masks, dim=1)
|
| 648 |
+
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
| 649 |
+
|
| 650 |
+
bsize = pad_masks.shape[0]
|
| 651 |
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
| 652 |
+
|
| 653 |
+
return embs, pad_masks, att_masks
|
| 654 |
+
|
| 655 |
+
def embed_suffix(self, noisy_actions, timestep):
|
| 656 |
+
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
|
| 657 |
+
embs = []
|
| 658 |
+
pad_masks = []
|
| 659 |
+
att_masks = []
|
| 660 |
+
|
| 661 |
+
# Embed timestep using sine-cosine positional encoding
|
| 662 |
+
time_emb = create_sinusoidal_pos_embedding(
|
| 663 |
+
timestep,
|
| 664 |
+
self.action_in_proj.out_features,
|
| 665 |
+
min_period=self.config.min_period,
|
| 666 |
+
max_period=self.config.max_period,
|
| 667 |
+
device=timestep.device,
|
| 668 |
+
)
|
| 669 |
+
time_emb = time_emb.type(dtype=timestep.dtype)
|
| 670 |
+
|
| 671 |
+
# Fuse timestep + action information using an MLP
|
| 672 |
+
def action_proj_func(noisy_actions):
|
| 673 |
+
return self.action_in_proj(noisy_actions)
|
| 674 |
+
|
| 675 |
+
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
| 676 |
+
|
| 677 |
+
def time_mlp_func(time_emb):
|
| 678 |
+
x = self.time_mlp_in(time_emb)
|
| 679 |
+
x = F.silu(x)
|
| 680 |
+
x = self.time_mlp_out(x)
|
| 681 |
+
return F.silu(x)
|
| 682 |
+
|
| 683 |
+
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
|
| 684 |
+
action_time_emb = action_emb
|
| 685 |
+
adarms_cond = time_emb
|
| 686 |
+
|
| 687 |
+
embs.append(action_time_emb)
|
| 688 |
+
bsize, action_time_dim = action_time_emb.shape[:2]
|
| 689 |
+
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
| 690 |
+
pad_masks.append(action_time_mask)
|
| 691 |
+
|
| 692 |
+
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
| 693 |
+
att_masks += [1] + ([0] * (self.config.chunk_size - 1))
|
| 694 |
+
|
| 695 |
+
embs = torch.cat(embs, dim=1)
|
| 696 |
+
pad_masks = torch.cat(pad_masks, dim=1)
|
| 697 |
+
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
| 698 |
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
| 699 |
+
|
| 700 |
+
return embs, pad_masks, att_masks, adarms_cond
|
| 701 |
+
|
| 702 |
+
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
| 703 |
+
"""Do a full training forward pass and compute the loss."""
|
| 704 |
+
if noise is None:
|
| 705 |
+
noise = self.sample_noise(actions.shape, actions.device)
|
| 706 |
+
|
| 707 |
+
if time is None:
|
| 708 |
+
time = self.sample_time(actions.shape[0], actions.device)
|
| 709 |
+
|
| 710 |
+
time_expanded = time[:, None, None]
|
| 711 |
+
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
| 712 |
+
u_t = noise - actions
|
| 713 |
+
|
| 714 |
+
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
| 715 |
+
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
| 716 |
+
|
| 717 |
+
if (
|
| 718 |
+
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
| 719 |
+
== torch.bfloat16
|
| 720 |
+
):
|
| 721 |
+
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
| 722 |
+
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
| 723 |
+
|
| 724 |
+
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
| 725 |
+
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
| 726 |
+
|
| 727 |
+
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
| 728 |
+
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
| 729 |
+
|
| 730 |
+
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
| 731 |
+
|
| 732 |
+
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
| 733 |
+
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
| 734 |
+
attention_mask=att_2d_masks_4d,
|
| 735 |
+
position_ids=position_ids,
|
| 736 |
+
past_key_values=None,
|
| 737 |
+
inputs_embeds=[prefix_embs, suffix_embs],
|
| 738 |
+
use_cache=False,
|
| 739 |
+
adarms_cond=[None, adarms_cond],
|
| 740 |
+
)
|
| 741 |
+
return suffix_out
|
| 742 |
+
|
| 743 |
+
suffix_out = self._apply_checkpoint(
|
| 744 |
+
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
| 748 |
+
suffix_out = suffix_out.to(dtype=torch.float32)
|
| 749 |
+
|
| 750 |
+
def action_out_proj_func(suffix_out):
|
| 751 |
+
return self.action_out_proj(suffix_out)
|
| 752 |
+
|
| 753 |
+
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
| 754 |
+
|
| 755 |
+
return F.mse_loss(u_t, v_t, reduction="none")
|
| 756 |
+
|
| 757 |
+
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
| 758 |
+
def sample_actions(
|
| 759 |
+
self,
|
| 760 |
+
images,
|
| 761 |
+
img_masks,
|
| 762 |
+
tokens,
|
| 763 |
+
masks,
|
| 764 |
+
noise=None,
|
| 765 |
+
num_steps=None,
|
| 766 |
+
**kwargs: Unpack[ActionSelectKwargs],
|
| 767 |
+
) -> Tensor:
|
| 768 |
+
"""Do a full inference forward and compute the action."""
|
| 769 |
+
if num_steps is None:
|
| 770 |
+
num_steps = self.config.num_inference_steps
|
| 771 |
+
|
| 772 |
+
bsize = tokens.shape[0]
|
| 773 |
+
device = tokens.device
|
| 774 |
+
|
| 775 |
+
if noise is None:
|
| 776 |
+
# Sample noise with padded dimension as expected by action_in_proj
|
| 777 |
+
actions_shape = (
|
| 778 |
+
bsize,
|
| 779 |
+
self.config.chunk_size,
|
| 780 |
+
self.config.max_action_dim,
|
| 781 |
+
) # Use config max_action_dim for internal processing
|
| 782 |
+
noise = self.sample_noise(actions_shape, device)
|
| 783 |
+
|
| 784 |
+
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
| 785 |
+
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
| 786 |
+
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
| 787 |
+
|
| 788 |
+
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
| 789 |
+
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
| 790 |
+
|
| 791 |
+
_, past_key_values = self.paligemma_with_expert.forward(
|
| 792 |
+
attention_mask=prefix_att_2d_masks_4d,
|
| 793 |
+
position_ids=prefix_position_ids,
|
| 794 |
+
past_key_values=None,
|
| 795 |
+
inputs_embeds=[prefix_embs, None],
|
| 796 |
+
use_cache=True,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
dt = -1.0 / num_steps
|
| 800 |
+
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
| 801 |
+
|
| 802 |
+
x_t = noise
|
| 803 |
+
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
| 804 |
+
while time >= -dt / 2:
|
| 805 |
+
expanded_time = time.expand(bsize)
|
| 806 |
+
|
| 807 |
+
# Define a closure function to properly capture expanded_time
|
| 808 |
+
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
| 809 |
+
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
| 810 |
+
return self.denoise_step(
|
| 811 |
+
prefix_pad_masks=prefix_pad_masks,
|
| 812 |
+
past_key_values=past_key_values,
|
| 813 |
+
x_t=input_x_t,
|
| 814 |
+
timestep=current_timestep,
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
if self._rtc_enabled():
|
| 818 |
+
inference_delay = kwargs.get("inference_delay")
|
| 819 |
+
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
| 820 |
+
execution_horizon = kwargs.get("execution_horizon")
|
| 821 |
+
|
| 822 |
+
v_t = self.rtc_processor.denoise_step(
|
| 823 |
+
x_t=x_t,
|
| 824 |
+
prev_chunk_left_over=prev_chunk_left_over,
|
| 825 |
+
inference_delay=inference_delay,
|
| 826 |
+
time=time,
|
| 827 |
+
original_denoise_step_partial=denoise_step_partial_call,
|
| 828 |
+
execution_horizon=execution_horizon,
|
| 829 |
+
)
|
| 830 |
+
else:
|
| 831 |
+
v_t = denoise_step_partial_call(x_t)
|
| 832 |
+
|
| 833 |
+
# Euler step
|
| 834 |
+
x_t += dt * v_t
|
| 835 |
+
|
| 836 |
+
# Record x_t and v_t after Euler step
|
| 837 |
+
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
| 838 |
+
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
| 839 |
+
|
| 840 |
+
time += dt
|
| 841 |
+
|
| 842 |
+
return x_t
|
| 843 |
+
|
| 844 |
+
def denoise_step(
|
| 845 |
+
self,
|
| 846 |
+
prefix_pad_masks,
|
| 847 |
+
past_key_values,
|
| 848 |
+
x_t,
|
| 849 |
+
timestep,
|
| 850 |
+
):
|
| 851 |
+
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
| 852 |
+
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
|
| 853 |
+
|
| 854 |
+
suffix_len = suffix_pad_masks.shape[1]
|
| 855 |
+
batch_size = prefix_pad_masks.shape[0]
|
| 856 |
+
prefix_len = prefix_pad_masks.shape[1]
|
| 857 |
+
|
| 858 |
+
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
| 859 |
+
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
| 860 |
+
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
| 861 |
+
|
| 862 |
+
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
| 863 |
+
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
| 864 |
+
|
| 865 |
+
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
| 866 |
+
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
| 867 |
+
|
| 868 |
+
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
| 869 |
+
attention_mask=full_att_2d_masks_4d,
|
| 870 |
+
position_ids=position_ids,
|
| 871 |
+
past_key_values=past_key_values,
|
| 872 |
+
inputs_embeds=[None, suffix_embs],
|
| 873 |
+
use_cache=False,
|
| 874 |
+
adarms_cond=[None, adarms_cond],
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
suffix_out = outputs_embeds[1]
|
| 878 |
+
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
| 879 |
+
suffix_out = suffix_out.to(dtype=torch.float32)
|
| 880 |
+
return self.action_out_proj(suffix_out)
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
class PI05Policy(PreTrainedPolicy):
|
| 884 |
+
"""PI05 Policy for LeRobot."""
|
| 885 |
+
|
| 886 |
+
config_class = PI05Config
|
| 887 |
+
name = "pi05"
|
| 888 |
+
|
| 889 |
+
def __init__(
|
| 890 |
+
self,
|
| 891 |
+
config: PI05Config,
|
| 892 |
+
):
|
| 893 |
+
"""
|
| 894 |
+
Args:
|
| 895 |
+
config: Policy configuration class instance.
|
| 896 |
+
"""
|
| 897 |
+
super().__init__(config)
|
| 898 |
+
config.validate_features()
|
| 899 |
+
self.config = config
|
| 900 |
+
|
| 901 |
+
# Initialize the core PI05 model
|
| 902 |
+
self.init_rtc_processor()
|
| 903 |
+
self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor)
|
| 904 |
+
|
| 905 |
+
# Enable gradient checkpointing if requested
|
| 906 |
+
if config.gradient_checkpointing:
|
| 907 |
+
self.model.gradient_checkpointing_enable()
|
| 908 |
+
|
| 909 |
+
self.model.to(config.device)
|
| 910 |
+
|
| 911 |
+
self.reset()
|
| 912 |
+
|
| 913 |
+
@classmethod
|
| 914 |
+
def from_pretrained(
|
| 915 |
+
cls: builtins.type[T],
|
| 916 |
+
pretrained_name_or_path: str | Path,
|
| 917 |
+
*,
|
| 918 |
+
config: PreTrainedConfig | None = None,
|
| 919 |
+
force_download: bool = False,
|
| 920 |
+
resume_download: bool | None = None,
|
| 921 |
+
proxies: dict | None = None,
|
| 922 |
+
token: str | bool | None = None,
|
| 923 |
+
cache_dir: str | Path | None = None,
|
| 924 |
+
local_files_only: bool = False,
|
| 925 |
+
revision: str | None = None,
|
| 926 |
+
strict: bool = True,
|
| 927 |
+
**kwargs,
|
| 928 |
+
) -> T:
|
| 929 |
+
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
|
| 930 |
+
print(
|
| 931 |
+
"The PI05 model is a direct port of the OpenPI implementation. \n"
|
| 932 |
+
"This implementation follows the original OpenPI structure for compatibility. \n"
|
| 933 |
+
"Original implementation: https://github.com/Physical-Intelligence/openpi"
|
| 934 |
+
)
|
| 935 |
+
if pretrained_name_or_path is None:
|
| 936 |
+
raise ValueError("pretrained_name_or_path is required")
|
| 937 |
+
|
| 938 |
+
# Use provided config if available, otherwise create default config
|
| 939 |
+
if config is None:
|
| 940 |
+
config = PreTrainedConfig.from_pretrained(
|
| 941 |
+
pretrained_name_or_path=pretrained_name_or_path,
|
| 942 |
+
force_download=force_download,
|
| 943 |
+
resume_download=resume_download,
|
| 944 |
+
proxies=proxies,
|
| 945 |
+
token=token,
|
| 946 |
+
cache_dir=cache_dir,
|
| 947 |
+
local_files_only=local_files_only,
|
| 948 |
+
revision=revision,
|
| 949 |
+
**kwargs,
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
# Initialize model without loading weights
|
| 953 |
+
# Check if dataset_stats were provided in kwargs
|
| 954 |
+
model = cls(config, **kwargs)
|
| 955 |
+
|
| 956 |
+
# Now manually load and remap the state dict
|
| 957 |
+
try:
|
| 958 |
+
# Try to load the pytorch_model.bin or model.safetensors file
|
| 959 |
+
print(f"Loading model from: {pretrained_name_or_path}")
|
| 960 |
+
try:
|
| 961 |
+
from transformers.utils import cached_file
|
| 962 |
+
|
| 963 |
+
# Try safetensors first
|
| 964 |
+
resolved_file = cached_file(
|
| 965 |
+
pretrained_name_or_path,
|
| 966 |
+
"model.safetensors",
|
| 967 |
+
cache_dir=kwargs.get("cache_dir"),
|
| 968 |
+
force_download=kwargs.get("force_download", False),
|
| 969 |
+
resume_download=kwargs.get("resume_download"),
|
| 970 |
+
proxies=kwargs.get("proxies"),
|
| 971 |
+
use_auth_token=kwargs.get("use_auth_token"),
|
| 972 |
+
revision=kwargs.get("revision"),
|
| 973 |
+
local_files_only=kwargs.get("local_files_only", False),
|
| 974 |
+
)
|
| 975 |
+
from safetensors.torch import load_file
|
| 976 |
+
|
| 977 |
+
original_state_dict = load_file(resolved_file)
|
| 978 |
+
print("✓ Loaded state dict from model.safetensors")
|
| 979 |
+
except Exception as e: # noqa: BLE001
|
| 980 |
+
print(f"Could not load state dict from remote files: {e}")
|
| 981 |
+
print("Returning model without loading pretrained weights")
|
| 982 |
+
return model
|
| 983 |
+
|
| 984 |
+
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
| 985 |
+
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
| 986 |
+
|
| 987 |
+
# Then add "model." prefix for all keys that don't already have it
|
| 988 |
+
remapped_state_dict = {}
|
| 989 |
+
remap_count = 0
|
| 990 |
+
|
| 991 |
+
for key, value in fixed_state_dict.items():
|
| 992 |
+
if not key.startswith("model."):
|
| 993 |
+
new_key = f"model.{key}"
|
| 994 |
+
remapped_state_dict[new_key] = value
|
| 995 |
+
remap_count += 1
|
| 996 |
+
if remap_count <= 10: # Only print first 10 to avoid spam
|
| 997 |
+
print(f"Remapped: {key} -> {new_key}")
|
| 998 |
+
else:
|
| 999 |
+
remapped_state_dict[key] = value
|
| 1000 |
+
|
| 1001 |
+
if remap_count > 0:
|
| 1002 |
+
print(f"Remapped {remap_count} state dict keys")
|
| 1003 |
+
|
| 1004 |
+
# Load the remapped state dict into the model
|
| 1005 |
+
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
| 1006 |
+
|
| 1007 |
+
# --- PATCH: tie embed_tokens to lm_head if ckpt omitted embed_tokens ---
|
| 1008 |
+
if any("embed_tokens.weight" in k for k in missing_keys):
|
| 1009 |
+
try:
|
| 1010 |
+
with torch.no_grad():
|
| 1011 |
+
embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens
|
| 1012 |
+
lm_head = model.model.paligemma_with_expert.paligemma.lm_head
|
| 1013 |
+
if embed is not None and lm_head is not None:
|
| 1014 |
+
embed.weight = lm_head.weight
|
| 1015 |
+
except Exception as _e: # noqa: BLE001
|
| 1016 |
+
print("[patch_pi05] Could not tie embed_tokens to lm_head:", _e)
|
| 1017 |
+
|
| 1018 |
+
# --- FIX: tie embed_tokens to lm_head if embed_tokens missing in ckpt ---
|
| 1019 |
+
if any("embed_tokens.weight" in k for k in missing_keys):
|
| 1020 |
+
with torch.no_grad():
|
| 1021 |
+
embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens
|
| 1022 |
+
lm_head = model.model.paligemma_with_expert.paligemma.lm_head
|
| 1023 |
+
embed.weight = lm_head.weight
|
| 1024 |
+
|
| 1025 |
+
if missing_keys:
|
| 1026 |
+
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
|
| 1027 |
+
if len(missing_keys) <= 5:
|
| 1028 |
+
for key in missing_keys:
|
| 1029 |
+
print(f" - {key}")
|
| 1030 |
+
else:
|
| 1031 |
+
for key in missing_keys[:5]:
|
| 1032 |
+
print(f" - {key}")
|
| 1033 |
+
print(f" ... and {len(missing_keys) - 5} more")
|
| 1034 |
+
|
| 1035 |
+
if unexpected_keys:
|
| 1036 |
+
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
|
| 1037 |
+
if len(unexpected_keys) <= 5:
|
| 1038 |
+
for key in unexpected_keys:
|
| 1039 |
+
print(f" - {key}")
|
| 1040 |
+
else:
|
| 1041 |
+
for key in unexpected_keys[:5]:
|
| 1042 |
+
print(f" - {key}")
|
| 1043 |
+
print(f" ... and {len(unexpected_keys) - 5} more")
|
| 1044 |
+
|
| 1045 |
+
if not missing_keys and not unexpected_keys:
|
| 1046 |
+
print("All keys loaded successfully!")
|
| 1047 |
+
|
| 1048 |
+
except Exception as e: # noqa: BLE001
|
| 1049 |
+
print(f"Warning: Could not remap state dict keys: {e}")
|
| 1050 |
+
|
| 1051 |
+
return model
|
| 1052 |
+
|
| 1053 |
+
def _fix_pytorch_state_dict_keys(
|
| 1054 |
+
self, state_dict, model_config
|
| 1055 |
+
): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys`
|
| 1056 |
+
"""Fix state dict keys to match current model architecture."""
|
| 1057 |
+
import re
|
| 1058 |
+
|
| 1059 |
+
fixed_state_dict = {}
|
| 1060 |
+
|
| 1061 |
+
for key, value in state_dict.items():
|
| 1062 |
+
new_key = key
|
| 1063 |
+
|
| 1064 |
+
# Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias
|
| 1065 |
+
# For gemma expert layers
|
| 1066 |
+
if re.match(
|
| 1067 |
+
r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight",
|
| 1068 |
+
key,
|
| 1069 |
+
):
|
| 1070 |
+
# Check if the model actually has adaRMS enabled for the expert
|
| 1071 |
+
expert_uses_adarms = getattr(
|
| 1072 |
+
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
|
| 1073 |
+
)
|
| 1074 |
+
if expert_uses_adarms:
|
| 1075 |
+
logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}")
|
| 1076 |
+
continue
|
| 1077 |
+
|
| 1078 |
+
if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key):
|
| 1079 |
+
# Check if the model actually has adaRMS enabled for the expert
|
| 1080 |
+
expert_uses_adarms = getattr(
|
| 1081 |
+
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
|
| 1082 |
+
)
|
| 1083 |
+
if expert_uses_adarms:
|
| 1084 |
+
logging.warning(f"Skipping norm key (adaRMS mismatch): {key}")
|
| 1085 |
+
continue
|
| 1086 |
+
|
| 1087 |
+
# Handle MLP naming changes for pi05
|
| 1088 |
+
# pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_*
|
| 1089 |
+
if key.startswith("action_time_mlp_in."):
|
| 1090 |
+
new_key = key.replace("action_time_mlp_in.", "time_mlp_in.")
|
| 1091 |
+
elif key.startswith("action_time_mlp_out."):
|
| 1092 |
+
new_key = key.replace("action_time_mlp_out.", "time_mlp_out.")
|
| 1093 |
+
# Also handle state_proj which shouldn't exist in pi05
|
| 1094 |
+
if key.startswith("state_proj."):
|
| 1095 |
+
logging.warning(f"Skipping state_proj key in pi05 mode: {key}")
|
| 1096 |
+
continue
|
| 1097 |
+
|
| 1098 |
+
# Handle vision tower embedding layer potential differences
|
| 1099 |
+
if "patch_embedding" in key:
|
| 1100 |
+
# Some checkpoints might have this, but current model expects different structure
|
| 1101 |
+
logging.warning(f"Vision embedding key might need handling: {key}")
|
| 1102 |
+
|
| 1103 |
+
fixed_state_dict[new_key] = value
|
| 1104 |
+
|
| 1105 |
+
return fixed_state_dict
|
| 1106 |
+
|
| 1107 |
+
def get_optim_params(self) -> dict:
|
| 1108 |
+
return self.parameters()
|
| 1109 |
+
|
| 1110 |
+
def reset(self):
|
| 1111 |
+
"""Reset internal state - called when environment resets."""
|
| 1112 |
+
self._action_queue = deque(maxlen=self.config.n_action_steps)
|
| 1113 |
+
self._queues = {
|
| 1114 |
+
ACTION: deque(maxlen=self.config.n_action_steps),
|
| 1115 |
+
}
|
| 1116 |
+
|
| 1117 |
+
def init_rtc_processor(self):
|
| 1118 |
+
"""Initialize RTC processor if RTC is enabled in config."""
|
| 1119 |
+
self.rtc_processor = None
|
| 1120 |
+
|
| 1121 |
+
# Create processor if config provided
|
| 1122 |
+
# If RTC is not enabled - we can still track the denoising data
|
| 1123 |
+
if self.config.rtc_config is not None:
|
| 1124 |
+
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
| 1125 |
+
|
| 1126 |
+
model_value = getattr(self, "model", None)
|
| 1127 |
+
if model_value is not None:
|
| 1128 |
+
model_value.rtc_processor = self.rtc_processor
|
| 1129 |
+
|
| 1130 |
+
def _rtc_enabled(self) -> bool:
|
| 1131 |
+
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
| 1132 |
+
|
| 1133 |
+
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
| 1134 |
+
"""Preprocess images for the model.
|
| 1135 |
+
|
| 1136 |
+
Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1].
|
| 1137 |
+
PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1].
|
| 1138 |
+
"""
|
| 1139 |
+
images = []
|
| 1140 |
+
img_masks = []
|
| 1141 |
+
|
| 1142 |
+
# Get device from model parameters
|
| 1143 |
+
device = next(self.parameters()).device
|
| 1144 |
+
|
| 1145 |
+
present_img_keys = [key for key in self.config.image_features if key in batch]
|
| 1146 |
+
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
| 1147 |
+
|
| 1148 |
+
if len(present_img_keys) == 0:
|
| 1149 |
+
raise ValueError(
|
| 1150 |
+
f"All image features are missing from the batch. At least one expected. "
|
| 1151 |
+
f"(batch: {batch.keys()}) (image_features: {self.config.image_features})"
|
| 1152 |
+
)
|
| 1153 |
+
|
| 1154 |
+
# Preprocess image features present in the batch
|
| 1155 |
+
for key in present_img_keys:
|
| 1156 |
+
img = batch[key]
|
| 1157 |
+
|
| 1158 |
+
# Ensure tensor is on the same device as the model
|
| 1159 |
+
if img.device != device:
|
| 1160 |
+
img = img.to(device)
|
| 1161 |
+
|
| 1162 |
+
# Ensure float32 dtype for consistency
|
| 1163 |
+
if img.dtype != torch.float32:
|
| 1164 |
+
img = img.to(torch.float32)
|
| 1165 |
+
|
| 1166 |
+
# from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats
|
| 1167 |
+
is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1
|
| 1168 |
+
|
| 1169 |
+
if is_channels_first:
|
| 1170 |
+
# Convert [B, C, H, W] to [B, H, W, C] for processing
|
| 1171 |
+
img = img.permute(0, 2, 3, 1)
|
| 1172 |
+
|
| 1173 |
+
# from openpi preprocess_observation_pytorch: Resize with padding if needed
|
| 1174 |
+
if img.shape[1:3] != self.config.image_resolution:
|
| 1175 |
+
img = resize_with_pad_torch(img, *self.config.image_resolution)
|
| 1176 |
+
|
| 1177 |
+
# Normalize from [0,1] to [-1,1] as expected by siglip
|
| 1178 |
+
img = img * 2.0 - 1.0
|
| 1179 |
+
|
| 1180 |
+
# from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first
|
| 1181 |
+
if is_channels_first:
|
| 1182 |
+
img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
|
| 1183 |
+
|
| 1184 |
+
images.append(img)
|
| 1185 |
+
# Create mask (all ones for real images)
|
| 1186 |
+
bsize = img.shape[0]
|
| 1187 |
+
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
| 1188 |
+
img_masks.append(mask)
|
| 1189 |
+
|
| 1190 |
+
# Create image features not present in the batch as fully 0 padded images
|
| 1191 |
+
for _num_empty_cameras in range(len(missing_img_keys)):
|
| 1192 |
+
img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP
|
| 1193 |
+
mask = torch.zeros_like(mask) # Mask is zero for empty cameras
|
| 1194 |
+
images.append(img)
|
| 1195 |
+
img_masks.append(mask)
|
| 1196 |
+
|
| 1197 |
+
return images, img_masks
|
| 1198 |
+
|
| 1199 |
+
def prepare_action(self, batch):
|
| 1200 |
+
"""Pad action"""
|
| 1201 |
+
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
| 1202 |
+
return actions
|
| 1203 |
+
|
| 1204 |
+
@torch.no_grad()
|
| 1205 |
+
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
| 1206 |
+
"""Select a single action given environment observations."""
|
| 1207 |
+
assert not self._rtc_enabled(), (
|
| 1208 |
+
"RTC is not supported for select_action, use it with predict_action_chunk"
|
| 1209 |
+
)
|
| 1210 |
+
|
| 1211 |
+
self.eval()
|
| 1212 |
+
|
| 1213 |
+
# Action queue logic for n_action_steps > 1
|
| 1214 |
+
if len(self._action_queue) == 0:
|
| 1215 |
+
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
| 1216 |
+
# Transpose to get shape (n_action_steps, batch_size, action_dim)
|
| 1217 |
+
self._action_queue.extend(actions.transpose(0, 1))
|
| 1218 |
+
|
| 1219 |
+
return self._action_queue.popleft()
|
| 1220 |
+
|
| 1221 |
+
@torch.no_grad()
|
| 1222 |
+
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
| 1223 |
+
"""Predict a chunk of actions given environment observations."""
|
| 1224 |
+
self.eval()
|
| 1225 |
+
|
| 1226 |
+
# Prepare inputs
|
| 1227 |
+
images, img_masks = self._preprocess_images(batch)
|
| 1228 |
+
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
| 1229 |
+
|
| 1230 |
+
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
| 1231 |
+
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
|
| 1232 |
+
|
| 1233 |
+
# Unpad actions to actual action dimension
|
| 1234 |
+
original_action_dim = self.config.output_features[ACTION].shape[0]
|
| 1235 |
+
actions = actions[:, :, :original_action_dim]
|
| 1236 |
+
|
| 1237 |
+
return actions
|
| 1238 |
+
|
| 1239 |
+
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
| 1240 |
+
"""Run the batch through the model and compute the loss for training."""
|
| 1241 |
+
|
| 1242 |
+
# Prepare inputs
|
| 1243 |
+
images, img_masks = self._preprocess_images(batch)
|
| 1244 |
+
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
| 1245 |
+
|
| 1246 |
+
actions = self.prepare_action(batch)
|
| 1247 |
+
|
| 1248 |
+
# Compute loss (no separate state needed for PI05)
|
| 1249 |
+
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
| 1250 |
+
|
| 1251 |
+
# Truncate losses to actual action dimensions
|
| 1252 |
+
original_action_dim = self.config.output_features[ACTION].shape[0]
|
| 1253 |
+
losses = losses[:, :, :original_action_dim]
|
| 1254 |
+
|
| 1255 |
+
loss = losses.mean()
|
| 1256 |
+
|
| 1257 |
+
loss_dict = {
|
| 1258 |
+
"loss": loss.item(),
|
| 1259 |
+
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
| 1260 |
+
}
|
| 1261 |
+
|
| 1262 |
+
return loss, loss_dict
|
| 1263 |
+
|
| 1264 |
+
# PATCH: downgrade transformer version guard
|
patch_sigma_env.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
patch_sigma_env.py
|
| 4 |
+
|
| 5 |
+
Idempotent patcher for Sigma VLA experiments.
|
| 6 |
+
|
| 7 |
+
Patch goals:
|
| 8 |
+
1) LeRobot PI05Policy (modeling_pi05.py):
|
| 9 |
+
1.1 If ckpt omits embed_tokens.weight, tie embed_tokens.weight to lm_head.weight
|
| 10 |
+
*after* load_state_dict runs.
|
| 11 |
+
1.2 Ensure torch is imported if target file lacks it.
|
| 12 |
+
1.3 Downgrade the "incorrect transformer version" hard guard
|
| 13 |
+
(ValueError) to a WARNING so new GPU environments don't crash.
|
| 14 |
+
IMPORTANT: preserve indentation and patch only the intended guard.
|
| 15 |
+
|
| 16 |
+
2) LeRobot policies __init__ (lerobot/policies/__init__.py):
|
| 17 |
+
2.1 Make ONLY Groot/Diffusers-related imports optional (wrapped in try/except),
|
| 18 |
+
leaving all other exports untouched.
|
| 19 |
+
This prevents errors like: No module named 'triton.ops'
|
| 20 |
+
or diffusers/peft chain issues on fresh GPUs.
|
| 21 |
+
|
| 22 |
+
3) eval_sigma_vla_rollout.py (your /workspace eval script):
|
| 23 |
+
3.1 Force strict=False for PI05Policy.from_pretrained calls:
|
| 24 |
+
- strict=True -> strict=False
|
| 25 |
+
- if a PI05Policy load call has no strict arg, add strict=False
|
| 26 |
+
3.2 Ensure randomized subset evaluation is possible:
|
| 27 |
+
- add --shuffle arg if missing
|
| 28 |
+
- change DataLoader shuffle=False -> shuffle=getattr(args,"shuffle",False)
|
| 29 |
+
|
| 30 |
+
Safe to run multiple times; no-op if already patched.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
import os
|
| 34 |
+
import re
|
| 35 |
+
import sys
|
| 36 |
+
import pathlib
|
| 37 |
+
from typing import Optional, Tuple, List
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# -------------------------
|
| 41 |
+
# Utilities
|
| 42 |
+
# -------------------------
|
| 43 |
+
|
| 44 |
+
def _read_text(p: pathlib.Path) -> str:
|
| 45 |
+
return p.read_text(encoding="utf-8")
|
| 46 |
+
|
| 47 |
+
def _write_text(p: pathlib.Path, s: str) -> None:
|
| 48 |
+
p.write_text(s, encoding="utf-8")
|
| 49 |
+
|
| 50 |
+
def _search_file(
|
| 51 |
+
roots: List[os.PathLike],
|
| 52 |
+
filename: str,
|
| 53 |
+
must_contain: Optional[str] = None
|
| 54 |
+
) -> Optional[pathlib.Path]:
|
| 55 |
+
for r in roots:
|
| 56 |
+
r = pathlib.Path(r)
|
| 57 |
+
if not r.exists():
|
| 58 |
+
continue
|
| 59 |
+
for p in r.rglob(filename):
|
| 60 |
+
if must_contain and must_contain not in str(p):
|
| 61 |
+
continue
|
| 62 |
+
return p
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
def _default_roots():
|
| 66 |
+
return [
|
| 67 |
+
"/workspace/lerobot/src",
|
| 68 |
+
"/workspace/lerobot",
|
| 69 |
+
pathlib.Path(sys.prefix)
|
| 70 |
+
/ "lib"
|
| 71 |
+
/ f"python{sys.version_info.major}.{sys.version_info.minor}"
|
| 72 |
+
/ "site-packages",
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# -------------------------
|
| 77 |
+
# Patch 1: PI05Policy (LeRobot)
|
| 78 |
+
# -------------------------
|
| 79 |
+
|
| 80 |
+
def find_pi05_file() -> pathlib.Path:
|
| 81 |
+
env = os.getenv("PI05_FILE")
|
| 82 |
+
if env:
|
| 83 |
+
p = pathlib.Path(env)
|
| 84 |
+
if p.exists():
|
| 85 |
+
return p
|
| 86 |
+
|
| 87 |
+
p = _search_file(_default_roots(), "modeling_pi05.py", must_contain="/pi05/")
|
| 88 |
+
if p and p.exists():
|
| 89 |
+
return p
|
| 90 |
+
|
| 91 |
+
raise FileNotFoundError("modeling_pi05.py not found. Set PI05_FILE env var to its path.")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def ensure_torch_import(s: str) -> str:
|
| 95 |
+
if re.search(r"(?m)^\s*import\s+torch\b", s) or re.search(r"(?m)^\s*from\s+torch\b", s):
|
| 96 |
+
return s
|
| 97 |
+
|
| 98 |
+
lines = s.splitlines(True)
|
| 99 |
+
insert_idx = 0
|
| 100 |
+
|
| 101 |
+
if lines and lines[0].startswith("#!"):
|
| 102 |
+
insert_idx = 1
|
| 103 |
+
|
| 104 |
+
# skip module docstring block if present
|
| 105 |
+
if insert_idx < len(lines) and lines[insert_idx].lstrip().startswith('"""'):
|
| 106 |
+
i = insert_idx + 1
|
| 107 |
+
while i < len(lines) and '"""' not in lines[i]:
|
| 108 |
+
i += 1
|
| 109 |
+
if i < len(lines):
|
| 110 |
+
insert_idx = i + 1
|
| 111 |
+
|
| 112 |
+
lines.insert(insert_idx, "import torch # PATCH: required for embed/lm_head tying\n")
|
| 113 |
+
return "".join(lines)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def patch_pi05_embed_tie(p: pathlib.Path) -> Tuple[bool, str]:
|
| 117 |
+
s = _read_text(p)
|
| 118 |
+
s = ensure_torch_import(s)
|
| 119 |
+
|
| 120 |
+
marker = "PATCH: tie embed_tokens to lm_head if ckpt omitted embed_tokens"
|
| 121 |
+
if marker in s:
|
| 122 |
+
_write_text(p, s)
|
| 123 |
+
return False, f"PI05 embed-tie patch already present: {p}"
|
| 124 |
+
|
| 125 |
+
pat = r"(?m)^(\s*)missing_keys,\s*unexpected_keys\s*=\s*model\.load_state_dict\(\s*remapped_state_dict\s*,\s*strict\s*=\s*strict\s*\)\s*$"
|
| 126 |
+
m = re.search(pat, s)
|
| 127 |
+
if not m:
|
| 128 |
+
_write_text(p, s)
|
| 129 |
+
return False, f"Could not find load_state_dict line to patch in PI05 file: {p}"
|
| 130 |
+
|
| 131 |
+
indent = m.group(1)
|
| 132 |
+
inject = (
|
| 133 |
+
f"\n{indent}# --- PATCH: tie embed_tokens to lm_head if ckpt omitted embed_tokens ---\n"
|
| 134 |
+
f"{indent}if any('embed_tokens.weight' in k for k in missing_keys):\n"
|
| 135 |
+
f"{indent} try:\n"
|
| 136 |
+
f"{indent} with torch.no_grad():\n"
|
| 137 |
+
f"{indent} embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens\n"
|
| 138 |
+
f"{indent} lm_head = model.model.paligemma_with_expert.paligemma.lm_head\n"
|
| 139 |
+
f"{indent} if embed is not None and lm_head is not None:\n"
|
| 140 |
+
f"{indent} embed.weight = lm_head.weight # {marker}\n"
|
| 141 |
+
f"{indent} except Exception as _e:\n"
|
| 142 |
+
f"{indent} print('[patch_pi05] Could not tie embed_tokens to lm_head:', _e)\n"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
s2 = re.sub(pat, lambda mm: mm.group(0) + inject, s, count=1)
|
| 146 |
+
_write_text(p, s2)
|
| 147 |
+
return True, f"Patched PI05 embed-tie in: {p}"
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def patch_pi05_transformers_guard(p: pathlib.Path) -> Tuple[bool, str]:
|
| 151 |
+
"""
|
| 152 |
+
Downgrade ONLY the PI05 hard guard:
|
| 153 |
+
ValueError: An incorrect transformer version is used...
|
| 154 |
+
to WARNING print, preserving indentation.
|
| 155 |
+
|
| 156 |
+
Strategy:
|
| 157 |
+
- Find raise ValueError(msg) from None lines.
|
| 158 |
+
- Only patch the one whose nearby context contains
|
| 159 |
+
"incorrect transformer version".
|
| 160 |
+
"""
|
| 161 |
+
s = _read_text(p)
|
| 162 |
+
marker = "PATCH: downgrade transformer version guard"
|
| 163 |
+
if marker in s:
|
| 164 |
+
return False, f"PI05 transformers-guard patch already present: {p}"
|
| 165 |
+
|
| 166 |
+
if "incorrect transformer version" not in s:
|
| 167 |
+
return False, f"No transformers guard message found to patch in: {p}"
|
| 168 |
+
|
| 169 |
+
lines = s.splitlines(True)
|
| 170 |
+
raise_pat = re.compile(r"^(\s*)raise\s+ValueError\(\s*msg\s*\)\s*from\s*None\s*$")
|
| 171 |
+
|
| 172 |
+
target_idx = None
|
| 173 |
+
target_indent = ""
|
| 174 |
+
|
| 175 |
+
for i, line in enumerate(lines):
|
| 176 |
+
m = raise_pat.match(line)
|
| 177 |
+
if not m:
|
| 178 |
+
continue
|
| 179 |
+
# look back a few lines for the specific guard text
|
| 180 |
+
window_start = max(0, i - 8)
|
| 181 |
+
window = "".join(lines[window_start:i+1]).lower()
|
| 182 |
+
if "incorrect transformer version" in window:
|
| 183 |
+
target_idx = i
|
| 184 |
+
target_indent = m.group(1)
|
| 185 |
+
break
|
| 186 |
+
|
| 187 |
+
if target_idx is None:
|
| 188 |
+
return False, f"Guard raise line with context not found in: {p}"
|
| 189 |
+
|
| 190 |
+
repl = (
|
| 191 |
+
f"{target_indent}# --- PATCH: downgrade transformer version guard ---\n"
|
| 192 |
+
f"{target_indent}print('[patch_pi05] WARNING:', msg) # {marker}\n"
|
| 193 |
+
f"{target_indent}# continues execution despite version mismatch\n"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
lines[target_idx] = repl
|
| 197 |
+
s2 = "".join(lines)
|
| 198 |
+
_write_text(p, s2)
|
| 199 |
+
return True, f"Patched PI05 transformers guard (raise->warn) in: {p}"
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# -------------------------
|
| 203 |
+
# Patch 2: LeRobot policies optional imports
|
| 204 |
+
# -------------------------
|
| 205 |
+
|
| 206 |
+
def find_policies_init() -> pathlib.Path:
|
| 207 |
+
env = os.getenv("POLICIES_INIT_FILE")
|
| 208 |
+
if env:
|
| 209 |
+
p = pathlib.Path(env)
|
| 210 |
+
if p.exists():
|
| 211 |
+
return p
|
| 212 |
+
|
| 213 |
+
p = _search_file(_default_roots(), "__init__.py", must_contain="/lerobot/policies/")
|
| 214 |
+
if p and p.exists():
|
| 215 |
+
return p
|
| 216 |
+
|
| 217 |
+
raise FileNotFoundError("lerobot/policies/__init__.py not found. Set POLICIES_INIT_FILE env var.")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def patch_policies_optional_imports(p: pathlib.Path) -> Tuple[bool, str]:
|
| 221 |
+
"""
|
| 222 |
+
Make ONLY Groot/Diffusers imports optional.
|
| 223 |
+
This avoids wrapping unrelated exports/imports.
|
| 224 |
+
"""
|
| 225 |
+
s = _read_text(p)
|
| 226 |
+
marker = "PATCH: optional Groot/Diffusers imports"
|
| 227 |
+
if marker in s:
|
| 228 |
+
return False, f"Policies optional-import patch already present: {p}"
|
| 229 |
+
|
| 230 |
+
lines = s.splitlines(True)
|
| 231 |
+
|
| 232 |
+
def is_groot_line(line: str) -> bool:
|
| 233 |
+
# strict filter: only lines that import groot submodule
|
| 234 |
+
return bool(re.search(r"^\s*from\s+\.\s*groot\b|^\s*from\s+\.groot\b|^\s*import\s+.*\bgroot\b", line))
|
| 235 |
+
|
| 236 |
+
idxs = [i for i, l in enumerate(lines) if is_groot_line(l)]
|
| 237 |
+
if not idxs:
|
| 238 |
+
return False, f"No Groot imports found to wrap in: {p}"
|
| 239 |
+
|
| 240 |
+
# group consecutive indices
|
| 241 |
+
groups = []
|
| 242 |
+
start = prev = idxs[0]
|
| 243 |
+
for i in idxs[1:]:
|
| 244 |
+
if i == prev + 1:
|
| 245 |
+
prev = i
|
| 246 |
+
else:
|
| 247 |
+
groups.append((start, prev))
|
| 248 |
+
start = prev = i
|
| 249 |
+
groups.append((start, prev))
|
| 250 |
+
|
| 251 |
+
new_lines = []
|
| 252 |
+
last_end = -1
|
| 253 |
+
for (a, b) in groups:
|
| 254 |
+
# copy lines before this group
|
| 255 |
+
new_lines.extend(lines[last_end + 1:a])
|
| 256 |
+
|
| 257 |
+
# wrap group
|
| 258 |
+
new_lines.append("# --- PATCH: optional Groot/Diffusers imports ---\n")
|
| 259 |
+
new_lines.append(f"try: # {marker}\n")
|
| 260 |
+
for j in range(a, b + 1):
|
| 261 |
+
new_lines.append(" " + lines[j].lstrip())
|
| 262 |
+
new_lines.append("except Exception as _e:\n")
|
| 263 |
+
new_lines.append(" print('[policies_init] WARNING: optional groot deps missing:', _e)\n")
|
| 264 |
+
|
| 265 |
+
last_end = b
|
| 266 |
+
|
| 267 |
+
# copy rest
|
| 268 |
+
new_lines.extend(lines[last_end + 1:])
|
| 269 |
+
|
| 270 |
+
s2 = "".join(new_lines)
|
| 271 |
+
if s2 == s:
|
| 272 |
+
return False, f"Policies file unchanged after optional-import attempt: {p}"
|
| 273 |
+
|
| 274 |
+
_write_text(p, s2)
|
| 275 |
+
return True, f"Patched policies __init__ optional imports in: {p}"
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# -------------------------
|
| 279 |
+
# Patch 3: eval_sigma_vla_rollout.py
|
| 280 |
+
# -------------------------
|
| 281 |
+
|
| 282 |
+
def find_eval_file() -> pathlib.Path:
|
| 283 |
+
env = os.getenv("EVAL_FILE")
|
| 284 |
+
if env:
|
| 285 |
+
p = pathlib.Path(env)
|
| 286 |
+
if p.exists():
|
| 287 |
+
return p
|
| 288 |
+
|
| 289 |
+
p = pathlib.Path("/workspace/eval_sigma_vla_rollout.py")
|
| 290 |
+
if p.exists():
|
| 291 |
+
return p
|
| 292 |
+
|
| 293 |
+
pp = _search_file(["/workspace", "/workspace/lerobot"], "eval_sigma_vla_rollout.py")
|
| 294 |
+
if pp and pp.exists():
|
| 295 |
+
return pp
|
| 296 |
+
|
| 297 |
+
raise FileNotFoundError("eval_sigma_vla_rollout.py not found. Set EVAL_FILE env var.")
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def patch_eval_force_strict_false(p: pathlib.Path) -> Tuple[bool, str]:
|
| 301 |
+
s = _read_text(p)
|
| 302 |
+
marker = "PATCH: force strict=False for PI05Policy"
|
| 303 |
+
|
| 304 |
+
# 1) strict=True -> strict=False in PI05 loads
|
| 305 |
+
pat_strict_true = r"(policy_cls\.from_pretrained\([^)]*strict\s*=\s*)True(\s*[^)]*\))"
|
| 306 |
+
s2, n_true = re.subn(pat_strict_true, r"\1False\2", s)
|
| 307 |
+
|
| 308 |
+
# 2) add strict=False if missing on PI05 loads
|
| 309 |
+
def _add_strict_false_call(match: re.Match) -> str:
|
| 310 |
+
call = match.group(0)
|
| 311 |
+
if "strict" in call:
|
| 312 |
+
return call
|
| 313 |
+
return call[:-1] + ", strict=False)"
|
| 314 |
+
|
| 315 |
+
pat_no_strict_1 = r"policy_cls\.from_pretrained\(\s*repo_id\s*,\s*token\s*=\s*hf_token\s*\)"
|
| 316 |
+
pat_no_strict_2 = r"policy_cls\.from_pretrained\(\s*pretrained_name_or_path\s*=\s*repo_id\s*,\s*token\s*=\s*hf_token\s*\)"
|
| 317 |
+
|
| 318 |
+
s3, n_add1 = re.subn(pat_no_strict_1, _add_strict_false_call, s2)
|
| 319 |
+
s4, n_add2 = re.subn(pat_no_strict_2, _add_strict_false_call, s3)
|
| 320 |
+
|
| 321 |
+
changed = (n_true + n_add1 + n_add2) > 0
|
| 322 |
+
if not changed:
|
| 323 |
+
if marker in s:
|
| 324 |
+
return False, f"Eval strict patch already present: {p}"
|
| 325 |
+
return False, f"Eval already strict=False or no PI05 strict targets found: {p}"
|
| 326 |
+
|
| 327 |
+
if marker not in s4:
|
| 328 |
+
# annotate the first strict=False we introduced / touched
|
| 329 |
+
s4 = s4.replace("strict=False)", f"strict=False) # {marker}", 1)
|
| 330 |
+
|
| 331 |
+
_write_text(p, s4)
|
| 332 |
+
return True, f"Patched eval PI05 strict=False in: {p}"
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def patch_eval_shuffle_support(p: pathlib.Path) -> Tuple[bool, str]:
|
| 336 |
+
s = _read_text(p)
|
| 337 |
+
marker_arg = "PATCH: add --shuffle arg"
|
| 338 |
+
marker_dl = "PATCH: DataLoader shuffle uses args.shuffle"
|
| 339 |
+
|
| 340 |
+
changed = False
|
| 341 |
+
|
| 342 |
+
# 1) add CLI arg --shuffle if absent
|
| 343 |
+
if re.search(r'add_argument\(\s*["\']--shuffle["\']', s) is None:
|
| 344 |
+
# find last parser.add_argument(...) to insert after
|
| 345 |
+
arg_pat = re.compile(r"(?m)^\s*parser\.add_argument\(.+?\)\s*$")
|
| 346 |
+
matches = list(arg_pat.finditer(s))
|
| 347 |
+
if matches:
|
| 348 |
+
last = matches[-1]
|
| 349 |
+
insert_pos = last.end()
|
| 350 |
+
insert_text = (
|
| 351 |
+
"\nparser.add_argument("
|
| 352 |
+
"\"--shuffle\", action=\"store_true\", "
|
| 353 |
+
"help=\"Shuffle dataset order to sample different subsets per seed.\")"
|
| 354 |
+
f" # {marker_arg}\n"
|
| 355 |
+
)
|
| 356 |
+
s = s[:insert_pos] + insert_text + s[insert_pos:]
|
| 357 |
+
changed = True
|
| 358 |
+
|
| 359 |
+
# 2) DataLoader(... shuffle=False ...) -> args.shuffle
|
| 360 |
+
if marker_dl not in s:
|
| 361 |
+
def _dl_repl(m: re.Match) -> str:
|
| 362 |
+
prefix = m.group(1)
|
| 363 |
+
return prefix + f'getattr(args, "shuffle", False) # {marker_dl}'
|
| 364 |
+
|
| 365 |
+
# replace only literal shuffle=False
|
| 366 |
+
pat_dl = re.compile(r"(?s)(DataLoader\([\s\S]{0,1200}?shuffle\s*=\s*)False")
|
| 367 |
+
if pat_dl.search(s):
|
| 368 |
+
s = pat_dl.sub(_dl_repl, s, count=1)
|
| 369 |
+
changed = True
|
| 370 |
+
|
| 371 |
+
if changed:
|
| 372 |
+
_write_text(p, s)
|
| 373 |
+
return True, f"Patched eval shuffle support in: {p}"
|
| 374 |
+
|
| 375 |
+
return False, f"Eval shuffle support already present or no targets found: {p}"
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# -------------------------
|
| 379 |
+
# Main
|
| 380 |
+
# -------------------------
|
| 381 |
+
|
| 382 |
+
def main():
|
| 383 |
+
changed_any = False
|
| 384 |
+
|
| 385 |
+
try:
|
| 386 |
+
pi05_file = find_pi05_file()
|
| 387 |
+
changed, msg = patch_pi05_embed_tie(pi05_file)
|
| 388 |
+
print(msg)
|
| 389 |
+
changed_any |= changed
|
| 390 |
+
except Exception as e:
|
| 391 |
+
print("[patch_sigma_env] PI05 embed-tie patch skipped:", e)
|
| 392 |
+
|
| 393 |
+
try:
|
| 394 |
+
pi05_file = find_pi05_file()
|
| 395 |
+
changed, msg = patch_pi05_transformers_guard(pi05_file)
|
| 396 |
+
print(msg)
|
| 397 |
+
changed_any |= changed
|
| 398 |
+
except Exception as e:
|
| 399 |
+
print("[patch_sigma_env] PI05 transformers-guard patch skipped:", e)
|
| 400 |
+
|
| 401 |
+
try:
|
| 402 |
+
policies_init = find_policies_init()
|
| 403 |
+
changed, msg = patch_policies_optional_imports(policies_init)
|
| 404 |
+
print(msg)
|
| 405 |
+
changed_any |= changed
|
| 406 |
+
except Exception as e:
|
| 407 |
+
print("[patch_sigma_env] policies __init__ patch skipped:", e)
|
| 408 |
+
|
| 409 |
+
try:
|
| 410 |
+
eval_file = find_eval_file()
|
| 411 |
+
changed, msg = patch_eval_force_strict_false(eval_file)
|
| 412 |
+
print(msg)
|
| 413 |
+
changed_any |= changed
|
| 414 |
+
except Exception as e:
|
| 415 |
+
print("[patch_sigma_env] Eval strict patch skipped:", e)
|
| 416 |
+
|
| 417 |
+
try:
|
| 418 |
+
eval_file = find_eval_file()
|
| 419 |
+
changed, msg = patch_eval_shuffle_support(eval_file)
|
| 420 |
+
print(msg)
|
| 421 |
+
changed_any |= changed
|
| 422 |
+
except Exception as e:
|
| 423 |
+
print("[patch_sigma_env] Eval shuffle patch skipped:", e)
|
| 424 |
+
|
| 425 |
+
if changed_any:
|
| 426 |
+
print("[patch_sigma_env] Done. Patches applied.")
|
| 427 |
+
else:
|
| 428 |
+
print("[patch_sigma_env] Done. Nothing to change (already patched).")
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
if __name__ == "__main__":
|
| 432 |
+
main()
|
pi05_embed_tie.patch
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py
|
| 2 |
+
index b017bbc5..d6290da6 100644
|
| 3 |
+
--- a/src/lerobot/policies/pi05/modeling_pi05.py
|
| 4 |
+
+++ b/src/lerobot/policies/pi05/modeling_pi05.py
|
| 5 |
+
@@ -989,6 +989,13 @@ class PI05Policy(PreTrainedPolicy):
|
| 6 |
+
if remap_count > 0:
|
| 7 |
+
print(f"Remapped {remap_count} state dict keys")
|
| 8 |
+
|
| 9 |
+
# Load the remapped state dict into the model
|
| 10 |
+
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
| 11 |
+
+
|
| 12 |
+
+ # --- FIX: tie embed_tokens to lm_head if embed_tokens missing in ckpt ---
|
| 13 |
+
+ if any("embed_tokens.weight" in k for k in missing_keys):
|
| 14 |
+
+ with torch.no_grad():
|
| 15 |
+
+ embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens
|
| 16 |
+
+ lm_head = model.model.paligemma_with_expert.paligemma.lm_head
|
| 17 |
+
+ embed.weight = lm_head.weight
|
| 18 |
+
|
| 19 |
+
return model
|
requirements.txt
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--index-url https://download.pytorch.org/whl/cu121
|
| 2 |
+
--extra-index-url https://pypi.org/simple
|
| 3 |
+
|
| 4 |
+
torch==2.5.1+cu121
|
| 5 |
+
torchvision==0.20.1+cu121
|
| 6 |
+
|
| 7 |
+
transformers==4.44.2
|
| 8 |
+
accelerate==1.10.1
|
| 9 |
+
peft==0.17.0
|
| 10 |
+
safetensors==0.4.5
|
| 11 |
+
huggingface_hub[cli,hf-transfer]==0.35.3
|
| 12 |
+
datasets==4.1.1
|
| 13 |
+
sentencepiece==0.2.0
|
| 14 |
+
einops==0.8.0
|
| 15 |
+
bitsandbytes==0.43.3
|
| 16 |
+
|
| 17 |
+
numpy==2.0.2
|
| 18 |
+
pandas==2.2.3
|
| 19 |
+
pyarrow==21.0.0
|
| 20 |
+
tqdm==4.66.5
|
| 21 |
+
|
| 22 |
+
opencv-python-headless==4.10.0.84
|
| 23 |
+
pillow==10.4.0
|
| 24 |
+
av==15.1.0
|
| 25 |
+
imageio==2.36.0
|
| 26 |
+
imageio-ffmpeg==0.5.1
|
| 27 |
+
|
| 28 |
+
hydra-core==1.3.2
|
| 29 |
+
omegaconf==2.3.0
|
| 30 |
+
pyyaml==6.0.2
|
| 31 |
+
packaging==24.2
|
| 32 |
+
python-dotenv==1.0.1
|
| 33 |
+
wandb==0.21.1
|