File size: 5,419 Bytes
1327f34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for multilabel_classification_model.py."""

from absl.testing import absltest
from flax import jax_utils
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
from scenic.model_lib.base_models import multilabel_classification_model

NUM_CLASSES = 1000
BATCH_SIZE = 4


class FakeMultiLabelClassificationModel(
    multilabel_classification_model.MultiLabelClassificationModel):
  """A dummy multi-label classification model for testing purposes."""

  def __init__(self):
    dataset_meta_data = {'num_classes': NUM_CLASSES, 'target_is_onehot': True}
    super().__init__(
        ml_collections.ConfigDict(),  # An empty config dict.
        dataset_meta_data)

  def build_flax_model(self):
    pass

  def default_flax_model_config(self):
    pass


def get_fake_batch_output(array_size=(BATCH_SIZE, NUM_CLASSES)):
  """Generates a fake `batch`.

  Args:
    array_size: size of the label and output array.

  Returns:
    `batch`: Dictionary of None inputs and fake ground truth targets.
        outputs_noaux.pop('aux_outputs')
    `output`: Dictionary of a fake output logits.
  """
  batch = {
      'inputs': None,
      'label': jnp.array(np.random.randint(2, size=array_size)),
  }
  output = jnp.array(np.random.random(size=array_size))
  return batch, output


class TestMultiLabelClassificationModel(absltest.TestCase):
  """Tests for the MultiLabelClassificationModel."""

  def is_valid(self, t, value_name):
    """Helper function to assert that tensor `t` does not have `nan`, `inf`."""
    self.assertFalse(
        jnp.isnan(t).any(), msg=f'Found nan\'s in {t} for {value_name}')
    self.assertFalse(
        jnp.isinf(t).any(), msg=f'Found inf\'s in {t} for {value_name}')

  def test_loss_function(self):
    """Tests loss_function by checking its output's validity."""
    model = FakeMultiLabelClassificationModel()
    batch, output = get_fake_batch_output()
    batch_replicated, outputs_replicated = (jax_utils.replicate(batch),
                                            jax_utils.replicate(output))

    # Test loss function in the pmapped setup:
    loss_function_pmapped = jax.pmap(model.loss_function, axis_name='batch')
    total_loss = loss_function_pmapped(outputs_replicated, batch_replicated)
    # Check that loss is returning valid values:
    self.is_valid(jax_utils.unreplicate(total_loss), value_name='loss')

  def test_loss_function_masked(self):
    """Tests a masked loss_function by comparing different canonical masks."""
    array_size = (BATCH_SIZE, 50, NUM_CLASSES)

    model = FakeMultiLabelClassificationModel()
    batch, output = get_fake_batch_output(
        array_size=array_size)

    # Unmasked loss
    loss_value_unmasked = model.loss_function(output, batch)

    # Mask with only ones (so effectively no mask).
    batch['batch_mask'] = jnp.ones((BATCH_SIZE, 50))
    loss_value_masked = model.loss_function(output, batch)

    self.assertAlmostEqual(
        float(loss_value_unmasked),
        float(loss_value_masked))

    # Extend the batch with random outputs and labels, but mask them with 0's.
    batch_extended = {
        'label': jnp.concatenate(
            (batch['label'], np.random.randint(2, size=array_size)), axis=1),
        'batch_mask': jnp.concatenate(
            (batch['batch_mask'], np.zeros((BATCH_SIZE, 50))), axis=1),
    }
    output_extended = jnp.concatenate(
        (output, np.random.random(size=array_size)), axis=1)
    loss_value_extended = model.loss_function(output_extended, batch_extended)

    # Test with `places=3` due to JAX issue: github.com/jax-ml/jax/issues/6553
    # TODO(robromijnders): follow up with JAX issue and remove `places=3`.
    self.assertAlmostEqual(
        float(loss_value_masked),
        float(loss_value_extended),
        places=3)

  def test_metric_function(self):
    """Tests metric_function by checking its output's format and validity."""
    model = FakeMultiLabelClassificationModel()
    batch, output = get_fake_batch_output()
    batch_replicated, outputs_replicated = (jax_utils.replicate(batch),
                                            jax_utils.replicate(output))

    # Test metric function in the pmapped setup
    metrics_fn_pmapped = jax.pmap(model.get_metrics_fn(), axis_name='batch')
    all_metrics = metrics_fn_pmapped(outputs_replicated, batch_replicated)
    # Check expected metrics exist in the output:
    expected_metrics_keys = ['prec@1', 'loss']
    self.assertSameElements(expected_metrics_keys, all_metrics.keys())

    # For each metric, check that it is a valid value.
    all_metrics = jax_utils.unreplicate(all_metrics)
    for k, v in all_metrics.items():
      self.is_valid(v[0], value_name=f'numerator of {k}')
      self.is_valid(v[1], value_name=f'denominator of {k}')


if __name__ == '__main__':
  absltest.main()