File size: 3,138 Bytes
40571aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import flax.linen as nn
import jax
import jax.numpy as jnp
import openpi.models.lora as lora
def test_lora_einsum_params_shape():
shape = (3, 8, 32, 4) # (3KDH)
einsum = lora.Einsum(shape)
lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2))
lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2)))
key = jax.random.key(0)
x = jax.random.normal(key, (8, 64, 32)) # (BSD)
eqn = "BSD,3KDH->3BSKH"
# Ensure that lora parameters are not initialized when LoRA is not used.
params = einsum.init(key, eqn, x)
assert "lora_a" not in params["params"]
assert "lora_b" not in params["params"]
# Check that default axes work.
params_lora0 = lora0.init(key, eqn, x)
assert params_lora0["params"]["lora_a"].shape == (3, 8, 32, 2)
assert params_lora0["params"]["lora_b"].shape == (3, 8, 2, 4)
# Check that user provided axes work.
params_lora1 = lora1.init(key, eqn, x)
assert params_lora1["params"]["lora_a"].shape == (3, 8, 2, 4)
assert params_lora1["params"]["lora_b"].shape == (3, 2, 32, 4)
def test_lora_einsum_same_output():
shape = (3, 8, 32, 4) # (3KDH)
einsum = lora.Einsum(shape)
einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros))
key = jax.random.key(0)
x = jax.random.normal(key, (8, 64, 32)) # (BSD)
eqn = "BSD,3KDH->3BSKH"
params = einsum.init(key, eqn, x)
output = einsum.apply(params, eqn, x)
params_lora = einsum_lora.init(key, eqn, x)
output_lora = einsum_lora.apply(params_lora, eqn, x)
# Results are the same since the LoRA parameters are initialized to zeros.
assert jnp.allclose(output, output_lora)
def test_lora_ffn_params_shape():
ffn = lora.FeedForward(features=8, hidden_dim=32)
ffn_lora = lora.FeedForward(
features=8,
hidden_dim=32,
lora_config=lora.LoRAConfig(rank=2),
)
key = jax.random.key(0)
x = jax.random.normal(key, (2, 8))
params = ffn.init(key, x)
assert params["params"]["gating_einsum"].shape == (2, 8, 32)
assert params["params"]["linear"].shape == (32, 8)
params_lora = ffn_lora.init(key, x)
assert params_lora["params"]["gating_einsum"].shape == (2, 8, 32)
assert params_lora["params"]["linear"].shape == (32, 8)
assert params_lora["params"]["gating_einsum_lora_a"].shape == (2, 8, 2)
assert params_lora["params"]["gating_einsum_lora_b"].shape == (2, 2, 32)
assert params_lora["params"]["linear_lora_a"].shape == (32, 2)
assert params_lora["params"]["linear_lora_b"].shape == (2, 8)
def test_lora_ffn_same_output():
ffn = lora.FeedForward(features=8, hidden_dim=32)
ffn_lora = lora.FeedForward(
features=8,
hidden_dim=32,
lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros),
)
key = jax.random.key(0)
x = jax.random.normal(key, (2, 8))
params = ffn.init(key, x)
output = ffn.apply(params, x)
params_lora = ffn_lora.init(key, x)
output_lora = ffn_lora.apply(params_lora, x)
assert jnp.allclose(output, output_lora)
|