Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Tests for multitask.train_lib.""" | |
| from absl.testing import parameterized | |
| import tensorflow as tf, tf_keras | |
| from tensorflow.python.distribute import combinations | |
| from tensorflow.python.distribute import strategy_combinations | |
| from official.core import task_factory | |
| from official.modeling.hyperparams import params_dict | |
| from official.modeling.multitask import configs | |
| from official.modeling.multitask import multitask | |
| from official.modeling.multitask import test_utils | |
| from official.modeling.multitask import train_lib | |
| class TrainLibTest(tf.test.TestCase, parameterized.TestCase): | |
| def setUp(self): | |
| super().setUp() | |
| self._test_config = { | |
| 'trainer': { | |
| 'checkpoint_interval': 10, | |
| 'steps_per_loop': 10, | |
| 'summary_interval': 10, | |
| 'train_steps': 10, | |
| 'validation_steps': 5, | |
| 'validation_interval': 10, | |
| 'continuous_eval_timeout': 1, | |
| 'optimizer_config': { | |
| 'optimizer': { | |
| 'type': 'sgd', | |
| }, | |
| 'learning_rate': { | |
| 'type': 'constant' | |
| } | |
| } | |
| }, | |
| } | |
| def test_end_to_end(self, distribution_strategy, optimizer, flag_mode): | |
| model_dir = self.get_temp_dir() | |
| experiment_config = configs.MultiTaskExperimentConfig( | |
| task=configs.MultiTaskConfig( | |
| task_routines=( | |
| configs.TaskRoutine( | |
| task_name='foo', task_config=test_utils.FooConfig()), | |
| configs.TaskRoutine( | |
| task_name='bar', task_config=test_utils.BarConfig())))) | |
| experiment_config = params_dict.override_params_dict( | |
| experiment_config, self._test_config, is_strict=False) | |
| experiment_config.trainer.optimizer_config.optimizer.type = optimizer | |
| with distribution_strategy.scope(): | |
| test_multitask = multitask.MultiTask.from_config(experiment_config.task) | |
| model = test_utils.MockMultiTaskModel() | |
| train_lib.run_experiment( | |
| distribution_strategy=distribution_strategy, | |
| task=test_multitask, | |
| model=model, | |
| mode=flag_mode, | |
| params=experiment_config, | |
| model_dir=model_dir) | |
| def test_end_to_end_multi_eval(self, distribution_strategy, flag_mode): | |
| model_dir = self.get_temp_dir() | |
| experiment_config = configs.MultiEvalExperimentConfig( | |
| task=test_utils.FooConfig(), | |
| eval_tasks=(configs.TaskRoutine( | |
| task_name='foo', task_config=test_utils.FooConfig(), eval_steps=2), | |
| configs.TaskRoutine( | |
| task_name='bar', | |
| task_config=test_utils.BarConfig(), | |
| eval_steps=3))) | |
| experiment_config = params_dict.override_params_dict( | |
| experiment_config, self._test_config, is_strict=False) | |
| with distribution_strategy.scope(): | |
| train_task = task_factory.get_task(experiment_config.task) | |
| eval_tasks = [ | |
| task_factory.get_task(config.task_config, name=config.task_name) | |
| for config in experiment_config.eval_tasks | |
| ] | |
| train_lib.run_experiment_with_multitask_eval( | |
| distribution_strategy=distribution_strategy, | |
| train_task=train_task, | |
| eval_tasks=eval_tasks, | |
| mode=flag_mode, | |
| params=experiment_config, | |
| model_dir=model_dir) | |
| if __name__ == '__main__': | |
| tf.test.main() | |