| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Tests for optax.transforms._adding.""" |
|
|
| from absl.testing import absltest |
|
|
| import chex |
| from jax import tree_util as jtu |
| import jax.numpy as jnp |
|
|
| from optax.transforms import _adding |
|
|
| STEPS = 50 |
|
|
|
|
| class AddingTest(chex.TestCase): |
|
|
| def setUp(self): |
| super().setUp() |
| self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) |
| self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) |
|
|
| @chex.all_variants |
| def test_add_decayed_weights(self): |
| |
| |
| |
| mask = (True, dict(a=True, b=False)) |
| tx = _adding.add_decayed_weights(0.1, mask=mask) |
| |
| updates = ( |
| jnp.zeros((2,), dtype=jnp.float32), |
| dict( |
| a=jnp.zeros((2,), dtype=jnp.float32), |
| b=jnp.zeros((2,), dtype=jnp.float32),)) |
| weights = ( |
| jnp.ones((2,), dtype=jnp.float32), |
| dict( |
| a=jnp.ones((2,), dtype=jnp.float32), |
| b=jnp.ones((2,), dtype=jnp.float32),)) |
| |
| |
| expected_tx_updates = ( |
| 0.1*jnp.ones((2,), dtype=jnp.float32), |
| dict( |
| a=0.1*jnp.ones((2,), dtype=jnp.float32), |
| b=jnp.zeros((2,), dtype=jnp.float32),)) |
| |
| state = tx.init(weights) |
| transform_fn = self.variant(tx.update) |
| new_updates, _ = transform_fn(updates, state, weights) |
| |
| chex.assert_trees_all_close(new_updates, expected_tx_updates) |
|
|
| @chex.all_variants |
| def test_add_noise_has_correct_variance_scaling(self): |
| |
| eta = 0.3 |
| gamma = 0.55 |
| seed = 314 |
| noise = _adding.add_noise(eta, gamma, seed) |
| noise_unit = _adding.add_noise(1.0, 0.0, seed) |
|
|
| params = self.init_params |
| state = noise.init(params) |
| state_unit = noise_unit.init(params) |
|
|
| |
| updates = jtu.tree_map(jnp.zeros_like, params) |
|
|
| for i in range(1, STEPS + 1): |
| updates_i, state = self.variant(noise.update)(updates, state) |
| updates_i_unit, state_unit = noise_unit.update(updates, state_unit) |
|
|
| scale = jnp.sqrt(eta / i**gamma) |
|
|
| updates_i_rescaled = jtu.tree_map( |
| lambda g, s=scale: g * s, updates_i_unit) |
|
|
| chex.assert_trees_all_close(updates_i, updates_i_rescaled, rtol=1e-4) |
|
|
|
|
| if __name__ == "__main__": |
| absltest.main() |
|
|