File size: 2,966 Bytes
40571aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
from flax import nnx
import jax
import pytest
from openpi.models import model as _model
from openpi.models import pi0_config
from openpi.models import pi0_fast
from openpi.shared import download
from openpi.shared import nnx_utils
def test_pi0_model():
key = jax.random.key(0)
config = pi0_config.Pi0Config()
model = config.create(key)
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
assert loss.shape == (batch_size, config.action_horizon)
actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
def test_pi0_lora_model():
key = jax.random.key(0)
config = pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora")
model = config.create(key)
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
assert loss.shape == (batch_size, config.action_horizon)
actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
def test_pi0_fast_model():
key = jax.random.key(0)
config = pi0_fast.Pi0FASTConfig()
model = config.create(key)
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
assert loss.shape == (batch_size,)
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
assert actions.shape == (batch_size, 256)
def test_pi0_fast_lora_model():
key = jax.random.key(0)
config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora")
model = config.create(key)
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
assert loss.shape == (batch_size,)
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
assert actions.shape == (batch_size, 256)
lora_filter = nnx_utils.PathRegex(".*lora.*")
model_state = nnx.state(model)
lora_state_elems = list(model_state.filter(lora_filter))
assert len(lora_state_elems) > 0
@pytest.mark.manual
def test_model_restore():
key = jax.random.key(0)
config = pi0_config.Pi0Config()
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
model = config.load(
_model.restore_params(download.maybe_download("gs://openpi-assets/checkpoints/pi0_base/params"))
)
loss = model.compute_loss(key, obs, act)
assert loss.shape == (batch_size, config.action_horizon)
actions = model.sample_actions(key, obs, num_steps=10)
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
|