pi0 / src /openpi /models /pi0_test.py
s3y's picture
Upload folder using huggingface_hub
1be5b40 verified
import flax.nnx as nnx
import jax
import openpi.models.pi0_config as _pi0_config
def _get_frozen_state(config: _pi0_config.Pi0Config) -> nnx.State:
abstract_model = nnx.eval_shape(config.create, jax.random.key(0))
freeze_filter = config.get_freeze_filter()
return nnx.state(abstract_model, nnx.All(nnx.Param, freeze_filter)).flat_state()
def test_pi0_full_finetune():
config = _pi0_config.Pi0Config()
state = _get_frozen_state(config)
assert len(state) == 0
def test_pi0_gemma_lora():
config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora")
state = _get_frozen_state(config)
assert len(state) == 9
assert all("lora" not in p for p in state)
assert all("llm" in p for p in state)
assert all("_1" not in p for p in state)
def test_pi0_action_expert_lora():
config = _pi0_config.Pi0Config(action_expert_variant="gemma_300m_lora")
state = _get_frozen_state(config)
# excluding embedder, rest of the params should be same as gemma_lora.
assert len(state) == 8
assert all("lora" not in p for p in state)
assert all("llm" in p for p in state)
# all frozen params should have _1 in their path since it's the action expert.
assert all(any("_1" in p for p in path) for path in state)
def test_pi0_all_lora():
config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora")
state = _get_frozen_state(config)
# sum of gemma_lora and action_expert_lora's frozen params.
assert len(state) == 17
assert all("lora" not in p for p in state)
assert all("llm" in p for p in state)