File size: 1,631 Bytes
1be5b40 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import flax.nnx as nnx
import jax
import openpi.models.pi0_config as _pi0_config
def _get_frozen_state(config: _pi0_config.Pi0Config) -> nnx.State:
abstract_model = nnx.eval_shape(config.create, jax.random.key(0))
freeze_filter = config.get_freeze_filter()
return nnx.state(abstract_model, nnx.All(nnx.Param, freeze_filter)).flat_state()
def test_pi0_full_finetune():
config = _pi0_config.Pi0Config()
state = _get_frozen_state(config)
assert len(state) == 0
def test_pi0_gemma_lora():
config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora")
state = _get_frozen_state(config)
assert len(state) == 9
assert all("lora" not in p for p in state)
assert all("llm" in p for p in state)
assert all("_1" not in p for p in state)
def test_pi0_action_expert_lora():
config = _pi0_config.Pi0Config(action_expert_variant="gemma_300m_lora")
state = _get_frozen_state(config)
# excluding embedder, rest of the params should be same as gemma_lora.
assert len(state) == 8
assert all("lora" not in p for p in state)
assert all("llm" in p for p in state)
# all frozen params should have _1 in their path since it's the action expert.
assert all(any("_1" in p for p in path) for path in state)
def test_pi0_all_lora():
config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora")
state = _get_frozen_state(config)
# sum of gemma_lora and action_expert_lora's frozen params.
assert len(state) == 17
assert all("lora" not in p for p in state)
assert all("llm" in p for p in state)
|