| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | try: |
| | from apex.normalization import FusedLayerNorm as _FusedLayerNorm |
| |
|
| | has_fused_layernorm = True |
| |
|
| | class FusedLayerNorm(_FusedLayerNorm): |
| | @torch.jit.unused |
| | def forward(self, x): |
| | if not x.is_cuda: |
| | return super().forward(x) |
| | else: |
| | with torch.cuda.device(x.device): |
| | return super().forward(x) |
| |
|
| |
|
| | except ImportError: |
| | has_fused_layernorm = False |
| |
|
| |
|
| | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): |
| | if torch.jit.is_scripting(): |
| | export = True |
| | if not export and torch.cuda.is_available() and has_fused_layernorm: |
| | return FusedLayerNorm(normalized_shape, eps, elementwise_affine) |
| | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) |
| |
|
| |
|
| | class Fp32LayerNorm(nn.LayerNorm): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | def forward(self, input): |
| | output = F.layer_norm( |
| | input.float(), |
| | self.normalized_shape, |
| | self.weight.float() if self.weight is not None else None, |
| | self.bias.float() if self.bias is not None else None, |
| | self.eps, |
| | ) |
| | return output.type_as(input) |
| |
|