|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Tests for tensorflow_models.core.trainers.trainer."""
|
|
|
| import gc
|
| import multiprocessing
|
| import os
|
| import sys
|
|
|
| from absl.testing import parameterized
|
| import orbit
|
| import portpicker
|
| import tensorflow as tf, tf_keras
|
|
|
| from tensorflow.python.distribute import combinations
|
| from tensorflow.python.distribute import strategy_combinations
|
| from official.core import base_trainer as trainer_lib
|
| from official.core import config_definitions as cfg
|
| from official.core import train_lib
|
| from official.utils.testing import mock_task
|
|
|
| TPU_TEST = 'test_tpu' in sys.argv[0]
|
| GPU_TEST = 'test_gpu' in sys.argv[0]
|
|
|
|
|
| def all_strategy_combinations():
|
| return combinations.combine(
|
| distribution=[
|
| strategy_combinations.default_strategy,
|
| strategy_combinations.cloud_tpu_strategy,
|
| strategy_combinations.one_device_strategy_gpu,
|
| ],)
|
|
|
|
|
| def create_in_process_cluster(num_workers, num_ps):
|
| """Creates and starts local servers and returns the cluster_resolver."""
|
| worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
|
| ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
|
|
|
| cluster_dict = {}
|
| cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports]
|
| if num_ps > 0:
|
| cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
|
|
|
| cluster_spec = tf.train.ClusterSpec(cluster_dict)
|
|
|
|
|
| worker_config = tf.compat.v1.ConfigProto()
|
| if multiprocessing.cpu_count() < num_workers + 1:
|
| worker_config.inter_op_parallelism_threads = num_workers + 1
|
|
|
| for i in range(num_workers):
|
| tf.distribute.Server(
|
| cluster_spec,
|
| job_name='worker',
|
| task_index=i,
|
| config=worker_config,
|
| protocol='grpc')
|
|
|
| for i in range(num_ps):
|
| tf.distribute.Server(
|
| cluster_spec, job_name='ps', task_index=i, protocol='grpc')
|
|
|
| cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
|
| cluster_spec, rpc_layer='grpc')
|
| return cluster_resolver
|
|
|
|
|
| def dataset_fn(input_context=None):
|
| del input_context
|
|
|
| def dummy_data(_):
|
| return tf.zeros((1, 1), dtype=tf.float32)
|
|
|
| dataset = tf.data.Dataset.range(1)
|
| dataset = dataset.repeat()
|
| dataset = dataset.map(
|
| dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
| return dataset
|
|
|
|
|
| class MockAsyncTrainer(trainer_lib._AsyncTrainer):
|
| """Mock AsyncTrainer to test the _AsyncTrainer class."""
|
|
|
| def __init__(self):
|
| self._strategy = tf.distribute.get_strategy()
|
| self.init_async()
|
|
|
| self.global_step = tf.Variable(
|
| 0,
|
| dtype=tf.int64,
|
| name='global_step',
|
| trainable=False,
|
| aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
|
| self.eval_global_step = tf.Variable(
|
| 0,
|
| dtype=tf.int64,
|
| name='eval_global_step',
|
| trainable=False,
|
| aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
|
|
|
| train_dataset = self.distribute_dataset(dataset_fn)
|
| orbit.StandardTrainer.__init__(
|
| self, train_dataset, options=orbit.StandardTrainerOptions())
|
|
|
| validation_dataset = self.distribute_dataset(dataset_fn)
|
| orbit.StandardEvaluator.__init__(
|
| self,
|
| validation_dataset,
|
| options=orbit.StandardEvaluatorOptions(use_tf_while_loop=True))
|
|
|
| def train_loop_begin(self):
|
| self.global_step.assign(0)
|
|
|
| def train_step(self, iterator):
|
|
|
| def replica_step(_):
|
| self.global_step.assign_add(1)
|
|
|
| self._strategy.run(replica_step, args=(next(iterator),))
|
|
|
| def train_loop_end(self):
|
| self.join()
|
| return self.global_step.numpy()
|
|
|
| def eval_begin(self):
|
| self.eval_global_step.assign(0)
|
|
|
| def eval_step(self, iterator):
|
|
|
| def replica_step(_):
|
| self.eval_global_step.assign_add(1)
|
|
|
| self._strategy.run(replica_step, args=(next(iterator),))
|
|
|
| def eval_end(self):
|
| self.join()
|
| return self.eval_global_step.numpy()
|
|
|
|
|
| class TrainerTest(tf.test.TestCase, parameterized.TestCase):
|
|
|
| def setUp(self):
|
| super().setUp()
|
| self._config = cfg.ExperimentConfig(
|
| trainer=cfg.TrainerConfig(
|
| optimizer_config=cfg.OptimizationConfig({
|
| 'optimizer': {
|
| 'type': 'sgd'
|
| },
|
| 'learning_rate': {
|
| 'type': 'constant'
|
| }
|
| })))
|
|
|
| def tearDown(self):
|
| gc.collect()
|
|
|
|
|
| self.assertEmpty(gc.garbage)
|
| super().tearDown()
|
|
|
| def create_test_trainer(self, config, model_dir=None, task=None):
|
| task = task or mock_task.MockTask(config.task, logging_dir=model_dir)
|
| ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
|
| trainer = trainer_lib.Trainer(
|
| config,
|
| task,
|
| model=task.build_model(),
|
| optimizer=task.create_optimizer(config.trainer.optimizer_config,
|
| config.runtime),
|
| checkpoint_exporter=ckpt_exporter)
|
| return trainer
|
|
|
| @combinations.generate(all_strategy_combinations())
|
| def test_trainer_train(self, distribution):
|
| with distribution.scope():
|
| trainer = self.create_test_trainer(self._config)
|
| logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
|
| self.assertIn('training_loss', logs)
|
| self.assertIn('learning_rate', logs)
|
|
|
| @combinations.generate(all_strategy_combinations())
|
| def test_trainer_passing_datasets(self, distribution):
|
| with distribution.scope():
|
| task = mock_task.MockTask(self._config)
|
| train_dataset = orbit.utils.make_distributed_dataset(
|
| distribution, task.build_inputs, self._config.task.train_data)
|
| validation_dataset = orbit.utils.make_distributed_dataset(
|
| distribution, task.build_inputs, self._config.task.validation_data)
|
| self._config.task.train_data = None
|
| self._config.task.validation_data = None
|
| trainer = trainer_lib.Trainer(
|
| self._config,
|
| task,
|
| model=task.build_model(),
|
| optimizer=task.create_optimizer(self._config.trainer.optimizer_config,
|
| self._config.runtime),
|
| train_dataset=train_dataset,
|
| validation_dataset=validation_dataset)
|
| logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
|
| self.assertIn('training_loss', logs)
|
| self.assertIn('learning_rate', logs)
|
| logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
|
| self.assertIn('validation_loss', logs)
|
|
|
| def test_base_async_trainer(self):
|
| if TPU_TEST or GPU_TEST:
|
| self.skipTest('Aysnc training is not available on GPU/GPU.')
|
| num_workers = 3
|
| num_ps = 2
|
| cluster_resolver = create_in_process_cluster(num_workers, num_ps)
|
| distribution = tf.distribute.experimental.ParameterServerStrategy(
|
| cluster_resolver)
|
| with distribution.scope():
|
| trainer = MockAsyncTrainer()
|
| trainer.init_async()
|
| self.assertIsInstance(
|
| trainer._coordinator,
|
| tf.distribute.experimental.coordinator.ClusterCoordinator)
|
| self.assertEqual(trainer.train(tf.constant(10)), 10)
|
| self.assertEqual(trainer.evaluate(tf.constant(11)), 11)
|
|
|
| def test_async_trainer_train(self):
|
| if TPU_TEST or GPU_TEST:
|
| self.skipTest('Aysnc training is not available on GPU/TPU.')
|
| num_workers = 3
|
| num_ps = 2
|
| cluster_resolver = create_in_process_cluster(num_workers, num_ps)
|
| distribution = tf.distribute.experimental.ParameterServerStrategy(
|
| cluster_resolver)
|
| with distribution.scope():
|
| config = cfg.ExperimentConfig(**self._config.as_dict())
|
| config.trainer.eval_tf_while_loop = True
|
| trainer = self.create_test_trainer(config)
|
| logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
|
| self.assertIn('training_loss', logs)
|
| self.assertIn('learning_rate', logs)
|
|
|
| def test_async_trainer_validate(self):
|
| if TPU_TEST or GPU_TEST:
|
| self.skipTest('Aysnc training is not available on GPU/GPU.')
|
| num_workers = 3
|
| num_ps = 2
|
| cluster_resolver = create_in_process_cluster(num_workers, num_ps)
|
| distribution = tf.distribute.experimental.ParameterServerStrategy(
|
| cluster_resolver)
|
| with distribution.scope():
|
| config = cfg.ExperimentConfig(**self._config.as_dict())
|
| config.trainer.eval_tf_while_loop = True
|
| trainer = self.create_test_trainer(config)
|
| logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
|
| self.assertIn('acc', logs)
|
| self.assertIn('validation_loss', logs)
|
|
|
| @combinations.generate(all_strategy_combinations())
|
| def test_trainer_validate(self, distribution):
|
| with distribution.scope():
|
| trainer = self.create_test_trainer(self._config)
|
| logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
|
| self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
|
| self.assertIn('validation_loss', logs)
|
|
|
| @combinations.generate(all_strategy_combinations())
|
| def test_trainer_validate_without_loss(self, distribution):
|
|
|
| class MockTaskWithoutValidationLoss(mock_task.MockTask):
|
|
|
| def validation_step(self, inputs, model, metrics=None):
|
|
|
| logs = super().validation_step(inputs, model)
|
| del logs[self.loss]
|
| return logs
|
|
|
| with distribution.scope():
|
| task = MockTaskWithoutValidationLoss()
|
| trainer = self.create_test_trainer(self._config, task=task)
|
| logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
|
| self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
|
| self.assertNotIn('validation_loss', logs)
|
|
|
| @combinations.generate(
|
| combinations.combine(
|
| mixed_precision_dtype=['float32', 'bfloat16', 'float16'],
|
| loss_scale=[None, 'dynamic', 128, 256],
|
| ))
|
| def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
|
| config = cfg.ExperimentConfig(
|
| runtime=cfg.RuntimeConfig(
|
| mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
|
| trainer=cfg.TrainerConfig(
|
| optimizer_config=cfg.OptimizationConfig({
|
| 'optimizer': {
|
| 'type': 'sgd'
|
| },
|
| 'learning_rate': {
|
| 'type': 'constant'
|
| },
|
| })))
|
| trainer = self.create_test_trainer(config)
|
| if mixed_precision_dtype == 'float16':
|
| self.assertIsInstance(trainer.optimizer,
|
| tf_keras.mixed_precision.LossScaleOptimizer)
|
| if loss_scale in (None, 'dynamic'):
|
| self.assertTrue(trainer.optimizer.dynamic)
|
| else:
|
| self.assertFalse(trainer.optimizer.dynamic)
|
| self.assertEqual(trainer.optimizer.initial_scale, loss_scale)
|
| else:
|
| self.assertIsInstance(
|
| trainer.optimizer,
|
| (tf_keras.optimizers.SGD, tf_keras.optimizers.legacy.SGD))
|
|
|
| metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
|
| self.assertIn('training_loss', metrics)
|
|
|
| def test_export_best_ckpt(self):
|
| config = cfg.ExperimentConfig(
|
| trainer=cfg.TrainerConfig(
|
| best_checkpoint_export_subdir='best_ckpt',
|
| best_checkpoint_eval_metric='acc',
|
| optimizer_config=cfg.OptimizationConfig({
|
| 'optimizer': {
|
| 'type': 'sgd'
|
| },
|
| 'learning_rate': {
|
| 'type': 'constant'
|
| }
|
| })))
|
| model_dir = self.get_temp_dir()
|
| trainer = self.create_test_trainer(config, model_dir=model_dir)
|
| trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
|
| trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
|
| self.assertTrue(
|
| tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
|
|
|
| def test_model_with_compiled_loss(self):
|
| task = mock_task.MockTask()
|
| model = task.build_model()
|
| model.compile(loss=tf_keras.losses.CategoricalCrossentropy())
|
| trainer = trainer_lib.Trainer(
|
| self._config,
|
| task,
|
| model=model,
|
| optimizer=task.create_optimizer(self._config.trainer.optimizer_config))
|
| logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
|
| self.assertIn('training_loss', logs)
|
|
|
|
|
| if __name__ == '__main__':
|
| tf.test.main()
|
|
|