diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index b017bbc5..d6290da6 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -989,6 +989,13 @@ class PI05Policy(PreTrainedPolicy): if remap_count > 0: print(f"Remapped {remap_count} state dict keys") # Load the remapped state dict into the model missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) + + # --- FIX: tie embed_tokens to lm_head if embed_tokens missing in ckpt --- + if any("embed_tokens.weight" in k for k in missing_keys): + with torch.no_grad(): + embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens + lm_head = model.model.paligemma_with_expert.paligemma.lm_head + embed.weight = lm_head.weight return model