|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Train a network across multiple GPUs. |
|
|
""" |
|
|
|
|
|
import contextlib |
|
|
from itertools import chain |
|
|
import logging |
|
|
import sys |
|
|
import time |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
import torch |
|
|
|
|
|
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils |
|
|
from fairseq.file_io import PathManager |
|
|
from fairseq.logging import meters, metrics |
|
|
from fairseq.nan_detector import NanDetector |
|
|
from fairseq.optim import lr_scheduler |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class Trainer(object): |
|
|
"""Main class for data parallel training. |
|
|
|
|
|
This class supports synchronous distributed data parallel training, |
|
|
where multiple workers each have a full model replica and gradients |
|
|
are accumulated across workers before each update. We use |
|
|
:class:`~torch.nn.parallel.DistributedDataParallel` to handle |
|
|
communication of the gradients across workers. |
|
|
""" |
|
|
|
|
|
def __init__(self, args, task, model, criterion, quantizer=None): |
|
|
self.args = args |
|
|
self.task = task |
|
|
|
|
|
|
|
|
shared_params = _catalog_shared_params(model) |
|
|
|
|
|
self.tpu = getattr(args, 'tpu', False) |
|
|
self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu |
|
|
if self.cuda: |
|
|
self.device = torch.device('cuda') |
|
|
elif self.tpu: |
|
|
self.device = utils.get_tpu_device(args) |
|
|
else: |
|
|
self.device = torch.device('cpu') |
|
|
|
|
|
|
|
|
self._criterion = criterion |
|
|
self._model = model |
|
|
if self.tpu: |
|
|
import torch_xla.core.xla_model as xm |
|
|
self._model = xm.send_cpu_data_to_device(self._model, self.device) |
|
|
if args.fp16: |
|
|
self._criterion = self._criterion.half() |
|
|
self._model = self._model.half() |
|
|
elif args.bf16: |
|
|
self._criterion = self._criterion.to(dtype=torch.bfloat16) |
|
|
self._model = self._model.to(dtype=torch.bfloat16) |
|
|
self._criterion = self._criterion.to(device=self.device) |
|
|
self._model = self._model.to(device=self.device) |
|
|
|
|
|
|
|
|
for shared_param in shared_params: |
|
|
ref = _get_module_by_path(self._model, shared_param[0]) |
|
|
for path in shared_param[1:]: |
|
|
logger.info( |
|
|
'detected shared parameter: {} <- {}'.format(shared_param[0], path) |
|
|
) |
|
|
_set_module_by_path(self._model, path, ref) |
|
|
|
|
|
self._dummy_batch = "DUMMY" |
|
|
self._lr_scheduler = None |
|
|
self._num_updates = 0 |
|
|
self._num_xla_compiles = 0 |
|
|
self._optim_history = None |
|
|
self._optimizer = None |
|
|
self._warn_once = set() |
|
|
self._wrapped_criterion = None |
|
|
self._wrapped_model = None |
|
|
|
|
|
|
|
|
if self.cuda and self.data_parallel_world_size > 1: |
|
|
self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size) |
|
|
else: |
|
|
self._grad_norm_buf = None |
|
|
|
|
|
self.quantizer = quantizer |
|
|
if self.quantizer is not None: |
|
|
self.quantizer.set_trainer(self) |
|
|
|
|
|
|
|
|
if self.cuda: |
|
|
self.cuda_env = utils.CudaEnvironment() |
|
|
if self.data_parallel_world_size > 1: |
|
|
self.cuda_env_arr = distributed_utils.all_gather_list(self.cuda_env) |
|
|
else: |
|
|
self.cuda_env_arr = [self.cuda_env] |
|
|
if self.data_parallel_rank == 0: |
|
|
utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr) |
|
|
else: |
|
|
self.cuda_env = None |
|
|
self.cuda_env_arr = None |
|
|
|
|
|
metrics.log_start_time("wall", priority=790, round=0) |
|
|
|
|
|
self._start_time = time.time() |
|
|
self._previous_training_time = 0 |
|
|
self._cumulative_training_time = None |
|
|
|
|
|
def reinitialize(self): |
|
|
"""Reinitialize the Trainer, typically after model params change.""" |
|
|
self._lr_scheduler = None |
|
|
self._optimizer = None |
|
|
self._wrapped_criterion = None |
|
|
self._wrapped_model = None |
|
|
|
|
|
@property |
|
|
def data_parallel_world_size(self): |
|
|
return self.args.distributed_world_size |
|
|
|
|
|
@property |
|
|
def data_parallel_process_group(self): |
|
|
if self.tpu: |
|
|
return ('tpu', None) |
|
|
else: |
|
|
return None |
|
|
|
|
|
@property |
|
|
def data_parallel_rank(self): |
|
|
return self.args.distributed_rank |
|
|
|
|
|
@property |
|
|
def is_data_parallel_master(self): |
|
|
return distributed_utils.is_master(self.args) |
|
|
|
|
|
@property |
|
|
def criterion(self): |
|
|
if self._wrapped_criterion is None: |
|
|
if ( |
|
|
utils.has_parameters(self._criterion) |
|
|
and self.data_parallel_world_size > 1 |
|
|
and not self.args.use_bmuf |
|
|
and not self.tpu |
|
|
): |
|
|
self._wrapped_criterion = models.DistributedFairseqModel( |
|
|
self.args, self._criterion, |
|
|
process_group=self.data_parallel_process_group |
|
|
) |
|
|
else: |
|
|
self._wrapped_criterion = self._criterion |
|
|
return self._wrapped_criterion |
|
|
|
|
|
@property |
|
|
def model(self): |
|
|
if self._wrapped_model is None: |
|
|
if ( |
|
|
self.data_parallel_world_size > 1 |
|
|
and not self.args.use_bmuf |
|
|
and not self.tpu |
|
|
): |
|
|
self._wrapped_model = models.DistributedFairseqModel( |
|
|
self.args, self._model, |
|
|
process_group=self.data_parallel_process_group |
|
|
) |
|
|
else: |
|
|
self._wrapped_model = self._model |
|
|
return self._wrapped_model |
|
|
|
|
|
@property |
|
|
def optimizer(self): |
|
|
if self._optimizer is None: |
|
|
self._build_optimizer() |
|
|
return self._optimizer |
|
|
|
|
|
@property |
|
|
def lr_scheduler(self): |
|
|
if self._lr_scheduler is None: |
|
|
self._build_optimizer() |
|
|
return self._lr_scheduler |
|
|
|
|
|
def _build_optimizer(self): |
|
|
params = list( |
|
|
filter( |
|
|
lambda p: p.requires_grad, |
|
|
chain(self.model.parameters(), self.criterion.parameters()), |
|
|
) |
|
|
) |
|
|
|
|
|
if self.args.fp16 or self.args.bf16: |
|
|
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: |
|
|
logger.info( |
|
|
"NOTE: your device does NOT support faster training with --fp16, " |
|
|
"please switch to FP32 which is likely to be faster" |
|
|
) |
|
|
if self.args.memory_efficient_fp16 or self.args.memory_efficient_bf16: |
|
|
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( |
|
|
self.args, params |
|
|
) |
|
|
else: |
|
|
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params) |
|
|
else: |
|
|
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: |
|
|
logger.info("NOTE: your device may support faster training with --fp16") |
|
|
self._optimizer = optim.build_optimizer(self.args, params) |
|
|
|
|
|
if self.args.use_bmuf: |
|
|
self._optimizer = optim.FairseqBMUF(self.args, self._optimizer) |
|
|
|
|
|
|
|
|
|
|
|
self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) |
|
|
self._lr_scheduler.step_update(0) |
|
|
|
|
|
def save_checkpoint(self, filename, extra_state): |
|
|
"""Save all training state in a checkpoint file.""" |
|
|
if self.is_data_parallel_master: |
|
|
extra_state["metrics"] = metrics.state_dict() |
|
|
extra_state["previous_training_time"] = self.cumulative_training_time() |
|
|
checkpoint_utils.save_state( |
|
|
filename, |
|
|
self.args, |
|
|
self.get_model().state_dict(), |
|
|
self.get_criterion(), |
|
|
self.optimizer, |
|
|
self.lr_scheduler, |
|
|
self.get_num_updates(), |
|
|
self._optim_history, |
|
|
extra_state, |
|
|
) |
|
|
|
|
|
def load_checkpoint( |
|
|
self, |
|
|
filename, |
|
|
reset_optimizer=False, |
|
|
reset_lr_scheduler=False, |
|
|
optimizer_overrides=None, |
|
|
reset_meters=False, |
|
|
): |
|
|
"""Load all training state from a checkpoint file.""" |
|
|
extra_state, self._optim_history, last_optim_state = None, [], None |
|
|
|
|
|
bexists = PathManager.isfile(filename) |
|
|
if bexists: |
|
|
state = checkpoint_utils.load_checkpoint_to_cpu(filename) |
|
|
|
|
|
|
|
|
try: |
|
|
self.get_model().load_state_dict( |
|
|
state["model"], strict=True, args=self.args |
|
|
) |
|
|
if utils.has_parameters(self.get_criterion()): |
|
|
self.get_criterion().load_state_dict( |
|
|
state["criterion"], strict=True |
|
|
) |
|
|
except Exception: |
|
|
raise Exception( |
|
|
"Cannot load model parameters from checkpoint {}; " |
|
|
"please ensure that the architectures match.".format(filename) |
|
|
) |
|
|
|
|
|
extra_state = state["extra_state"] |
|
|
self._optim_history = state["optimizer_history"] |
|
|
last_optim_state = state.get("last_optimizer_state", None) |
|
|
|
|
|
if last_optim_state is not None and not reset_optimizer: |
|
|
|
|
|
self._build_optimizer() |
|
|
|
|
|
|
|
|
last_optim = self._optim_history[-1] |
|
|
assert ( |
|
|
last_optim["criterion_name"] == self.get_criterion().__class__.__name__ |
|
|
), "Criterion does not match; please reset the optimizer (--reset-optimizer)." |
|
|
assert ( |
|
|
last_optim["optimizer_name"] == self.optimizer.__class__.__name__ |
|
|
), "Optimizer does not match; please reset the optimizer (--reset-optimizer)." |
|
|
|
|
|
if not reset_lr_scheduler: |
|
|
self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) |
|
|
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) |
|
|
|
|
|
self.set_num_updates(last_optim["num_updates"]) |
|
|
|
|
|
if extra_state is not None: |
|
|
epoch = extra_state["train_iterator"]["epoch"] |
|
|
logger.info( |
|
|
"loaded checkpoint {} (epoch {} @ {} updates)".format( |
|
|
filename, epoch, self.get_num_updates() |
|
|
) |
|
|
) |
|
|
|
|
|
if "previous_training_time" in extra_state: |
|
|
self._previous_training_time = extra_state["previous_training_time"] |
|
|
self._start_time = time.time() |
|
|
|
|
|
self.lr_step(epoch) |
|
|
|
|
|
if "metrics" in extra_state and not reset_meters: |
|
|
metrics.load_state_dict(extra_state["metrics"]) |
|
|
|
|
|
|
|
|
for meter in metrics.get_meters("default"): |
|
|
if isinstance(meter, meters.TimeMeter): |
|
|
meter.reset() |
|
|
else: |
|
|
logger.info("no existing checkpoint found {}".format(filename)) |
|
|
|
|
|
return extra_state |
|
|
|
|
|
def get_train_iterator( |
|
|
self, |
|
|
epoch, |
|
|
combine=True, |
|
|
load_dataset=True, |
|
|
data_selector=None, |
|
|
shard_batch_itr=True, |
|
|
): |
|
|
"""Return an EpochBatchIterator over the training set for a given epoch.""" |
|
|
if load_dataset: |
|
|
logger.info("loading train data for epoch {}".format(epoch)) |
|
|
self.task.load_dataset( |
|
|
self.args.train_subset, |
|
|
epoch=epoch, |
|
|
combine=combine, |
|
|
data_selector=data_selector, |
|
|
) |
|
|
return self.task.get_batch_iterator( |
|
|
dataset=self.task.dataset(self.args.train_subset), |
|
|
max_tokens=self.args.max_tokens, |
|
|
max_sentences=self.args.max_sentences, |
|
|
max_positions=utils.resolve_max_positions( |
|
|
self.task.max_positions(), |
|
|
self.model.max_positions(), |
|
|
self.args.max_tokens, |
|
|
), |
|
|
ignore_invalid_inputs=True, |
|
|
required_batch_size_multiple=self.args.required_batch_size_multiple, |
|
|
seed=self.args.seed, |
|
|
num_shards=self.data_parallel_world_size if shard_batch_itr else 1, |
|
|
shard_id=self.data_parallel_rank if shard_batch_itr else 0, |
|
|
num_workers=self.args.num_workers, |
|
|
epoch=epoch |
|
|
) |
|
|
|
|
|
def get_valid_iterator( |
|
|
self, |
|
|
subset, |
|
|
): |
|
|
"""Return an EpochBatchIterator over given validation subset for a given epoch.""" |
|
|
return self.task.get_batch_iterator( |
|
|
dataset=self.task.dataset(subset), |
|
|
max_tokens=self.args.max_tokens_valid, |
|
|
max_sentences=self.args.max_sentences_valid, |
|
|
max_positions=utils.resolve_max_positions( |
|
|
self.task.max_positions(), |
|
|
self.model.max_positions(), |
|
|
), |
|
|
ignore_invalid_inputs=self.args.skip_invalid_size_inputs_valid_test, |
|
|
required_batch_size_multiple=self.args.required_batch_size_multiple, |
|
|
seed=self.args.seed, |
|
|
num_shards=self.data_parallel_world_size, |
|
|
shard_id=self.data_parallel_rank, |
|
|
num_workers=self.args.num_workers |
|
|
) |
|
|
|
|
|
def begin_epoch(self, epoch): |
|
|
"""Called at the beginning of each epoch.""" |
|
|
if self.quantizer is not None: |
|
|
self.quantizer.begin_epoch(epoch) |
|
|
|
|
|
|
|
|
self.task.begin_epoch(epoch, self.get_model()) |
|
|
|
|
|
@metrics.aggregate("train") |
|
|
def train_step(self, samples, raise_oom=False): |
|
|
"""Do forward, backward and parameter update.""" |
|
|
if self._dummy_batch == "DUMMY": |
|
|
self._dummy_batch = samples[0] |
|
|
|
|
|
self._set_seed() |
|
|
self.model.train() |
|
|
self.criterion.train() |
|
|
self.zero_grad() |
|
|
|
|
|
metrics.log_start_time("train_wall", priority=800, round=0) |
|
|
|
|
|
|
|
|
logging_outputs, sample_size, ooms = [], 0, 0 |
|
|
for i, sample in enumerate(samples): |
|
|
sample = self._prepare_sample(sample) |
|
|
if sample is None: |
|
|
|
|
|
|
|
|
sample = self._prepare_sample(self._dummy_batch) |
|
|
is_dummy_batch = True |
|
|
else: |
|
|
is_dummy_batch = False |
|
|
|
|
|
def maybe_no_sync(): |
|
|
""" |
|
|
Whenever *samples* contains more than one mini-batch, we |
|
|
want to accumulate gradients locally and only call |
|
|
all-reduce in the last backwards pass. |
|
|
""" |
|
|
if ( |
|
|
self.data_parallel_world_size > 1 |
|
|
and hasattr(self.model, "no_sync") |
|
|
and i < len(samples) - 1 |
|
|
): |
|
|
return self.model.no_sync() |
|
|
else: |
|
|
return contextlib.ExitStack() |
|
|
|
|
|
try: |
|
|
with maybe_no_sync(): |
|
|
|
|
|
loss, sample_size_i, logging_output = self.task.train_step( |
|
|
sample=sample, |
|
|
model=self.model, |
|
|
criterion=self.criterion, |
|
|
optimizer=self.optimizer, |
|
|
update_num=self.get_num_updates(), |
|
|
ignore_grad=is_dummy_batch, |
|
|
) |
|
|
del loss |
|
|
|
|
|
logging_outputs.append(logging_output) |
|
|
sample_size += sample_size_i |
|
|
|
|
|
|
|
|
|
|
|
if self.cuda and self.get_num_updates() == 0: |
|
|
torch.cuda.empty_cache() |
|
|
except RuntimeError as e: |
|
|
if "out of memory" in str(e): |
|
|
self._log_oom(e) |
|
|
if raise_oom: |
|
|
raise e |
|
|
logger.warning( |
|
|
"attempting to recover from OOM in forward/backward pass" |
|
|
) |
|
|
ooms += 1 |
|
|
self.zero_grad() |
|
|
if self.cuda: |
|
|
torch.cuda.empty_cache() |
|
|
if self.args.distributed_world_size == 1: |
|
|
return None |
|
|
else: |
|
|
raise e |
|
|
|
|
|
if self.tpu and i < len(samples) - 1: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch_xla.core.xla_model as xm |
|
|
xm.mark_step() |
|
|
|
|
|
if is_dummy_batch: |
|
|
if torch.is_tensor(sample_size): |
|
|
sample_size.zero_() |
|
|
else: |
|
|
sample_size *= 0. |
|
|
|
|
|
if torch.is_tensor(sample_size): |
|
|
sample_size = sample_size.float() |
|
|
else: |
|
|
sample_size = float(sample_size) |
|
|
|
|
|
|
|
|
if self._sync_stats(): |
|
|
train_time = self._local_cumulative_training_time() |
|
|
logging_outputs, (sample_size, ooms, total_train_time) = self._aggregate_logging_outputs( |
|
|
logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch, |
|
|
) |
|
|
self._cumulative_training_time = total_train_time / self.data_parallel_world_size |
|
|
|
|
|
if hasattr(self.model, 'all_reduce'): |
|
|
self.model.all_reduce() |
|
|
|
|
|
overflow = False |
|
|
try: |
|
|
if self.tpu and self.data_parallel_world_size > 1: |
|
|
import torch_xla.core.xla_model as xm |
|
|
gradients = xm._fetch_gradients(self.optimizer.optimizer) |
|
|
xm.all_reduce('sum', gradients, scale=1.0 / self.data_parallel_world_size) |
|
|
|
|
|
with torch.autograd.profiler.record_function("multiply-grads"): |
|
|
|
|
|
|
|
|
|
|
|
if not self.args.use_bmuf: |
|
|
self.optimizer.multiply_grads(self.data_parallel_world_size / sample_size) |
|
|
elif sample_size > 0: |
|
|
num = self.data_parallel_world_size if self._sync_stats() else 1 |
|
|
self.optimizer.multiply_grads(num / sample_size) |
|
|
|
|
|
with torch.autograd.profiler.record_function("clip-grads"): |
|
|
|
|
|
grad_norm = self.clip_grad_norm(self.args.clip_norm) |
|
|
|
|
|
|
|
|
if ( |
|
|
not self.args.use_bmuf |
|
|
and self.args.distributed_wrapper != 'SlowMo' |
|
|
and not self.tpu |
|
|
): |
|
|
self._check_grad_norms(grad_norm) |
|
|
|
|
|
with torch.autograd.profiler.record_function("optimizer"): |
|
|
|
|
|
self.optimizer.step() |
|
|
except FloatingPointError: |
|
|
|
|
|
|
|
|
with NanDetector(self.model): |
|
|
self.task.train_step( |
|
|
sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), |
|
|
ignore_grad=False |
|
|
) |
|
|
raise |
|
|
except OverflowError as e: |
|
|
overflow = True |
|
|
logger.info("NOTE: overflow detected, " + str(e)) |
|
|
grad_norm = torch.tensor(0.).cuda() |
|
|
self.zero_grad() |
|
|
except RuntimeError as e: |
|
|
if "out of memory" in str(e): |
|
|
self._log_oom(e) |
|
|
logger.error("OOM during optimization, irrecoverable") |
|
|
raise e |
|
|
|
|
|
|
|
|
if hasattr(self.model, 'perform_additional_optimizer_actions'): |
|
|
if hasattr(self.optimizer, 'fp32_params'): |
|
|
self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params) |
|
|
else: |
|
|
self.model.perform_additional_optimizer_actions(self.optimizer.optimizer) |
|
|
|
|
|
if not overflow or self.args.distributed_wrapper == 'SlowMo': |
|
|
self.set_num_updates(self.get_num_updates() + 1) |
|
|
|
|
|
if self.tpu: |
|
|
|
|
|
import torch_xla.core.xla_model as xm |
|
|
xm.mark_step() |
|
|
|
|
|
|
|
|
|
|
|
logging_output = {} |
|
|
if self.get_num_updates() % self.args.log_interval == 0: |
|
|
logging_output = self._reduce_and_log_stats( |
|
|
logging_outputs, sample_size, grad_norm, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._check_xla_compilation() |
|
|
else: |
|
|
|
|
|
logging_output = self._reduce_and_log_stats( |
|
|
logging_outputs, sample_size, grad_norm, |
|
|
) |
|
|
|
|
|
|
|
|
if ( |
|
|
self.cuda |
|
|
and self.args.empty_cache_freq > 0 |
|
|
and ( |
|
|
(self.get_num_updates() + self.args.empty_cache_freq - 1) |
|
|
% self.args.empty_cache_freq |
|
|
) == 0 |
|
|
): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
if self.args.fp16: |
|
|
metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) |
|
|
|
|
|
metrics.log_stop_time("train_wall") |
|
|
|
|
|
return logging_output |
|
|
|
|
|
@metrics.aggregate("valid") |
|
|
def valid_step(self, sample, raise_oom=False): |
|
|
"""Do forward pass in evaluation mode.""" |
|
|
if self._dummy_batch == "DUMMY": |
|
|
self._dummy_batch = sample |
|
|
if self.tpu: |
|
|
import torch_xla.core.xla_model as xm |
|
|
xm.rendezvous('valid_step') |
|
|
xm.mark_step() |
|
|
|
|
|
with torch.no_grad(): |
|
|
self.model.eval() |
|
|
self.criterion.eval() |
|
|
|
|
|
sample = self._prepare_sample(sample) |
|
|
if sample is None: |
|
|
sample = self._prepare_sample(self._dummy_batch) |
|
|
is_dummy_batch = True |
|
|
else: |
|
|
is_dummy_batch = False |
|
|
|
|
|
try: |
|
|
_loss, sample_size, logging_output = self.task.valid_step( |
|
|
sample, self.model, self.criterion |
|
|
) |
|
|
except RuntimeError as e: |
|
|
if "out of memory" in str(e): |
|
|
self._log_oom(e) |
|
|
if not raise_oom: |
|
|
logger.warning( |
|
|
"ran out of memory in validation step, retrying batch" |
|
|
) |
|
|
for p in self.model.parameters(): |
|
|
if p.grad is not None: |
|
|
p.grad = None |
|
|
if self.cuda: |
|
|
torch.cuda.empty_cache() |
|
|
return self.valid_step(sample, raise_oom=True) |
|
|
raise e |
|
|
|
|
|
logging_outputs = [logging_output] |
|
|
if is_dummy_batch: |
|
|
if torch.is_tensor(sample_size): |
|
|
sample_size.zero_() |
|
|
else: |
|
|
sample_size *= 0. |
|
|
|
|
|
|
|
|
if self.data_parallel_world_size > 1: |
|
|
logging_outputs, (sample_size, ) = self._aggregate_logging_outputs( |
|
|
logging_outputs, sample_size, ignore=is_dummy_batch, |
|
|
) |
|
|
|
|
|
|
|
|
logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) |
|
|
|
|
|
return logging_output |
|
|
|
|
|
def zero_grad(self): |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
def lr_step(self, epoch, val_loss=None): |
|
|
"""Adjust the learning rate at the end of the epoch.""" |
|
|
self.lr_scheduler.step(epoch, val_loss) |
|
|
|
|
|
return self.lr_step_update() |
|
|
|
|
|
def lr_step_update(self): |
|
|
"""Update the learning rate after each update.""" |
|
|
new_lr = self.lr_scheduler.step_update(self.get_num_updates()) |
|
|
metrics.log_scalar("lr", new_lr, weight=0, priority=300) |
|
|
return new_lr |
|
|
|
|
|
def get_lr(self): |
|
|
"""Get the current learning rate.""" |
|
|
return self.optimizer.get_lr() |
|
|
|
|
|
def get_model(self): |
|
|
"""Get the (non-wrapped) model instance.""" |
|
|
return self._model |
|
|
|
|
|
def get_criterion(self): |
|
|
"""Get the (non-wrapped) criterion instance.""" |
|
|
return self._criterion |
|
|
|
|
|
def get_meter(self, name): |
|
|
"""[deprecated] Get a specific meter by name.""" |
|
|
from fairseq import meters |
|
|
|
|
|
if 'get_meter' not in self._warn_once: |
|
|
self._warn_once.add('get_meter') |
|
|
utils.deprecation_warning( |
|
|
'Trainer.get_meter is deprecated. Please use fairseq.metrics instead.' |
|
|
) |
|
|
|
|
|
train_meters = metrics.get_meters("train") |
|
|
if train_meters is None: |
|
|
train_meters = {} |
|
|
|
|
|
if name == "train_loss" and "loss" in train_meters: |
|
|
return train_meters["loss"] |
|
|
elif name == "train_nll_loss": |
|
|
|
|
|
|
|
|
m = train_meters.get("nll_loss", None) |
|
|
return m or meters.AverageMeter() |
|
|
elif name == "wall": |
|
|
|
|
|
|
|
|
m = metrics.get_meter("default", "wall") |
|
|
return m or meters.TimeMeter() |
|
|
elif name == "wps": |
|
|
m = metrics.get_meter("train", "wps") |
|
|
return m or meters.TimeMeter() |
|
|
elif name in {"valid_loss", "valid_nll_loss"}: |
|
|
|
|
|
|
|
|
k = name[len("valid_"):] |
|
|
m = metrics.get_meter("valid", k) |
|
|
return m or meters.AverageMeter() |
|
|
elif name == "oom": |
|
|
return meters.AverageMeter() |
|
|
elif name in train_meters: |
|
|
return train_meters[name] |
|
|
return None |
|
|
|
|
|
def get_num_updates(self): |
|
|
"""Get the number of parameters updates.""" |
|
|
return self._num_updates |
|
|
|
|
|
def set_num_updates(self, num_updates): |
|
|
"""Set the number of parameters updates.""" |
|
|
self._num_updates = num_updates |
|
|
self.lr_step_update() |
|
|
if self.quantizer: |
|
|
self.quantizer.step_update(self._num_updates) |
|
|
metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200) |
|
|
|
|
|
def clip_grad_norm(self, clip_norm): |
|
|
return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=None) |
|
|
|
|
|
def cumulative_training_time(self): |
|
|
if self._cumulative_training_time is None: |
|
|
|
|
|
return self._local_cumulative_training_time() |
|
|
else: |
|
|
return self._cumulative_training_time |
|
|
|
|
|
def _local_cumulative_training_time(self): |
|
|
"""Aggregate training time in seconds.""" |
|
|
return time.time() - self._start_time + self._previous_training_time |
|
|
|
|
|
def _prepare_sample(self, sample): |
|
|
if sample == "DUMMY": |
|
|
raise Exception( |
|
|
"Trying to use an uninitialized 'dummy' batch. This usually indicates " |
|
|
"that the total number of batches is smaller than the number of " |
|
|
"participating GPUs. Try reducing the batch size or using fewer GPUs." |
|
|
) |
|
|
|
|
|
if sample is None or len(sample) == 0: |
|
|
return None |
|
|
|
|
|
if self.cuda: |
|
|
sample = utils.move_to_cuda(sample) |
|
|
|
|
|
def apply_half(t): |
|
|
if t.dtype is torch.float32: |
|
|
return t.half() |
|
|
return t |
|
|
|
|
|
def apply_bfloat16(t): |
|
|
if t.dtype is torch.float32: |
|
|
return t.to(dtype=torch.bfloat16) |
|
|
return t |
|
|
|
|
|
if self.args.fp16: |
|
|
sample = utils.apply_to_sample(apply_half, sample) |
|
|
|
|
|
if self.args.bf16: |
|
|
sample = utils.apply_to_sample(apply_bfloat16, sample) |
|
|
|
|
|
return sample |
|
|
|
|
|
def _set_seed(self): |
|
|
|
|
|
|
|
|
seed = self.args.seed + self.get_num_updates() |
|
|
utils.set_torch_seed(seed) |
|
|
|
|
|
def _sync_stats(self): |
|
|
|
|
|
|
|
|
if self.data_parallel_world_size == 1: |
|
|
return False |
|
|
elif self.args.use_bmuf: |
|
|
return ( |
|
|
(self.get_num_updates() + 1) % self.args.global_sync_iter == 0 |
|
|
and (self.get_num_updates() + 1) > self.args.warmup_iterations |
|
|
) |
|
|
else: |
|
|
return True |
|
|
|
|
|
def _log_oom(self, exc): |
|
|
msg = "OOM: Ran out of memory with exception: {}".format(exc) |
|
|
logger.warning(msg) |
|
|
if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"): |
|
|
for device_idx in range(torch.cuda.device_count()): |
|
|
logger.warning(torch.cuda.memory_summary(device=device_idx)) |
|
|
sys.stderr.flush() |
|
|
|
|
|
def _aggregate_logging_outputs( |
|
|
self, |
|
|
logging_outputs: List[Dict[str, Any]], |
|
|
*extra_stats_to_sum, |
|
|
ignore=False, |
|
|
): |
|
|
if self.task.__class__.logging_outputs_can_be_summed(self.get_criterion()): |
|
|
return self._fast_stat_sync_sum( |
|
|
logging_outputs, *extra_stats_to_sum, ignore=ignore |
|
|
) |
|
|
else: |
|
|
return self._all_gather_list_sync( |
|
|
logging_outputs, *extra_stats_to_sum, ignore=ignore |
|
|
) |
|
|
|
|
|
def _all_gather_list_sync( |
|
|
self, |
|
|
logging_outputs: List[Dict[str, Any]], |
|
|
*extra_stats_to_sum, |
|
|
ignore=False, |
|
|
): |
|
|
""" |
|
|
Sync logging outputs across workers. all_gather_list_sync is |
|
|
suitable when logging outputs are complex types. |
|
|
""" |
|
|
if self.tpu: |
|
|
raise NotImplementedError |
|
|
if ignore: |
|
|
logging_outputs = [] |
|
|
results = list(zip( |
|
|
*distributed_utils.all_gather_list( |
|
|
[logging_outputs] + list(extra_stats_to_sum), |
|
|
max_size=getattr(self.args, 'all_gather_list_size', 16384), |
|
|
group=self.data_parallel_process_group, |
|
|
) |
|
|
)) |
|
|
logging_outputs, extra_stats_to_sum = results[0], results[1:] |
|
|
logging_outputs = list(chain.from_iterable(logging_outputs)) |
|
|
extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum] |
|
|
return logging_outputs, extra_stats_to_sum |
|
|
|
|
|
def _fast_stat_sync_sum( |
|
|
self, |
|
|
logging_outputs: List[Dict[str, Any]], |
|
|
*extra_stats_to_sum, |
|
|
ignore=False, |
|
|
): |
|
|
""" |
|
|
Sync logging outputs across workers. fast_stat_sync_sum is |
|
|
faster than all_gather_list_sync, but is only suitable when |
|
|
logging outputs are scalars and can be summed. Note that |
|
|
*logging_outputs* cannot contain any nested dicts/lists. |
|
|
""" |
|
|
data = {} |
|
|
for i, stat in enumerate(extra_stats_to_sum): |
|
|
data['extra_stats_' + str(i)] = stat |
|
|
if len(logging_outputs) > 0: |
|
|
log_keys = list(logging_outputs[0].keys()) |
|
|
for k in log_keys: |
|
|
if not ignore: |
|
|
v = sum(log[k] for log in logging_outputs if k in log) |
|
|
else: |
|
|
v = logging_outputs[0][k] |
|
|
v = torch.zeros_like(v) if torch.is_tensor(v) else 0 |
|
|
data['logging_outputs_' + k] = v |
|
|
else: |
|
|
log_keys = None |
|
|
|
|
|
data = distributed_utils.all_reduce_dict( |
|
|
data, |
|
|
device=self.device, |
|
|
group=self.data_parallel_process_group |
|
|
) |
|
|
|
|
|
extra_stats_to_sum = [ |
|
|
data['extra_stats_' + str(i)] for i in range(len(extra_stats_to_sum)) |
|
|
] |
|
|
if log_keys is not None: |
|
|
logging_outputs = [{k: data['logging_outputs_' + k] for k in log_keys}] |
|
|
else: |
|
|
logging_outputs = [] |
|
|
return logging_outputs, extra_stats_to_sum |
|
|
|
|
|
def _check_grad_norms(self, grad_norm): |
|
|
"""Check that grad norms are consistent across workers.""" |
|
|
if self._grad_norm_buf is not None: |
|
|
self._grad_norm_buf.zero_() |
|
|
self._grad_norm_buf[self.data_parallel_rank] = grad_norm |
|
|
distributed_utils.all_reduce( |
|
|
self._grad_norm_buf, |
|
|
group=self.data_parallel_process_group |
|
|
) |
|
|
|
|
|
def is_consistent(tensor): |
|
|
max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) |
|
|
return (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() |
|
|
|
|
|
if not is_consistent(self._grad_norm_buf): |
|
|
pretty_detail = "\n".join( |
|
|
"rank {:3d} = {:.8f}".format(r, n) |
|
|
for r, n in enumerate(self._grad_norm_buf.tolist()) |
|
|
) |
|
|
error_detail = "grad_norm across the workers:\n{}\n".format(pretty_detail) |
|
|
raise RuntimeError( |
|
|
"Fatal error: gradients are inconsistent between workers. " |
|
|
"Try --ddp-backend=no_c10d. " |
|
|
"Or are you mixing up different generation of GPUs in training?" |
|
|
+ "\n" |
|
|
+ "-" * 80 |
|
|
+ "\n{}\n".format(error_detail) |
|
|
+ "-" * 80 |
|
|
) |
|
|
|
|
|
def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): |
|
|
if grad_norm is not None: |
|
|
metrics.log_speed("ups", 1., priority=100, round=2) |
|
|
metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) |
|
|
if self.args.clip_norm > 0: |
|
|
metrics.log_scalar( |
|
|
"clip", |
|
|
torch.where( |
|
|
grad_norm > self.args.clip_norm, |
|
|
grad_norm.new_tensor(100), |
|
|
grad_norm.new_tensor(0), |
|
|
), |
|
|
priority=500, |
|
|
round=1, |
|
|
) |
|
|
|
|
|
with metrics.aggregate() as agg: |
|
|
if logging_outputs is not None: |
|
|
self.task.reduce_metrics(logging_outputs, self.get_criterion()) |
|
|
del logging_outputs |
|
|
|
|
|
|
|
|
if "loss" not in agg: |
|
|
if "loss" not in self._warn_once: |
|
|
self._warn_once.add("loss") |
|
|
logger.warning( |
|
|
"Criterion.reduce_metrics did not log a 'loss' value, " |
|
|
"which may break some functionality" |
|
|
) |
|
|
metrics.log_scalar("loss", -1) |
|
|
|
|
|
|
|
|
if self.tpu: |
|
|
logging_output = {} |
|
|
else: |
|
|
logging_output = agg.get_smoothed_values() |
|
|
logging_output["sample_size"] = sample_size |
|
|
for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: |
|
|
if key_to_delete in logging_output: |
|
|
del logging_output[key_to_delete] |
|
|
return logging_output |
|
|
|
|
|
def _check_xla_compilation(self, message=None): |
|
|
import torch_xla.debug.metrics as met |
|
|
compile_stats = met.metric_data("CompileTime") |
|
|
if compile_stats is None: |
|
|
return |
|
|
num_xla_compiles = compile_stats[0] |
|
|
if num_xla_compiles > self._num_xla_compiles: |
|
|
if message is None: |
|
|
message = ( |
|
|
"too many of these can lead to slow training, " |
|
|
"but we expect a few in the beginning" |
|
|
) |
|
|
logging.info("NOTE: XLA compilation detected; {}".format(message)) |
|
|
self._num_xla_compiles = num_xla_compiles |
|
|
|
|
|
|
|
|
def _catalog_shared_params(module, memo=None, prefix=''): |
|
|
if memo is None: |
|
|
first_call = True |
|
|
memo = {} |
|
|
else: |
|
|
first_call = False |
|
|
for name, param in module._parameters.items(): |
|
|
param_prefix = prefix + ('.' if prefix else '') + name |
|
|
if param not in memo: |
|
|
memo[param] = [] |
|
|
memo[param].append(param_prefix) |
|
|
for name, m in module._modules.items(): |
|
|
if m is None: |
|
|
continue |
|
|
submodule_prefix = prefix + ('.' if prefix else '') + name |
|
|
_catalog_shared_params(m, memo, submodule_prefix) |
|
|
if first_call: |
|
|
return [x for x in memo.values() if len(x) > 1] |
|
|
|
|
|
|
|
|
def _get_module_by_path(module, path): |
|
|
path = path.split('.') |
|
|
for name in path: |
|
|
module = getattr(module, name) |
|
|
return module |
|
|
|
|
|
|
|
|
def _set_module_by_path(module, path, value): |
|
|
path = path.split('.') |
|
|
for name in path[:-1]: |
|
|
module = getattr(module, name) |
|
|
setattr(module, path[-1], value) |
|
|
|