File size: 10,639 Bytes
09d8e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for optax.transforms._conditionality."""

from absl.testing import absltest
from absl.testing import parameterized

import chex
import jax
import jax.numpy as jnp

from optax._src import alias
from optax._src import base
from optax._src import combine
from optax._src import transform
from optax._src import update
from optax.transforms import _conditionality
from optax.transforms import _constraining


def _build_sgd():
  return alias.sgd(1.)


def _build_stateful_sgd():
  # This SGD behaves like _build_sgd but also tests the optimizer state. The
  # momentum is set to zero rather than None so that the momentum terms are
  # calculated, but do not change the results.
  return alias.sgd(1., momentum=0.)


def _build_sgd_extra_args():

  def init_fn(params):
    del params
    return {'foo': 1}

  def update_fn(grads, state, params=None, *, foo=None, **extra_args):
    del extra_args, foo, params
    return grads, state

  t = base.GradientTransformationExtraArgs(init_fn, update_fn)
  return combine.chain(_build_sgd(), t)


class ConditionalityTest(parameterized.TestCase):

  @chex.variants(with_jit=True, without_jit=True, with_pmap=True)
  @parameterized.named_parameters(
      ('sgd', _build_sgd),
      ('stateful_sgd', _build_stateful_sgd),
      ('sgd_extra_args', _build_sgd_extra_args),
  )
  def test_apply_if_finite(self, opt_builder):
    one = jnp.array(1.)
    nan = jnp.array(jnp.nan)

    def fn(p, x):
      return p * x

    params = jnp.array(0.)
    opt = _conditionality.apply_if_finite(opt_builder(), 2)
    state = opt.init(params)
    grads_fn = jax.grad(self.variant(fn))
    # Do one successful param update
    grads = grads_fn(params, one)
    updates, state = opt.update(grads, state, params)
    params = update.apply_updates(params, updates)
    # We know exactly what should be the value of params since we are
    # effectively using sgd in all cases.
    self.assertEqual(-1., float(jax.tree_util.tree_flatten(params)[0][0]))
    self.assertTrue(bool(getattr(state, 'last_finite')))
    # Check 2 rejected param updates
    for step in range(2):
      grads = grads_fn(params, nan)
      updates, state = opt.update(grads, state, params)
      params = update.apply_updates(params, updates)
      self.assertEqual(-1., float(jax.tree_util.tree_flatten(params)[0][0]))
      self.assertFalse(bool(getattr(state, 'last_finite')))
      self.assertEqual(step + 1, int(getattr(state, 'notfinite_count')))
    # Next successful param update
    grads = grads_fn(params, one)
    updates, state = opt.update(grads, state, params)
    params = update.apply_updates(params, updates)
    self.assertEqual(-2., float(jax.tree_util.tree_flatten(params)[0][0]))
    self.assertTrue(bool(getattr(state, 'last_finite')))
    # Again 2 rejected param updates
    for step in range(2):
      grads = grads_fn(params, nan)
      updates, state = opt.update(grads, state, params)
      params = update.apply_updates(params, updates)
      self.assertEqual(-2., float(jax.tree_util.tree_flatten(params)[0][0]))
      self.assertFalse(bool(getattr(state, 'last_finite')))
      self.assertEqual(step + 1, int(getattr(state, 'notfinite_count')))
    # Next param update with NaN is accepted since we reached maximum
    grads = grads_fn(params, nan)
    updates, state = opt.update(grads, state, params)
    params = update.apply_updates(params, updates)
    self.assertTrue(bool(jnp.isnan(jax.tree_util.tree_flatten(params)[0][0])))
    self.assertEqual(5, int(getattr(state, 'total_notfinite')))

  def test_apply_if_finite_pmap(self):
    # Unlike in `test_apply_if_finite`:
    # * pmap is applied to the gradient computation and the optimisation;
    # * the NaNs are caused inside the function and do not come from the inputs.
    half = jnp.ones([1]) / 2.
    two = jnp.ones([1]) * 2.  # Causes a NaN in arctanh
    def fn(p, x):
      return jnp.arctanh(x) * p

    opt = _conditionality.apply_if_finite(alias.sgd(1.), 2)
    def fn_update(params, opt_state, x):
      grads = jax.grad(fn)(params, x)
      grads = jax.lax.psum(grads, axis_name='i')
      updates, new_opt_state = opt.update(grads, opt_state, params)
      new_params = update.apply_updates(params, updates)
      return new_params, new_opt_state
    fn_update = jax.pmap(fn_update, axis_name='i')

    params = jnp.array(0.)
    opt_state = opt.init(params)
    params = jax.tree_util.tree_map(lambda x: x[None], params)
    opt_state = jax.tree_util.tree_map(lambda x: x[None], opt_state)
    # Do one successful param update
    params, opt_state = fn_update(params, opt_state, half)
    self.assertTrue(bool(opt_state.last_finite))
    # Check 2 rejected param updates
    for step in range(2):
      params, opt_state = fn_update(params, opt_state, two)
      self.assertFalse(bool(opt_state.last_finite))
      self.assertEqual(step + 1, opt_state.notfinite_count.item())
    # Next successful param update
    params, opt_state = fn_update(params, opt_state, half)
    self.assertTrue(bool(opt_state.last_finite))
    # Again 2 rejected param updates
    for step in range(2):
      params, opt_state = fn_update(params, opt_state, two)
      self.assertFalse(bool(opt_state.last_finite))
      self.assertEqual(step + 1, opt_state.notfinite_count.item())
    # Next param update with NaN is accepted since we reached maximum
    _, opt_state = fn_update(params, opt_state, two)
    self.assertEqual(5, opt_state.total_notfinite.item())


class ConditionallyTransformTest(chex.TestCase):
  """Tests for the conditionally_transform wrapper."""

  NUM_STEPS = 3

  @chex.all_variants
  def test_stateless_inner(self):
    params = jnp.zeros([])
    grads = jnp.ones([])

    def should_update(step):
      return step < ConditionallyTransformTest.NUM_STEPS

    opt = _conditionality.conditionally_transform(
        transform.scale(2.), should_update)
    state = opt.init(params)
    update_fn = self.variant(opt.update)
    for _ in range(ConditionallyTransformTest.NUM_STEPS):
      updates, state = update_fn(grads, state)
      self.assertEqual(updates, 2.)
    # Further updates stop calling the inner optimiser.
    for _ in range(5):
      updates, state = update_fn(grads, state)
      self.assertEqual(updates, 1.)

  @chex.all_variants
  def test_statefull_inner(self):
    params = jnp.zeros([])
    grads_with_nan = jnp.array(float('nan'))
    grads = jnp.ones([])

    def should_update(step):
      return step < ConditionallyTransformTest.NUM_STEPS

    opt = _conditionality.conditionally_transform(
        _constraining.zero_nans(), should_update)
    state = opt.init(params)
    update_fn = self.variant(opt.update)
    for _ in range(ConditionallyTransformTest.NUM_STEPS - 1):
      updates, state = update_fn(grads_with_nan, state)
      self.assertEqual(updates, 0.)
      self.assertEqual(state.inner_state.found_nan, True)
    updates, state = update_fn(grads, state)
    self.assertEqual(updates, 1.)
    self.assertEqual(state.inner_state.found_nan, False)
    # Further updates stop calling the inner optimiser.
    for _ in range(5):
      updates, state = update_fn(grads_with_nan, state)
      # Warning: do not use assertEqual with a NaN as NaN == NaN returns False.
      self.assertTrue(jnp.isnan(updates))
      # Inner state is not be updated.
      self.assertEqual(state.inner_state.found_nan, False)


class ConditionallyMaskTest(chex.TestCase):
  """Tests for the conditionally_mask wrapper."""

  NUM_STEPS = 3
  MIN_LOSS = 0.1

  @chex.all_variants
  def test_stateless_inner(self):
    params = jnp.zeros([])
    grads = jnp.ones([])

    def should_update(step):
      return step < ConditionallyMaskTest.NUM_STEPS

    opt = _conditionality.conditionally_mask(transform.scale(2.), should_update)
    state = opt.init(params)
    update_fn = self.variant(opt.update)
    for _ in range(ConditionallyMaskTest.NUM_STEPS):
      updates, state = update_fn(grads, state)
      self.assertEqual(updates, 2.)
    # Further updates stop calling the inner optimiser.
    for _ in range(5):
      updates, state = update_fn(grads, state)
      self.assertEqual(updates, 0.)

  @chex.all_variants
  def test_statefull_inner(self):
    params = jnp.zeros([])
    grads_with_nan = jnp.array(float('nan'))
    grads = jnp.ones([])

    def should_update(step):
      return step < ConditionallyMaskTest.NUM_STEPS

    opt = _conditionality.conditionally_mask(
        _constraining.zero_nans(), should_update)
    state = opt.init(params)
    update_fn = self.variant(opt.update)
    for _ in range(ConditionallyMaskTest.NUM_STEPS - 1):
      updates, state = update_fn(grads_with_nan, state)
      self.assertEqual(updates, 0.)
      self.assertEqual(state.inner_state.found_nan, True)
    updates, state = update_fn(grads, state)
    self.assertEqual(updates, 1.)
    self.assertEqual(state.inner_state.found_nan, False)
    # Further updates stop calling the inner optimiser.
    for _ in range(5):
      updates, state = update_fn(grads_with_nan, state)
      self.assertEqual(updates, 0.)
      # Inner state is not be updated.
      self.assertEqual(state.inner_state.found_nan, False)

  @chex.all_variants
  def test_stateless_inner_with_extra_args(self):
    params = jnp.zeros([])
    grads = jnp.ones([])

    def should_update(step, loss, **extra_args):
      del step, extra_args
      return loss > ConditionallyMaskTest.MIN_LOSS

    opt = _conditionality.conditionally_mask(
        transform.scale(2.), should_update, forward_extra_args=True)
    state = opt.init(params)
    update_fn = self.variant(opt.update)
    for _ in range(ConditionallyMaskTest.NUM_STEPS):
      updates, state = update_fn(grads, state, loss=0.2)
      self.assertEqual(updates, 2.)
    # Further updates stop calling the inner optimiser.
    for _ in range(5):
      updates, state = update_fn(grads, state, loss=0.)
      self.assertEqual(updates, 0.)


if __name__ == '__main__':
  absltest.main()