| from typing import Dict, List |
|
|
| import torch |
|
|
| if torch.__version__ < '1.9': |
| Iterable = torch._six.container_abcs.Iterable |
| else: |
| import collections |
|
|
| Iterable = collections.abc.Iterable |
| from torch.cuda.amp import GradScaler |
|
|
|
|
| class _MultiDeviceReplicator(object): |
| """ |
| Lazily serves copies of a tensor to requested devices. Copies are cached per-device. |
| """ |
|
|
| def __init__(self, master_tensor: torch.Tensor) -> None: |
| assert master_tensor.is_cuda |
| self.master = master_tensor |
| self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} |
|
|
| def get(self, device) -> torch.Tensor: |
| retval = self._per_device_tensors.get(device, None) |
| if retval is None: |
| retval = self.master.to(device=device, non_blocking=True, copy=True) |
| self._per_device_tensors[device] = retval |
| return retval |
|
|
|
|
| class MaxClipGradScaler(GradScaler): |
| def __init__(self, init_scale, max_scale: float, growth_interval=100): |
| GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) |
| self.max_scale = max_scale |
|
|
| def scale_clip(self): |
| if self.get_scale() == self.max_scale: |
| self.set_growth_factor(1) |
| elif self.get_scale() < self.max_scale: |
| self.set_growth_factor(2) |
| elif self.get_scale() > self.max_scale: |
| self._scale.fill_(self.max_scale) |
| self.set_growth_factor(1) |
|
|
| def scale(self, outputs): |
| """ |
| Multiplies ('scales') a tensor or list of tensors by the scale factor. |
| |
| Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned |
| unmodified. |
| |
| Arguments: |
| outputs (Tensor or iterable of Tensors): Outputs to scale. |
| """ |
| if not self._enabled: |
| return outputs |
| self.scale_clip() |
| |
| if isinstance(outputs, torch.Tensor): |
| assert outputs.is_cuda |
| if self._scale is None: |
| self._lazy_init_scale_growth_tracker(outputs.device) |
| assert self._scale is not None |
| return outputs * self._scale.to(device=outputs.device, non_blocking=True) |
|
|
| |
| stash: List[_MultiDeviceReplicator] = [] |
|
|
| def apply_scale(val): |
| if isinstance(val, torch.Tensor): |
| assert val.is_cuda |
| if len(stash) == 0: |
| if self._scale is None: |
| self._lazy_init_scale_growth_tracker(val.device) |
| assert self._scale is not None |
| stash.append(_MultiDeviceReplicator(self._scale)) |
| return val * stash[0].get(val.device) |
| elif isinstance(val, Iterable): |
| iterable = map(apply_scale, val) |
| if isinstance(val, list) or isinstance(val, tuple): |
| return type(val)(iterable) |
| else: |
| return iterable |
| else: |
| raise ValueError("outputs must be a Tensor or an iterable of Tensors") |
|
|
| return apply_scale(outputs) |
|
|