|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from itertools import chain |
|
|
|
|
|
import torch |
|
|
|
|
|
from fairseq import optim, utils |
|
|
|
|
|
from .dynamic_loss_scaler import DynamicLossScaler |
|
|
|
|
|
|
|
|
class _FP16OptimizerMixin(object): |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
@property |
|
|
def has_flat_params(self): |
|
|
return torch.is_tensor(self.fp32_params) |
|
|
|
|
|
@classmethod |
|
|
def build_fp32_params(cls, params, flatten=True): |
|
|
|
|
|
if flatten: |
|
|
total_param_size = sum(p.data.numel() for p in params) |
|
|
fp32_params = torch.zeros(total_param_size, dtype=torch.float, device=params[0].device) |
|
|
offset = 0 |
|
|
for p in params: |
|
|
numel = p.data.numel() |
|
|
fp32_params[offset:offset+numel].copy_(p.data.view(-1)) |
|
|
offset += numel |
|
|
fp32_params = torch.nn.Parameter(fp32_params) |
|
|
fp32_params.grad = fp32_params.data.new(total_param_size) |
|
|
return fp32_params |
|
|
else: |
|
|
fp32_params = [] |
|
|
for p in params: |
|
|
p32 = torch.nn.Parameter(p.data.float()) |
|
|
p32.grad = torch.zeros_like(p32.data) |
|
|
fp32_params.append(p32) |
|
|
return fp32_params |
|
|
|
|
|
def state_dict(self): |
|
|
"""Return the optimizer's state dict.""" |
|
|
state_dict = self.fp32_optimizer.state_dict() |
|
|
if self.scaler is not None: |
|
|
state_dict['loss_scale'] = self.scaler.loss_scale |
|
|
return state_dict |
|
|
|
|
|
def load_state_dict(self, state_dict, optimizer_overrides=None): |
|
|
"""Load an optimizer state dict. |
|
|
|
|
|
In general we should prefer the configuration of the existing optimizer |
|
|
instance (e.g., learning rate) over that found in the state_dict. This |
|
|
allows us to resume training from a checkpoint using a new set of |
|
|
optimizer args. |
|
|
""" |
|
|
if 'loss_scale' in state_dict and self.scaler is not None: |
|
|
self.scaler.loss_scale = state_dict['loss_scale'] |
|
|
self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides) |
|
|
|
|
|
def backward(self, loss): |
|
|
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves. |
|
|
|
|
|
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this |
|
|
function additionally dynamically scales the loss to avoid gradient |
|
|
underflow. |
|
|
""" |
|
|
if self.scaler is not None: |
|
|
loss = self.scaler.scale(loss) |
|
|
loss.backward() |
|
|
self._needs_sync = True |
|
|
|
|
|
def _sync_fp16_grads_to_fp32(self, multiply_grads=1.): |
|
|
if self._needs_sync: |
|
|
if self.scaler is not None: |
|
|
|
|
|
multiply_grads /= self.scaler.loss_scale |
|
|
|
|
|
|
|
|
if self.has_flat_params: |
|
|
offset = 0 |
|
|
for p in self.fp16_params: |
|
|
if not p.requires_grad: |
|
|
continue |
|
|
grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape) |
|
|
numel = grad_data.numel() |
|
|
self.fp32_params.grad.data[offset:offset+numel].copy_(grad_data.view(-1)) |
|
|
offset += numel |
|
|
self.fp32_params.grad.data.mul_(multiply_grads) |
|
|
else: |
|
|
for p, p32 in zip(self.fp16_params, self.fp32_params): |
|
|
if not p.requires_grad: |
|
|
continue |
|
|
if p.grad is not None: |
|
|
p32.grad.data.copy_(p.grad.data) |
|
|
p32.grad.data.mul_(multiply_grads) |
|
|
else: |
|
|
p32.grad = torch.zeros_like(p.data, dtype=torch.float) |
|
|
|
|
|
self._needs_sync = False |
|
|
|
|
|
def _sync_fp32_grads_to_fp16(self): |
|
|
|
|
|
if self.has_flat_params: |
|
|
offset = 0 |
|
|
for p in self.fp16_params: |
|
|
if not p.requires_grad: |
|
|
continue |
|
|
numel = p.data.numel() |
|
|
p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data)) |
|
|
offset += numel |
|
|
else: |
|
|
for p, p32 in zip(self.fp16_params, self.fp32_params): |
|
|
if not p.requires_grad: |
|
|
continue |
|
|
p.data.copy_(p32.data) |
|
|
|
|
|
def multiply_grads(self, c): |
|
|
"""Multiplies grads by a constant ``c``.""" |
|
|
if self._needs_sync: |
|
|
self._sync_fp16_grads_to_fp32(c) |
|
|
elif self.has_flat_params: |
|
|
self.fp32_params.grad.data.mul_(c) |
|
|
else: |
|
|
for p32 in self.fp32_params: |
|
|
p32.grad.data.mul_(c) |
|
|
|
|
|
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): |
|
|
"""Clips gradient norm and updates dynamic loss scaler.""" |
|
|
self._sync_fp16_grads_to_fp32() |
|
|
grad_norm = utils.clip_grad_norm_(self.fp32_params, max_norm, aggregate_norm_fn) |
|
|
|
|
|
|
|
|
if self.scaler is not None: |
|
|
self.scaler.check_overflow(grad_norm) |
|
|
|
|
|
return grad_norm |
|
|
|
|
|
def step(self, closure=None): |
|
|
"""Performs a single optimization step.""" |
|
|
self._sync_fp16_grads_to_fp32() |
|
|
self.fp32_optimizer.step(closure) |
|
|
|
|
|
if self.scaler is not None: |
|
|
self.scaler.update() |
|
|
|
|
|
self._sync_fp32_grads_to_fp16() |
|
|
|
|
|
def zero_grad(self): |
|
|
"""Clears the gradients of all optimized parameters.""" |
|
|
for p in self.fp16_params: |
|
|
p.grad = None |
|
|
if self.has_flat_params: |
|
|
self.fp32_params.grad.zero_() |
|
|
else: |
|
|
for p32 in self.fp32_params: |
|
|
p32.grad.zero_() |
|
|
self._needs_sync = False |
|
|
|
|
|
|
|
|
class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer): |
|
|
""" |
|
|
Wrap an *optimizer* to support FP16 (mixed precision) training. |
|
|
""" |
|
|
|
|
|
def __init__(self, args, params, fp32_optimizer, fp32_params): |
|
|
super().__init__(args) |
|
|
self.fp16_params = params |
|
|
self.fp32_optimizer = fp32_optimizer |
|
|
self.fp32_params = fp32_params |
|
|
|
|
|
if getattr(args, 'fp16_scale_window', None) is None: |
|
|
if len(args.update_freq) > 1: |
|
|
raise ValueError( |
|
|
'--fp16-scale-window must be given explicitly when using a ' |
|
|
'custom --update-freq schedule' |
|
|
) |
|
|
data_parallel_size = int(args.distributed_world_size / args.model_parallel_size) |
|
|
scale_window = int(2**14 / data_parallel_size / args.update_freq[0]) |
|
|
else: |
|
|
scale_window = args.fp16_scale_window |
|
|
|
|
|
if not getattr(args, 'bf16', False): |
|
|
self.scaler = DynamicLossScaler( |
|
|
init_scale=args.fp16_init_scale, |
|
|
scale_window=scale_window, |
|
|
tolerance=args.fp16_scale_tolerance, |
|
|
threshold=args.threshold_loss_scale, |
|
|
min_loss_scale=args.min_loss_scale |
|
|
) |
|
|
else: |
|
|
|
|
|
self.scaler = None |
|
|
|
|
|
@classmethod |
|
|
def build_optimizer(cls, args, params): |
|
|
""" |
|
|
Args: |
|
|
args (argparse.Namespace): fairseq args |
|
|
params (iterable): iterable of parameters to optimize |
|
|
""" |
|
|
flatten = not getattr(args, 'fp16_no_flatten_grads', False) |
|
|
if getattr(args, 'bf16', False): |
|
|
flatten = False |
|
|
fp32_params = cls.build_fp32_params(params, flatten=flatten) |
|
|
if flatten: |
|
|
fp32_optimizer = optim.build_optimizer(args, [fp32_params]) |
|
|
else: |
|
|
fp32_optimizer = optim.build_optimizer(args, fp32_params) |
|
|
if flatten and not fp32_optimizer.supports_flat_params: |
|
|
raise RuntimeError( |
|
|
'chosen optimizer does not support flat params, ' |
|
|
'please set --fp16-no-flatten-grads' |
|
|
) |
|
|
return cls(args, params, fp32_optimizer, fp32_params) |
|
|
|
|
|
@property |
|
|
def optimizer(self): |
|
|
return self.fp32_optimizer.optimizer |
|
|
|
|
|
@property |
|
|
def optimizer_config(self): |
|
|
return self.fp32_optimizer.optimizer_config |
|
|
|
|
|
def get_lr(self): |
|
|
return self.fp32_optimizer.get_lr() |
|
|
|
|
|
def set_lr(self, lr): |
|
|
self.fp32_optimizer.set_lr(lr) |
|
|
|
|
|
|
|
|
class _MemoryEfficientFP16OptimizerMixin(object): |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
@property |
|
|
def has_flat_params(self): |
|
|
return False |
|
|
|
|
|
def state_dict(self): |
|
|
"""Return the optimizer's state dict.""" |
|
|
state_dict = self.wrapped_optimizer.state_dict() |
|
|
if self.scaler is not None: |
|
|
state_dict['loss_scale'] = self.scaler.loss_scale |
|
|
return state_dict |
|
|
|
|
|
def load_state_dict(self, state_dict, optimizer_overrides=None): |
|
|
"""Load an optimizer state dict. |
|
|
|
|
|
In general we should prefer the configuration of the existing optimizer |
|
|
instance (e.g., learning rate) over that found in the state_dict. This |
|
|
allows us to resume training from a checkpoint using a new set of |
|
|
optimizer args. |
|
|
""" |
|
|
if 'loss_scale' in state_dict and self.scaler is not None: |
|
|
self.scaler.loss_scale = state_dict['loss_scale'] |
|
|
|
|
|
self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
groups = self.optimizer.param_groups |
|
|
saved_groups = state_dict['param_groups'] |
|
|
id_map = { |
|
|
old_id: p |
|
|
for old_id, p in zip( |
|
|
chain(*(g['params'] for g in saved_groups)), |
|
|
chain(*(g['params'] for g in groups)) |
|
|
) |
|
|
} |
|
|
for k, v in state_dict['state'].items(): |
|
|
if k in id_map: |
|
|
param = id_map[k] |
|
|
self.optimizer.state[param] = v |
|
|
|
|
|
def backward(self, loss): |
|
|
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves. |
|
|
|
|
|
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this |
|
|
function additionally dynamically scales the loss to avoid gradient |
|
|
underflow. |
|
|
""" |
|
|
if self.scaler is not None: |
|
|
loss = self.scaler.scale(loss) |
|
|
loss.backward() |
|
|
|
|
|
def _unscale_grads(self): |
|
|
if self._multiply_factor != 1.: |
|
|
self.wrapped_optimizer.multiply_grads(self._multiply_factor) |
|
|
self._multiply_factor = 1. |
|
|
|
|
|
def multiply_grads(self, c): |
|
|
"""Multiplies grads by a constant *c*.""" |
|
|
self._multiply_factor *= c |
|
|
|
|
|
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): |
|
|
"""Clips gradient norm and updates dynamic loss scaler.""" |
|
|
max_norm = float(max_norm) |
|
|
grad_norm = self._multiply_factor * self.wrapped_optimizer.clip_grad_norm(0, aggregate_norm_fn) |
|
|
|
|
|
if self.scaler is not None: |
|
|
grad_norm_cpu = float(grad_norm) |
|
|
if grad_norm_cpu > max_norm > 0.: |
|
|
self._multiply_factor *= max_norm / grad_norm_cpu |
|
|
|
|
|
|
|
|
self.scaler.check_overflow(grad_norm_cpu) |
|
|
else: |
|
|
clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) |
|
|
self._multiply_factor *= clip_coef |
|
|
|
|
|
return grad_norm |
|
|
|
|
|
def step(self, closure=None): |
|
|
"""Performs a single optimization step.""" |
|
|
if self.supports_step_with_scale: |
|
|
|
|
|
self.wrapped_optimizer.step(closure, scale=(1. / self._multiply_factor)) |
|
|
else: |
|
|
self._unscale_grads() |
|
|
self.wrapped_optimizer.step(closure) |
|
|
|
|
|
if self.scaler is not None: |
|
|
self.scaler.update() |
|
|
|
|
|
def zero_grad(self): |
|
|
"""Clears the gradients of all optimized parameters.""" |
|
|
self.wrapped_optimizer.zero_grad() |
|
|
if self.scaler is not None: |
|
|
self._multiply_factor = 1. / float(self.scaler.loss_scale) |
|
|
|
|
|
|
|
|
class MemoryEfficientFP16Optimizer(_MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer): |
|
|
""" |
|
|
Wrap an *optimizer* to support FP16 (mixed precision) training. |
|
|
|
|
|
Compared to :class:`fairseq.optim.FP16Optimizer`, this version does not |
|
|
maintain an FP32 copy of the model. We instead expect the optimizer to |
|
|
convert the gradients to FP32 internally and sync the results back to the |
|
|
FP16 model params. This significantly reduces memory usage but slightly |
|
|
increases the time spent in the optimizer. |
|
|
|
|
|
Since this wrapper depends on specific functionality in the wrapped |
|
|
optimizer (i.e., on-the-fly conversion of grads to FP32), only certain |
|
|
optimizers can be wrapped. This is determined by the |
|
|
*supports_memory_efficient_fp16* property. |
|
|
""" |
|
|
|
|
|
def __init__(self, args, params, optimizer): |
|
|
if not optimizer.supports_memory_efficient_fp16: |
|
|
raise ValueError( |
|
|
'Unsupported optimizer: {}'.format(optimizer.__class__.__name__) |
|
|
) |
|
|
|
|
|
super().__init__(args) |
|
|
self.wrapped_optimizer = optimizer |
|
|
|
|
|
if getattr(args, 'fp16_scale_window', None) is None: |
|
|
if len(args.update_freq) > 1: |
|
|
raise ValueError( |
|
|
'--fp16-scale-window must be given explicitly when using a ' |
|
|
'custom --update-freq schedule' |
|
|
) |
|
|
data_parallel_size = int(args.distributed_world_size / args.model_parallel_size) |
|
|
scale_window = 2**14 / data_parallel_size / args.update_freq[0] |
|
|
else: |
|
|
scale_window = args.fp16_scale_window |
|
|
|
|
|
if not getattr(args, 'bf16', False): |
|
|
self.scaler = DynamicLossScaler( |
|
|
init_scale=args.fp16_init_scale, |
|
|
scale_window=scale_window, |
|
|
tolerance=args.fp16_scale_tolerance, |
|
|
threshold=args.threshold_loss_scale, |
|
|
min_loss_scale=args.min_loss_scale |
|
|
) |
|
|
else: |
|
|
|
|
|
self.scaler = None |
|
|
|
|
|
@classmethod |
|
|
def build_optimizer(cls, args, params): |
|
|
""" |
|
|
Args: |
|
|
args (argparse.Namespace): fairseq args |
|
|
params (iterable): iterable of parameters to optimize |
|
|
""" |
|
|
fp16_optimizer = optim.build_optimizer(args, params) |
|
|
return cls(args, params, fp16_optimizer) |
|
|
|
|
|
@property |
|
|
def optimizer(self): |
|
|
return self.wrapped_optimizer.optimizer |
|
|
|
|
|
@property |
|
|
def optimizer_config(self): |
|
|
return self.wrapped_optimizer.optimizer_config |
|
|
|
|
|
def get_lr(self): |
|
|
return self.wrapped_optimizer.get_lr() |
|
|
|
|
|
def set_lr(self, lr): |
|
|
self.wrapped_optimizer.set_lr(lr) |
|
|
|