File size: 2,874 Bytes
fc0f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `update.py`."""

from absl.testing import absltest

import chex
import jax
import jax.numpy as jnp

from optax._src import update


class UpdateTest(chex.TestCase):

  @chex.all_variants
  def test_apply_updates(self):
    params = ({'a': jnp.ones((3, 2))}, jnp.ones((1,)))
    grads = jax.tree_util.tree_map(lambda t: 2 * t, params)
    exp_params = jax.tree_util.tree_map(lambda t: 3 * t, params)
    new_params = self.variant(update.apply_updates)(params, grads)

    chex.assert_trees_all_close(
        exp_params, new_params, atol=1e-10, rtol=1e-5)

  @chex.all_variants
  def test_apply_updates_mixed_precision(self):
    params = (
        {'a': jnp.ones((3, 2), dtype=jnp.bfloat16)},
        jnp.ones((1,), dtype=jnp.bfloat16))
    grads = jax.tree_util.tree_map(
        lambda t: (2 * t).astype(jnp.float32), params)
    new_params = self.variant(update.apply_updates)(params, grads)

    for leaf in jax.tree_util.tree_leaves(new_params):
      assert leaf.dtype == jnp.bfloat16

  @chex.all_variants
  def test_incremental_update(self):
    params_1 = ({'a': jnp.ones((3, 2))}, jnp.ones((1,)))
    params_2 = jax.tree_util.tree_map(lambda t: 2 * t, params_1)
    exp_params = jax.tree_util.tree_map(lambda t: 1.5 * t, params_1)
    new_params = self.variant(
        update.incremental_update)(params_2, params_1, 0.5)

    chex.assert_trees_all_close(
        exp_params, new_params, atol=1e-10, rtol=1e-5)

  @chex.all_variants
  def test_periodic_update(self):
    params_1 = ({'a': jnp.ones((3, 2))}, jnp.ones((1,)))
    params_2 = jax.tree_util.tree_map(lambda t: 2 * t, params_1)

    update_period = 5
    update_fn = self.variant(update.periodic_update)

    for j in range(3):
      for i in range(1, update_period):
        new_params = update_fn(
            params_2, params_1, j*update_period+i, update_period)
        chex.assert_trees_all_close(
            params_1, new_params, atol=1e-10, rtol=1e-5)

      new_params = update_fn(
          params_2, params_1, (j+1)*update_period, update_period)
      chex.assert_trees_all_close(
          params_2, new_params, atol=1e-10, rtol=1e-5)


if __name__ == '__main__':
  absltest.main()