| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """This code is copied fron NVIDIA apex: |
| https://github.com/NVIDIA/apex |
| with some changes. """ |
|
|
| import numbers |
| import torch |
| from torch.nn.parameter import Parameter |
| from torch.nn import init |
| import importlib |
|
|
| from megatron.mpu import make_viewless_tensor |
|
|
| try: |
| from apex.contrib.layer_norm.layer_norm import FastLayerNormFN |
| HAVE_PERSIST_LAYER_NORM = True |
| except: |
| HAVE_PERSIST_LAYER_NORM = False |
|
|
| global fused_mix_prec_layer_norm_cuda |
| fused_mix_prec_layer_norm_cuda = None |
|
|
|
|
| class FusedLayerNormAffineFunction(torch.autograd.Function): |
|
|
| @staticmethod |
| def forward(ctx, input, weight, bias, normalized_shape, eps): |
|
|
| ctx.normalized_shape = normalized_shape |
| ctx.eps = eps |
| input_ = input.contiguous() |
| weight_ = weight.contiguous() |
| bias_ = bias.contiguous() |
| output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( |
| input_, ctx.normalized_shape, weight_, bias_, ctx.eps) |
| ctx.save_for_backward(input_, weight_, bias_, mean, invvar) |
|
|
| return output |
|
|
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
|
|
| input_, weight_, bias_, mean, invvar = ctx.saved_tensors |
| grad_input = grad_weight = grad_bias = None |
| grad_input, grad_weight, grad_bias \ |
| = fused_mix_prec_layer_norm_cuda.backward_affine( |
| grad_output.contiguous(), mean, invvar, |
| input_, ctx.normalized_shape, |
| weight_, bias_, ctx.eps) |
|
|
| return grad_input, grad_weight, grad_bias, None, None |
|
|
|
|
|
|
| class MixedFusedLayerNorm(torch.nn.Module): |
|
|
| def __init__(self, normalized_shape, eps=1e-5, |
| no_persist_layer_norm=True, |
| sequence_parallel=False): |
| super(MixedFusedLayerNorm, self).__init__() |
|
|
| global fused_mix_prec_layer_norm_cuda |
| fused_mix_prec_layer_norm_cuda = importlib.import_module( |
| "fused_mix_prec_layer_norm_cuda") |
|
|
| |
| |
| |
| persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096, |
| 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, |
| 24576, 25600, 30720, 32768, 40960, 49152, 65536] |
| if normalized_shape not in persist_ln_hidden_sizes or \ |
| not HAVE_PERSIST_LAYER_NORM: |
| no_persist_layer_norm = True |
|
|
| if isinstance(normalized_shape, numbers.Integral): |
| normalized_shape = (normalized_shape,) |
| self.normalized_shape = torch.Size(normalized_shape) |
| self.eps = eps |
| self.weight = Parameter(torch.Tensor(*normalized_shape)) |
| self.bias = Parameter(torch.Tensor(*normalized_shape)) |
| self.reset_parameters() |
| self.no_persist_layer_norm = no_persist_layer_norm |
| self.sequence_parallel = sequence_parallel |
| |
| |
| setattr(self.weight, 'sequence_parallel', self.sequence_parallel) |
| setattr(self.bias, 'sequence_parallel', self.sequence_parallel) |
|
|
|
|
| def reset_parameters(self): |
|
|
| init.ones_(self.weight) |
| init.zeros_(self.bias) |
|
|
|
|
| def forward(self, input): |
|
|
| if self.no_persist_layer_norm: |
| return FusedLayerNormAffineFunction.apply( |
| input, self.weight, self.bias, self.normalized_shape, self.eps) |
| else: |
| output = FastLayerNormFN.apply( |
| input, self.weight, self.bias, self.eps) |
|
|
| |
| |
| |
| |
| output = make_viewless_tensor(inp = output, |
| requires_grad = input.requires_grad, |
| keep_graph = True) |
|
|
| return output |
|
|