maotao / fairseq /data /multilingual /sampling_method.py
julse's picture
Upload 551 files
be611b4 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
import logging
logger = logging.getLogger(__name__)
def uniform(dataset_sizes: List[int]):
return [1.0] * len(dataset_sizes)
def temperature_sampling(dataset_sizes, temp):
total_size = sum(dataset_sizes)
return [(size / total_size) ** (1.0/temp) for size in dataset_sizes]
def make_temperature_sampling(temp=1.0):
def sampling_func(dataset_sizes):
return temperature_sampling(dataset_sizes, temp)
return sampling_func
def make_ratio_sampling(ratios):
def sampling_func(dataset_sizes):
return ratios
return sampling_func
class SamplingMethod:
@staticmethod
def add_arguments(parser):
parser.add_argument(
'--sampling-method',
choices=['uniform', 'temperature', 'concat', 'RoundRobin', ],
type=str,
default='concat',
help='The method to sample data per language pairs')
parser.add_argument('--sampling-temperature', default=1.5, type=float,
help='only work with --sampling-method temperature')
@staticmethod
def build_sampler(args, task):
return SamplingMethod(args, task)
def __init__(self, args, task):
self.args = args
self.task = task
def is_adaptive(self):
return False
def sampling_method_selector(self):
args = self.args
logger.info(f'selected sampler: {args.sampling_method}')
if args.sampling_method == 'uniform':
return uniform
elif args.sampling_method == 'temperature' or self.is_adaptive():
return make_temperature_sampling(float(args.sampling_temperature))
else:
# default to concating all data set together
return None