maotao / fairseq /models /distributed_fairseq_model.py
julse's picture
Upload 551 files
be611b4 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import torch.nn as nn
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
from fairseq.models import BaseFairseqModel
_GOSSIP_DISABLED = False
try:
import gossip
except ImportError:
_GOSSIP_DISABLED = True
def DistributedFairseqModel(args, model, process_group=None):
"""
Wrap a *model* to support distributed data parallel training.
This is similar to the built-in DistributedDataParallel, but allows
additional configuration of the DistributedDataParallel class to
use, and also provides easier access to the wrapped model by
forwarding requests for missing attributes to the wrapped model.
Args:
args (argparse.Namespace): fairseq args
model (BaseFairseqModel): model to wrap
"""
# determine which DDP class to extend
assert isinstance(model, nn.Module)
if args.distributed_wrapper == 'DDP' and args.ddp_backend == 'c10d':
ddp_class = nn.parallel.DistributedDataParallel
init_kwargs = dict(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=args.broadcast_buffers,
bucket_cap_mb=args.bucket_cap_mb,
process_group=process_group,
)
# Maintain backward compatibility
if 'check_reduction' in inspect.getargspec(ddp_class)[0]:
init_kwargs['check_reduction'] = True
if 'find_unused_parameters' in inspect.getargspec(ddp_class)[0]:
init_kwargs['find_unused_parameters'] = args.find_unused_parameters
elif args.distributed_wrapper == 'DDP' and args.ddp_backend == 'no_c10d':
ddp_class = LegacyDistributedDataParallel
init_kwargs = dict(
module=model,
world_size=args.distributed_world_size,
buffer_size=2**28,
process_group=process_group,
)
elif args.distributed_wrapper == 'SlowMo':
if _GOSSIP_DISABLED:
raise ImportError(
'Cannot find gossip library. Please install from: '
'github.com/facebookresearch/stochastic_gradient_push'
)
ddp_class = gossip.GossipDataParallel
# The values of slowmo_momentum below were obtained by tuning on the
# En-De 16 dataset by training the transformer_wmt_en_de_large model
if args.slowmo_momentum is None:
if args.distributed_world_size <= 16:
args.slowmo_momentum = 0.0
elif args.distributed_world_size <= 32:
args.slowmo_momentum = 0.2
elif args.distributed_world_size <= 64:
args.slowmo_momentum = 0.5
else:
args.slowmo_momentum = 0.6
init_kwargs = dict(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=args.broadcast_buffers,
nprocs_per_node=args.nprocs_per_node,
slowmo_momentum=args.slowmo_momentum,
localsgd=(args.slowmo_algorithm == 'LocalSGD'),
localsgd_frequency=args.localsgd_frequency
)
else:
raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend)
class _DistributedFairseqModel(ddp_class):
"""Extend DistributedDataParallel to check for missing
attributes in the wrapped module."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __getattr__(self, name):
wrapped_module = super().__getattr__('module')
if hasattr(wrapped_module, name):
return getattr(wrapped_module, name)
return super().__getattr__(name)
return _DistributedFairseqModel(**init_kwargs)