File size: 3,913 Bytes
1327f34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for encoder_decoder_model.py."""

from absl.testing import absltest
from flax import jax_utils
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
from scenic.model_lib.base_models import encoder_decoder_model

VOCAB_SIZE = 4_000
TARGET_LENGTH = 32
BATCH_SIZE = 4


class FakeEncoderDecoderModel(encoder_decoder_model.EncoderDecoderModel):
  """A dummy encoder-decoder model for testing purposes."""

  def __init__(self):
    dataset_meta_data = {'num_classes': VOCAB_SIZE, 'target_is_onehot': False}
    super().__init__(
        ml_collections.ConfigDict(),  # An empty config dict.
        dataset_meta_data)

  def build_flax_model(self):
    pass

  def default_flax_model_config(self):
    pass


def get_fake_batch_output():
  """Generates a fake `batch`.

  Returns:
    `batch`: Dictionary of None inputs and fake ground truth targets.
        outputs_noaux.pop('aux_outputs')
    `output`: Dictionary of a fake output logits.
  """
  batch = {
      'inputs':
          None,
      'label':
          jnp.array(
              np.random.randint(VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH))),
  }
  output = np.random.random(size=(BATCH_SIZE, TARGET_LENGTH, VOCAB_SIZE))
  return batch, output


class TestEncoderDecoderModel(absltest.TestCase):
  """Tests for the EncoderDecoderModel."""

  def is_valid(self, t, value_name):
    """Helper function to assert that tensor `t` does not have `nan`, `inf`."""
    self.assertFalse(
        jnp.isnan(t).any(), msg=f'Found nan\'s in {t} for {value_name}')
    self.assertFalse(
        jnp.isinf(t).any(), msg=f'Found inf\'s in {t} for {value_name}')

  def test_loss_function(self):
    """Tests loss_function by checking its output's validity."""
    model = FakeEncoderDecoderModel()
    batch, output = get_fake_batch_output()
    batch_replicated, outputs_replicated = (jax_utils.replicate(batch),
                                            jax_utils.replicate(output))

    # Test loss function in the pmapped setup:
    loss_function_pmapped = jax.pmap(model.loss_function, axis_name='batch')
    total_loss = loss_function_pmapped(outputs_replicated, batch_replicated)
    # Check that loss is returning valid values:
    self.is_valid(jax_utils.unreplicate(total_loss), value_name='loss')

  def test_metric_function(self):
    """Tests metric_function by checking its output's format and validity."""
    model = FakeEncoderDecoderModel()
    batch, output = get_fake_batch_output()
    batch_replicated, outputs_replicated = (jax_utils.replicate(batch),
                                            jax_utils.replicate(output))

    # Test metric function in the pmapped setup
    metrics_fn_pmapped = jax.pmap(model.get_metrics_fn(), axis_name='batch')
    all_metrics = metrics_fn_pmapped(outputs_replicated, batch_replicated)
    # Check expected metrics exist in the output:
    expected_metrics_keys = ['accuracy', 'loss', 'perplexity']
    self.assertSameElements(expected_metrics_keys, all_metrics.keys())

    # For each metric, check that it is a valid value.
    all_metrics = jax_utils.unreplicate(all_metrics)
    for k, v in all_metrics.items():
      self.is_valid(v[0], value_name=f'numerator of {k}')
      self.is_valid(v[1], value_name=f'denominator of {k}')


if __name__ == '__main__':
  absltest.main()