| | from typing import Dict, List |
| |
|
| | import torch |
| |
|
| | if torch.__version__ < '1.9': |
| | Iterable = torch._six.container_abcs.Iterable |
| | else: |
| | import collections |
| |
|
| | Iterable = collections.abc.Iterable |
| | from torch.cuda.amp import GradScaler |
| |
|
| |
|
| | class _MultiDeviceReplicator(object): |
| | """ |
| | Lazily serves copies of a tensor to requested devices. Copies are cached per-device. |
| | """ |
| |
|
| | def __init__(self, master_tensor: torch.Tensor) -> None: |
| | assert master_tensor.is_cuda |
| | self.master = master_tensor |
| | self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} |
| |
|
| | def get(self, device) -> torch.Tensor: |
| | retval = self._per_device_tensors.get(device, None) |
| | if retval is None: |
| | retval = self.master.to(device=device, non_blocking=True, copy=True) |
| | self._per_device_tensors[device] = retval |
| | return retval |
| |
|
| |
|
| | class MaxClipGradScaler(GradScaler): |
| | def __init__(self, init_scale, max_scale: float, growth_interval=100): |
| | GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) |
| | self.max_scale = max_scale |
| |
|
| | def scale_clip(self): |
| | if self.get_scale() == self.max_scale: |
| | self.set_growth_factor(1) |
| | elif self.get_scale() < self.max_scale: |
| | self.set_growth_factor(2) |
| | elif self.get_scale() > self.max_scale: |
| | self._scale.fill_(self.max_scale) |
| | self.set_growth_factor(1) |
| |
|
| | def scale(self, outputs): |
| | """ |
| | Multiplies ('scales') a tensor or list of tensors by the scale factor. |
| | |
| | Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned |
| | unmodified. |
| | |
| | Arguments: |
| | outputs (Tensor or iterable of Tensors): Outputs to scale. |
| | """ |
| | if not self._enabled: |
| | return outputs |
| | self.scale_clip() |
| | |
| | if isinstance(outputs, torch.Tensor): |
| | assert outputs.is_cuda |
| | if self._scale is None: |
| | self._lazy_init_scale_growth_tracker(outputs.device) |
| | assert self._scale is not None |
| | return outputs * self._scale.to(device=outputs.device, non_blocking=True) |
| |
|
| | |
| | stash: List[_MultiDeviceReplicator] = [] |
| |
|
| | def apply_scale(val): |
| | if isinstance(val, torch.Tensor): |
| | assert val.is_cuda |
| | if len(stash) == 0: |
| | if self._scale is None: |
| | self._lazy_init_scale_growth_tracker(val.device) |
| | assert self._scale is not None |
| | stash.append(_MultiDeviceReplicator(self._scale)) |
| | return val * stash[0].get(val.device) |
| | elif isinstance(val, Iterable): |
| | iterable = map(apply_scale, val) |
| | if isinstance(val, list) or isinstance(val, tuple): |
| | return type(val)(iterable) |
| | else: |
| | return iterable |
| | else: |
| | raise ValueError("outputs must be a Tensor or an iterable of Tensors") |
| |
|
| | return apply_scale(outputs) |
| |
|