Spaces:
Runtime error
Runtime error
Dit-document-layout-analysis
/
unilm
/decoding
/GAD
/fairseq
/data
/multilingual
/sampling_method.py
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| from typing import List | |
| logger = logging.getLogger(__name__) | |
| def uniform(dataset_sizes: List[int]): | |
| return [1.0] * len(dataset_sizes) | |
| def temperature_sampling(dataset_sizes, temp): | |
| total_size = sum(dataset_sizes) | |
| return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes] | |
| def make_temperature_sampling(temp=1.0): | |
| def sampling_func(dataset_sizes): | |
| return temperature_sampling(dataset_sizes, temp) | |
| return sampling_func | |
| def make_ratio_sampling(ratios): | |
| def sampling_func(dataset_sizes): | |
| return ratios | |
| return sampling_func | |
| class SamplingMethod: | |
| def add_arguments(parser): | |
| parser.add_argument( | |
| "--sampling-method", | |
| choices=[ | |
| "uniform", | |
| "temperature", | |
| "concat", | |
| "RoundRobin", | |
| ], | |
| type=str, | |
| default="concat", | |
| help="The method to sample data per language pairs", | |
| ) | |
| parser.add_argument( | |
| "--sampling-temperature", | |
| default=1.5, | |
| type=float, | |
| help="only work with --sampling-method temperature", | |
| ) | |
| def build_sampler(args, task): | |
| return SamplingMethod(args, task) | |
| def __init__(self, args, task): | |
| self.args = args | |
| self.task = task | |
| def is_adaptive(self): | |
| return False | |
| def sampling_method_selector(self): | |
| args = self.args | |
| logger.info(f"selected sampler: {args.sampling_method}") | |
| if args.sampling_method == "uniform": | |
| return uniform | |
| elif args.sampling_method == "temperature" or self.is_adaptive(): | |
| return make_temperature_sampling(float(args.sampling_temperature)) | |
| else: | |
| # default to concating all data set together | |
| return None | |