|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from typing import Callable, List, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def relu_squared(x: torch.Tensor): |
|
|
return F.relu(x).pow(2) |
|
|
|
|
|
|
|
|
def gelu_accurate(x): |
|
|
if not hasattr(gelu_accurate, "_a"): |
|
|
gelu_accurate._a = math.sqrt(2 / math.pi) |
|
|
return ( |
|
|
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) |
|
|
) |
|
|
|
|
|
|
|
|
def is_xla_tensor(tensor): |
|
|
return torch.is_tensor(tensor) and tensor.device.type == "xla" |
|
|
|
|
|
|
|
|
def index_put(tensor, indices, value): |
|
|
if is_xla_tensor(tensor): |
|
|
for _ in range(indices.dim(), tensor.dim()): |
|
|
indices = indices.unsqueeze(-1) |
|
|
if indices.size(-1) < tensor.size(-1): |
|
|
indices = indices.expand_as(tensor) |
|
|
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) |
|
|
else: |
|
|
tensor[indices] = value |
|
|
return tensor |
|
|
|
|
|
|
|
|
def pad_to_multiple(x, multiple, dim=-1, value=0): |
|
|
|
|
|
if x is None: |
|
|
return None, 0 |
|
|
tsz = x.size(dim) |
|
|
m = tsz / multiple |
|
|
remainder = math.ceil(m) * multiple - tsz |
|
|
if m.is_integer(): |
|
|
return x, 0 |
|
|
pad_offset = (0,) * (-1 - dim) * 2 |
|
|
|
|
|
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder |
|
|
|
|
|
|
|
|
def gelu(x: torch.Tensor) -> torch.Tensor: |
|
|
return torch.nn.functional.gelu(x.float()).type_as(x) |
|
|
|
|
|
|
|
|
def get_activation_fn(activation: str) -> Callable: |
|
|
"""Returns the activation function corresponding to `activation`""" |
|
|
if activation == "relu": |
|
|
return F.relu |
|
|
elif activation == "relu_squared": |
|
|
return relu_squared |
|
|
elif activation == "gelu": |
|
|
return gelu |
|
|
elif activation == "gelu_fast": |
|
|
return gelu_accurate |
|
|
elif activation == "gelu_accurate": |
|
|
return gelu_accurate |
|
|
elif activation == "tanh": |
|
|
return torch.tanh |
|
|
elif activation == "linear": |
|
|
return lambda x: x |
|
|
elif activation == "swish": |
|
|
return torch.nn.SiLU |
|
|
else: |
|
|
raise RuntimeError("--activation-fn {} not supported".format(activation)) |
|
|
|
|
|
|
|
|
class SamePad(nn.Module): |
|
|
def __init__(self, kernel_size, causal=False): |
|
|
super().__init__() |
|
|
if causal: |
|
|
self.remove = kernel_size - 1 |
|
|
else: |
|
|
self.remove = 1 if kernel_size % 2 == 0 else 0 |
|
|
|
|
|
def forward(self, x): |
|
|
if self.remove > 0: |
|
|
x = x[:, :, : -self.remove] |
|
|
return x |
|
|
|
|
|
|
|
|
class SamePad2d(nn.Module): |
|
|
def __init__(self, kernel_size): |
|
|
super().__init__() |
|
|
self.remove = 1 if kernel_size % 2 == 0 else 0 |
|
|
|
|
|
def forward(self, x): |
|
|
assert len(x.size()) == 4 |
|
|
if self.remove > 0: |
|
|
x = x[:, :, : -self.remove, : -self.remove] |
|
|
return x |
|
|
|
|
|
|
|
|
class TransposeLast(nn.Module): |
|
|
def __init__(self, deconstruct_idx=None, tranpose_dim=-2): |
|
|
super().__init__() |
|
|
self.deconstruct_idx = deconstruct_idx |
|
|
self.tranpose_dim = tranpose_dim |
|
|
|
|
|
def forward(self, x): |
|
|
if self.deconstruct_idx is not None: |
|
|
x = x[self.deconstruct_idx] |
|
|
return x.transpose(self.tranpose_dim, -1) |
|
|
|
|
|
|
|
|
try: |
|
|
from apex.normalization import FusedLayerNorm as _FusedLayerNorm |
|
|
|
|
|
has_fused_layernorm = True |
|
|
|
|
|
class FusedLayerNorm(_FusedLayerNorm): |
|
|
@torch.jit.unused |
|
|
def forward(self, x): |
|
|
if not x.is_cuda: |
|
|
return super().forward(x) |
|
|
else: |
|
|
with torch.cuda.device(x.device): |
|
|
return super().forward(x) |
|
|
|
|
|
except ImportError: |
|
|
has_fused_layernorm = False |
|
|
|
|
|
|
|
|
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): |
|
|
if torch.jit.is_scripting() or torch.jit.is_tracing(): |
|
|
export = True |
|
|
if not export and torch.cuda.is_available() and has_fused_layernorm: |
|
|
return FusedLayerNorm(normalized_shape, eps, elementwise_affine) |
|
|
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) |
|
|
|
|
|
|
|
|
class Fp32LayerNorm(nn.LayerNorm): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
def forward(self, input): |
|
|
output = F.layer_norm( |
|
|
input.float(), |
|
|
self.normalized_shape, |
|
|
self.weight.float() if self.weight is not None else None, |
|
|
self.bias.float() if self.bias is not None else None, |
|
|
self.eps, |
|
|
) |
|
|
return output.type_as(input) |
|
|
|
|
|
|
|
|
class Fp32GroupNorm(nn.GroupNorm): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
def forward(self, input): |
|
|
output = F.group_norm( |
|
|
input.float(), |
|
|
self.num_groups, |
|
|
self.weight.float() if self.weight is not None else None, |
|
|
self.bias.float() if self.bias is not None else None, |
|
|
self.eps, |
|
|
) |
|
|
return output.type_as(input) |
|
|
|
|
|
|
|
|
def softmax(x, dim: int, onnx_trace: bool = False): |
|
|
if onnx_trace: |
|
|
return F.softmax(x.float(), dim=dim) |
|
|
else: |
|
|
return F.softmax(x, dim=dim, dtype=torch.float32) |
|
|
|
|
|
|
|
|
def quant_noise(module, p, block_size): |
|
|
""" |
|
|
Wraps modules and applies quantization noise to the weights for |
|
|
subsequent quantization with Iterative Product Quantization as |
|
|
described in "Training with Quantization Noise for Extreme Model Compression" |
|
|
|
|
|
Args: |
|
|
- module: nn.Module |
|
|
- p: amount of Quantization Noise |
|
|
- block_size: size of the blocks for subsequent quantization with iPQ |
|
|
|
|
|
Remarks: |
|
|
- Module weights must have the right sizes wrt the block size |
|
|
- Only Linear, Embedding and Conv2d modules are supported for the moment |
|
|
- For more detail on how to quantize by blocks with convolutional weights, |
|
|
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" |
|
|
- We implement the simplest form of noise here as stated in the paper |
|
|
which consists in randomly dropping blocks |
|
|
""" |
|
|
|
|
|
|
|
|
if p <= 0: |
|
|
return module |
|
|
|
|
|
|
|
|
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) |
|
|
|
|
|
|
|
|
is_conv = module.weight.ndim == 4 |
|
|
|
|
|
|
|
|
if not is_conv: |
|
|
assert ( |
|
|
module.weight.size(1) % block_size == 0 |
|
|
), "Input features must be a multiple of block sizes" |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
if module.kernel_size == (1, 1): |
|
|
assert ( |
|
|
module.in_channels % block_size == 0 |
|
|
), "Input channels must be a multiple of block sizes" |
|
|
|
|
|
else: |
|
|
k = module.kernel_size[0] * module.kernel_size[1] |
|
|
assert k % block_size == 0, "Kernel size must be a multiple of block size" |
|
|
|
|
|
def _forward_pre_hook(mod, input): |
|
|
|
|
|
if mod.training: |
|
|
if not is_conv: |
|
|
|
|
|
weight = mod.weight |
|
|
in_features = weight.size(1) |
|
|
out_features = weight.size(0) |
|
|
|
|
|
|
|
|
mask = torch.zeros( |
|
|
in_features // block_size * out_features, device=weight.device |
|
|
) |
|
|
mask.bernoulli_(p) |
|
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) |
|
|
|
|
|
else: |
|
|
|
|
|
weight = mod.weight |
|
|
in_channels = mod.in_channels |
|
|
out_channels = mod.out_channels |
|
|
|
|
|
|
|
|
if mod.kernel_size == (1, 1): |
|
|
mask = torch.zeros( |
|
|
int(in_channels // block_size * out_channels), |
|
|
device=weight.device, |
|
|
) |
|
|
mask.bernoulli_(p) |
|
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) |
|
|
else: |
|
|
mask = torch.zeros( |
|
|
weight.size(0), weight.size(1), device=weight.device |
|
|
) |
|
|
mask.bernoulli_(p) |
|
|
mask = ( |
|
|
mask.unsqueeze(2) |
|
|
.unsqueeze(3) |
|
|
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) |
|
|
) |
|
|
|
|
|
|
|
|
mask = mask.to( |
|
|
torch.bool |
|
|
) |
|
|
s = 1 / (1 - p) |
|
|
mod.weight.data = s * weight.masked_fill(mask, 0) |
|
|
|
|
|
module.register_forward_pre_hook(_forward_pre_hook) |
|
|
return module |
|
|
|
|
|
|
|
|
class FairseqDropout(nn.Module): |
|
|
def __init__(self, p, module_name=None): |
|
|
super().__init__() |
|
|
self.p = p |
|
|
self.module_name = module_name |
|
|
self.apply_during_inference = False |
|
|
|
|
|
def forward(self, x, inplace: bool = False): |
|
|
if self.p > 0 and (self.training or self.apply_during_inference): |
|
|
return F.dropout(x, p=self.p, training=True, inplace=inplace) |
|
|
else: |
|
|
return x |
|
|
|
|
|
def make_generation_fast_( |
|
|
self, |
|
|
name: str, |
|
|
retain_dropout: bool = False, |
|
|
retain_dropout_modules: Optional[List[str]] = None, |
|
|
**kwargs |
|
|
): |
|
|
if retain_dropout: |
|
|
if retain_dropout_modules is not None and self.module_name is None: |
|
|
pass |
|
|
elif ( |
|
|
retain_dropout_modules is None |
|
|
or self.module_name in retain_dropout_modules |
|
|
): |
|
|
self.apply_during_inference = True |
|
|
|
|
|
|
|
|
class GradMultiply(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, x, scale): |
|
|
ctx.scale = scale |
|
|
res = x.new(x) |
|
|
return res |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad): |
|
|
return grad * ctx.scale, None |
|
|
|