# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ Train a network across multiple GPUs. """ from fairseq import distributed_utils from fairseq.trainer import Trainer try: from fairseq.model_parallel.megatron.mpu import ( get_data_parallel_group, get_data_parallel_rank, get_data_parallel_world_size, get_model_parallel_group, get_model_parallel_src_rank, ) has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False class MegatronTrainer(Trainer): """Main class for model parallel with data parallel training. """ def __init__(self, args, task, model, criterion): if not has_megatron_submodule: raise ImportError( '\n\nPlease install the megatron submodule:' '\n\n git submodule update --init ' 'fairseq/model_parallel/megatron' ) super().__init__(args, task, model, criterion) @property def data_parallel_world_size(self): return get_data_parallel_world_size() @property def data_parallel_process_group(self): return get_data_parallel_group() @property def data_parallel_rank(self): return get_data_parallel_rank() @property def is_data_parallel_master(self): return get_model_parallel_src_rank() == 0 def clip_grad_norm(self, clip_norm): def _aggregate_model_parallel_grad_norm(total_norm): total_norm = total_norm ** 2 distributed_utils.all_reduce(total_norm, group=get_model_parallel_group()) total_norm = total_norm ** 0.5 return total_norm return self.optimizer.clip_grad_norm( clip_norm, aggregate_norm_fn=_aggregate_model_parallel_grad_norm, )