|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel |
|
|
from fairseq.models import BaseFairseqModel |
|
|
|
|
|
|
|
|
_GOSSIP_DISABLED = False |
|
|
try: |
|
|
import gossip |
|
|
except ImportError: |
|
|
_GOSSIP_DISABLED = True |
|
|
|
|
|
|
|
|
def DistributedFairseqModel(args, model, process_group=None): |
|
|
""" |
|
|
Wrap a *model* to support distributed data parallel training. |
|
|
|
|
|
This is similar to the built-in DistributedDataParallel, but allows |
|
|
additional configuration of the DistributedDataParallel class to |
|
|
use, and also provides easier access to the wrapped model by |
|
|
forwarding requests for missing attributes to the wrapped model. |
|
|
|
|
|
Args: |
|
|
args (argparse.Namespace): fairseq args |
|
|
model (BaseFairseqModel): model to wrap |
|
|
""" |
|
|
|
|
|
assert isinstance(model, nn.Module) |
|
|
if args.distributed_wrapper == 'DDP' and args.ddp_backend == 'c10d': |
|
|
ddp_class = nn.parallel.DistributedDataParallel |
|
|
init_kwargs = dict( |
|
|
module=model, |
|
|
device_ids=[args.device_id], |
|
|
output_device=args.device_id, |
|
|
broadcast_buffers=args.broadcast_buffers, |
|
|
bucket_cap_mb=args.bucket_cap_mb, |
|
|
process_group=process_group, |
|
|
) |
|
|
|
|
|
if 'check_reduction' in inspect.getargspec(ddp_class)[0]: |
|
|
init_kwargs['check_reduction'] = True |
|
|
if 'find_unused_parameters' in inspect.getargspec(ddp_class)[0]: |
|
|
init_kwargs['find_unused_parameters'] = args.find_unused_parameters |
|
|
elif args.distributed_wrapper == 'DDP' and args.ddp_backend == 'no_c10d': |
|
|
ddp_class = LegacyDistributedDataParallel |
|
|
init_kwargs = dict( |
|
|
module=model, |
|
|
world_size=args.distributed_world_size, |
|
|
buffer_size=2**28, |
|
|
process_group=process_group, |
|
|
) |
|
|
elif args.distributed_wrapper == 'SlowMo': |
|
|
if _GOSSIP_DISABLED: |
|
|
raise ImportError( |
|
|
'Cannot find gossip library. Please install from: ' |
|
|
'github.com/facebookresearch/stochastic_gradient_push' |
|
|
) |
|
|
ddp_class = gossip.GossipDataParallel |
|
|
|
|
|
|
|
|
|
|
|
if args.slowmo_momentum is None: |
|
|
if args.distributed_world_size <= 16: |
|
|
args.slowmo_momentum = 0.0 |
|
|
elif args.distributed_world_size <= 32: |
|
|
args.slowmo_momentum = 0.2 |
|
|
elif args.distributed_world_size <= 64: |
|
|
args.slowmo_momentum = 0.5 |
|
|
else: |
|
|
args.slowmo_momentum = 0.6 |
|
|
|
|
|
init_kwargs = dict( |
|
|
module=model, |
|
|
device_ids=[args.device_id], |
|
|
output_device=args.device_id, |
|
|
broadcast_buffers=args.broadcast_buffers, |
|
|
nprocs_per_node=args.nprocs_per_node, |
|
|
slowmo_momentum=args.slowmo_momentum, |
|
|
localsgd=(args.slowmo_algorithm == 'LocalSGD'), |
|
|
localsgd_frequency=args.localsgd_frequency |
|
|
) |
|
|
else: |
|
|
raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend) |
|
|
|
|
|
class _DistributedFairseqModel(ddp_class): |
|
|
"""Extend DistributedDataParallel to check for missing |
|
|
attributes in the wrapped module.""" |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
def __getattr__(self, name): |
|
|
wrapped_module = super().__getattr__('module') |
|
|
if hasattr(wrapped_module, name): |
|
|
return getattr(wrapped_module, name) |
|
|
return super().__getattr__(name) |
|
|
|
|
|
return _DistributedFairseqModel(**init_kwargs) |
|
|
|