|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Utils to sample tasks for interleaved optimization."""
|
| import abc
|
| from typing import Union, Dict, Text
|
| import tensorflow as tf, tf_keras
|
|
|
| from official.modeling.multitask import configs
|
|
|
|
|
| class TaskSampler(tf.Module, metaclass=abc.ABCMeta):
|
| """An abstract class defining task sampling API for interleaving trainer."""
|
|
|
| def __init__(self, task_weights: Dict[Text, Union[float, int]]):
|
| self._task_weights = task_weights
|
|
|
| @property
|
| def task_weights(self):
|
| return self._task_weights
|
|
|
| @abc.abstractmethod
|
| def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
|
| """Compute cumulative distribution to sample tasks.
|
|
|
| It calculates the cumulative distribution of the multinomial task
|
| distribution with respect to which to be sampled against.
|
|
|
| Args:
|
| global_step: A tensor indicating current progess of training.
|
|
|
| Returns:
|
| A float tensor with shape (#(task), 1) that represents the cumulative
|
| sampling distribution.
|
| """
|
| pass
|
|
|
|
|
| class UniformTaskSampler(TaskSampler):
|
| """Sample all tasks uniformly."""
|
|
|
| def __init__(self, task_weights: Dict[Text, Union[float, int]]):
|
| super(UniformTaskSampler, self).__init__(task_weights=task_weights)
|
| self._uniform_cumulative = tf.math.cumsum(
|
| tf.constant(
|
| [1.0 / len(self._task_weights)] * len(self._task_weights),
|
| dtype=tf.float32))
|
|
|
| def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
|
| del global_step
|
| return self._uniform_cumulative
|
|
|
|
|
| class ProportionalTaskSampler(TaskSampler):
|
| """Sample tasks proportional to task weights."""
|
|
|
| def __init__(self,
|
| task_weights: Dict[Text, Union[float, int]],
|
| alpha: float = 1.0):
|
| super(ProportionalTaskSampler, self).__init__(task_weights=task_weights)
|
| self._alpha = tf.cast(alpha, dtype=tf.float32)
|
| task_weight_dict_ordered_list = tf.constant(
|
| [weight for _, weight in self._task_weights.items()], dtype=tf.float32)
|
| task_sizes = tf.math.pow(task_weight_dict_ordered_list, self._alpha)
|
| task_distribution = task_sizes / tf.reduce_sum(task_sizes)
|
| self._porportional_cumulative = tf.math.cumsum(task_distribution)
|
|
|
| def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
|
| del global_step
|
| return self._porportional_cumulative
|
|
|
|
|
| class AnnealingTaskSampler(TaskSampler):
|
| """Sample tasks according to task weights as well as training progress.
|
|
|
| See http://proceedings.mlr.press/v97/stickland19a/stickland19a.pdf
|
| """
|
|
|
| def __init__(self,
|
| task_weights: Dict[Text, Union[float, int]],
|
| steps_per_epoch: int,
|
| total_steps: int):
|
| super(AnnealingTaskSampler, self).__init__(task_weights=task_weights)
|
| self._steps_per_epoch = tf.cast(steps_per_epoch, dtype=tf.float32)
|
| self._total_epochs = tf.cast(
|
| total_steps / self._steps_per_epoch, dtype=tf.float32)
|
|
|
| def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
|
| cur_epoch = tf.math.floor(
|
| tf.cast(global_step, dtype=tf.float32) / self._steps_per_epoch)
|
| alpha = 1.0 - 0.8 * (cur_epoch - 1) / (self._total_epochs - 1 + 1e-10)
|
| task_weight_dict_ordered_list = [
|
| weight for _, weight in self._task_weights.items()
|
| ]
|
| task_sizes = tf.math.pow(
|
| tf.constant(task_weight_dict_ordered_list, dtype=tf.float32),
|
| tf.cast(alpha, dtype=tf.float32))
|
| dynamic_task_distribution = task_sizes / tf.reduce_sum(task_sizes)
|
| return tf.math.cumsum(dynamic_task_distribution)
|
|
|
|
|
| def get_task_sampler(config: configs.TaskSamplingConfig,
|
| task_weights: Dict[Text, float]) -> TaskSampler:
|
| """Utils to create task sampler with configuration and task weights."""
|
| oneof_config = config.get()
|
| if config.type == 'uniform':
|
| return UniformTaskSampler(task_weights=task_weights)
|
| elif config.type == 'proportional':
|
| return ProportionalTaskSampler(
|
| task_weights=task_weights, alpha=oneof_config.alpha)
|
| elif config.type == 'annealing':
|
| return AnnealingTaskSampler(
|
| task_weights=task_weights,
|
| steps_per_epoch=oneof_config.steps_per_epoch,
|
| total_steps=oneof_config.total_steps)
|
| else:
|
| raise RuntimeError('Task sampler type not supported')
|
|
|