Spaces:
Build error
Build error
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Multitask trainer that interleaves each task's train step.""" | |
| from typing import Union | |
| import gin | |
| import orbit | |
| import tensorflow as tf, tf_keras | |
| from official.modeling.multitask import base_model | |
| from official.modeling.multitask import base_trainer | |
| from official.modeling.multitask import multitask | |
| from official.modeling.multitask import task_sampler as sampler | |
| class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer): | |
| """MultiTask trainer that interleaves task update.""" | |
| def __init__(self, | |
| multi_task: multitask.MultiTask, | |
| multi_task_model: Union[tf_keras.Model, | |
| base_model.MultiTaskBaseModel], | |
| optimizer: Union[tf.optimizers.Optimizer, | |
| tf_keras.optimizers.experimental.Optimizer, | |
| tf_keras.optimizers.legacy.Optimizer], | |
| task_sampler: sampler.TaskSampler, | |
| trainer_options=None): | |
| super().__init__( | |
| multi_task=multi_task, | |
| multi_task_model=multi_task_model, | |
| optimizer=optimizer, | |
| trainer_options=trainer_options) | |
| self._task_sampler = task_sampler | |
| # Build per task train step. | |
| def _get_task_step(task_name, task): | |
| def step_fn(inputs): | |
| if isinstance(self.multi_task_model, base_model.MultiTaskBaseModel): | |
| task_model = self.multi_task_model.sub_tasks[task_name] | |
| else: | |
| task_model = self.multi_task_model | |
| task_logs = task.train_step( | |
| inputs, | |
| model=task_model, | |
| optimizer=self.optimizer, | |
| metrics=self.training_metrics[task_name]) | |
| self.training_losses[task_name].update_state(task_logs[task.loss]) | |
| return step_fn | |
| self._task_train_step_map = { | |
| name: _get_task_step(name, task) | |
| for name, task in self.multi_task.tasks.items() | |
| } | |
| # TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging | |
| # on TensorBoard. | |
| self._task_step_counters = { | |
| name: orbit.utils.create_global_step() for name in self.multi_task.tasks | |
| } | |
| # If the new Keras optimizer is used, we require all model variables are | |
| # created before the training and let the optimizer to create the slot | |
| # variable all together. | |
| if isinstance(optimizer, tf_keras.optimizers.experimental.Optimizer): | |
| multi_task_model.build() | |
| optimizer.build(multi_task_model.trainable_variables) | |
| def task_step_counter(self, name): | |
| return self._task_step_counters[name] | |
| def train_step(self, iterator_map): | |
| # Sample one task to train according to a multinomial distribution | |
| rn = tf.random.stateless_uniform(shape=[], seed=(0, self.global_step)) | |
| cumulative_sample_distribution = self._task_sampler.task_cumulative_distribution( | |
| self.global_step) | |
| # Prepend a [0.0] for indexing convenience. | |
| cumulative_sample_distribution = tf.concat( | |
| [tf.constant([0.0], dtype=tf.float32), cumulative_sample_distribution], | |
| axis=0) | |
| for idx, (name, _) in enumerate(self.multi_task.tasks.items()): | |
| begin = cumulative_sample_distribution[idx] | |
| end = cumulative_sample_distribution[idx + 1] | |
| if rn >= begin and rn < end: | |
| self._strategy.run( | |
| self._task_train_step_map[name], args=(next(iterator_map[name]),)) | |
| self.global_step.assign_add(1) | |
| self.task_step_counter(name).assign_add(1) | |
| def train_loop_end(self): | |
| """Record loss and metric values per task.""" | |
| result = super().train_loop_end() | |
| # Interleaving training does not have a good semantic for `total_loss`. In | |
| # fact, it is always zero. To avoid confusion, we filter the `total_loss` | |
| # from the result logs. | |
| if 'total_loss' in result: | |
| result.pop('total_loss') | |
| return result | |