|
|
import logging |
|
|
import math |
|
|
|
|
|
import torch |
|
|
from torch import Tensor |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
import openpi.models.gemma as _gemma |
|
|
from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel |
|
|
import openpi.models_pytorch.preprocessing_pytorch as _preprocessing |
|
|
|
|
|
|
|
|
def get_safe_dtype(target_dtype, device_type): |
|
|
"""Get a safe dtype for the given device type.""" |
|
|
if device_type == "cpu": |
|
|
|
|
|
if target_dtype == torch.bfloat16: |
|
|
return torch.float32 |
|
|
if target_dtype == torch.float64: |
|
|
return torch.float64 |
|
|
return target_dtype |
|
|
|
|
|
|
|
|
def create_sinusoidal_pos_embedding( |
|
|
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" |
|
|
) -> Tensor: |
|
|
"""Computes sine-cosine positional embedding vectors for scalar positions.""" |
|
|
if dimension % 2 != 0: |
|
|
raise ValueError(f"dimension ({dimension}) must be divisible by 2") |
|
|
|
|
|
if time.ndim != 1: |
|
|
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") |
|
|
|
|
|
dtype = get_safe_dtype(torch.float64, device.type) |
|
|
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) |
|
|
period = min_period * (max_period / min_period) ** fraction |
|
|
|
|
|
|
|
|
scaling_factor = 1.0 / period * 2 * math.pi |
|
|
sin_input = scaling_factor[None, :] * time[:, None] |
|
|
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) |
|
|
|
|
|
|
|
|
def sample_beta(alpha, beta, bsize, device): |
|
|
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) |
|
|
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) |
|
|
dist = torch.distributions.Beta(alpha_t, beta_t) |
|
|
return dist.sample((bsize,)) |
|
|
|
|
|
|
|
|
def make_att_2d_masks(pad_masks, att_masks): |
|
|
"""Copied from big_vision. |
|
|
|
|
|
Tokens can attend to valid inputs tokens which have a cumulative mask_ar |
|
|
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to |
|
|
setup several types of attention, for example: |
|
|
|
|
|
[[1 1 1 1 1 1]]: pure causal attention. |
|
|
|
|
|
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between |
|
|
themselves and the last 3 tokens have a causal attention. The first |
|
|
entry could also be a 1 without changing behaviour. |
|
|
|
|
|
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a |
|
|
block can attend all previous blocks and all tokens on the same block. |
|
|
|
|
|
Args: |
|
|
input_mask: bool[B, N] true if its part of the input, false if padding. |
|
|
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on |
|
|
it and 0 where it shares the same attention mask as the previous token. |
|
|
""" |
|
|
if att_masks.ndim != 2: |
|
|
raise ValueError(att_masks.ndim) |
|
|
if pad_masks.ndim != 2: |
|
|
raise ValueError(pad_masks.ndim) |
|
|
|
|
|
cumsum = torch.cumsum(att_masks, dim=1) |
|
|
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] |
|
|
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] |
|
|
return att_2d_masks & pad_2d_masks |
|
|
|
|
|
|
|
|
class PI0Pytorch(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.pi05 = config.pi05 |
|
|
|
|
|
paligemma_config = _gemma.get_config(config.paligemma_variant) |
|
|
action_expert_config = _gemma.get_config(config.action_expert_variant) |
|
|
|
|
|
self.paligemma_with_expert = PaliGemmaWithExpertModel( |
|
|
paligemma_config, |
|
|
action_expert_config, |
|
|
use_adarms=[False, True] if self.pi05 else [False, False], |
|
|
precision=config.dtype, |
|
|
) |
|
|
|
|
|
self.action_in_proj = nn.Linear(32, action_expert_config.width) |
|
|
self.action_out_proj = nn.Linear(action_expert_config.width, 32) |
|
|
|
|
|
if self.pi05: |
|
|
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) |
|
|
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) |
|
|
else: |
|
|
self.state_proj = nn.Linear(32, action_expert_config.width) |
|
|
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width) |
|
|
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) |
|
|
|
|
|
torch.set_float32_matmul_precision("high") |
|
|
self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune") |
|
|
|
|
|
|
|
|
self.gradient_checkpointing_enabled = False |
|
|
|
|
|
msg = "transformers_replace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`." |
|
|
try: |
|
|
from transformers.models.siglip import check |
|
|
|
|
|
if not check.check_whether_transformers_replace_is_installed_correctly(): |
|
|
raise ValueError(msg) |
|
|
except ImportError: |
|
|
raise ValueError(msg) from None |
|
|
|
|
|
def gradient_checkpointing_enable(self): |
|
|
"""Enable gradient checkpointing for memory optimization.""" |
|
|
self.gradient_checkpointing_enabled = True |
|
|
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True |
|
|
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True |
|
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True |
|
|
|
|
|
logging.info("Enabled gradient checkpointing for PI0Pytorch model") |
|
|
|
|
|
def gradient_checkpointing_disable(self): |
|
|
"""Disable gradient checkpointing.""" |
|
|
self.gradient_checkpointing_enabled = False |
|
|
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False |
|
|
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False |
|
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False |
|
|
|
|
|
logging.info("Disabled gradient checkpointing for PI0Pytorch model") |
|
|
|
|
|
def is_gradient_checkpointing_enabled(self): |
|
|
"""Check if gradient checkpointing is enabled.""" |
|
|
return self.gradient_checkpointing_enabled |
|
|
|
|
|
def _apply_checkpoint(self, func, *args, **kwargs): |
|
|
"""Helper method to apply gradient checkpointing if enabled.""" |
|
|
if self.gradient_checkpointing_enabled and self.training: |
|
|
return torch.utils.checkpoint.checkpoint( |
|
|
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs |
|
|
) |
|
|
return func(*args, **kwargs) |
|
|
|
|
|
def _prepare_attention_masks_4d(self, att_2d_masks): |
|
|
"""Helper method to prepare 4D attention masks for transformer.""" |
|
|
att_2d_masks_4d = att_2d_masks[:, None, :, :] |
|
|
return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38) |
|
|
|
|
|
def _preprocess_observation(self, observation, *, train=True): |
|
|
"""Helper method to preprocess observation.""" |
|
|
observation = _preprocessing.preprocess_observation_pytorch(observation, train=train) |
|
|
return ( |
|
|
list(observation.images.values()), |
|
|
list(observation.image_masks.values()), |
|
|
observation.tokenized_prompt, |
|
|
observation.tokenized_prompt_mask, |
|
|
observation.state, |
|
|
) |
|
|
|
|
|
def sample_noise(self, shape, device): |
|
|
return torch.normal( |
|
|
mean=0.0, |
|
|
std=1.0, |
|
|
size=shape, |
|
|
dtype=torch.float32, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
def sample_time(self, bsize, device): |
|
|
time_beta = sample_beta(1.5, 1.0, bsize, device) |
|
|
time = time_beta * 0.999 + 0.001 |
|
|
return time.to(dtype=torch.float32, device=device) |
|
|
|
|
|
def embed_prefix( |
|
|
self, images, img_masks, lang_tokens, lang_masks |
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
"""Embed images with SigLIP and language tokens with embedding layer to prepare |
|
|
for PaliGemma transformer processing. |
|
|
""" |
|
|
embs = [] |
|
|
pad_masks = [] |
|
|
att_masks = [] |
|
|
|
|
|
|
|
|
for img, img_mask in zip(images, img_masks, strict=True): |
|
|
|
|
|
def image_embed_func(img): |
|
|
return self.paligemma_with_expert.embed_image(img) |
|
|
|
|
|
img_emb = self._apply_checkpoint(image_embed_func, img) |
|
|
|
|
|
bsize, num_img_embs = img_emb.shape[:2] |
|
|
|
|
|
embs.append(img_emb) |
|
|
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) |
|
|
|
|
|
|
|
|
att_masks += [0] * num_img_embs |
|
|
|
|
|
|
|
|
def lang_embed_func(lang_tokens): |
|
|
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) |
|
|
lang_emb_dim = lang_emb.shape[-1] |
|
|
return lang_emb * math.sqrt(lang_emb_dim) |
|
|
|
|
|
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) |
|
|
|
|
|
embs.append(lang_emb) |
|
|
pad_masks.append(lang_masks) |
|
|
|
|
|
|
|
|
num_lang_embs = lang_emb.shape[1] |
|
|
att_masks += [0] * num_lang_embs |
|
|
|
|
|
embs = torch.cat(embs, dim=1) |
|
|
pad_masks = torch.cat(pad_masks, dim=1) |
|
|
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) |
|
|
|
|
|
|
|
|
bsize = pad_masks.shape[0] |
|
|
att_masks = att_masks[None, :].expand(bsize, len(att_masks)) |
|
|
|
|
|
return embs, pad_masks, att_masks |
|
|
|
|
|
def embed_suffix(self, state, noisy_actions, timestep): |
|
|
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" |
|
|
embs = [] |
|
|
pad_masks = [] |
|
|
att_masks = [] |
|
|
|
|
|
if not self.pi05: |
|
|
if self.state_proj.weight.dtype == torch.float32: |
|
|
state = state.to(torch.float32) |
|
|
|
|
|
|
|
|
def state_proj_func(state): |
|
|
return self.state_proj(state) |
|
|
|
|
|
state_emb = self._apply_checkpoint(state_proj_func, state) |
|
|
|
|
|
embs.append(state_emb[:, None, :]) |
|
|
bsize = state_emb.shape[0] |
|
|
device = state_emb.device |
|
|
|
|
|
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device) |
|
|
pad_masks.append(state_mask) |
|
|
|
|
|
|
|
|
att_masks += [1] |
|
|
|
|
|
|
|
|
time_emb = create_sinusoidal_pos_embedding( |
|
|
timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0, device=timestep.device |
|
|
) |
|
|
time_emb = time_emb.type(dtype=timestep.dtype) |
|
|
|
|
|
|
|
|
def action_proj_func(noisy_actions): |
|
|
return self.action_in_proj(noisy_actions) |
|
|
|
|
|
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) |
|
|
|
|
|
if not self.pi05: |
|
|
time_emb = time_emb[:, None, :].expand_as(action_emb) |
|
|
action_time_emb = torch.cat([action_emb, time_emb], dim=2) |
|
|
|
|
|
|
|
|
def mlp_func(action_time_emb): |
|
|
x = self.action_time_mlp_in(action_time_emb) |
|
|
x = F.silu(x) |
|
|
return self.action_time_mlp_out(x) |
|
|
|
|
|
action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb) |
|
|
adarms_cond = None |
|
|
else: |
|
|
|
|
|
def time_mlp_func(time_emb): |
|
|
x = self.time_mlp_in(time_emb) |
|
|
x = F.silu(x) |
|
|
x = self.time_mlp_out(x) |
|
|
return F.silu(x) |
|
|
|
|
|
time_emb = self._apply_checkpoint(time_mlp_func, time_emb) |
|
|
action_time_emb = action_emb |
|
|
adarms_cond = time_emb |
|
|
|
|
|
|
|
|
embs.append(action_time_emb) |
|
|
|
|
|
bsize, action_time_dim = action_time_emb.shape[:2] |
|
|
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) |
|
|
pad_masks.append(action_time_mask) |
|
|
|
|
|
|
|
|
att_masks += [1] + ([0] * (self.config.action_horizon - 1)) |
|
|
|
|
|
embs = torch.cat(embs, dim=1) |
|
|
pad_masks = torch.cat(pad_masks, dim=1) |
|
|
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) |
|
|
att_masks = att_masks[None, :].expand(bsize, len(att_masks)) |
|
|
|
|
|
return embs, pad_masks, att_masks, adarms_cond |
|
|
|
|
|
def forward(self, observation, actions, noise=None, time=None) -> Tensor: |
|
|
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" |
|
|
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=True) |
|
|
|
|
|
if noise is None: |
|
|
noise = self.sample_noise(actions.shape, actions.device) |
|
|
|
|
|
if time is None: |
|
|
time = self.sample_time(actions.shape[0], actions.device) |
|
|
|
|
|
time_expanded = time[:, None, None] |
|
|
x_t = time_expanded * noise + (1 - time_expanded) * actions |
|
|
u_t = noise - actions |
|
|
|
|
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) |
|
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) |
|
|
if ( |
|
|
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype |
|
|
== torch.bfloat16 |
|
|
): |
|
|
suffix_embs = suffix_embs.to(dtype=torch.bfloat16) |
|
|
prefix_embs = prefix_embs.to(dtype=torch.bfloat16) |
|
|
|
|
|
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) |
|
|
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) |
|
|
|
|
|
att_2d_masks = make_att_2d_masks(pad_masks, att_masks) |
|
|
position_ids = torch.cumsum(pad_masks, dim=1) - 1 |
|
|
|
|
|
|
|
|
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) |
|
|
|
|
|
|
|
|
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): |
|
|
(_, suffix_out), _ = self.paligemma_with_expert.forward( |
|
|
attention_mask=att_2d_masks_4d, |
|
|
position_ids=position_ids, |
|
|
past_key_values=None, |
|
|
inputs_embeds=[prefix_embs, suffix_embs], |
|
|
use_cache=False, |
|
|
adarms_cond=[None, adarms_cond], |
|
|
) |
|
|
return suffix_out |
|
|
|
|
|
suffix_out = self._apply_checkpoint( |
|
|
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond |
|
|
) |
|
|
|
|
|
suffix_out = suffix_out[:, -self.config.action_horizon :] |
|
|
suffix_out = suffix_out.to(dtype=torch.float32) |
|
|
|
|
|
|
|
|
def action_out_proj_func(suffix_out): |
|
|
return self.action_out_proj(suffix_out) |
|
|
|
|
|
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) |
|
|
|
|
|
return F.mse_loss(u_t, v_t, reduction="none") |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor: |
|
|
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" |
|
|
bsize = observation.state.shape[0] |
|
|
if noise is None: |
|
|
actions_shape = (bsize, self.config.action_horizon, self.config.action_dim) |
|
|
noise = self.sample_noise(actions_shape, device) |
|
|
|
|
|
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False) |
|
|
|
|
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) |
|
|
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) |
|
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 |
|
|
|
|
|
|
|
|
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) |
|
|
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" |
|
|
|
|
|
_, past_key_values = self.paligemma_with_expert.forward( |
|
|
attention_mask=prefix_att_2d_masks_4d, |
|
|
position_ids=prefix_position_ids, |
|
|
past_key_values=None, |
|
|
inputs_embeds=[prefix_embs, None], |
|
|
use_cache=True, |
|
|
) |
|
|
|
|
|
dt = -1.0 / num_steps |
|
|
dt = torch.tensor(dt, dtype=torch.float32, device=device) |
|
|
|
|
|
x_t = noise |
|
|
time = torch.tensor(1.0, dtype=torch.float32, device=device) |
|
|
while time >= -dt / 2: |
|
|
expanded_time = time.expand(bsize) |
|
|
v_t = self.denoise_step( |
|
|
state, |
|
|
prefix_pad_masks, |
|
|
past_key_values, |
|
|
x_t, |
|
|
expanded_time, |
|
|
) |
|
|
|
|
|
|
|
|
x_t = x_t + dt * v_t |
|
|
time += dt |
|
|
return x_t |
|
|
|
|
|
def denoise_step( |
|
|
self, |
|
|
state, |
|
|
prefix_pad_masks, |
|
|
past_key_values, |
|
|
x_t, |
|
|
timestep, |
|
|
): |
|
|
"""Apply one denoising step of the noise `x_t` at a given timestep.""" |
|
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep) |
|
|
|
|
|
suffix_len = suffix_pad_masks.shape[1] |
|
|
batch_size = prefix_pad_masks.shape[0] |
|
|
prefix_len = prefix_pad_masks.shape[1] |
|
|
|
|
|
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) |
|
|
|
|
|
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) |
|
|
|
|
|
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) |
|
|
|
|
|
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] |
|
|
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 |
|
|
|
|
|
|
|
|
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) |
|
|
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" |
|
|
|
|
|
outputs_embeds, _ = self.paligemma_with_expert.forward( |
|
|
attention_mask=full_att_2d_masks_4d, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=[None, suffix_embs], |
|
|
use_cache=False, |
|
|
adarms_cond=[None, adarms_cond], |
|
|
) |
|
|
|
|
|
suffix_out = outputs_embeds[1] |
|
|
suffix_out = suffix_out[:, -self.config.action_horizon :] |
|
|
suffix_out = suffix_out.to(dtype=torch.float32) |
|
|
return self.action_out_proj(suffix_out) |
|
|
|