maotao / fairseq /trainer.py
julse's picture
Upload 551 files
be611b4 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train a network across multiple GPUs.
"""
import contextlib
from itertools import chain
import logging
import sys
import time
from typing import Any, Dict, List
import torch
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
from fairseq.file_io import PathManager
from fairseq.logging import meters, metrics
from fairseq.nan_detector import NanDetector
from fairseq.optim import lr_scheduler
logger = logging.getLogger(__name__)
class Trainer(object):
"""Main class for data parallel training.
This class supports synchronous distributed data parallel training,
where multiple workers each have a full model replica and gradients
are accumulated across workers before each update. We use
:class:`~torch.nn.parallel.DistributedDataParallel` to handle
communication of the gradients across workers.
"""
def __init__(self, args, task, model, criterion, quantizer=None):
self.args = args
self.task = task
# catalog shared parameters
shared_params = _catalog_shared_params(model)
self.tpu = getattr(args, 'tpu', False)
self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu
if self.cuda:
self.device = torch.device('cuda')
elif self.tpu:
self.device = utils.get_tpu_device(args)
else:
self.device = torch.device('cpu')
# copy model and criterion to current device/dtype
self._criterion = criterion
self._model = model
if self.tpu:
import torch_xla.core.xla_model as xm
self._model = xm.send_cpu_data_to_device(self._model, self.device)
if args.fp16:
self._criterion = self._criterion.half()
self._model = self._model.half()
elif args.bf16:
self._criterion = self._criterion.to(dtype=torch.bfloat16)
self._model = self._model.to(dtype=torch.bfloat16)
self._criterion = self._criterion.to(device=self.device)
self._model = self._model.to(device=self.device)
# check that shared parameters are preserved after device transfer
for shared_param in shared_params:
ref = _get_module_by_path(self._model, shared_param[0])
for path in shared_param[1:]:
logger.info(
'detected shared parameter: {} <- {}'.format(shared_param[0], path)
)
_set_module_by_path(self._model, path, ref)
self._dummy_batch = "DUMMY" # indicates we don't have a dummy batch at first
self._lr_scheduler = None
self._num_updates = 0
self._num_xla_compiles = 0 # for TPUs
self._optim_history = None
self._optimizer = None
self._warn_once = set()
self._wrapped_criterion = None
self._wrapped_model = None
# TODO(myleott): support tpu
if self.cuda and self.data_parallel_world_size > 1:
self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size)
else:
self._grad_norm_buf = None
self.quantizer = quantizer
if self.quantizer is not None:
self.quantizer.set_trainer(self)
# get detailed cuda environment
if self.cuda:
self.cuda_env = utils.CudaEnvironment()
if self.data_parallel_world_size > 1:
self.cuda_env_arr = distributed_utils.all_gather_list(self.cuda_env)
else:
self.cuda_env_arr = [self.cuda_env]
if self.data_parallel_rank == 0:
utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr)
else:
self.cuda_env = None
self.cuda_env_arr = None
metrics.log_start_time("wall", priority=790, round=0)
self._start_time = time.time()
self._previous_training_time = 0
self._cumulative_training_time = None
def reinitialize(self):
"""Reinitialize the Trainer, typically after model params change."""
self._lr_scheduler = None
self._optimizer = None
self._wrapped_criterion = None
self._wrapped_model = None
@property
def data_parallel_world_size(self):
return self.args.distributed_world_size
@property
def data_parallel_process_group(self):
if self.tpu:
return ('tpu', None)
else:
return None
@property
def data_parallel_rank(self):
return self.args.distributed_rank
@property
def is_data_parallel_master(self):
return distributed_utils.is_master(self.args)
@property
def criterion(self):
if self._wrapped_criterion is None:
if (
utils.has_parameters(self._criterion)
and self.data_parallel_world_size > 1
and not self.args.use_bmuf
and not self.tpu
):
self._wrapped_criterion = models.DistributedFairseqModel(
self.args, self._criterion,
process_group=self.data_parallel_process_group
)
else:
self._wrapped_criterion = self._criterion
return self._wrapped_criterion
@property
def model(self):
if self._wrapped_model is None:
if (
self.data_parallel_world_size > 1
and not self.args.use_bmuf
and not self.tpu
):
self._wrapped_model = models.DistributedFairseqModel(
self.args, self._model,
process_group=self.data_parallel_process_group
)
else:
self._wrapped_model = self._model
return self._wrapped_model
@property
def optimizer(self):
if self._optimizer is None:
self._build_optimizer()
return self._optimizer
@property
def lr_scheduler(self):
if self._lr_scheduler is None:
self._build_optimizer() # this will initialize self._lr_scheduler
return self._lr_scheduler
def _build_optimizer(self):
params = list(
filter(
lambda p: p.requires_grad,
chain(self.model.parameters(), self.criterion.parameters()),
)
)
if self.args.fp16 or self.args.bf16:
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
logger.info(
"NOTE: your device does NOT support faster training with --fp16, "
"please switch to FP32 which is likely to be faster"
)
if self.args.memory_efficient_fp16 or self.args.memory_efficient_bf16:
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
self.args, params
)
else:
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
else:
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
logger.info("NOTE: your device may support faster training with --fp16")
self._optimizer = optim.build_optimizer(self.args, params)
if self.args.use_bmuf:
self._optimizer = optim.FairseqBMUF(self.args, self._optimizer)
# We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set.
self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
self._lr_scheduler.step_update(0)
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
if self.is_data_parallel_master: # only save one checkpoint
extra_state["metrics"] = metrics.state_dict()
extra_state["previous_training_time"] = self.cumulative_training_time()
checkpoint_utils.save_state(
filename,
self.args,
self.get_model().state_dict(),
self.get_criterion(),
self.optimizer,
self.lr_scheduler,
self.get_num_updates(),
self._optim_history,
extra_state,
)
def load_checkpoint(
self,
filename,
reset_optimizer=False,
reset_lr_scheduler=False,
optimizer_overrides=None,
reset_meters=False,
):
"""Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = None, [], None
bexists = PathManager.isfile(filename)
if bexists:
state = checkpoint_utils.load_checkpoint_to_cpu(filename)
# load model parameters
try:
self.get_model().load_state_dict(
state["model"], strict=True, args=self.args
)
if utils.has_parameters(self.get_criterion()):
self.get_criterion().load_state_dict(
state["criterion"], strict=True
)
except Exception:
raise Exception(
"Cannot load model parameters from checkpoint {}; "
"please ensure that the architectures match.".format(filename)
)
extra_state = state["extra_state"]
self._optim_history = state["optimizer_history"]
last_optim_state = state.get("last_optimizer_state", None)
if last_optim_state is not None and not reset_optimizer:
# rebuild optimizer after loading model, since params may have changed
self._build_optimizer()
# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
assert (
last_optim["criterion_name"] == self.get_criterion().__class__.__name__
), "Criterion does not match; please reset the optimizer (--reset-optimizer)."
assert (
last_optim["optimizer_name"] == self.optimizer.__class__.__name__
), "Optimizer does not match; please reset the optimizer (--reset-optimizer)."
if not reset_lr_scheduler:
self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"])
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
self.set_num_updates(last_optim["num_updates"])
if extra_state is not None:
epoch = extra_state["train_iterator"]["epoch"]
logger.info(
"loaded checkpoint {} (epoch {} @ {} updates)".format(
filename, epoch, self.get_num_updates()
)
)
if "previous_training_time" in extra_state:
self._previous_training_time = extra_state["previous_training_time"]
self._start_time = time.time()
self.lr_step(epoch)
if "metrics" in extra_state and not reset_meters:
metrics.load_state_dict(extra_state["metrics"])
# reset TimeMeters, since their start times don't make sense anymore
for meter in metrics.get_meters("default"):
if isinstance(meter, meters.TimeMeter):
meter.reset()
else:
logger.info("no existing checkpoint found {}".format(filename))
return extra_state
def get_train_iterator(
self,
epoch,
combine=True,
load_dataset=True,
data_selector=None,
shard_batch_itr=True,
):
"""Return an EpochBatchIterator over the training set for a given epoch."""
if load_dataset:
logger.info("loading train data for epoch {}".format(epoch))
self.task.load_dataset(
self.args.train_subset,
epoch=epoch,
combine=combine,
data_selector=data_selector,
)
return self.task.get_batch_iterator(
dataset=self.task.dataset(self.args.train_subset),
max_tokens=self.args.max_tokens,
max_sentences=self.args.max_sentences,
max_positions=utils.resolve_max_positions(
self.task.max_positions(),
self.model.max_positions(),
self.args.max_tokens,
),
ignore_invalid_inputs=True,
required_batch_size_multiple=self.args.required_batch_size_multiple,
seed=self.args.seed,
num_shards=self.data_parallel_world_size if shard_batch_itr else 1,
shard_id=self.data_parallel_rank if shard_batch_itr else 0,
num_workers=self.args.num_workers,
epoch=epoch
)
def get_valid_iterator(
self,
subset,
):
"""Return an EpochBatchIterator over given validation subset for a given epoch."""
return self.task.get_batch_iterator(
dataset=self.task.dataset(subset),
max_tokens=self.args.max_tokens_valid,
max_sentences=self.args.max_sentences_valid,
max_positions=utils.resolve_max_positions(
self.task.max_positions(),
self.model.max_positions(),
),
ignore_invalid_inputs=self.args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=self.args.required_batch_size_multiple,
seed=self.args.seed,
num_shards=self.data_parallel_world_size,
shard_id=self.data_parallel_rank,
num_workers=self.args.num_workers
)
def begin_epoch(self, epoch):
"""Called at the beginning of each epoch."""
if self.quantizer is not None:
self.quantizer.begin_epoch(epoch)
# task specific setup per epoch
self.task.begin_epoch(epoch, self.get_model())
@metrics.aggregate("train")
def train_step(self, samples, raise_oom=False):
"""Do forward, backward and parameter update."""
if self._dummy_batch == "DUMMY":
self._dummy_batch = samples[0]
self._set_seed()
self.model.train()
self.criterion.train()
self.zero_grad()
metrics.log_start_time("train_wall", priority=800, round=0)
# forward and backward pass
logging_outputs, sample_size, ooms = [], 0, 0
for i, sample in enumerate(samples):
sample = self._prepare_sample(sample)
if sample is None:
# when sample is None, run forward/backward on a dummy batch
# and ignore the resulting gradients
sample = self._prepare_sample(self._dummy_batch)
is_dummy_batch = True
else:
is_dummy_batch = False
def maybe_no_sync():
"""
Whenever *samples* contains more than one mini-batch, we
want to accumulate gradients locally and only call
all-reduce in the last backwards pass.
"""
if (
self.data_parallel_world_size > 1
and hasattr(self.model, "no_sync")
and i < len(samples) - 1
):
return self.model.no_sync()
else:
return contextlib.ExitStack() # dummy contextmanager
try:
with maybe_no_sync():
# forward and backward
loss, sample_size_i, logging_output = self.task.train_step(
sample=sample,
model=self.model,
criterion=self.criterion,
optimizer=self.optimizer,
update_num=self.get_num_updates(),
ignore_grad=is_dummy_batch,
)
del loss
logging_outputs.append(logging_output)
sample_size += sample_size_i
# emptying the CUDA cache after the first step can
# reduce the chance of OOM
if self.cuda and self.get_num_updates() == 0:
torch.cuda.empty_cache()
except RuntimeError as e:
if "out of memory" in str(e):
self._log_oom(e)
if raise_oom:
raise e
logger.warning(
"attempting to recover from OOM in forward/backward pass"
)
ooms += 1
self.zero_grad()
if self.cuda:
torch.cuda.empty_cache()
if self.args.distributed_world_size == 1:
return None
else:
raise e
if self.tpu and i < len(samples) - 1:
# tpu-comment: every XLA operation before marking step is
# appended to the IR graph, and processing too many batches
# before marking step can lead to OOM errors.
# To handle gradient accumulation use case, we explicitly
# mark step here for every forward pass without a backward pass
import torch_xla.core.xla_model as xm
xm.mark_step()
if is_dummy_batch:
if torch.is_tensor(sample_size):
sample_size.zero_()
else:
sample_size *= 0.
if torch.is_tensor(sample_size):
sample_size = sample_size.float()
else:
sample_size = float(sample_size)
# gather logging outputs from all replicas
if self._sync_stats():
train_time = self._local_cumulative_training_time()
logging_outputs, (sample_size, ooms, total_train_time) = self._aggregate_logging_outputs(
logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch,
)
self._cumulative_training_time = total_train_time / self.data_parallel_world_size
if hasattr(self.model, 'all_reduce'):
self.model.all_reduce()
overflow = False
try:
if self.tpu and self.data_parallel_world_size > 1:
import torch_xla.core.xla_model as xm
gradients = xm._fetch_gradients(self.optimizer.optimizer)
xm.all_reduce('sum', gradients, scale=1.0 / self.data_parallel_world_size)
with torch.autograd.profiler.record_function("multiply-grads"):
# multiply gradients by (# GPUs / sample_size) since DDP
# already normalizes by the number of GPUs. Thus we get
# (sum_of_gradients / sample_size).
if not self.args.use_bmuf:
self.optimizer.multiply_grads(self.data_parallel_world_size / sample_size)
elif sample_size > 0: # BMUF needs to check sample size
num = self.data_parallel_world_size if self._sync_stats() else 1
self.optimizer.multiply_grads(num / sample_size)
with torch.autograd.profiler.record_function("clip-grads"):
# clip grads
grad_norm = self.clip_grad_norm(self.args.clip_norm)
# check that grad norms are consistent across workers
if (
not self.args.use_bmuf
and self.args.distributed_wrapper != 'SlowMo'
and not self.tpu
):
self._check_grad_norms(grad_norm)
with torch.autograd.profiler.record_function("optimizer"):
# take an optimization step
self.optimizer.step()
except FloatingPointError:
# re-run the forward and backward pass with hooks attached to print
# out where it fails
with NanDetector(self.model):
self.task.train_step(
sample, self.model, self.criterion, self.optimizer, self.get_num_updates(),
ignore_grad=False
)
raise
except OverflowError as e:
overflow = True
logger.info("NOTE: overflow detected, " + str(e))
grad_norm = torch.tensor(0.).cuda()
self.zero_grad()
except RuntimeError as e:
if "out of memory" in str(e):
self._log_oom(e)
logger.error("OOM during optimization, irrecoverable")
raise e
# Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step
if hasattr(self.model, 'perform_additional_optimizer_actions'):
if hasattr(self.optimizer, 'fp32_params'):
self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params)
else:
self.model.perform_additional_optimizer_actions(self.optimizer.optimizer)
if not overflow or self.args.distributed_wrapper == 'SlowMo':
self.set_num_updates(self.get_num_updates() + 1)
if self.tpu:
# mark step on TPUs
import torch_xla.core.xla_model as xm
xm.mark_step()
# only log stats every log_interval steps
# this causes wps to be misreported when log_interval > 1
logging_output = {}
if self.get_num_updates() % self.args.log_interval == 0:
logging_output = self._reduce_and_log_stats(
logging_outputs, sample_size, grad_norm,
)
# log whenever there's an XLA compilation, since these
# slow down training and may indicate opportunities for
# optimization
self._check_xla_compilation()
else:
# log stats
logging_output = self._reduce_and_log_stats(
logging_outputs, sample_size, grad_norm,
)
# clear CUDA cache to reduce memory fragmentation
if (
self.cuda
and self.args.empty_cache_freq > 0
and (
(self.get_num_updates() + self.args.empty_cache_freq - 1)
% self.args.empty_cache_freq
) == 0
):
torch.cuda.empty_cache()
if self.args.fp16:
metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0)
metrics.log_stop_time("train_wall")
return logging_output
@metrics.aggregate("valid")
def valid_step(self, sample, raise_oom=False):
"""Do forward pass in evaluation mode."""
if self._dummy_batch == "DUMMY":
self._dummy_batch = sample
if self.tpu:
import torch_xla.core.xla_model as xm
xm.rendezvous('valid_step') # wait for all workers
xm.mark_step()
with torch.no_grad():
self.model.eval()
self.criterion.eval()
sample = self._prepare_sample(sample)
if sample is None:
sample = self._prepare_sample(self._dummy_batch)
is_dummy_batch = True
else:
is_dummy_batch = False
try:
_loss, sample_size, logging_output = self.task.valid_step(
sample, self.model, self.criterion
)
except RuntimeError as e:
if "out of memory" in str(e):
self._log_oom(e)
if not raise_oom:
logger.warning(
"ran out of memory in validation step, retrying batch"
)
for p in self.model.parameters():
if p.grad is not None:
p.grad = None # free some memory
if self.cuda:
torch.cuda.empty_cache()
return self.valid_step(sample, raise_oom=True)
raise e
logging_outputs = [logging_output]
if is_dummy_batch:
if torch.is_tensor(sample_size):
sample_size.zero_()
else:
sample_size *= 0.
# gather logging outputs from all replicas
if self.data_parallel_world_size > 1:
logging_outputs, (sample_size, ) = self._aggregate_logging_outputs(
logging_outputs, sample_size, ignore=is_dummy_batch,
)
# log validation stats
logging_output = self._reduce_and_log_stats(logging_outputs, sample_size)
return logging_output
def zero_grad(self):
self.optimizer.zero_grad()
def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate at the end of the epoch."""
self.lr_scheduler.step(epoch, val_loss)
# prefer updating the LR based on the number of steps
return self.lr_step_update()
def lr_step_update(self):
"""Update the learning rate after each update."""
new_lr = self.lr_scheduler.step_update(self.get_num_updates())
metrics.log_scalar("lr", new_lr, weight=0, priority=300)
return new_lr
def get_lr(self):
"""Get the current learning rate."""
return self.optimizer.get_lr()
def get_model(self):
"""Get the (non-wrapped) model instance."""
return self._model
def get_criterion(self):
"""Get the (non-wrapped) criterion instance."""
return self._criterion
def get_meter(self, name):
"""[deprecated] Get a specific meter by name."""
from fairseq import meters
if 'get_meter' not in self._warn_once:
self._warn_once.add('get_meter')
utils.deprecation_warning(
'Trainer.get_meter is deprecated. Please use fairseq.metrics instead.'
)
train_meters = metrics.get_meters("train")
if train_meters is None:
train_meters = {}
if name == "train_loss" and "loss" in train_meters:
return train_meters["loss"]
elif name == "train_nll_loss":
# support for legacy train.py, which assumed this meter is
# always initialized
m = train_meters.get("nll_loss", None)
return m or meters.AverageMeter()
elif name == "wall":
# support for legacy train.py, which assumed this meter is
# always initialized
m = metrics.get_meter("default", "wall")
return m or meters.TimeMeter()
elif name == "wps":
m = metrics.get_meter("train", "wps")
return m or meters.TimeMeter()
elif name in {"valid_loss", "valid_nll_loss"}:
# support for legacy train.py, which assumed these meters
# are always initialized
k = name[len("valid_"):]
m = metrics.get_meter("valid", k)
return m or meters.AverageMeter()
elif name == "oom":
return meters.AverageMeter()
elif name in train_meters:
return train_meters[name]
return None
def get_num_updates(self):
"""Get the number of parameters updates."""
return self._num_updates
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
self._num_updates = num_updates
self.lr_step_update()
if self.quantizer:
self.quantizer.step_update(self._num_updates)
metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
def clip_grad_norm(self, clip_norm):
return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=None)
def cumulative_training_time(self):
if self._cumulative_training_time is None:
# single GPU
return self._local_cumulative_training_time()
else:
return self._cumulative_training_time
def _local_cumulative_training_time(self):
"""Aggregate training time in seconds."""
return time.time() - self._start_time + self._previous_training_time
def _prepare_sample(self, sample):
if sample == "DUMMY":
raise Exception(
"Trying to use an uninitialized 'dummy' batch. This usually indicates "
"that the total number of batches is smaller than the number of "
"participating GPUs. Try reducing the batch size or using fewer GPUs."
)
if sample is None or len(sample) == 0:
return None
if self.cuda:
sample = utils.move_to_cuda(sample)
def apply_half(t):
if t.dtype is torch.float32:
return t.half()
return t
def apply_bfloat16(t):
if t.dtype is torch.float32:
return t.to(dtype=torch.bfloat16)
return t
if self.args.fp16:
sample = utils.apply_to_sample(apply_half, sample)
if self.args.bf16:
sample = utils.apply_to_sample(apply_bfloat16, sample)
return sample
def _set_seed(self):
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
utils.set_torch_seed(seed)
def _sync_stats(self):
# Return True if it's using multiple GPUs and DDP or multiple GPUs with
# BMUF and it's a bmuf sync with warmup iterations completed before.
if self.data_parallel_world_size == 1:
return False
elif self.args.use_bmuf:
return (
(self.get_num_updates() + 1) % self.args.global_sync_iter == 0
and (self.get_num_updates() + 1) > self.args.warmup_iterations
)
else:
return True
def _log_oom(self, exc):
msg = "OOM: Ran out of memory with exception: {}".format(exc)
logger.warning(msg)
if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
for device_idx in range(torch.cuda.device_count()):
logger.warning(torch.cuda.memory_summary(device=device_idx))
sys.stderr.flush()
def _aggregate_logging_outputs(
self,
logging_outputs: List[Dict[str, Any]],
*extra_stats_to_sum,
ignore=False,
):
if self.task.__class__.logging_outputs_can_be_summed(self.get_criterion()):
return self._fast_stat_sync_sum(
logging_outputs, *extra_stats_to_sum, ignore=ignore
)
else:
return self._all_gather_list_sync(
logging_outputs, *extra_stats_to_sum, ignore=ignore
)
def _all_gather_list_sync(
self,
logging_outputs: List[Dict[str, Any]],
*extra_stats_to_sum,
ignore=False,
):
"""
Sync logging outputs across workers. all_gather_list_sync is
suitable when logging outputs are complex types.
"""
if self.tpu:
raise NotImplementedError
if ignore:
logging_outputs = []
results = list(zip(
*distributed_utils.all_gather_list(
[logging_outputs] + list(extra_stats_to_sum),
max_size=getattr(self.args, 'all_gather_list_size', 16384),
group=self.data_parallel_process_group,
)
))
logging_outputs, extra_stats_to_sum = results[0], results[1:]
logging_outputs = list(chain.from_iterable(logging_outputs))
extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum]
return logging_outputs, extra_stats_to_sum
def _fast_stat_sync_sum(
self,
logging_outputs: List[Dict[str, Any]],
*extra_stats_to_sum,
ignore=False,
):
"""
Sync logging outputs across workers. fast_stat_sync_sum is
faster than all_gather_list_sync, but is only suitable when
logging outputs are scalars and can be summed. Note that
*logging_outputs* cannot contain any nested dicts/lists.
"""
data = {}
for i, stat in enumerate(extra_stats_to_sum):
data['extra_stats_' + str(i)] = stat
if len(logging_outputs) > 0:
log_keys = list(logging_outputs[0].keys())
for k in log_keys:
if not ignore:
v = sum(log[k] for log in logging_outputs if k in log)
else:
v = logging_outputs[0][k]
v = torch.zeros_like(v) if torch.is_tensor(v) else 0
data['logging_outputs_' + k] = v
else:
log_keys = None
data = distributed_utils.all_reduce_dict(
data,
device=self.device,
group=self.data_parallel_process_group
)
extra_stats_to_sum = [
data['extra_stats_' + str(i)] for i in range(len(extra_stats_to_sum))
]
if log_keys is not None:
logging_outputs = [{k: data['logging_outputs_' + k] for k in log_keys}]
else:
logging_outputs = []
return logging_outputs, extra_stats_to_sum
def _check_grad_norms(self, grad_norm):
"""Check that grad norms are consistent across workers."""
if self._grad_norm_buf is not None:
self._grad_norm_buf.zero_()
self._grad_norm_buf[self.data_parallel_rank] = grad_norm
distributed_utils.all_reduce(
self._grad_norm_buf,
group=self.data_parallel_process_group
)
def is_consistent(tensor):
max_abs_diff = torch.max(torch.abs(tensor - tensor[0]))
return (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all()
if not is_consistent(self._grad_norm_buf):
pretty_detail = "\n".join(
"rank {:3d} = {:.8f}".format(r, n)
for r, n in enumerate(self._grad_norm_buf.tolist())
)
error_detail = "grad_norm across the workers:\n{}\n".format(pretty_detail)
raise RuntimeError(
"Fatal error: gradients are inconsistent between workers. "
"Try --ddp-backend=no_c10d. "
"Or are you mixing up different generation of GPUs in training?"
+ "\n"
+ "-" * 80
+ "\n{}\n".format(error_detail)
+ "-" * 80
)
def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None):
if grad_norm is not None:
metrics.log_speed("ups", 1., priority=100, round=2)
metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
if self.args.clip_norm > 0:
metrics.log_scalar(
"clip",
torch.where(
grad_norm > self.args.clip_norm,
grad_norm.new_tensor(100),
grad_norm.new_tensor(0),
),
priority=500,
round=1,
)
with metrics.aggregate() as agg:
if logging_outputs is not None:
self.task.reduce_metrics(logging_outputs, self.get_criterion())
del logging_outputs
# extra warning for criterions that don't properly log a loss value
if "loss" not in agg:
if "loss" not in self._warn_once:
self._warn_once.add("loss")
logger.warning(
"Criterion.reduce_metrics did not log a 'loss' value, "
"which may break some functionality"
)
metrics.log_scalar("loss", -1)
# support legacy interface
if self.tpu:
logging_output = {}
else:
logging_output = agg.get_smoothed_values()
logging_output["sample_size"] = sample_size
for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
if key_to_delete in logging_output:
del logging_output[key_to_delete]
return logging_output
def _check_xla_compilation(self, message=None):
import torch_xla.debug.metrics as met
compile_stats = met.metric_data("CompileTime")
if compile_stats is None:
return
num_xla_compiles = compile_stats[0]
if num_xla_compiles > self._num_xla_compiles:
if message is None:
message = (
"too many of these can lead to slow training, "
"but we expect a few in the beginning"
)
logging.info("NOTE: XLA compilation detected; {}".format(message))
self._num_xla_compiles = num_xla_compiles
def _catalog_shared_params(module, memo=None, prefix=''):
if memo is None:
first_call = True
memo = {}
else:
first_call = False
for name, param in module._parameters.items():
param_prefix = prefix + ('.' if prefix else '') + name
if param not in memo:
memo[param] = []
memo[param].append(param_prefix)
for name, m in module._modules.items():
if m is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
_catalog_shared_params(m, memo, submodule_prefix)
if first_call:
return [x for x in memo.values() if len(x) > 1]
def _get_module_by_path(module, path):
path = path.split('.')
for name in path:
module = getattr(module, name)
return module
def _set_module_by_path(module, path, value):
path = path.split('.')
for name in path[:-1]:
module = getattr(module, name)
setattr(module, path[-1], value)