|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import json
|
| import os
|
|
|
| from absl import flags
|
| from absl.testing import flagsaver
|
| import gin
|
| import tensorflow as tf, tf_keras
|
|
|
| from official.projects.maxvit import train as train_lib
|
| from official.vision.dataloaders import tfexample_utils
|
|
|
|
|
| FLAGS = flags.FLAGS
|
|
|
|
|
| class TrainTest(tf.test.TestCase):
|
|
|
| def setUp(self):
|
| super().setUp()
|
|
|
| self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
|
| tf.io.gfile.makedirs(self._model_dir)
|
| self._test_tfrecord_file = os.path.join(
|
| self.get_temp_dir(), 'test.tfrecord'
|
| )
|
| num_samples = 3
|
| example = tf.train.Example.FromString(
|
| tfexample_utils.create_classification_example(
|
| image_height=224, image_width=224
|
| )
|
| )
|
| examples = [example] * num_samples
|
| tfexample_utils.dump_to_tfrecord(
|
| record_file=self._test_tfrecord_file, tf_examples=examples
|
| )
|
|
|
| def test_run(self):
|
| saved_flag_values = flagsaver.save_flag_values()
|
| train_lib.tfm_flags.define_flags()
|
| FLAGS.mode = 'train'
|
| FLAGS.model_dir = self._model_dir
|
| FLAGS.experiment = 'maxvit_imagenet'
|
|
|
| params_override = json.dumps({
|
| 'runtime': {
|
| 'mixed_precision_dtype': 'float32',
|
| },
|
| 'trainer': {
|
| 'train_steps': 1,
|
| 'validation_steps': 1,
|
| 'optimizer_config': {
|
| 'ema': None,
|
| },
|
| },
|
| 'task': {
|
| 'init_checkpoint': '',
|
| 'model': {
|
| 'backbone': {
|
| 'maxvit': {
|
| 'model_name': 'maxvit-tiny-for-test',
|
| 'representation_size': 64,
|
| 'add_gap_layer_norm': True,
|
| }
|
| },
|
| 'input_size': [224, 224, 3],
|
| 'num_classes': 3,
|
| },
|
| 'train_data': {
|
| 'global_batch_size': 2,
|
| 'input_path': self._test_tfrecord_file,
|
| },
|
| 'validation_data': {
|
| 'global_batch_size': 2,
|
| 'input_path': self._test_tfrecord_file,
|
| },
|
| },
|
| })
|
| FLAGS.params_override = params_override
|
| train_lib.train.main('unused_args')
|
|
|
| FLAGS.mode = 'eval'
|
| with gin.unlock_config():
|
| train_lib.train.main('unused_args')
|
|
|
| flagsaver.restore_flag_values(saved_flag_values)
|
|
|
|
|
| if __name__ == '__main__':
|
| tf.test.main()
|
|
|