# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import List import logging logger = logging.getLogger(__name__) def uniform(dataset_sizes: List[int]): return [1.0] * len(dataset_sizes) def temperature_sampling(dataset_sizes, temp): total_size = sum(dataset_sizes) return [(size / total_size) ** (1.0/temp) for size in dataset_sizes] def make_temperature_sampling(temp=1.0): def sampling_func(dataset_sizes): return temperature_sampling(dataset_sizes, temp) return sampling_func def make_ratio_sampling(ratios): def sampling_func(dataset_sizes): return ratios return sampling_func class SamplingMethod: @staticmethod def add_arguments(parser): parser.add_argument( '--sampling-method', choices=['uniform', 'temperature', 'concat', 'RoundRobin', ], type=str, default='concat', help='The method to sample data per language pairs') parser.add_argument('--sampling-temperature', default=1.5, type=float, help='only work with --sampling-method temperature') @staticmethod def build_sampler(args, task): return SamplingMethod(args, task) def __init__(self, args, task): self.args = args self.task = task def is_adaptive(self): return False def sampling_method_selector(self): args = self.args logger.info(f'selected sampler: {args.sampling_method}') if args.sampling_method == 'uniform': return uniform elif args.sampling_method == 'temperature' or self.is_adaptive(): return make_temperature_sampling(float(args.sampling_temperature)) else: # default to concating all data set together return None