Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Utils to sample tasks for interleaved optimization.""" | |
| import abc | |
| from typing import Union, Dict, Text | |
| import tensorflow as tf, tf_keras | |
| from official.modeling.multitask import configs | |
| class TaskSampler(tf.Module, metaclass=abc.ABCMeta): | |
| """An abstract class defining task sampling API for interleaving trainer.""" | |
| def __init__(self, task_weights: Dict[Text, Union[float, int]]): | |
| self._task_weights = task_weights | |
| def task_weights(self): | |
| return self._task_weights | |
| def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor: | |
| """Compute cumulative distribution to sample tasks. | |
| It calculates the cumulative distribution of the multinomial task | |
| distribution with respect to which to be sampled against. | |
| Args: | |
| global_step: A tensor indicating current progess of training. | |
| Returns: | |
| A float tensor with shape (#(task), 1) that represents the cumulative | |
| sampling distribution. | |
| """ | |
| pass | |
| class UniformTaskSampler(TaskSampler): | |
| """Sample all tasks uniformly.""" | |
| def __init__(self, task_weights: Dict[Text, Union[float, int]]): | |
| super(UniformTaskSampler, self).__init__(task_weights=task_weights) | |
| self._uniform_cumulative = tf.math.cumsum( | |
| tf.constant( | |
| [1.0 / len(self._task_weights)] * len(self._task_weights), | |
| dtype=tf.float32)) | |
| def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor: | |
| del global_step | |
| return self._uniform_cumulative | |
| class ProportionalTaskSampler(TaskSampler): | |
| """Sample tasks proportional to task weights.""" | |
| def __init__(self, | |
| task_weights: Dict[Text, Union[float, int]], | |
| alpha: float = 1.0): | |
| super(ProportionalTaskSampler, self).__init__(task_weights=task_weights) | |
| self._alpha = tf.cast(alpha, dtype=tf.float32) | |
| task_weight_dict_ordered_list = tf.constant( | |
| [weight for _, weight in self._task_weights.items()], dtype=tf.float32) | |
| task_sizes = tf.math.pow(task_weight_dict_ordered_list, self._alpha) | |
| task_distribution = task_sizes / tf.reduce_sum(task_sizes) | |
| self._porportional_cumulative = tf.math.cumsum(task_distribution) | |
| def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor: | |
| del global_step | |
| return self._porportional_cumulative | |
| class AnnealingTaskSampler(TaskSampler): | |
| """Sample tasks according to task weights as well as training progress. | |
| See http://proceedings.mlr.press/v97/stickland19a/stickland19a.pdf | |
| """ | |
| def __init__(self, | |
| task_weights: Dict[Text, Union[float, int]], | |
| steps_per_epoch: int, | |
| total_steps: int): | |
| super(AnnealingTaskSampler, self).__init__(task_weights=task_weights) | |
| self._steps_per_epoch = tf.cast(steps_per_epoch, dtype=tf.float32) | |
| self._total_epochs = tf.cast( | |
| total_steps / self._steps_per_epoch, dtype=tf.float32) | |
| def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor: | |
| cur_epoch = tf.math.floor( | |
| tf.cast(global_step, dtype=tf.float32) / self._steps_per_epoch) | |
| alpha = 1.0 - 0.8 * (cur_epoch - 1) / (self._total_epochs - 1 + 1e-10) | |
| task_weight_dict_ordered_list = [ | |
| weight for _, weight in self._task_weights.items() | |
| ] | |
| task_sizes = tf.math.pow( | |
| tf.constant(task_weight_dict_ordered_list, dtype=tf.float32), | |
| tf.cast(alpha, dtype=tf.float32)) | |
| dynamic_task_distribution = task_sizes / tf.reduce_sum(task_sizes) | |
| return tf.math.cumsum(dynamic_task_distribution) | |
| def get_task_sampler(config: configs.TaskSamplingConfig, | |
| task_weights: Dict[Text, float]) -> TaskSampler: | |
| """Utils to create task sampler with configuration and task weights.""" | |
| oneof_config = config.get() | |
| if config.type == 'uniform': | |
| return UniformTaskSampler(task_weights=task_weights) | |
| elif config.type == 'proportional': | |
| return ProportionalTaskSampler( | |
| task_weights=task_weights, alpha=oneof_config.alpha) | |
| elif config.type == 'annealing': | |
| return AnnealingTaskSampler( | |
| task_weights=task_weights, | |
| steps_per_epoch=oneof_config.steps_per_epoch, | |
| total_steps=oneof_config.total_steps) | |
| else: | |
| raise RuntimeError('Task sampler type not supported') | |