| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import logging |
| | import math |
| | import random |
| | from typing import Optional, Tuple, Union |
| |
|
| | |
| | import torch |
| | import torch.nn as nn |
| | from torch import Tensor |
| | from torch.cuda.amp import custom_bwd, custom_fwd |
| |
|
| |
|
| | def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: |
| | max_value = torch.max(x, y) |
| | diff = torch.abs(x - y) |
| | return max_value + torch.log1p(torch.exp(-diff)) |
| |
|
| | class JointCodebookLoss(torch.nn.Module): |
| | def __init__( |
| | self, |
| | input_dim: int = 512, |
| | num_codebooks: int = 16, |
| | codebook_size: int = 256, |
| | ignore_index: int = -100, |
| | reduction: str = "none" |
| | ): |
| | super().__init__() |
| | self.input_dim = input_dim |
| | self.num_codebooks = num_codebooks |
| | self.codebook_size = codebook_size |
| | self.reduction = reduction |
| | self.ignore_index = ignore_index |
| | |
| | self.proj = nn.Linear(input_dim, num_codebooks * codebook_size) |
| | |
| | def forward_logprobs(self, input: torch.Tensor): |
| | B,T,_ = input.shape |
| | logits = self.proj(input) |
| | logits = logits.view(B, T, self.num_codebooks, self.codebook_size) |
| | log_probs = F.log_softmax(logits, dim=-1) |
| | return log_probs |
| | |
| | |
| | def forward(self, input, target, return_log_probs: bool = False): |
| | |
| | |
| |
|
| | B,T,_ = input.shape |
| | logits = self.proj(input) |
| | logits = logits.view(B, T, self.num_codebooks, self.codebook_size) |
| | |
| | loss = F.cross_entropy( |
| | logits.reshape(-1, self.codebook_size), |
| | target.reshape(-1), |
| | ignore_index=self.ignore_index, |
| | reduction=self.reduction |
| | ) |
| | log_probs = None |
| | if return_log_probs: |
| | log_probs = F.log_softmax(logits, dim=-1) |
| | |
| | if self.reduction == "none": |
| | loss = loss.view(B, T, self.num_codebooks) |
| |
|
| | if return_log_probs: |
| | return loss, log_probs |
| | |
| | return loss |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | def logaddexp(x: Tensor, y: Tensor) -> Tensor: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if torch.jit.is_scripting(): |
| | |
| | |
| | return torch.logaddexp(x, y) |
| | elif torch.onnx.is_in_onnx_export(): |
| | return logaddexp_onnx(x, y) |
| | else: |
| | |
| | return torch.logaddexp(x, y) |
| |
|
| |
|
| | class PiecewiseLinear(object): |
| | """ |
| | Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with |
| | the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] |
| | respectively. |
| | """ |
| |
|
| | def __init__(self, *args): |
| | assert len(args) >= 1, len(args) |
| | if len(args) == 1 and isinstance(args[0], PiecewiseLinear): |
| | self.pairs = list(args[0].pairs) |
| | else: |
| | self.pairs = [(float(x), float(y)) for x, y in args] |
| | for x, y in self.pairs: |
| | assert isinstance(x, (float, int)), type(x) |
| | assert isinstance(y, (float, int)), type(y) |
| |
|
| | for i in range(len(self.pairs) - 1): |
| | assert self.pairs[i + 1][0] > self.pairs[i][0], ( |
| | i, |
| | self.pairs[i], |
| | self.pairs[i + 1], |
| | ) |
| |
|
| | def __str__(self): |
| | |
| | return f"PiecewiseLinear({str(self.pairs)[1:-1]})" |
| |
|
| | def __call__(self, x): |
| | if x <= self.pairs[0][0]: |
| | return self.pairs[0][1] |
| | elif x >= self.pairs[-1][0]: |
| | return self.pairs[-1][1] |
| | else: |
| | cur_x, cur_y = self.pairs[0] |
| | for i in range(1, len(self.pairs)): |
| | next_x, next_y = self.pairs[i] |
| | if x >= cur_x and x <= next_x: |
| | return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) |
| | cur_x, cur_y = next_x, next_y |
| | assert False |
| |
|
| | def __mul__(self, alpha): |
| | return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) |
| |
|
| | def __add__(self, x): |
| | if isinstance(x, (float, int)): |
| | return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) |
| | s, x = self.get_common_basis(x) |
| | return PiecewiseLinear( |
| | *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] |
| | ) |
| |
|
| | def max(self, x): |
| | if isinstance(x, (float, int)): |
| | x = PiecewiseLinear((0, x)) |
| | s, x = self.get_common_basis(x, include_crossings=True) |
| | return PiecewiseLinear( |
| | *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] |
| | ) |
| |
|
| | def min(self, x): |
| | if isinstance(x, float) or isinstance(x, int): |
| | x = PiecewiseLinear((0, x)) |
| | s, x = self.get_common_basis(x, include_crossings=True) |
| | return PiecewiseLinear( |
| | *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] |
| | ) |
| |
|
| | def __eq__(self, other): |
| | return self.pairs == other.pairs |
| |
|
| | def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): |
| | """ |
| | Returns (self_mod, p_mod) which are equivalent piecewise linear |
| | functions to self and p, but with the same x values. |
| | |
| | p: the other piecewise linear function |
| | include_crossings: if true, include in the x values positions |
| | where the functions indicate by this and p cross. |
| | """ |
| | assert isinstance(p, PiecewiseLinear), type(p) |
| |
|
| | |
| | x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) |
| | y_vals1 = [self(x) for x in x_vals] |
| | y_vals2 = [p(x) for x in x_vals] |
| |
|
| | if include_crossings: |
| | extra_x_vals = [] |
| | for i in range(len(x_vals) - 1): |
| | if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): |
| | |
| | diff_cur = abs(y_vals1[i] - y_vals2[i]) |
| | diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) |
| | |
| | |
| | pos = diff_cur / (diff_cur + diff_next) |
| | extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) |
| | extra_x_vals.append(extra_x_val) |
| | if len(extra_x_vals) > 0: |
| | x_vals = sorted(set(x_vals + extra_x_vals)) |
| | y_vals1 = [self(x) for x in x_vals] |
| | y_vals2 = [p(x) for x in x_vals] |
| | return ( |
| | PiecewiseLinear(*zip(x_vals, y_vals1)), |
| | PiecewiseLinear(*zip(x_vals, y_vals2)), |
| | ) |
| |
|
| |
|
| | class ScheduledFloat(torch.nn.Module): |
| | """ |
| | This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); |
| | it does not have a working forward() function. You are supposed to cast it to float, as |
| | in, float(parent_module.whatever), and use it as something like a dropout prob. |
| | |
| | It is a floating point value whose value changes depending on the batch count of the |
| | training loop. It is a piecewise linear function where you specify the (x,y) pairs |
| | in sorted order on x; x corresponds to the batch index. For batch-index values before the |
| | first x or after the last x, we just use the first or last y value. |
| | |
| | Example: |
| | self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) |
| | |
| | `default` is used when self.batch_count is not set or not in training mode or in |
| | torch.jit scripting mode. |
| | """ |
| |
|
| | def __init__(self, *args, default: float = 0.0): |
| | super().__init__() |
| | |
| | self.batch_count = None |
| | self.name = None |
| | self.default = default |
| | self.schedule = PiecewiseLinear(*args) |
| |
|
| | def extra_repr(self) -> str: |
| | return ( |
| | f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" |
| | ) |
| |
|
| | def __float__(self): |
| | batch_count = self.batch_count |
| | if ( |
| | batch_count is None |
| | or not self.training |
| | or torch.jit.is_scripting() |
| | or torch.jit.is_tracing() |
| | ): |
| | return float(self.default) |
| | else: |
| | ans = self.schedule(self.batch_count) |
| | if random.random() < 0.0002: |
| | logging.info( |
| | f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}" |
| | ) |
| | return ans |
| |
|
| | def __add__(self, x): |
| | if isinstance(x, float) or isinstance(x, int): |
| | return ScheduledFloat(self.schedule + x, default=self.default) |
| | else: |
| | return ScheduledFloat( |
| | self.schedule + x.schedule, default=self.default + x.default |
| | ) |
| |
|
| | def max(self, x): |
| | if isinstance(x, float) or isinstance(x, int): |
| | return ScheduledFloat(self.schedule.max(x), default=self.default) |
| | else: |
| | return ScheduledFloat( |
| | self.schedule.max(x.schedule), default=max(self.default, x.default) |
| | ) |
| |
|
| |
|
| | FloatLike = Union[float, ScheduledFloat] |
| |
|
| |
|
| | def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: |
| | """ |
| | A randomized way of casting a floating point value to half precision. |
| | """ |
| | if x.dtype == torch.float16: |
| | return x |
| | x_abs = x.abs() |
| | is_too_small = x_abs < min_abs |
| | |
| | |
| | |
| | random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) |
| | return torch.where(is_too_small, random_val, x).to(torch.float16) |
| |
|
| |
|
| | class CutoffEstimator: |
| | """ |
| | Estimates cutoffs of an arbitrary numerical quantity such that a specified |
| | proportion of items will be above the cutoff on average. |
| | |
| | p is the proportion of items that should be above the cutoff. |
| | """ |
| |
|
| | def __init__(self, p: float): |
| | self.p = p |
| | |
| | self.count = 0 |
| | |
| | self.count_above = 0 |
| | |
| | self.cutoff = 0 |
| |
|
| | def __call__(self, x: float) -> bool: |
| | """ |
| | Returns true if x is above the cutoff. |
| | """ |
| | ans = x > self.cutoff |
| | self.count += 1 |
| | if ans: |
| | self.count_above += 1 |
| | cur_p = self.count_above / self.count |
| | delta_p = cur_p - self.p |
| | if (delta_p > 0) == ans: |
| | q = abs(delta_p) |
| | self.cutoff = x * q + self.cutoff * (1 - q) |
| | return ans |
| |
|
| |
|
| | class SoftmaxFunction(torch.autograd.Function): |
| | """ |
| | Tries to handle half-precision derivatives in a randomized way that should |
| | be more accurate for training than the default behavior. |
| | """ |
| |
|
| | @staticmethod |
| | def forward(ctx, x: Tensor, dim: int): |
| | ans = x.softmax(dim=dim) |
| | |
| | |
| | |
| | if torch.is_autocast_enabled(): |
| | ans = ans.to(torch.get_autocast_gpu_dtype()) |
| | ctx.save_for_backward(ans) |
| | ctx.x_dtype = x.dtype |
| | ctx.dim = dim |
| | return ans |
| |
|
| | @staticmethod |
| | def backward(ctx, ans_grad: Tensor): |
| | (ans,) = ctx.saved_tensors |
| | with torch.cuda.amp.autocast(enabled=False): |
| | ans_grad = ans_grad.to(torch.float32) |
| | ans = ans.to(torch.float32) |
| | x_grad = ans_grad * ans |
| | x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) |
| | return x_grad, None |
| |
|
| |
|
| | def softmax(x: Tensor, dim: int): |
| | if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing(): |
| | return x.softmax(dim=dim) |
| |
|
| | return SoftmaxFunction.apply(x, dim) |
| |
|
| |
|
| | class MaxEigLimiterFunction(torch.autograd.Function): |
| | @staticmethod |
| | def forward( |
| | ctx, |
| | x: Tensor, |
| | coeffs: Tensor, |
| | direction: Tensor, |
| | channel_dim: int, |
| | grad_scale: float, |
| | ) -> Tensor: |
| | ctx.channel_dim = channel_dim |
| | ctx.grad_scale = grad_scale |
| | ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) |
| | return x |
| |
|
| | @staticmethod |
| | def backward(ctx, x_grad, *args): |
| | with torch.enable_grad(): |
| | (x_orig, coeffs, new_direction) = ctx.saved_tensors |
| | x_orig.requires_grad = True |
| | num_channels = x_orig.shape[ctx.channel_dim] |
| | x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) |
| | new_direction.requires_grad = False |
| | x = x - x.mean(dim=0) |
| | x_var = (x**2).mean() |
| | x_residual = x - coeffs * new_direction |
| | x_residual_var = (x_residual**2).mean() |
| | |
| | |
| | variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) |
| | variance_proportion.backward() |
| | x_orig_grad = x_orig.grad |
| | x_extra_grad = ( |
| | x_orig.grad |
| | * ctx.grad_scale |
| | * x_grad.norm() |
| | / (x_orig_grad.norm() + 1.0e-20) |
| | ) |
| | return x_grad + x_extra_grad.detach(), None, None, None, None |
| |
|
| |
|
| | class BiasNormFunction(torch.autograd.Function): |
| | |
| | |
| | |
| | |
| | |
| | |
| | @staticmethod |
| | def forward( |
| | ctx, |
| | x: Tensor, |
| | bias: Tensor, |
| | log_scale: Tensor, |
| | channel_dim: int, |
| | store_output_for_backprop: bool, |
| | ) -> Tensor: |
| | assert bias.ndim == 1 |
| | if channel_dim < 0: |
| | channel_dim = channel_dim + x.ndim |
| | ctx.store_output_for_backprop = store_output_for_backprop |
| | ctx.channel_dim = channel_dim |
| | for _ in range(channel_dim + 1, x.ndim): |
| | bias = bias.unsqueeze(-1) |
| | scales = ( |
| | torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 |
| | ) * log_scale.exp() |
| | ans = x * scales |
| | ctx.save_for_backward( |
| | ans.detach() if store_output_for_backprop else x, |
| | scales.detach(), |
| | bias.detach(), |
| | log_scale.detach(), |
| | ) |
| | return ans |
| |
|
| | @staticmethod |
| | def backward(ctx, ans_grad: Tensor) -> Tensor: |
| | ans_or_x, scales, bias, log_scale = ctx.saved_tensors |
| | if ctx.store_output_for_backprop: |
| | x = ans_or_x / scales |
| | else: |
| | x = ans_or_x |
| | x = x.detach() |
| | x.requires_grad = True |
| | bias.requires_grad = True |
| | log_scale.requires_grad = True |
| | with torch.enable_grad(): |
| | |
| | scales = ( |
| | torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 |
| | ) * log_scale.exp() |
| | ans = x * scales |
| | ans.backward(gradient=ans_grad) |
| | return x.grad, bias.grad.flatten(), log_scale.grad, None, None |
| |
|
| |
|
| | class BiasNorm(torch.nn.Module): |
| | """ |
| | This is intended to be a simpler, and hopefully cheaper, replacement for |
| | LayerNorm. The observation this is based on, is that Transformer-type |
| | networks, especially with pre-norm, sometimes seem to set one of the |
| | feature dimensions to a large constant value (e.g. 50), which "defeats" |
| | the LayerNorm because the output magnitude is then not strongly dependent |
| | on the other (useful) features. Presumably the weight and bias of the |
| | LayerNorm are required to allow it to do this. |
| | |
| | Instead, we give the BiasNorm a trainable bias that it can use when |
| | computing the scale for normalization. We also give it a (scalar) |
| | trainable scale on the output. |
| | |
| | |
| | Args: |
| | num_channels: the number of channels, e.g. 512. |
| | channel_dim: the axis/dimension corresponding to the channel, |
| | interpreted as an offset from the input's ndim if negative. |
| | This is NOT the num_channels; it should typically be one of |
| | {-2, -1, 0, 1, 2, 3}. |
| | log_scale: the initial log-scale that we multiply the output by; this |
| | is learnable. |
| | log_scale_min: FloatLike, minimum allowed value of log_scale |
| | log_scale_max: FloatLike, maximum allowed value of log_scale |
| | store_output_for_backprop: only possibly affects memory use; recommend |
| | to set to True if you think the output of this module is more likely |
| | than the input of this module to be required to be stored for the |
| | backprop. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | num_channels: int, |
| | channel_dim: int = -1, |
| | log_scale: float = 1.0, |
| | log_scale_min: float = -1.5, |
| | log_scale_max: float = 1.5, |
| | store_output_for_backprop: bool = False, |
| | ) -> None: |
| | super(BiasNorm, self).__init__() |
| | self.num_channels = num_channels |
| | self.channel_dim = channel_dim |
| | self.log_scale = nn.Parameter(torch.tensor(log_scale)) |
| | self.bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4)) |
| |
|
| | self.log_scale_min = log_scale_min |
| | self.log_scale_max = log_scale_max |
| |
|
| | self.store_output_for_backprop = store_output_for_backprop |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | assert x.shape[self.channel_dim] == self.num_channels |
| |
|
| | if torch.jit.is_scripting() or torch.jit.is_tracing(): |
| | channel_dim = self.channel_dim |
| | if channel_dim < 0: |
| | channel_dim += x.ndim |
| | bias = self.bias |
| | for _ in range(channel_dim + 1, x.ndim): |
| | bias = bias.unsqueeze(-1) |
| | scales = ( |
| | torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 |
| | ) * self.log_scale.exp() |
| | return x * scales |
| |
|
| | log_scale = limit_param_value( |
| | self.log_scale, |
| | min=float(self.log_scale_min), |
| | max=float(self.log_scale_max), |
| | training=self.training, |
| | ) |
| |
|
| | return BiasNormFunction.apply( |
| | x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop |
| | ) |
| |
|
| |
|
| | def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: |
| | """ |
| | Behaves like a constructor of a modified version of nn.Linear |
| | that gives an easy way to set the default initial parameter scale. |
| | |
| | Args: |
| | Accepts the standard args and kwargs that nn.Linear accepts |
| | e.g. in_features, out_features, bias=False. |
| | |
| | initial_scale: you can override this if you want to increase |
| | or decrease the initial magnitude of the module's output |
| | (affects the initialization of weight_scale and bias_scale). |
| | Another option, if you want to do something like this, is |
| | to re-initialize the parameters. |
| | """ |
| | ans = nn.Linear(*args, **kwargs) |
| | with torch.no_grad(): |
| | ans.weight[:] *= initial_scale |
| | if ans.bias is not None: |
| | torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) |
| | return ans |
| |
|
| |
|
| | def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: |
| | """ |
| | Behaves like a constructor of a modified version of nn.Conv1d |
| | that gives an easy way to set the default initial parameter scale. |
| | |
| | Args: |
| | Accepts the standard args and kwargs that nn.Linear accepts |
| | e.g. in_features, out_features, bias=False. |
| | |
| | initial_scale: you can override this if you want to increase |
| | or decrease the initial magnitude of the module's output |
| | (affects the initialization of weight_scale and bias_scale). |
| | Another option, if you want to do something like this, is |
| | to re-initialize the parameters. |
| | """ |
| | ans = nn.Conv1d(*args, **kwargs) |
| | with torch.no_grad(): |
| | ans.weight[:] *= initial_scale |
| | if ans.bias is not None: |
| | torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) |
| | return ans |
| |
|
| |
|
| | def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: |
| | """ |
| | Behaves like a constructor of a modified version of nn.Conv2d |
| | that gives an easy way to set the default initial parameter scale. |
| | |
| | Args: |
| | Accepts the standard args and kwargs that nn.Linear accepts |
| | e.g. in_features, out_features, bias=False, but: |
| | NO PADDING-RELATED ARGS. |
| | |
| | initial_scale: you can override this if you want to increase |
| | or decrease the initial magnitude of the module's output |
| | (affects the initialization of weight_scale and bias_scale). |
| | Another option, if you want to do something like this, is |
| | to re-initialize the parameters. |
| | """ |
| | ans = nn.Conv2d(*args, **kwargs) |
| | with torch.no_grad(): |
| | ans.weight[:] *= initial_scale |
| | if ans.bias is not None: |
| | torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) |
| | return ans |
| |
|
| |
|
| | class ChunkCausalDepthwiseConv1d(torch.nn.Module): |
| | """ |
| | Behaves like a depthwise 1d convolution, except that it is causal in |
| | a chunkwise way, as if we had a block-triangular attention mask. |
| | The chunk size is provided at test time (it should probably be |
| | kept in sync with the attention mask). |
| | |
| | This has a little more than twice the parameters of a conventional |
| | depthwise conv1d module: we implement it by having one |
| | depthwise convolution, of half the width, that is causal (via |
| | right-padding); and one depthwise convolution that is applied only |
| | within chunks, that we multiply by a scaling factor which depends |
| | on the position within the chunk. |
| | |
| | Args: |
| | Accepts the standard args and kwargs that nn.Linear accepts |
| | e.g. in_features, out_features, bias=False. |
| | |
| | initial_scale: you can override this if you want to increase |
| | or decrease the initial magnitude of the module's output |
| | (affects the initialization of weight_scale and bias_scale). |
| | Another option, if you want to do something like this, is |
| | to re-initialize the parameters. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | channels: int, |
| | kernel_size: int, |
| | initial_scale: float = 1.0, |
| | bias: bool = True, |
| | ): |
| | super().__init__() |
| | assert kernel_size % 2 == 1 |
| |
|
| | half_kernel_size = (kernel_size + 1) // 2 |
| | |
| | self.causal_conv = nn.Conv1d( |
| | in_channels=channels, |
| | out_channels=channels, |
| | groups=channels, |
| | kernel_size=half_kernel_size, |
| | padding=0, |
| | bias=True, |
| | ) |
| |
|
| | self.chunkwise_conv = nn.Conv1d( |
| | in_channels=channels, |
| | out_channels=channels, |
| | groups=channels, |
| | kernel_size=kernel_size, |
| | padding=kernel_size // 2, |
| | bias=bias, |
| | ) |
| |
|
| | |
| | |
| | |
| | self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) |
| | self.kernel_size = kernel_size |
| |
|
| | with torch.no_grad(): |
| | self.causal_conv.weight[:] *= initial_scale |
| | self.chunkwise_conv.weight[:] *= initial_scale |
| | if bias: |
| | torch.nn.init.uniform_( |
| | self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale |
| | ) |
| |
|
| | def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: |
| | """Forward function. |
| | |
| | Args: |
| | x: a Tensor of shape (batch_size, channels, seq_len) |
| | chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. |
| | """ |
| | (batch_size, num_channels, seq_len) = x.shape |
| |
|
| | |
| | |
| | |
| | |
| | left_pad = self.kernel_size // 2 |
| |
|
| | if chunk_size < 0 or chunk_size > seq_len: |
| | chunk_size = seq_len |
| | right_pad = -seq_len % chunk_size |
| |
|
| | x = torch.nn.functional.pad(x, (left_pad, right_pad)) |
| |
|
| | x_causal = self.causal_conv(x[..., : left_pad + seq_len]) |
| | assert x_causal.shape == (batch_size, num_channels, seq_len) |
| |
|
| | x_chunk = x[..., left_pad:] |
| | num_chunks = x_chunk.shape[2] // chunk_size |
| | x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) |
| | x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( |
| | batch_size * num_chunks, num_channels, chunk_size |
| | ) |
| | x_chunk = self.chunkwise_conv(x_chunk) |
| |
|
| | chunk_scale = self._get_chunk_scale(chunk_size) |
| |
|
| | x_chunk = x_chunk * chunk_scale |
| | x_chunk = x_chunk.reshape( |
| | batch_size, num_chunks, num_channels, chunk_size |
| | ).permute(0, 2, 1, 3) |
| | x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ |
| | ..., :seq_len |
| | ] |
| |
|
| | return x_chunk + x_causal |
| |
|
| | def _get_chunk_scale(self, chunk_size: int): |
| | """Returns tensor of shape (num_channels, chunk_size) that will be used to |
| | scale the output of self.chunkwise_conv.""" |
| | left_edge = self.chunkwise_conv_scale[0] |
| | right_edge = self.chunkwise_conv_scale[1] |
| | if chunk_size < self.kernel_size: |
| | left_edge = left_edge[:, :chunk_size] |
| | right_edge = right_edge[:, -chunk_size:] |
| | else: |
| | t = chunk_size - self.kernel_size |
| | channels = left_edge.shape[0] |
| | pad = torch.zeros( |
| | channels, t, device=left_edge.device, dtype=left_edge.dtype |
| | ) |
| | left_edge = torch.cat((left_edge, pad), dim=-1) |
| | right_edge = torch.cat((pad, right_edge), dim=-1) |
| | return 1.0 + (left_edge + right_edge) |
| |
|
| | def streaming_forward( |
| | self, |
| | x: Tensor, |
| | cache: Tensor, |
| | ) -> Tuple[Tensor, Tensor]: |
| | """Streaming Forward function. |
| | |
| | Args: |
| | x: a Tensor of shape (batch_size, channels, seq_len) |
| | cache: cached left context of shape (batch_size, channels, left_pad) |
| | """ |
| | (batch_size, num_channels, seq_len) = x.shape |
| |
|
| | |
| | |
| | |
| | left_pad = self.kernel_size // 2 |
| |
|
| | |
| | assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad) |
| | x = torch.cat([cache, x], dim=2) |
| | |
| | cache = x[..., -left_pad:] |
| |
|
| | x_causal = self.causal_conv(x) |
| | assert x_causal.shape == (batch_size, num_channels, seq_len) |
| |
|
| | x_chunk = x[..., left_pad:] |
| | x_chunk = self.chunkwise_conv(x_chunk) |
| |
|
| | chunk_scale = self._get_chunk_scale(chunk_size=seq_len) |
| | x_chunk = x_chunk * chunk_scale |
| |
|
| | return x_chunk + x_causal, cache |
| |
|
| |
|
| | class BalancerFunction(torch.autograd.Function): |
| | @staticmethod |
| | def forward( |
| | ctx, |
| | x: Tensor, |
| | min_mean: float, |
| | max_mean: float, |
| | min_rms: float, |
| | max_rms: float, |
| | grad_scale: float, |
| | channel_dim: int, |
| | ) -> Tensor: |
| | if channel_dim < 0: |
| | channel_dim += x.ndim |
| | ctx.channel_dim = channel_dim |
| | ctx.save_for_backward(x) |
| | ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) |
| | return x |
| |
|
| | @staticmethod |
| | def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: |
| | (x,) = ctx.saved_tensors |
| | (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config |
| |
|
| | try: |
| | with torch.enable_grad(): |
| | with torch.cuda.amp.autocast(enabled=False): |
| | x = x.to(torch.float32) |
| | x = x.detach() |
| | x.requires_grad = True |
| | mean_dims = [i for i in range(x.ndim) if i != channel_dim] |
| | uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) |
| | mean = x.mean(dim=mean_dims, keepdim=True) |
| | stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() |
| | rms = uncentered_var.clamp(min=1.0e-20).sqrt() |
| |
|
| | m = mean / stddev |
| | |
| | m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() |
| |
|
| | |
| | |
| | rms_clamped = rms.clamp(min=min_rms, max=max_rms) |
| | r_loss = (rms_clamped / rms).log().abs() |
| |
|
| | loss = m_loss + r_loss |
| |
|
| | loss.backward(gradient=torch.ones_like(loss)) |
| | loss_grad = x.grad |
| | loss_grad_rms = ( |
| | (loss_grad**2) |
| | .mean(dim=mean_dims, keepdim=True) |
| | .sqrt() |
| | .clamp(min=1.0e-20) |
| | ) |
| |
|
| | loss_grad = loss_grad * (grad_scale / loss_grad_rms) |
| |
|
| | x_grad_float = x_grad.to(torch.float32) |
| | |
| | |
| | |
| | x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) |
| | x_grad = x_grad_mod.to(x_grad.dtype) |
| | except Exception as e: |
| | logging.info( |
| | f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." |
| | ) |
| |
|
| | return x_grad, None, None, None, None, None, None |
| |
|
| |
|
| | class Balancer(torch.nn.Module): |
| | """ |
| | Modifies the backpropped derivatives of a function to try to encourage, for |
| | each channel, that it is positive at least a proportion `threshold` of the |
| | time. It does this by multiplying negative derivative values by up to |
| | (1+max_factor), and positive derivative values by up to (1-max_factor), |
| | interpolated from 1 at the threshold to those extremal values when none |
| | of the inputs are positive. |
| | |
| | Args: |
| | num_channels: the number of channels |
| | channel_dim: the dimension/axis corresponding to the channel, e.g. |
| | -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. |
| | min_positive: the minimum, per channel, of the proportion of the time |
| | that (x > 0), below which we start to modify the derivatives. |
| | max_positive: the maximum, per channel, of the proportion of the time |
| | that (x > 0), above which we start to modify the derivatives. |
| | scale_gain_factor: determines the 'gain' with which we increase the |
| | change in gradient once the constraints on min_abs and max_abs |
| | are violated. |
| | min_abs: the minimum average-absolute-value difference from the mean |
| | value per channel, which we allow, before we start to modify |
| | the derivatives to prevent this. |
| | max_abs: the maximum average-absolute-value difference from the mean |
| | value per channel, which we allow, before we start to modify |
| | the derivatives to prevent this. |
| | prob: determines the minimum probability with which we modify the |
| | gradients for the {min,max}_positive and {min,max}_abs constraints, |
| | on each forward(). This is done randomly to prevent all layers |
| | from doing it at the same time. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | num_channels: int, |
| | channel_dim: int, |
| | min_positive: FloatLike = 0.05, |
| | max_positive: FloatLike = 0.95, |
| | min_abs: FloatLike = 0.2, |
| | max_abs: FloatLike = 100.0, |
| | grad_scale: FloatLike = 0.04, |
| | prob: Optional[FloatLike] = None, |
| | ): |
| | super().__init__() |
| |
|
| | if prob is None: |
| | prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) |
| | self.prob = prob |
| | |
| | |
| | self.mem_cutoff = CutoffEstimator(0.05) |
| |
|
| | |
| | self.num_channels = num_channels |
| | self.channel_dim = channel_dim |
| | self.min_positive = min_positive |
| | self.max_positive = max_positive |
| | self.min_abs = min_abs |
| | self.max_abs = max_abs |
| | self.grad_scale = grad_scale |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | if ( |
| | torch.jit.is_scripting() |
| | or not x.requires_grad |
| | or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) |
| | ): |
| | return _no_op(x) |
| |
|
| | prob = float(self.prob) |
| | if random.random() < prob: |
| | |
| | |
| | |
| | def _abs_to_rms(x): |
| | |
| | |
| | return 1.25331413732 * x |
| |
|
| | def _proportion_positive_to_mean(x): |
| | def _atanh(x): |
| | eps = 1.0e-10 |
| | |
| | |
| | return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 |
| |
|
| | def _approx_inverse_erf(x): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | return 0.8139535143 * _atanh(x) |
| |
|
| | |
| | |
| | x = -1 + (2 * x) |
| | return _approx_inverse_erf(x) |
| |
|
| | min_mean = _proportion_positive_to_mean(float(self.min_positive)) |
| | max_mean = _proportion_positive_to_mean(float(self.max_positive)) |
| | min_rms = _abs_to_rms(float(self.min_abs)) |
| | max_rms = _abs_to_rms(float(self.max_abs)) |
| | grad_scale = float(self.grad_scale) |
| |
|
| | assert x.shape[self.channel_dim] == self.num_channels |
| |
|
| | return BalancerFunction.apply( |
| | x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim |
| | ) |
| | else: |
| | return _no_op(x) |
| |
|
| |
|
| | def penalize_abs_values_gt( |
| | x: Tensor, limit: float, penalty: float, name: str = None |
| | ) -> Tensor: |
| | """ |
| | Returns x unmodified, but in backprop will put a penalty for the excess of |
| | the absolute values of elements of x over the limit "limit". E.g. if |
| | limit == 10.0, then if x has any values over 10 it will get a penalty. |
| | |
| | Caution: the value of this penalty will be affected by grad scaling used |
| | in automatic mixed precision training. For this reasons we use this, |
| | it shouldn't really matter, or may even be helpful; we just use this |
| | to disallow really implausible values of scores to be given to softmax. |
| | |
| | The name is for randomly printed debug info. |
| | """ |
| | x_sign = x.sign() |
| | over_limit = (x.abs() - limit) > 0 |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) |
| | |
| | |
| | x = with_loss(x, aux_loss, name) |
| | |
| | return x |
| |
|
| |
|
| | def _diag(x: Tensor): |
| | if x.ndim == 2: |
| | return x.diag() |
| | else: |
| | (batch, dim, dim) = x.shape |
| | x = x.reshape(batch, dim * dim) |
| | x = x[:, :: dim + 1] |
| | assert x.shape == (batch, dim) |
| | return x |
| |
|
| |
|
| | def _whitening_metric(x: Tensor, num_groups: int): |
| | """ |
| | Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of |
| | of the centered feature covariance are the same within each group's covariance matrix |
| | and also between groups. |
| | Args: |
| | x: a Tensor of shape (*, num_channels) |
| | num_groups: the number of groups of channels, a number >=1 that divides num_channels |
| | Returns: |
| | Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and |
| | greater than 1.0 otherwise. |
| | """ |
| | assert x.dtype != torch.float16 |
| | x = x.reshape(-1, x.shape[-1]) |
| | (num_frames, num_channels) = x.shape |
| | assert num_channels % num_groups == 0 |
| | channels_per_group = num_channels // num_groups |
| | x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) |
| | |
| | |
| | |
| | |
| | |
| | x = x - x.mean(dim=1, keepdim=True) |
| | |
| | x_covar = torch.matmul(x.transpose(1, 2), x) |
| | x_covar_mean_diag = _diag(x_covar).mean() |
| | |
| | |
| | |
| | x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) |
| | |
| | metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) |
| | return metric |
| |
|
| |
|
| | class WhiteningPenaltyFunction(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: |
| | ctx.save_for_backward(x) |
| | ctx.module = module |
| | return x |
| |
|
| | @staticmethod |
| | def backward(ctx, x_grad: Tensor): |
| | (x_orig,) = ctx.saved_tensors |
| | w = ctx.module |
| |
|
| | try: |
| | with torch.enable_grad(): |
| | with torch.cuda.amp.autocast(enabled=False): |
| | x_detached = x_orig.to(torch.float32).detach() |
| | x_detached.requires_grad = True |
| |
|
| | metric = _whitening_metric(x_detached, w.num_groups) |
| |
|
| | if random.random() < 0.005 or __name__ == "__main__": |
| | logging.info( |
| | f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " |
| | f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}" |
| | ) |
| |
|
| | if metric < float(w.whitening_limit): |
| | w.prob = w.min_prob |
| | return x_grad, None |
| | else: |
| | w.prob = w.max_prob |
| | metric.backward() |
| | penalty_grad = x_detached.grad |
| | scale = float(w.grad_scale) * ( |
| | x_grad.to(torch.float32).norm() |
| | / (penalty_grad.norm() + 1.0e-20) |
| | ) |
| | penalty_grad = penalty_grad * scale |
| | return x_grad + penalty_grad.to(x_grad.dtype), None |
| | except Exception as e: |
| | logging.info( |
| | f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue." |
| | ) |
| | return x_grad, None |
| |
|
| |
|
| | class Whiten(nn.Module): |
| | def __init__( |
| | self, |
| | num_groups: int, |
| | whitening_limit: FloatLike, |
| | prob: Union[float, Tuple[float, float]], |
| | grad_scale: FloatLike, |
| | ): |
| | """ |
| | Args: |
| | num_groups: the number of groups to divide the channel dim into before |
| | whitening. We will attempt to make the feature covariance |
| | within each group, after mean subtraction, as "white" as possible, |
| | while having the same trace across all groups. |
| | whitening_limit: a value greater than 1.0, that dictates how much |
| | freedom we have to violate the constraints. 1.0 would mean perfectly |
| | white, with exactly the same trace across groups; larger values |
| | give more freedom. E.g. 2.0. |
| | prob: the probability with which we apply the gradient modification |
| | (also affects the grad scale). May be supplied as a float, |
| | or as a pair (min_prob, max_prob) |
| | |
| | grad_scale: determines the scale on the gradient term from this object, |
| | relative to the rest of the gradient on the attention weights. |
| | E.g. 0.02 (you may want to use smaller values than this if prob is large) |
| | """ |
| | super(Whiten, self).__init__() |
| | assert num_groups >= 1 |
| | assert float(whitening_limit) >= 1 |
| | assert float(grad_scale) >= 0 |
| | self.num_groups = num_groups |
| | self.whitening_limit = whitening_limit |
| | self.grad_scale = grad_scale |
| |
|
| | if isinstance(prob, float): |
| | prob = (prob, prob) |
| | (self.min_prob, self.max_prob) = prob |
| | assert 0 < self.min_prob <= self.max_prob <= 1 |
| | self.prob = self.max_prob |
| | self.name = None |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | """ |
| | In the forward pass, this function just returns the input unmodified. |
| | In the backward pass, it will modify the gradients to ensure that the |
| | distribution in each group has close to (lambda times I) as the covariance |
| | after mean subtraction, with the same lambda across groups. |
| | For whitening_limit > 1, there will be more freedom to violate this |
| | constraint. |
| | |
| | Args: |
| | x: the input of shape (*, num_channels) |
| | |
| | Returns: |
| | x, unmodified. You should make sure |
| | you use the returned value, or the graph will be freed |
| | and nothing will happen in backprop. |
| | """ |
| | grad_scale = float(self.grad_scale) |
| | if not x.requires_grad or random.random() > self.prob or grad_scale == 0: |
| | return _no_op(x) |
| | else: |
| | return WhiteningPenaltyFunction.apply(x, self) |
| |
|
| |
|
| | class WithLoss(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, x: Tensor, y: Tensor, name: str): |
| | ctx.y_shape = y.shape |
| | if random.random() < 0.002 and name is not None: |
| | loss_sum = y.sum().item() |
| | logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") |
| | return x |
| |
|
| | @staticmethod |
| | def backward(ctx, ans_grad: Tensor): |
| | return ( |
| | ans_grad, |
| | torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), |
| | None, |
| | ) |
| |
|
| |
|
| | def with_loss(x, y, name): |
| | |
| | return WithLoss.apply(x, y, name) |
| |
|
| |
|
| | class ScaleGradFunction(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, x: Tensor, alpha: float) -> Tensor: |
| | ctx.alpha = alpha |
| | return x |
| |
|
| | @staticmethod |
| | def backward(ctx, grad: Tensor): |
| | return grad * ctx.alpha, None |
| |
|
| |
|
| | def scale_grad(x: Tensor, alpha: float): |
| | return ScaleGradFunction.apply(x, alpha) |
| |
|
| |
|
| | class ScaleGrad(nn.Module): |
| | def __init__(self, alpha: float): |
| | super().__init__() |
| | self.alpha = alpha |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: |
| | return x |
| | return scale_grad(x, self.alpha) |
| |
|
| |
|
| | class LimitParamValue(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, x: Tensor, min: float, max: float): |
| | ctx.save_for_backward(x) |
| | assert max >= min |
| | ctx.min = min |
| | ctx.max = max |
| | return x |
| |
|
| | @staticmethod |
| | def backward(ctx, x_grad: Tensor): |
| | (x,) = ctx.saved_tensors |
| | |
| | |
| | x_grad = x_grad * torch.where( |
| | torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 |
| | ) |
| | |
| | |
| | x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) |
| | return x_grad, None, None |
| |
|
| |
|
| | def limit_param_value( |
| | x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True |
| | ): |
| | |
| | |
| | |
| | |
| | |
| | if training and random.random() < prob: |
| | return LimitParamValue.apply(x, min, max) |
| | else: |
| | return x |
| |
|
| |
|
| | def _no_op(x: Tensor) -> Tensor: |
| | if torch.jit.is_scripting() or torch.jit.is_tracing(): |
| | return x |
| | else: |
| | |
| | |
| | return x.chunk(1, dim=-1)[0] |
| |
|
| |
|
| | class Identity(torch.nn.Module): |
| | def __init__(self): |
| | super(Identity, self).__init__() |
| |
|
| | def forward(self, x): |
| | return _no_op(x) |
| |
|
| |
|
| | class DoubleSwishFunction(torch.autograd.Function): |
| | """ |
| | double_swish(x) = x * torch.sigmoid(x-1) |
| | |
| | This is a definition, originally motivated by its close numerical |
| | similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). |
| | |
| | Memory-efficient derivative computation: |
| | double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) |
| | double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). |
| | Now, s'(x) = s(x) * (1-s(x)). |
| | double_swish'(x) = x * s'(x) + s(x). |
| | = x * s(x) * (1-s(x)) + s(x). |
| | = double_swish(x) * (1-s(x)) + s(x) |
| | ... so we just need to remember s(x) but not x itself. |
| | """ |
| |
|
| | @staticmethod |
| | def forward(ctx, x: Tensor) -> Tensor: |
| | requires_grad = x.requires_grad |
| | if x.dtype == torch.float16 or x.dtype == torch.bfloat16: |
| | x = x.to(torch.float32) |
| |
|
| | s = torch.sigmoid(x - 1.0) |
| | y = x * s |
| |
|
| | if requires_grad: |
| | deriv = y * (1 - s) + s |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | floor = -0.044 |
| | ceil = 1.2 |
| | d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( |
| | deriv |
| | ) |
| | if __name__ == "__main__": |
| | |
| | assert d_scaled.min() >= 0.0 |
| | assert d_scaled.max() < 256.0 |
| | d_int = d_scaled.to(torch.uint8) |
| | ctx.save_for_backward(d_int) |
| | if x.dtype == torch.float16 or torch.is_autocast_enabled(): |
| | y = y.to(torch.float16) |
| | return y |
| |
|
| | @staticmethod |
| | def backward(ctx, y_grad: Tensor) -> Tensor: |
| | (d,) = ctx.saved_tensors |
| | |
| | floor = -0.043637 |
| | ceil = 1.2 |
| |
|
| | d = d * ((ceil - floor) / 255.0) + floor |
| | return y_grad * d |
| |
|
| |
|
| | class DoubleSwish(torch.nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | """Return double-swish activation function which is an approximation to Swish(Swish(x)), |
| | that we approximate closely with x * sigmoid(x-1). |
| | """ |
| | if torch.jit.is_scripting() or torch.jit.is_tracing(): |
| | return x * torch.sigmoid(x - 1.0) |
| | return DoubleSwishFunction.apply(x) |
| |
|
| |
|
| | |
| | class Dropout2(nn.Module): |
| | def __init__(self, p: FloatLike): |
| | super().__init__() |
| | self.p = p |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) |
| |
|
| |
|
| | class MulForDropout3(torch.autograd.Function): |
| | |
| | |
| | @staticmethod |
| | @custom_fwd |
| | def forward(ctx, x, y, alpha): |
| | assert not y.requires_grad |
| | ans = x * y * alpha |
| | ctx.save_for_backward(ans) |
| | ctx.alpha = alpha |
| | return ans |
| |
|
| | @staticmethod |
| | @custom_bwd |
| | def backward(ctx, ans_grad): |
| | (ans,) = ctx.saved_tensors |
| | x_grad = ctx.alpha * ans_grad * (ans != 0) |
| | return x_grad, None, None |
| |
|
| |
|
| | |
| | |
| | class Dropout3(nn.Module): |
| | def __init__(self, p: FloatLike, shared_dim: int): |
| | super().__init__() |
| | self.p = p |
| | self.shared_dim = shared_dim |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | p = float(self.p) |
| | if not self.training or p == 0: |
| | return _no_op(x) |
| | scale = 1.0 / (1 - p) |
| | rand_shape = list(x.shape) |
| | rand_shape[self.shared_dim] = 1 |
| | mask = torch.rand(*rand_shape, device=x.device) > p |
| | ans = MulForDropout3.apply(x, mask, scale) |
| | return ans |
| |
|
| |
|
| | class SwooshLFunction(torch.autograd.Function): |
| | """ |
| | swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 |
| | """ |
| |
|
| | @staticmethod |
| | def forward(ctx, x: Tensor) -> Tensor: |
| | requires_grad = x.requires_grad |
| | if x.dtype == torch.float16 or x.dtype == torch.bfloat16: |
| | x = x.to(torch.float32) |
| |
|
| | zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) |
| |
|
| | coeff = -0.08 |
| |
|
| | with torch.cuda.amp.autocast(enabled=False): |
| | with torch.enable_grad(): |
| | x = x.detach() |
| | x.requires_grad = True |
| | y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 |
| |
|
| | if not requires_grad: |
| | return y |
| |
|
| | y.backward(gradient=torch.ones_like(y)) |
| |
|
| | grad = x.grad |
| | floor = coeff |
| | ceil = 1.0 + coeff + 0.005 |
| |
|
| | d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( |
| | grad |
| | ) |
| | if __name__ == "__main__": |
| | |
| | assert d_scaled.min() >= 0.0 |
| | assert d_scaled.max() < 256.0 |
| |
|
| | d_int = d_scaled.to(torch.uint8) |
| | ctx.save_for_backward(d_int) |
| | if x.dtype == torch.float16 or torch.is_autocast_enabled(): |
| | y = y.to(torch.get_autocast_gpu_dtype()) |
| | return y |
| |
|
| | @staticmethod |
| | def backward(ctx, y_grad: Tensor) -> Tensor: |
| | (d,) = ctx.saved_tensors |
| | |
| |
|
| | coeff = -0.08 |
| | floor = coeff |
| | ceil = 1.0 + coeff + 0.005 |
| | d = d * ((ceil - floor) / 255.0) + floor |
| | return y_grad * d |
| |
|
| |
|
| | class SwooshL(torch.nn.Module): |
| | def forward(self, x: Tensor) -> Tensor: |
| | """Return Swoosh-L activation.""" |
| | return SwooshLFunction.apply(x) |
| | if torch.jit.is_scripting() or torch.jit.is_tracing(): |
| | zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) |
| | return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 |
| | if not x.requires_grad: |
| | return k2.swoosh_l_forward(x) |
| | else: |
| | return k2.swoosh_l(x) |
| | |
| |
|
| |
|
| | class SwooshLOnnx(torch.nn.Module): |
| | def forward(self, x: Tensor) -> Tensor: |
| | """Return Swoosh-L activation.""" |
| | zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) |
| | return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 |
| |
|
| |
|
| | class SwooshRFunction(torch.autograd.Function): |
| | """ |
| | swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 |
| | |
| | derivatives are between -0.08 and 0.92. |
| | """ |
| |
|
| | @staticmethod |
| | def forward(ctx, x: Tensor) -> Tensor: |
| | requires_grad = x.requires_grad |
| |
|
| | if x.dtype == torch.float16 or x.dtype == torch.bfloat16: |
| | x = x.to(torch.float32) |
| |
|
| | zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) |
| |
|
| | with torch.cuda.amp.autocast(enabled=False): |
| | with torch.enable_grad(): |
| | x = x.detach() |
| | x.requires_grad = True |
| | y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 |
| |
|
| | if not requires_grad: |
| | return y |
| | y.backward(gradient=torch.ones_like(y)) |
| |
|
| | grad = x.grad |
| | floor = -0.08 |
| | ceil = 0.925 |
| |
|
| | d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( |
| | grad |
| | ) |
| | if __name__ == "__main__": |
| | |
| | assert d_scaled.min() >= 0.0 |
| | assert d_scaled.max() < 256.0 |
| |
|
| | d_int = d_scaled.to(torch.uint8) |
| | ctx.save_for_backward(d_int) |
| | if x.dtype == torch.float16 or torch.is_autocast_enabled(): |
| | y = y.to(torch.get_autocast_gpu_dtype()) |
| | return y |
| |
|
| | @staticmethod |
| | def backward(ctx, y_grad: Tensor) -> Tensor: |
| | (d,) = ctx.saved_tensors |
| | |
| | floor = -0.08 |
| | ceil = 0.925 |
| | d = d * ((ceil - floor) / 255.0) + floor |
| | return y_grad * d |
| |
|
| |
|
| | class SwooshR(torch.nn.Module): |
| | def forward(self, x: Tensor) -> Tensor: |
| | """Return Swoosh-R activation.""" |
| | |
| | return SwooshRFunction.apply(x) |
| | if True: |
| | zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) |
| | return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 |
| | if not x.requires_grad: |
| | return k2.swoosh_r_forward(x) |
| | else: |
| | return k2.swoosh_r(x) |
| | |
| |
|
| |
|
| | class SwooshROnnx(torch.nn.Module): |
| | def forward(self, x: Tensor) -> Tensor: |
| | """Return Swoosh-R activation.""" |
| | zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) |
| | return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687 |
| |
|
| |
|
| | |
| | |
| | def SwooshLForward(x: Tensor): |
| | x_offset = x - 4.0 |
| | |
| | log_sum = torch.nn.functional.softplus(x_offset) |
| | log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) |
| | return log_sum - 0.08 * x - 0.035 |
| |
|
| |
|
| | |
| | |
| | def SwooshRForward(x: Tensor): |
| | x_offset = x - 1.0 |
| | |
| | log_sum = torch.nn.functional.softplus(x_offset) |
| | log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) |
| | return log_sum - 0.08 * x - 0.313261687 |
| |
|
| |
|
| | class ActivationDropoutAndLinearFunction(torch.autograd.Function): |
| | @staticmethod |
| | @custom_fwd |
| | def forward( |
| | ctx, |
| | x: Tensor, |
| | weight: Tensor, |
| | bias: Optional[Tensor], |
| | activation: str, |
| | dropout_p: float, |
| | dropout_shared_dim: Optional[int], |
| | ): |
| | if dropout_p != 0.0: |
| | dropout_shape = list(x.shape) |
| | if dropout_shared_dim is not None: |
| | dropout_shape[dropout_shared_dim] = 1 |
| | |
| | dropout_mask = (1.0 / (1.0 - dropout_p)) * ( |
| | torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p |
| | ) |
| | else: |
| | dropout_mask = None |
| |
|
| | ctx.save_for_backward(x, weight, bias, dropout_mask) |
| |
|
| | ctx.activation = activation |
| |
|
| | forward_activation_dict = { |
| | "SwooshL": k2.swoosh_l_forward, |
| | "SwooshR": k2.swoosh_r_forward, |
| | } |
| | |
| | |
| | activation_func = forward_activation_dict[activation] |
| | x = activation_func(x) |
| | if dropout_mask is not None: |
| | x = x * dropout_mask |
| | x = torch.nn.functional.linear(x, weight, bias) |
| | return x |
| |
|
| | @staticmethod |
| | @custom_bwd |
| | def backward(ctx, ans_grad: Tensor): |
| | saved = ctx.saved_tensors |
| | (x, weight, bias, dropout_mask) = saved |
| |
|
| | forward_and_deriv_activation_dict = { |
| | "SwooshL": k2.swoosh_l_forward_and_deriv, |
| | "SwooshR": k2.swoosh_r_forward_and_deriv, |
| | } |
| | |
| | |
| | func = forward_and_deriv_activation_dict[ctx.activation] |
| |
|
| | y, func_deriv = func(x) |
| | if dropout_mask is not None: |
| | y = y * dropout_mask |
| | |
| | |
| | (out_channels, in_channels) = weight.shape |
| |
|
| | in_channels = y.shape[-1] |
| | g = ans_grad.reshape(-1, out_channels) |
| | weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) |
| | y_deriv = torch.matmul(ans_grad, weight) |
| | bias_deriv = None if bias is None else g.sum(dim=0) |
| | x_deriv = y_deriv * func_deriv |
| | if dropout_mask is not None: |
| | |
| | x_deriv = x_deriv * dropout_mask |
| |
|
| | return x_deriv, weight_deriv, bias_deriv, None, None, None |
| |
|
| |
|
| | class ActivationDropoutAndLinear(torch.nn.Module): |
| | """ |
| | This merges an activation function followed by dropout and then a nn.Linear module; |
| | it does so in a memory efficient way so that it only stores the input to the whole |
| | module. If activation == SwooshL and dropout_shared_dim != None, this will be |
| | equivalent to: |
| | nn.Sequential(SwooshL(), |
| | Dropout3(dropout_p, shared_dim=dropout_shared_dim), |
| | ScaledLinear(in_channels, out_channels, bias=bias, |
| | initial_scale=initial_scale)) |
| | If dropout_shared_dim is None, the dropout would be equivalent to |
| | Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout |
| | mask is smaller. |
| | |
| | Args: |
| | in_channels: number of input channels, e.g. 256 |
| | out_channels: number of output channels, e.g. 256 |
| | bias: if true, have a bias |
| | activation: the activation function, for now just support SwooshL. |
| | dropout_p: the dropout probability or schedule (happens after nonlinearity). |
| | dropout_shared_dim: the dimension, if any, across which the dropout mask is |
| | shared (e.g. the time dimension). If None, this may be less memory |
| | efficient if there are modules before this one that cache the input |
| | for their backprop (e.g. Balancer or Whiten). |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | bias: bool = True, |
| | activation: str = "SwooshL", |
| | dropout_p: FloatLike = 0.0, |
| | dropout_shared_dim: Optional[int] = -1, |
| | initial_scale: float = 1.0, |
| | ): |
| | super().__init__() |
| | |
| | |
| | l = ScaledLinear( |
| | in_channels, out_channels, bias=bias, initial_scale=initial_scale |
| | ) |
| |
|
| | self.weight = l.weight |
| | |
| | |
| | |
| | |
| | self.register_parameter("bias", l.bias) |
| |
|
| | self.activation = activation |
| | self.dropout_p = dropout_p |
| | self.dropout_shared_dim = dropout_shared_dim |
| |
|
| | def forward(self, x: Tensor): |
| | |
| | if True: |
| | if self.activation == "SwooshL": |
| | x = SwooshLForward(x) |
| | elif self.activation == "SwooshR": |
| | x = SwooshRForward(x) |
| | else: |
| | assert False, self.activation |
| | return torch.nn.functional.linear(x, self.weight, self.bias) |
| |
|
| | return ActivationDropoutAndLinearFunction.apply( |
| | x, |
| | self.weight, |
| | self.bias, |
| | self.activation, |
| | float(self.dropout_p), |
| | self.dropout_shared_dim, |
| | ) |
| |
|
| |
|
| | def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: |
| | if num_channels <= x.shape[-1]: |
| | return x[..., :num_channels] |
| | else: |
| | shape = list(x.shape) |
| | shape[-1] = num_channels - shape[-1] |
| | zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) |
| | return torch.cat((x, zeros), dim=-1) |
| |
|
| |
|
| | def _test_whiten(): |
| | for proportion in [0.1, 0.5, 10.0]: |
| | logging.info(f"_test_whiten(): proportion = {proportion}") |
| | x = torch.randn(100, 128) |
| | direction = torch.randn(128) |
| | coeffs = torch.randn(100, 1) |
| | x += proportion * direction * coeffs |
| |
|
| | x.requires_grad = True |
| |
|
| | m = Whiten( |
| | 1, 5.0, prob=1.0, grad_scale=0.1 |
| | ) |
| |
|
| | for _ in range(4): |
| | y = m(x) |
| |
|
| | y_grad = torch.randn_like(x) |
| | y.backward(gradient=y_grad) |
| |
|
| | if proportion < 0.2: |
| | assert torch.allclose(x.grad, y_grad) |
| | elif proportion > 1.0: |
| | assert not torch.allclose(x.grad, y_grad) |
| |
|
| |
|
| | def _test_balancer_sign(): |
| | probs = torch.arange(0, 1, 0.01) |
| | N = 1000 |
| | x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) |
| | x = x.detach() |
| | x.requires_grad = True |
| | m = Balancer( |
| | probs.numel(), |
| | channel_dim=0, |
| | min_positive=0.05, |
| | max_positive=0.95, |
| | min_abs=0.0, |
| | prob=1.0, |
| | ) |
| |
|
| | y_grad = torch.sign(torch.randn(probs.numel(), N)) |
| |
|
| | y = m(x) |
| | y.backward(gradient=y_grad) |
| | print("_test_balancer_sign: x = ", x) |
| | print("_test_balancer_sign: y grad = ", y_grad) |
| | print("_test_balancer_sign: x grad = ", x.grad) |
| |
|
| |
|
| | def _test_balancer_magnitude(): |
| | magnitudes = torch.arange(0, 1, 0.01) |
| | N = 1000 |
| | x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) |
| | x = x.detach() |
| | x.requires_grad = True |
| | m = Balancer( |
| | magnitudes.numel(), |
| | channel_dim=0, |
| | min_positive=0.0, |
| | max_positive=1.0, |
| | min_abs=0.2, |
| | max_abs=0.7, |
| | prob=1.0, |
| | ) |
| |
|
| | y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) |
| |
|
| | y = m(x) |
| | y.backward(gradient=y_grad) |
| | print("_test_balancer_magnitude: x = ", x) |
| | print("_test_balancer_magnitude: y grad = ", y_grad) |
| | print("_test_balancer_magnitude: x grad = ", x.grad) |
| |
|
| |
|
| | def _test_double_swish_deriv(): |
| | x = torch.randn(10, 12, dtype=torch.double) * 3.0 |
| | x.requires_grad = True |
| | m = DoubleSwish() |
| |
|
| | tol = (1.2 - (-0.043637)) / 255.0 |
| | torch.autograd.gradcheck(m, x, atol=tol) |
| |
|
| | |
| | x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 |
| | x.requires_grad = True |
| | y = m(x) |
| |
|
| |
|
| | def _test_swooshl_deriv(): |
| | x = torch.randn(10, 12, dtype=torch.double) * 3.0 |
| | x.requires_grad = True |
| | m = SwooshL() |
| |
|
| | tol = 1.0 / 255.0 |
| | torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) |
| |
|
| | |
| | x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 |
| | x.requires_grad = True |
| | y = m(x) |
| |
|
| |
|
| | def _test_swooshr_deriv(): |
| | x = torch.randn(10, 12, dtype=torch.double) * 3.0 |
| | x.requires_grad = True |
| | m = SwooshR() |
| |
|
| | tol = 1.0 / 255.0 |
| | torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) |
| |
|
| | |
| | x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 |
| | x.requires_grad = True |
| | y = m(x) |
| |
|
| |
|
| | def _test_softmax(): |
| | a = torch.randn(2, 10, dtype=torch.float64) |
| | b = a.clone() |
| | a.requires_grad = True |
| | b.requires_grad = True |
| | a.softmax(dim=1)[:, 0].sum().backward() |
| | print("a grad = ", a.grad) |
| | softmax(b, dim=1)[:, 0].sum().backward() |
| | print("b grad = ", b.grad) |
| | assert torch.allclose(a.grad, b.grad) |
| |
|
| |
|
| | def _test_piecewise_linear(): |
| | p = PiecewiseLinear((0, 10.0)) |
| | for x in [-100, 0, 100]: |
| | assert p(x) == 10.0 |
| | p = PiecewiseLinear((0, 10.0), (1, 0.0)) |
| | for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]: |
| | print("x, y = ", x, y) |
| | assert p(x) == y, (x, p(x), y) |
| |
|
| | q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) |
| | x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0] |
| | pq = p.max(q) |
| | for x in x_vals: |
| | y1 = max(p(x), q(x)) |
| | y2 = pq(x) |
| | assert abs(y1 - y2) < 0.001 |
| | pq = p.min(q) |
| | for x in x_vals: |
| | y1 = min(p(x), q(x)) |
| | y2 = pq(x) |
| | assert abs(y1 - y2) < 0.001 |
| | pq = p + q |
| | for x in x_vals: |
| | y1 = p(x) + q(x) |
| | y2 = pq(x) |
| | assert abs(y1 - y2) < 0.001 |
| |
|
| |
|
| | def _test_activation_dropout_and_linear(): |
| | in_channels = 20 |
| | out_channels = 30 |
| |
|
| | for bias in [True, False]: |
| | |
| | |
| | |
| | |
| | for dropout_p in [0.0]: |
| | for activation in ["SwooshL", "SwooshR"]: |
| | m1 = nn.Sequential( |
| | SwooshL() if activation == "SwooshL" else SwooshR(), |
| | Dropout3(p=dropout_p, shared_dim=-1), |
| | ScaledLinear( |
| | in_channels, out_channels, bias=bias, initial_scale=0.5 |
| | ), |
| | ) |
| | m2 = ActivationDropoutAndLinear( |
| | in_channels, |
| | out_channels, |
| | bias=bias, |
| | initial_scale=0.5, |
| | activation=activation, |
| | dropout_p=dropout_p, |
| | ) |
| | with torch.no_grad(): |
| | m2.weight[:] = m1[2].weight |
| | if bias: |
| | m2.bias[:] = m1[2].bias |
| | |
| | x1 = torch.randn(10, in_channels) |
| | x1.requires_grad = True |
| |
|
| | |
| | assert torch.allclose( |
| | SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03 |
| | ) |
| |
|
| | x2 = x1.clone().detach() |
| | x2.requires_grad = True |
| | seed = 10 |
| | torch.manual_seed(seed) |
| | y1 = m1(x1) |
| | y_grad = torch.randn_like(y1) |
| | y1.backward(gradient=y_grad) |
| | torch.manual_seed(seed) |
| | y2 = m2(x2) |
| | y2.backward(gradient=y_grad) |
| |
|
| | print( |
| | f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}" |
| | ) |
| | print("y1 = ", y1) |
| | print("y2 = ", y2) |
| | assert torch.allclose(y1, y2, atol=0.02) |
| | assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) |
| | if bias: |
| | assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) |
| | print("x1.grad = ", x1.grad) |
| | print("x2.grad = ", x2.grad) |
| |
|
| | def isclose(a, b): |
| | |
| | return (a * b).sum() > 0.9 * ( |
| | (a**2).sum() * (b**2).sum() |
| | ).sqrt() |
| |
|
| | |
| | |
| | assert isclose(x1.grad, x2.grad) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | logging.getLogger().setLevel(logging.INFO) |
| | torch.set_num_threads(1) |
| | torch.set_num_interop_threads(1) |
| | _test_piecewise_linear() |
| | _test_softmax() |
| | _test_whiten() |
| | _test_balancer_sign() |
| | _test_balancer_magnitude() |
| | _test_double_swish_deriv() |
| | _test_swooshr_deriv() |
| | _test_swooshl_deriv() |
| | _test_activation_dropout_and_linear() |