| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from torch import nn |
|
|
| try: |
| from apex.contrib.layer_norm.layer_norm import FastLayerNorm as OrigFastLayerNorm |
| from apex.contrib.layer_norm.layer_norm import _fast_layer_norm |
| from apex.transformer.layers.layer_norm import FastLayerNorm |
|
|
| HAVE_APEX = True |
| except (ImportError, ModuleNotFoundError): |
| HAVE_APEX = False |
|
|
|
|
| if HAVE_APEX: |
| |
| class LayerNorm1P(FastLayerNorm): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| assert isinstance( |
| self, OrigFastLayerNorm |
| ), 'LayerNorm1P implemented only as an apex.contrib.layer_norm.FastLayerNorm extension' |
|
|
| def reset_parameters(self): |
| nn.init.zeros_(self.weight) |
| nn.init.zeros_(self.bias) |
|
|
| def forward(self, x): |
| return _fast_layer_norm(x, self.weight + 1, self.bias, self.epsilon) |
|
|
|
|
| else: |
|
|
| class LayerNorm1P(nn.Module): |
| def __init__(self, *args, **kwargs): |
| raise NotImplementedError('LayerNorm1P available only with apex installed') |
|
|