| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import logging |
| import math |
| import random |
| from typing import Optional |
| from typing import Tuple |
| from typing import Union |
|
|
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
|
|
|
|
| class DoubleSwishFunction(torch.autograd.Function): |
| """ |
| double_swish(x) = x * torch.sigmoid(x-1) |
| This is a definition, originally motivated by its close numerical |
| similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). |
| |
| Memory-efficient derivative computation: |
| double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) |
| double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). |
| Now, s'(x) = s(x) * (1-s(x)). |
| double_swish'(x) = x * s'(x) + s(x). |
| = x * s(x) * (1-s(x)) + s(x). |
| = double_swish(x) * (1-s(x)) + s(x) |
| ... so we just need to remember s(x) but not x itself. |
| """ |
|
|
| @staticmethod |
| def forward(ctx, x: Tensor) -> Tensor: |
| requires_grad = x.requires_grad |
| x_dtype = x.dtype |
| if x.dtype == torch.float16: |
| x = x.to(torch.float32) |
|
|
| s = torch.sigmoid(x - 1.0) |
| y = x * s |
|
|
| if requires_grad: |
| deriv = y * (1 - s) + s |
| |
| |
| |
| |
| |
| |
| floor = -0.043637 |
| ceil = 1.2 |
| d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( |
| deriv |
| ) |
| if __name__ == "__main__": |
| |
| assert d_scaled.min() >= 0.0 |
| assert d_scaled.max() < 256.0 |
| d_int = d_scaled.to(torch.uint8) |
| ctx.save_for_backward(d_int) |
| if x.dtype == torch.float16 or torch.is_autocast_enabled(): |
| y = y.to(torch.float16) |
| return y |
|
|
| @staticmethod |
| def backward(ctx, y_grad: Tensor) -> Tensor: |
| (d,) = ctx.saved_tensors |
| |
| floor = -0.043637 |
| ceil = 1.2 |
| d = d * ((ceil - floor) / 255.0) + floor |
| return y_grad * d |
|
|
|
|
| class DoubleSwish(torch.nn.Module): |
| def forward(self, x: Tensor) -> Tensor: |
| """Return double-swish activation function which is an approximation to Swish(Swish(x)), |
| that we approximate closely with x * sigmoid(x-1). |
| """ |
| if torch.jit.is_scripting() or torch.jit.is_tracing(): |
| return x * torch.sigmoid(x - 1.0) |
| return DoubleSwishFunction.apply(x) |
|
|
|
|
| class ActivationBalancerFunction(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| x: Tensor, |
| scale_factor: Tensor, |
| sign_factor: Optional[Tensor], |
| channel_dim: int, |
| ) -> Tensor: |
| if channel_dim < 0: |
| channel_dim += x.ndim |
| ctx.channel_dim = channel_dim |
| xgt0 = x > 0 |
| if sign_factor is None: |
| ctx.save_for_backward(xgt0, scale_factor) |
| else: |
| ctx.save_for_backward(xgt0, scale_factor, sign_factor) |
| return x |
|
|
| @staticmethod |
| def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: |
| if len(ctx.saved_tensors) == 3: |
| xgt0, scale_factor, sign_factor = ctx.saved_tensors |
| for _ in range(ctx.channel_dim, x_grad.ndim - 1): |
| scale_factor = scale_factor.unsqueeze(-1) |
| sign_factor = sign_factor.unsqueeze(-1) |
| factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) |
| else: |
| xgt0, scale_factor = ctx.saved_tensors |
| for _ in range(ctx.channel_dim, x_grad.ndim - 1): |
| scale_factor = scale_factor.unsqueeze(-1) |
| factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) |
| neg_delta_grad = x_grad.abs() * factor |
| return ( |
| x_grad - neg_delta_grad, |
| None, |
| None, |
| None, |
| ) |
|
|
|
|
| def _compute_scale_factor( |
| x: Tensor, |
| channel_dim: int, |
| min_abs: float, |
| max_abs: float, |
| gain_factor: float, |
| max_factor: float, |
| ) -> Tensor: |
| if channel_dim < 0: |
| channel_dim += x.ndim |
| sum_dims = [d for d in range(x.ndim) if d != channel_dim] |
| x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) |
|
|
| if min_abs == 0.0: |
| below_threshold = 0.0 |
| else: |
| |
| |
| below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( |
| min=0, max=max_factor |
| ) |
|
|
| above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( |
| min=0, max=max_factor |
| ) |
|
|
| return below_threshold - above_threshold |
|
|
|
|
| def _compute_sign_factor( |
| x: Tensor, |
| channel_dim: int, |
| min_positive: float, |
| max_positive: float, |
| gain_factor: float, |
| max_factor: float, |
| ) -> Tensor: |
| if channel_dim < 0: |
| channel_dim += x.ndim |
| sum_dims = [d for d in range(x.ndim) if d != channel_dim] |
| proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) |
| if min_positive == 0.0: |
| factor1 = 0.0 |
| else: |
| |
| |
| factor1 = ( |
| (min_positive - proportion_positive) * (gain_factor / min_positive) |
| ).clamp_(min=0, max=max_factor) |
|
|
| if max_positive == 1.0: |
| factor2 = 0.0 |
| else: |
| |
| |
| factor2 = ( |
| (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) |
| ).clamp_(min=0, max=max_factor) |
| sign_factor = factor1 - factor2 |
| |
| assert not isinstance(sign_factor, float) |
| return sign_factor |
|
|
|
|
| class ActivationBalancer(torch.nn.Module): |
| """ |
| Modifies the backpropped derivatives of a function to try to encourage, for |
| each channel, that it is positive at least a proportion `threshold` of the |
| time. It does this by multiplying negative derivative values by up to |
| (1+max_factor), and positive derivative values by up to (1-max_factor), |
| interpolated from 1 at the threshold to those extremal values when none |
| of the inputs are positive. |
| |
| Args: |
| num_channels: the number of channels |
| channel_dim: the dimension/axis corresponding to the channel, e.g. |
| -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. |
| min_positive: the minimum, per channel, of the proportion of the time |
| that (x > 0), below which we start to modify the derivatives. |
| max_positive: the maximum, per channel, of the proportion of the time |
| that (x > 0), above which we start to modify the derivatives. |
| max_factor: the maximum factor by which we modify the derivatives for |
| either the sign constraint or the magnitude constraint; |
| e.g. with max_factor=0.02, the the derivatives would be multiplied by |
| values in the range [0.98..1.02]. |
| sign_gain_factor: determines the 'gain' with which we increase the |
| change in gradient once the constraints on min_positive and max_positive |
| are violated. |
| scale_gain_factor: determines the 'gain' with which we increase the |
| change in gradient once the constraints on min_abs and max_abs |
| are violated. |
| min_abs: the minimum average-absolute-value difference from the mean |
| value per channel, which we allow, before we start to modify |
| the derivatives to prevent this. |
| max_abs: the maximum average-absolute-value difference from the mean |
| value per channel, which we allow, before we start to modify |
| the derivatives to prevent this. |
| min_prob: determines the minimum probability with which we modify the |
| gradients for the {min,max}_positive and {min,max}_abs constraints, |
| on each forward(). This is done randomly to prevent all layers |
| from doing it at the same time. Early in training we may use |
| higher probabilities than this; it will decay to this value. |
| """ |
|
|
| def __init__( |
| self, |
| num_channels: int, |
| channel_dim: int, |
| min_positive: float = 0.05, |
| max_positive: float = 0.95, |
| max_factor: float = 0.04, |
| sign_gain_factor: float = 0.01, |
| scale_gain_factor: float = 0.02, |
| min_abs: float = 0.2, |
| max_abs: float = 100.0, |
| min_prob: float = 0.1, |
| ): |
| super(ActivationBalancer, self).__init__() |
| self.num_channels = num_channels |
| self.channel_dim = channel_dim |
| self.min_positive = min_positive |
| self.max_positive = max_positive |
| self.max_factor = max_factor |
| self.min_abs = min_abs |
| self.max_abs = max_abs |
| self.min_prob = min_prob |
| self.sign_gain_factor = sign_gain_factor |
| self.scale_gain_factor = scale_gain_factor |
|
|
| |
| |
| |
| self.cpu_count = 0 |
| self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): |
| return _no_op(x) |
|
|
| count = self.cpu_count |
| self.cpu_count += 1 |
|
|
| if random.random() < 0.01: |
| |
| |
| |
| self.cpu_count = max(self.cpu_count, self.count.item()) |
| self.count.fill_(self.cpu_count) |
|
|
| |
| |
| prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) |
|
|
| if random.random() < prob: |
| sign_gain_factor = 0.5 |
| if self.min_positive != 0.0 or self.max_positive != 1.0: |
| sign_factor = _compute_sign_factor( |
| x, |
| self.channel_dim, |
| self.min_positive, |
| self.max_positive, |
| gain_factor=self.sign_gain_factor / prob, |
| max_factor=self.max_factor, |
| ) |
| else: |
| sign_factor = None |
|
|
| scale_factor = _compute_scale_factor( |
| x.detach(), |
| self.channel_dim, |
| min_abs=self.min_abs, |
| max_abs=self.max_abs, |
| gain_factor=self.scale_gain_factor / prob, |
| max_factor=self.max_factor, |
| ) |
| return ActivationBalancerFunction.apply( |
| x, |
| scale_factor, |
| sign_factor, |
| self.channel_dim, |
| ) |
| else: |
| return _no_op(x) |
|
|
|
|
| def BalancedDoubleSwish( |
| d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25 |
| ) -> nn.Sequential: |
| """ |
| ActivationBalancer -> DoubleSwish |
| """ |
| balancer = ActivationBalancer( |
| d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob |
| ) |
| return nn.Sequential( |
| balancer, |
| DoubleSwish(), |
| ) |
|
|