#!/usr/bin/env python # Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import builtins import logging import math from collections import deque from pathlib import Path from typing import TYPE_CHECKING, Literal, TypedDict import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn from typing_extensions import Unpack from lerobot.utils.import_utils import _transformers_available # Conditional import for type checking and lazy loading if TYPE_CHECKING or _transformers_available: from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma from transformers.models.gemma.modeling_gemma import GemmaForCausalLM from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration else: CONFIG_MAPPING = None modeling_gemma = None GemmaForCausalLM = None PaliGemmaForConditionalGeneration = None from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.utils.constants import ( ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OPENPI_ATTENTION_MASK_VALUE, ) class ActionSelectKwargs(TypedDict, total=False): inference_delay: int | None prev_chunk_left_over: Tensor | None execution_horizon: int | None def get_safe_dtype(target_dtype, device_type): """Get a safe dtype for the given device type.""" if device_type == "mps" and target_dtype == torch.float64: return torch.float32 if device_type == "cpu": # CPU doesn't support bfloat16, use float32 instead if target_dtype == torch.bfloat16: return torch.float32 if target_dtype == torch.float64: return torch.float64 return target_dtype def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy) time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu" ) -> Tensor: """Computes sine-cosine positional embedding vectors for scalar positions.""" if dimension % 2 != 0: raise ValueError(f"dimension ({dimension}) must be divisible by 2") if time.ndim != 1: raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") dtype = get_safe_dtype(torch.float64, device.type) fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) period = min_period * (max_period / min_period) ** fraction # Compute the outer product scaling_factor = 1.0 / period * 2 * math.pi sin_input = scaling_factor[None, :] * time[:, None] return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) dist = torch.distributions.Beta(alpha_t, beta_t) return dist.sample((bsize,)) def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy) """Copied from big_vision. Tokens can attend to valid inputs tokens which have a cumulative mask_ar smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to setup several types of attention, for example: [[1 1 1 1 1 1]]: pure causal attention. [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between themselves and the last 3 tokens have a causal attention. The first entry could also be a 1 without changing behaviour. [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a block can attend all previous blocks and all tokens on the same block. Args: input_mask: bool[B, N] true if its part of the input, false if padding. mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on it and 0 where it shares the same attention mask as the previous token. """ if att_masks.ndim != 2: raise ValueError(att_masks.ndim) if pad_masks.ndim != 2: raise ValueError(pad_masks.ndim) cumsum = torch.cumsum(att_masks, dim=1) att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] return att_2d_masks & pad_2d_masks def pad_vector(vector, new_dim): """Pad the last dimension of a vector to new_dim with zeros. Can be (batch_size x sequence_length x features_dimension) or (batch_size x features_dimension) """ if vector.shape[-1] >= new_dim: return vector return F.pad(vector, (0, new_dim - vector.shape[-1])) def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) images: torch.Tensor, height: int, width: int, mode: str = "bilinear", ) -> torch.Tensor: """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion by padding with black. If the image is float32, it must be in the range [-1, 1]. Args: images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] height: Target height width: Target width mode: Interpolation mode ('bilinear', 'nearest', etc.) Returns: Resized and padded tensor with same shape format as input """ # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] if images.shape[-1] <= 4: # Assume channels-last format channels_last = True if images.dim() == 3: images = images.unsqueeze(0) # Add batch dimension images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] else: channels_last = False if images.dim() == 3: images = images.unsqueeze(0) # Add batch dimension batch_size, channels, cur_height, cur_width = images.shape # Calculate resize ratio ratio = max(cur_width / width, cur_height / height) resized_height = int(cur_height / ratio) resized_width = int(cur_width / ratio) # Resize resized_images = F.interpolate( images, size=(resized_height, resized_width), mode=mode, align_corners=False if mode == "bilinear" else None, ) # Handle dtype-specific clipping if images.dtype == torch.uint8: resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) elif images.dtype == torch.float32: resized_images = resized_images.clamp(-1.0, 1.0) else: raise ValueError(f"Unsupported image dtype: {images.dtype}") # Calculate padding pad_h0, remainder_h = divmod(height - resized_height, 2) pad_h1 = pad_h0 + remainder_h pad_w0, remainder_w = divmod(width - resized_width, 2) pad_w1 = pad_w0 + remainder_w # Pad constant_value = 0 if images.dtype == torch.uint8 else -1.0 padded_images = F.pad( resized_images, (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom mode="constant", value=constant_value, ) # Convert back to original format if needed if channels_last: padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] return padded_images # Define the complete layer computation function for gradient checkpointing def compute_layer_complete( layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert ): models = [paligemma.language_model, gemma_expert.model] query_states = [] key_states = [] value_states = [] gates = [] for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 gates.append(gate) input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states.append(query_state) key_states.append(key_state) value_states.append(value_state) # Concatenate and process attention query_states = torch.cat(query_states, dim=2) key_states = torch.cat(key_states, dim=2) value_states = torch.cat(value_states, dim=2) dummy_tensor = torch.zeros( query_states.shape[0], query_states.shape[2], query_states.shape[-1], device=query_states.device, dtype=query_states.dtype, ) cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) query_states, key_states = modeling_gemma.apply_rotary_pos_emb( query_states, key_states, cos, sin, unsqueeze_dim=1 ) batch_size = query_states.shape[0] scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling # Attention computation att_output, _ = modeling_gemma.eager_attention_forward( paligemma.language_model.layers[layer_idx].self_attn, query_states, key_states, value_states, attention_mask, scaling, ) # Get head_dim from the current layer, not from the model head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) # Process layer outputs outputs_embeds = [] start_pos = 0 for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] end_pos = start_pos + hidden_states.shape[1] if att_output.dtype != layer.self_attn.o_proj.weight.dtype: att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) # first residual out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 after_first_residual = out_emb.clone() out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) # Convert to bfloat16 if the next layer (mlp) uses bfloat16 if layer.mlp.up_proj.weight.dtype == torch.bfloat16: out_emb = out_emb.to(dtype=torch.bfloat16) out_emb = layer.mlp(out_emb) # second residual out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 outputs_embeds.append(out_emb) start_pos = end_pos return outputs_embeds class GemmaConfig: # see openpi `gemma.py: Config` """Configuration for Gemma model variants.""" def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim): self.width = width self.depth = depth self.mlp_dim = mlp_dim self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config` """Returns config for specified gemma variant.""" if variant == "gemma_300m": return GemmaConfig( width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256, ) elif variant == "gemma_2b": return GemmaConfig( width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256, ) else: raise ValueError(f"Unknown variant: {variant}") class PaliGemmaWithExpertModel( nn.Module ): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi """PaliGemma model with action expert for PI05.""" def __init__( self, vlm_config, action_expert_config, use_adarms=None, precision: Literal["bfloat16", "float32"] = "bfloat16", ): if use_adarms is None: use_adarms = [False, False] super().__init__() vlm_config_hf = CONFIG_MAPPING["paligemma"]() vlm_config_hf._vocab_size = 257152 # noqa: SLF001 vlm_config_hf.image_token_index = 257152 vlm_config_hf.text_config.hidden_size = vlm_config.width vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads vlm_config_hf.text_config.head_dim = vlm_config.head_dim vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" vlm_config_hf.text_config.torch_dtype = "float32" vlm_config_hf.text_config.vocab_size = 257152 vlm_config_hf.text_config.use_adarms = use_adarms[0] vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None vlm_config_hf.vision_config.intermediate_size = 4304 vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" vlm_config_hf.vision_config.torch_dtype = "float32" action_expert_config_hf = CONFIG_MAPPING["gemma"]( head_dim=action_expert_config.head_dim, hidden_size=action_expert_config.width, intermediate_size=action_expert_config.mlp_dim, num_attention_heads=action_expert_config.num_heads, num_hidden_layers=action_expert_config.depth, num_key_value_heads=action_expert_config.num_kv_heads, vocab_size=257152, hidden_activation="gelu_pytorch_tanh", torch_dtype="float32", use_adarms=use_adarms[1], adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, ) self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) self.gemma_expert.model.embed_tokens = None self.to_bfloat16_for_selected_params(precision) def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): if precision == "bfloat16": self.to(dtype=torch.bfloat16) elif precision == "float32": self.to(dtype=torch.float32) return else: raise ValueError(f"Invalid precision: {precision}") params_to_keep_float32 = [ "vision_tower.vision_model.embeddings.patch_embedding.weight", "vision_tower.vision_model.embeddings.patch_embedding.bias", "vision_tower.vision_model.embeddings.position_embedding.weight", "input_layernorm", "post_attention_layernorm", "model.norm", ] for name, param in self.named_parameters(): if any(selector in name for selector in params_to_keep_float32): param.data = param.data.to(dtype=torch.float32) def embed_image(self, image: torch.Tensor): return self.paligemma.model.get_image_features(image) def embed_language_tokens(self, tokens: torch.Tensor): return self.paligemma.language_model.embed_tokens(tokens) def forward( self, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: list[torch.FloatTensor] | None = None, inputs_embeds: list[torch.FloatTensor] | None = None, use_cache: bool | None = None, adarms_cond: list[torch.Tensor] | None = None, ): if adarms_cond is None: adarms_cond = [None, None] if inputs_embeds[1] is None: prefix_output = self.paligemma.language_model.forward( inputs_embeds=inputs_embeds[0], attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, adarms_cond=adarms_cond[0] if adarms_cond is not None else None, ) prefix_past_key_values = prefix_output.past_key_values prefix_output = prefix_output.last_hidden_state suffix_output = None elif inputs_embeds[0] is None: suffix_output = self.gemma_expert.model.forward( inputs_embeds=inputs_embeds[1], attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, adarms_cond=adarms_cond[1] if adarms_cond is not None else None, ) suffix_output = suffix_output.last_hidden_state prefix_output = None prefix_past_key_values = None else: models = [self.paligemma.language_model, self.gemma_expert.model] num_layers = self.paligemma.config.text_config.num_hidden_layers # Check if gradient checkpointing is enabled for any of the models use_gradient_checkpointing = ( hasattr(self.gemma_expert.model, "gradient_checkpointing") and self.gemma_expert.model.gradient_checkpointing and self.training ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) # Process all layers with gradient checkpointing if enabled for layer_idx in range(num_layers): if use_gradient_checkpointing: inputs_embeds = torch.utils.checkpoint.checkpoint( compute_layer_complete, layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, use_reentrant=False, preserve_rng_state=False, paligemma=self.paligemma, gemma_expert=self.gemma_expert, ) else: inputs_embeds = compute_layer_complete( layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma=self.paligemma, gemma_expert=self.gemma_expert, ) # final norm def compute_final_norms(inputs_embeds, adarms_cond): outputs_embeds = [] for i, hidden_states in enumerate(inputs_embeds): out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) outputs_embeds.append(out_emb) return outputs_embeds # Apply gradient checkpointing to final norm if enabled if use_gradient_checkpointing: outputs_embeds = torch.utils.checkpoint.checkpoint( compute_final_norms, inputs_embeds, adarms_cond, use_reentrant=False, preserve_rng_state=False, ) else: outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) prefix_output = outputs_embeds[0] suffix_output = outputs_embeds[1] prefix_past_key_values = None return [prefix_output, suffix_output], prefix_past_key_values class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` """Core PI05 PyTorch model.""" def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None): super().__init__() self.config = config self.rtc_processor = rtc_processor paligemma_config = get_gemma_config(config.paligemma_variant) action_expert_config = get_gemma_config(config.action_expert_variant) self.paligemma_with_expert = PaliGemmaWithExpertModel( paligemma_config, action_expert_config, use_adarms=[False, True], precision=config.dtype, ) self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim) self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) # Initialize gradient checkpointing flag self.gradient_checkpointing_enabled = False # Compile model if requested if config.compile_model: torch.set_float32_matmul_precision("high") self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" # PATCH: make transformers version guard non-fatal and robust across versions try: from transformers.models.siglip import check if hasattr(check, "check_whether_transformers_replace_is_installed_correctly"): ok = check.check_whether_transformers_replace_is_installed_correctly() if not ok: logging.warning("[pi05] %s", msg) else: logging.warning( "[patch_pi05] SigLIP check helper missing; skipping strict transformers version guard." ) except Exception as e: # noqa: BLE001 logging.warning( "[patch_pi05] Could not run transformers version guard (%s). " "Continuing without strict transformers check. %s", msg, e, ) def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" self.gradient_checkpointing_enabled = True self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True logging.info("Enabled gradient checkpointing for PI05Pytorch model") def gradient_checkpointing_disable(self): """Disable gradient checkpointing.""" self.gradient_checkpointing_enabled = False self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False logging.info("Disabled gradient checkpointing for PI05Pytorch model") def _rtc_enabled(self): return self.config.rtc_config is not None and self.config.rtc_config.enabled def _apply_checkpoint(self, func, *args, **kwargs): """Helper method to apply gradient checkpointing if enabled.""" if self.gradient_checkpointing_enabled and self.training: return torch.utils.checkpoint.checkpoint( func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs ) return func(*args, **kwargs) def _prepare_attention_masks_4d(self, att_2d_masks): """Helper method to prepare 4D attention masks for transformer.""" att_2d_masks_4d = att_2d_masks[:, None, :, :] return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) def sample_noise(self, shape, device): return torch.normal( mean=0.0, std=1.0, size=shape, dtype=torch.float32, device=device, ) def sample_time(self, bsize, device): time_beta = sample_beta( self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device ) time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset return time.to(dtype=torch.float32, device=device) def embed_prefix( self, images, img_masks, tokens, masks ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Embed images with SigLIP and language tokens with embedding layer.""" embs = [] pad_masks = [] att_masks = [] # Process images for img, img_mask in zip(images, img_masks, strict=True): def image_embed_func(img): return self.paligemma_with_expert.embed_image(img) img_emb = self._apply_checkpoint(image_embed_func, img) bsize, num_img_embs = img_emb.shape[:2] embs.append(img_emb) pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) att_masks += [0] * num_img_embs # Process language tokens def lang_embed_func(tokens): lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) lang_emb_dim = lang_emb.shape[-1] return lang_emb * math.sqrt(lang_emb_dim) lang_emb = self._apply_checkpoint(lang_embed_func, tokens) embs.append(lang_emb) pad_masks.append(masks) num_lang_embs = lang_emb.shape[1] att_masks += [0] * num_lang_embs embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) bsize = pad_masks.shape[0] att_masks = att_masks[None, :].expand(bsize, len(att_masks)) return embs, pad_masks, att_masks def embed_suffix(self, noisy_actions, timestep): """Embed noisy_actions, timestep to prepare for Expert Gemma processing.""" embs = [] pad_masks = [] att_masks = [] # Embed timestep using sine-cosine positional encoding time_emb = create_sinusoidal_pos_embedding( timestep, self.action_in_proj.out_features, min_period=self.config.min_period, max_period=self.config.max_period, device=timestep.device, ) time_emb = time_emb.type(dtype=timestep.dtype) # Fuse timestep + action information using an MLP def action_proj_func(noisy_actions): return self.action_in_proj(noisy_actions) action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) def time_mlp_func(time_emb): x = self.time_mlp_in(time_emb) x = F.silu(x) x = self.time_mlp_out(x) return F.silu(x) time_emb = self._apply_checkpoint(time_mlp_func, time_emb) action_time_emb = action_emb adarms_cond = time_emb embs.append(action_time_emb) bsize, action_time_dim = action_time_emb.shape[:2] action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) pad_masks.append(action_time_mask) # Set attention masks so that image, language and state inputs do not attend to action tokens att_masks += [1] + ([0] * (self.config.chunk_size - 1)) embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) att_masks = att_masks[None, :].expand(bsize, len(att_masks)) return embs, pad_masks, att_masks, adarms_cond def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor: """Do a full training forward pass and compute the loss.""" if noise is None: noise = self.sample_noise(actions.shape, actions.device) if time is None: time = self.sample_time(actions.shape[0], actions.device) time_expanded = time[:, None, None] x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) if ( self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 ): suffix_embs = suffix_embs.to(dtype=torch.bfloat16) prefix_embs = prefix_embs.to(dtype=torch.bfloat16) pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) att_2d_masks = make_att_2d_masks(pad_masks, att_masks) position_ids = torch.cumsum(pad_masks, dim=1) - 1 att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): (_, suffix_out), _ = self.paligemma_with_expert.forward( attention_mask=att_2d_masks_4d, position_ids=position_ids, past_key_values=None, inputs_embeds=[prefix_embs, suffix_embs], use_cache=False, adarms_cond=[None, adarms_cond], ) return suffix_out suffix_out = self._apply_checkpoint( forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond ) suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out.to(dtype=torch.float32) def action_out_proj_func(suffix_out): return self.action_out_proj(suffix_out) v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) return F.mse_loss(u_t, v_t, reduction="none") @torch.no_grad() # see openpi `sample_actions` (slightly adapted) def sample_actions( self, images, img_masks, tokens, masks, noise=None, num_steps=None, **kwargs: Unpack[ActionSelectKwargs], ) -> Tensor: """Do a full inference forward and compute the action.""" if num_steps is None: num_steps = self.config.num_inference_steps bsize = tokens.shape[0] device = tokens.device if noise is None: # Sample noise with padded dimension as expected by action_in_proj actions_shape = ( bsize, self.config.chunk_size, self.config.max_action_dim, ) # Use config max_action_dim for internal processing noise = self.sample_noise(actions_shape, device) prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 _, past_key_values = self.paligemma_with_expert.forward( attention_mask=prefix_att_2d_masks_4d, position_ids=prefix_position_ids, past_key_values=None, inputs_embeds=[prefix_embs, None], use_cache=True, ) dt = -1.0 / num_steps dt = torch.tensor(dt, dtype=torch.float32, device=device) x_t = noise time = torch.tensor(1.0, dtype=torch.float32, device=device) while time >= -dt / 2: expanded_time = time.expand(bsize) # Define a closure function to properly capture expanded_time # This avoids the lambda expression (E731) and loop variable binding (B023) issues def denoise_step_partial_call(input_x_t, current_timestep=expanded_time): return self.denoise_step( prefix_pad_masks=prefix_pad_masks, past_key_values=past_key_values, x_t=input_x_t, timestep=current_timestep, ) if self._rtc_enabled(): inference_delay = kwargs.get("inference_delay") prev_chunk_left_over = kwargs.get("prev_chunk_left_over") execution_horizon = kwargs.get("execution_horizon") v_t = self.rtc_processor.denoise_step( x_t=x_t, prev_chunk_left_over=prev_chunk_left_over, inference_delay=inference_delay, time=time, original_denoise_step_partial=denoise_step_partial_call, execution_horizon=execution_horizon, ) else: v_t = denoise_step_partial_call(x_t) # Euler step x_t += dt * v_t # Record x_t and v_t after Euler step if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t) time += dt return x_t def denoise_step( self, prefix_pad_masks, past_key_values, x_t, timestep, ): """Apply one denoising step of the noise `x_t` at a given timestep.""" suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep) suffix_len = suffix_pad_masks.shape[1] batch_size = prefix_pad_masks.shape[0] prefix_len = prefix_pad_masks.shape[1] prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 outputs_embeds, _ = self.paligemma_with_expert.forward( attention_mask=full_att_2d_masks_4d, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=[None, suffix_embs], use_cache=False, adarms_cond=[None, adarms_cond], ) suffix_out = outputs_embeds[1] suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out.to(dtype=torch.float32) return self.action_out_proj(suffix_out) class PI05Policy(PreTrainedPolicy): """PI05 Policy for LeRobot.""" config_class = PI05Config name = "pi05" def __init__( self, config: PI05Config, ): """ Args: config: Policy configuration class instance. """ super().__init__(config) config.validate_features() self.config = config # Initialize the core PI05 model self.init_rtc_processor() self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor) # Enable gradient checkpointing if requested if config.gradient_checkpointing: self.model.gradient_checkpointing_enable() self.model.to(config.device) self.reset() @classmethod def from_pretrained( cls: builtins.type[T], pretrained_name_or_path: str | Path, *, config: PreTrainedConfig | None = None, force_download: bool = False, resume_download: bool | None = None, proxies: dict | None = None, token: str | bool | None = None, cache_dir: str | Path | None = None, local_files_only: bool = False, revision: str | None = None, strict: bool = True, **kwargs, ) -> T: """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( "The PI05 model is a direct port of the OpenPI implementation. \n" "This implementation follows the original OpenPI structure for compatibility. \n" "Original implementation: https://github.com/Physical-Intelligence/openpi" ) if pretrained_name_or_path is None: raise ValueError("pretrained_name_or_path is required") # Use provided config if available, otherwise create default config if config is None: config = PreTrainedConfig.from_pretrained( pretrained_name_or_path=pretrained_name_or_path, force_download=force_download, resume_download=resume_download, proxies=proxies, token=token, cache_dir=cache_dir, local_files_only=local_files_only, revision=revision, **kwargs, ) # Initialize model without loading weights # Check if dataset_stats were provided in kwargs model = cls(config, **kwargs) # Now manually load and remap the state dict try: # Try to load the pytorch_model.bin or model.safetensors file print(f"Loading model from: {pretrained_name_or_path}") try: from transformers.utils import cached_file # Try safetensors first resolved_file = cached_file( pretrained_name_or_path, "model.safetensors", cache_dir=kwargs.get("cache_dir"), force_download=kwargs.get("force_download", False), resume_download=kwargs.get("resume_download"), proxies=kwargs.get("proxies"), use_auth_token=kwargs.get("use_auth_token"), revision=kwargs.get("revision"), local_files_only=kwargs.get("local_files_only", False), ) from safetensors.torch import load_file original_state_dict = load_file(resolved_file) print("✓ Loaded state dict from model.safetensors") except Exception as e: # noqa: BLE001 print(f"Could not load state dict from remote files: {e}") print("Returning model without loading pretrained weights") return model # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) # Then add "model." prefix for all keys that don't already have it remapped_state_dict = {} remap_count = 0 for key, value in fixed_state_dict.items(): if not key.startswith("model."): new_key = f"model.{key}" remapped_state_dict[new_key] = value remap_count += 1 if remap_count <= 10: # Only print first 10 to avoid spam print(f"Remapped: {key} -> {new_key}") else: remapped_state_dict[key] = value if remap_count > 0: print(f"Remapped {remap_count} state dict keys") # Load the remapped state dict into the model missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) # --- PATCH: tie embed_tokens to lm_head if ckpt omitted embed_tokens --- if any("embed_tokens.weight" in k for k in missing_keys): try: with torch.no_grad(): embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens lm_head = model.model.paligemma_with_expert.paligemma.lm_head if embed is not None and lm_head is not None: embed.weight = lm_head.weight except Exception as _e: # noqa: BLE001 print("[patch_pi05] Could not tie embed_tokens to lm_head:", _e) # --- FIX: tie embed_tokens to lm_head if embed_tokens missing in ckpt --- if any("embed_tokens.weight" in k for k in missing_keys): with torch.no_grad(): embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens lm_head = model.model.paligemma_with_expert.paligemma.lm_head embed.weight = lm_head.weight if missing_keys: print(f"Missing keys when loading state dict: {len(missing_keys)} keys") if len(missing_keys) <= 5: for key in missing_keys: print(f" - {key}") else: for key in missing_keys[:5]: print(f" - {key}") print(f" ... and {len(missing_keys) - 5} more") if unexpected_keys: print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") if len(unexpected_keys) <= 5: for key in unexpected_keys: print(f" - {key}") else: for key in unexpected_keys[:5]: print(f" - {key}") print(f" ... and {len(unexpected_keys) - 5} more") if not missing_keys and not unexpected_keys: print("All keys loaded successfully!") except Exception as e: # noqa: BLE001 print(f"Warning: Could not remap state dict keys: {e}") return model def _fix_pytorch_state_dict_keys( self, state_dict, model_config ): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys` """Fix state dict keys to match current model architecture.""" import re fixed_state_dict = {} for key, value in state_dict.items(): new_key = key # Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias # For gemma expert layers if re.match( r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", key, ): # Check if the model actually has adaRMS enabled for the expert expert_uses_adarms = getattr( self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False ) if expert_uses_adarms: logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}") continue if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key): # Check if the model actually has adaRMS enabled for the expert expert_uses_adarms = getattr( self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False ) if expert_uses_adarms: logging.warning(f"Skipping norm key (adaRMS mismatch): {key}") continue # Handle MLP naming changes for pi05 # pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_* if key.startswith("action_time_mlp_in."): new_key = key.replace("action_time_mlp_in.", "time_mlp_in.") elif key.startswith("action_time_mlp_out."): new_key = key.replace("action_time_mlp_out.", "time_mlp_out.") # Also handle state_proj which shouldn't exist in pi05 if key.startswith("state_proj."): logging.warning(f"Skipping state_proj key in pi05 mode: {key}") continue # Handle vision tower embedding layer potential differences if "patch_embedding" in key: # Some checkpoints might have this, but current model expects different structure logging.warning(f"Vision embedding key might need handling: {key}") fixed_state_dict[new_key] = value return fixed_state_dict def get_optim_params(self) -> dict: return self.parameters() def reset(self): """Reset internal state - called when environment resets.""" self._action_queue = deque(maxlen=self.config.n_action_steps) self._queues = { ACTION: deque(maxlen=self.config.n_action_steps), } def init_rtc_processor(self): """Initialize RTC processor if RTC is enabled in config.""" self.rtc_processor = None # Create processor if config provided # If RTC is not enabled - we can still track the denoising data if self.config.rtc_config is not None: self.rtc_processor = RTCProcessor(self.config.rtc_config) model_value = getattr(self, "model", None) if model_value is not None: model_value.rtc_processor = self.rtc_processor def _rtc_enabled(self) -> bool: return self.config.rtc_config is not None and self.config.rtc_config.enabled def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: """Preprocess images for the model. Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1]. PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1]. """ images = [] img_masks = [] # Get device from model parameters device = next(self.parameters()).device present_img_keys = [key for key in self.config.image_features if key in batch] missing_img_keys = [key for key in self.config.image_features if key not in batch] if len(present_img_keys) == 0: raise ValueError( f"All image features are missing from the batch. At least one expected. " f"(batch: {batch.keys()}) (image_features: {self.config.image_features})" ) # Preprocess image features present in the batch for key in present_img_keys: img = batch[key] # Ensure tensor is on the same device as the model if img.device != device: img = img.to(device) # Ensure float32 dtype for consistency if img.dtype != torch.float32: img = img.to(torch.float32) # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1 if is_channels_first: # Convert [B, C, H, W] to [B, H, W, C] for processing img = img.permute(0, 2, 3, 1) # from openpi preprocess_observation_pytorch: Resize with padding if needed if img.shape[1:3] != self.config.image_resolution: img = resize_with_pad_torch(img, *self.config.image_resolution) # Normalize from [0,1] to [-1,1] as expected by siglip img = img * 2.0 - 1.0 # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first if is_channels_first: img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] images.append(img) # Create mask (all ones for real images) bsize = img.shape[0] mask = torch.ones(bsize, dtype=torch.bool, device=device) img_masks.append(mask) # Create image features not present in the batch as fully 0 padded images for _num_empty_cameras in range(len(missing_img_keys)): img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP mask = torch.zeros_like(mask) # Mask is zero for empty cameras images.append(img) img_masks.append(mask) return images, img_masks def prepare_action(self, batch): """Pad action""" actions = pad_vector(batch[ACTION], self.config.max_action_dim) return actions @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" assert not self._rtc_enabled(), ( "RTC is not supported for select_action, use it with predict_action_chunk" ) self.eval() # Action queue logic for n_action_steps > 1 if len(self._action_queue) == 0: actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] # Transpose to get shape (n_action_steps, batch_size, action_dim) self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor: """Predict a chunk of actions given environment observations.""" self.eval() # Prepare inputs images, img_masks = self._preprocess_images(batch) tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] # Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05) actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs) # Unpad actions to actual action dimension original_action_dim = self.config.output_features[ACTION].shape[0] actions = actions[:, :, :original_action_dim] return actions def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training.""" # Prepare inputs images, img_masks = self._preprocess_images(batch) tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] actions = self.prepare_action(batch) # Compute loss (no separate state needed for PI05) losses = self.model.forward(images, img_masks, tokens, masks, actions) # Truncate losses to actual action dimensions original_action_dim = self.config.output_features[ACTION].shape[0] losses = losses[:, :, :original_action_dim] loss = losses.mean() loss_dict = { "loss": loss.item(), "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), } return loss, loss_dict # PATCH: downgrade transformer version guard