op-test / src /openpi /models /pi0_config.py
s3y's picture
Upload folder using huggingface_hub
40571aa verified
import dataclasses
from typing import TYPE_CHECKING
import flax.nnx as nnx
import jax
import jax.numpy as jnp
from typing_extensions import override
from openpi.models import model as _model
import openpi.models.gemma as _gemma
from openpi.shared import array_typing as at
import openpi.shared.nnx_utils as nnx_utils
if TYPE_CHECKING:
from openpi.models.pi0 import Pi0
@dataclasses.dataclass(frozen=True)
class Pi0Config(_model.BaseModelConfig):
dtype: str = "bfloat16"
paligemma_variant: _gemma.Variant = "gemma_2b"
action_expert_variant: _gemma.Variant = "gemma_300m"
# Set the model specific defaults.
action_dim: int = 32
action_horizon: int = 50
max_token_len: int = None # type: ignore
# Pi05 has two differences from Pi0:
# - the state input is part of the discrete language tokens rather than a continuous input that is part of the suffix
# - the action expert uses adaRMSNorm to inject the flow matching timestep
pi05: bool = False
# This config option is not used directly by the model, but it is read by the ModelTransformFactory.
discrete_state_input: bool = None # type: ignore
def __post_init__(self):
if self.max_token_len is None:
object.__setattr__(self, "max_token_len", 200 if self.pi05 else 48)
if self.discrete_state_input is None:
object.__setattr__(self, "discrete_state_input", self.pi05)
@property
@override
def model_type(self) -> _model.ModelType:
if self.pi05:
return _model.ModelType.PI05
return _model.ModelType.PI0
@override
def create(self, rng: at.KeyArrayLike) -> "Pi0":
from openpi.models.pi0 import Pi0
return Pi0(self, rngs=nnx.Rngs(rng))
@override
def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
with at.disable_typechecking():
observation_spec = _model.Observation(
images={
"base_0_rgb": image_spec,
"left_wrist_0_rgb": image_spec,
"right_wrist_0_rgb": image_spec,
},
image_masks={
"base_0_rgb": image_mask_spec,
"left_wrist_0_rgb": image_mask_spec,
"right_wrist_0_rgb": image_mask_spec,
},
state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
)
action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
return observation_spec, action_spec
def get_freeze_filter(self) -> nnx.filterlib.Filter:
"""Returns the freeze filter based on the model config."""
filters = []
has_lora = False
gemma_params_filter = nnx_utils.PathRegex(".*llm.*")
action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*")
if "lora" in self.paligemma_variant:
filters.append(
gemma_params_filter,
)
if "lora" not in self.action_expert_variant:
# If only freeze gemma params, exclude action expert params.
filters.append(
nnx.Not(action_expert_params_filter),
)
has_lora = True
elif "lora" in self.action_expert_variant:
filters.append(
action_expert_params_filter,
)
has_lora = True
if has_lora:
# If any lora is used, exclude all lora params.
filters.append(
nnx.Not(nnx_utils.PathRegex(".*lora.*")),
)
if not filters:
return nnx.Nothing
return nnx.All(*filters)