| | from flax import nnx |
| | import jax |
| | import pytest |
| |
|
| | from openpi.models import model as _model |
| | from openpi.models import pi0 |
| | from openpi.models import pi0_fast |
| | from openpi.shared import download |
| | from openpi.shared import nnx_utils |
| |
|
| |
|
| | def test_pi0_model(): |
| | key = jax.random.key(0) |
| | config = pi0.Pi0Config() |
| | model = config.create(key) |
| |
|
| | batch_size = 2 |
| | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) |
| |
|
| | loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) |
| | assert loss.shape == (batch_size, config.action_horizon) |
| |
|
| | actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10) |
| | assert actions.shape == (batch_size, model.action_horizon, model.action_dim) |
| |
|
| |
|
| | def test_pi0_lora_model(): |
| | key = jax.random.key(0) |
| | config = pi0.Pi0Config(paligemma_variant="gemma_2b_lora") |
| | model = config.create(key) |
| |
|
| | batch_size = 2 |
| | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) |
| |
|
| | loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) |
| | assert loss.shape == (batch_size, config.action_horizon) |
| |
|
| | actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10) |
| | assert actions.shape == (batch_size, model.action_horizon, model.action_dim) |
| |
|
| |
|
| | def test_pi0_fast_model(): |
| | key = jax.random.key(0) |
| | config = pi0_fast.Pi0FASTConfig() |
| | model = config.create(key) |
| |
|
| | batch_size = 2 |
| | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) |
| |
|
| | loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) |
| | assert loss.shape == (batch_size,) |
| |
|
| | actions = nnx_utils.module_jit(model.sample_actions)(key, obs) |
| | assert actions.shape == (batch_size, 256) |
| |
|
| |
|
| | def test_pi0_fast_lora_model(): |
| | key = jax.random.key(0) |
| | config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora") |
| | model = config.create(key) |
| |
|
| | batch_size = 2 |
| | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) |
| |
|
| | loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) |
| | assert loss.shape == (batch_size,) |
| |
|
| | actions = nnx_utils.module_jit(model.sample_actions)(key, obs) |
| | assert actions.shape == (batch_size, 256) |
| |
|
| | lora_filter = nnx_utils.PathRegex(".*lora.*") |
| | model_state = nnx.state(model) |
| |
|
| | lora_state_elems = list(model_state.filter(lora_filter)) |
| | assert len(lora_state_elems) > 0 |
| |
|
| |
|
| | @pytest.mark.manual |
| | def test_model_restore(): |
| | key = jax.random.key(0) |
| | config = pi0.Pi0Config() |
| |
|
| | batch_size = 2 |
| | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) |
| |
|
| | model = config.load( |
| | _model.restore_params(download.maybe_download("s3://openpi-assets/checkpoints/pi0_base/params")) |
| | ) |
| |
|
| | loss = model.compute_loss(key, obs, act) |
| | assert loss.shape == (batch_size, config.action_horizon) |
| |
|
| | actions = model.sample_actions(key, obs, num_steps=10) |
| | assert actions.shape == (batch_size, model.action_horizon, model.action_dim) |
| |
|