maotao / fairseq /model_parallel /megatron_trainer.py
julse's picture
Upload 551 files
be611b4 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train a network across multiple GPUs.
"""
from fairseq import distributed_utils
from fairseq.trainer import Trainer
try:
from fairseq.model_parallel.megatron.mpu import (
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_model_parallel_group,
get_model_parallel_src_rank,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
class MegatronTrainer(Trainer):
"""Main class for model parallel with data parallel training.
"""
def __init__(self, args, task, model, criterion):
if not has_megatron_submodule:
raise ImportError(
'\n\nPlease install the megatron submodule:'
'\n\n git submodule update --init '
'fairseq/model_parallel/megatron'
)
super().__init__(args, task, model, criterion)
@property
def data_parallel_world_size(self):
return get_data_parallel_world_size()
@property
def data_parallel_process_group(self):
return get_data_parallel_group()
@property
def data_parallel_rank(self):
return get_data_parallel_rank()
@property
def is_data_parallel_master(self):
return get_model_parallel_src_rank() == 0
def clip_grad_norm(self, clip_norm):
def _aggregate_model_parallel_grad_norm(total_norm):
total_norm = total_norm ** 2
distributed_utils.all_reduce(total_norm, group=get_model_parallel_group())
total_norm = total_norm ** 0.5
return total_norm
return self.optimizer.clip_grad_norm(
clip_norm,
aggregate_norm_fn=_aggregate_model_parallel_grad_norm,
)