|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Tests for VIT."""
|
|
|
| import math
|
|
|
| from absl.testing import parameterized
|
| import tensorflow as tf, tf_keras
|
|
|
| from official.vision.modeling.backbones import vit
|
|
|
|
|
| class VisionTransformerTest(parameterized.TestCase, tf.test.TestCase):
|
|
|
| @parameterized.parameters(
|
| (224, 85798656),
|
| (256, 85844736),
|
| )
|
| def test_network_creation(self, input_size, params_count):
|
| """Test creation of VisionTransformer family models."""
|
| tf_keras.backend.set_image_data_format('channels_last')
|
| input_specs = tf_keras.layers.InputSpec(
|
| shape=[2, input_size, input_size, 3])
|
| network = vit.VisionTransformer(input_specs=input_specs)
|
|
|
| inputs = tf_keras.Input(shape=(input_size, input_size, 3), batch_size=1)
|
| _ = network(inputs)
|
| self.assertEqual(network.count_params(), params_count)
|
|
|
| @parameterized.product(
|
| patch_size=[6, 4],
|
| output_2d_feature_maps=[True, False],
|
| pooler=['none', 'gap', 'token'],
|
| )
|
| def test_network_with_diferent_configs(
|
| self, patch_size, output_2d_feature_maps, pooler):
|
| tf_keras.backend.set_image_data_format('channels_last')
|
| input_size = 24
|
| expected_feat_level = str(round(math.log2(patch_size)))
|
| num_patch_rows = input_size // patch_size
|
| input_specs = tf_keras.layers.InputSpec(
|
| shape=[2, input_size, input_size, 3])
|
| network = vit.VisionTransformer(
|
| input_specs=input_specs,
|
| patch_size=patch_size,
|
| pooler=pooler,
|
| hidden_size=8,
|
| mlp_dim=8,
|
| num_layers=1,
|
| num_heads=2,
|
| representation_size=16,
|
| output_2d_feature_maps=output_2d_feature_maps)
|
|
|
| inputs = tf_keras.Input(shape=(input_size, input_size, 3), batch_size=1)
|
| output = network(inputs)
|
| if pooler == 'none':
|
| self.assertEqual(
|
| output['encoded_tokens'].shape, [1, num_patch_rows**2, 16])
|
| else:
|
| self.assertEqual(output['pre_logits'].shape, [1, 1, 1, 16])
|
|
|
| if output_2d_feature_maps:
|
| self.assertIn(expected_feat_level, output)
|
| self.assertIn(expected_feat_level, network.output_specs)
|
| self.assertEqual(
|
| network.output_specs[expected_feat_level][1:],
|
| [num_patch_rows, num_patch_rows, 8])
|
| else:
|
| self.assertNotIn(expected_feat_level, output)
|
|
|
| def test_posembedding_interpolation(self):
|
| tf_keras.backend.set_image_data_format('channels_last')
|
| input_size = 256
|
| input_specs = tf_keras.layers.InputSpec(
|
| shape=[2, input_size, input_size, 3])
|
| network = vit.VisionTransformer(
|
| input_specs=input_specs,
|
| patch_size=16,
|
| pooler='gap',
|
| pos_embed_shape=(14, 14))
|
|
|
| inputs = tf_keras.Input(shape=(input_size, input_size, 3), batch_size=1)
|
| output = network(inputs)['pre_logits']
|
| self.assertEqual(output.shape, [1, 1, 1, 768])
|
|
|
|
|
| if __name__ == '__main__':
|
| tf.test.main()
|
|
|