|
|
import flax.nnx as nnx |
|
|
import jax |
|
|
|
|
|
import openpi.models.pi0_config as _pi0_config |
|
|
|
|
|
|
|
|
def _get_frozen_state(config: _pi0_config.Pi0Config) -> nnx.State: |
|
|
abstract_model = nnx.eval_shape(config.create, jax.random.key(0)) |
|
|
|
|
|
freeze_filter = config.get_freeze_filter() |
|
|
return nnx.state(abstract_model, nnx.All(nnx.Param, freeze_filter)).flat_state() |
|
|
|
|
|
|
|
|
def test_pi0_full_finetune(): |
|
|
config = _pi0_config.Pi0Config() |
|
|
state = _get_frozen_state(config) |
|
|
assert len(state) == 0 |
|
|
|
|
|
|
|
|
def test_pi0_gemma_lora(): |
|
|
config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora") |
|
|
state = _get_frozen_state(config) |
|
|
assert len(state) == 9 |
|
|
assert all("lora" not in p for p in state) |
|
|
assert all("llm" in p for p in state) |
|
|
assert all("_1" not in p for p in state) |
|
|
|
|
|
|
|
|
def test_pi0_action_expert_lora(): |
|
|
config = _pi0_config.Pi0Config(action_expert_variant="gemma_300m_lora") |
|
|
state = _get_frozen_state(config) |
|
|
|
|
|
assert len(state) == 8 |
|
|
assert all("lora" not in p for p in state) |
|
|
assert all("llm" in p for p in state) |
|
|
|
|
|
assert all(any("_1" in p for p in path) for path in state) |
|
|
|
|
|
|
|
|
def test_pi0_all_lora(): |
|
|
config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora") |
|
|
state = _get_frozen_state(config) |
|
|
|
|
|
assert len(state) == 17 |
|
|
assert all("lora" not in p for p in state) |
|
|
assert all("llm" in p for p in state) |
|
|
|