File size: 6,812 Bytes
ab0f6ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 | # Copyright (c) Microsoft, Inc. 2020
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: penhe@microsoft.com
# Date: 01/15/2020
#
import pdb
import math
from packaging import version
import torch
from torch.nn import LayerNorm
from ..utils.jit_tracing import traceable
if version.Version(torch.__version__) >= version.Version('1.0.0'):
from torch import _softmax_backward_data as _softmax_backward_data
else:
from torch import softmax_backward_data as _softmax_backward_data
__all__ = ['StableDropout', 'MaskedLayerNorm', 'XSoftmax', 'ACT2FN', 'LayerNorm']
@traceable
class XSoftmax(torch.autograd.Function):
""" Masked Softmax which is optimized for saving memory
Args:
input (:obj:`torch.tensor`): The input tensor that will apply softmax.
mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax caculation.
dim (int): The dimenssion that will apply softmax.
Example::
import torch
from DeBERTa.deberta import XSoftmax
# Make a tensor
x = torch.randn([4,20,100])
# Create a mask
mask = (x>0).int()
y = XSoftmax.apply(x, mask, dim=-1)
"""
@staticmethod
def forward(self, input, mask, dim):
"""
"""
self.dim = dim
if version.Version(torch.__version__) >= version.Version('1.2.0a'):
rmask = ~(mask.bool())
else:
rmask = (1-mask).byte() # This line is not supported by Onnx tracing.
output = input.masked_fill(rmask, float('-inf'))
output = torch.softmax(output, self.dim)
output.masked_fill_(rmask, 0)
self.save_for_backward(output)
return output
@staticmethod
def backward(self, grad_output):
"""
"""
output, = self.saved_tensors
if version.Version(torch.__version__) >= version.Version('1.11.0a'):
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output.dtype)
else:
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
return inputGrad, None, None
@staticmethod
def symbolic(g, self, mask, dim):
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_opset9 import masked_fill, softmax
mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx['Long'])
r_mask = g.op("Cast", g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), to_i=sym_help.cast_pytorch_to_onnx['Byte'])
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float('-inf'))))
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
class DropoutContext(object):
def __init__(self):
self.dropout = 0
self.mask = None
self.scale = 1
self.reuse_mask = True
def get_mask(input, local_context):
if not isinstance(local_context, DropoutContext):
dropout = local_context
mask = None
else:
dropout = local_context.dropout
dropout *= local_context.scale
mask = local_context.mask if local_context.reuse_mask else None
if dropout>0 and mask is None:
if version.Version(torch.__version__) >= version.Version('1.2.0a'):
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).bool()
else:
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).byte()
if isinstance(local_context, DropoutContext):
if local_context.mask is None:
local_context.mask = mask
return mask, dropout
@traceable
class XDropout(torch.autograd.Function):
@staticmethod
def forward(ctx, input, local_ctx):
mask, dropout = get_mask(input, local_ctx)
ctx.scale=1.0/(1-dropout)
if dropout>0:
ctx.save_for_backward(mask)
return input.masked_fill(mask, 0)*ctx.scale
else:
return input
@staticmethod
def backward(ctx, grad_output):
if ctx.scale > 1:
mask, = ctx.saved_tensors
return grad_output.masked_fill(mask, 0)*ctx.scale, None
else:
return grad_output, None
class StableDropout(torch.nn.Module):
""" Optimized dropout module for stabilizing the training
Args:
drop_prob (float): the dropout probabilities
"""
def __init__(self, drop_prob):
super().__init__()
self.drop_prob = drop_prob
self.count = 0
self.context_stack = None
def forward(self, x):
""" Call the module
Args:
x (:obj:`torch.tensor`): The input tensor to apply dropout
"""
if self.training and self.drop_prob>0:
return XDropout.apply(x, self.get_context())
return x
def clear_context(self):
self.count = 0
self.context_stack = None
def init_context(self, reuse_mask=True, scale = 1):
if self.context_stack is None:
self.context_stack = []
self.count = 0
for c in self.context_stack:
c.reuse_mask = reuse_mask
c.scale = scale
def get_context(self):
if self.context_stack is not None:
if self.count >= len(self.context_stack):
self.context_stack.append(DropoutContext())
ctx = self.context_stack[self.count]
ctx.dropout = self.drop_prob
self.count += 1
return ctx
else:
return self.drop_prob
def MaskedLayerNorm(layerNorm, input, mask = None):
""" Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm module.
Args:
layernorm (:obj:`~DeBERTa.deberta.LayerNorm`): LayerNorm module or function
input (:obj:`torch.tensor`): The input tensor
mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0`
Example::
# Create a tensor b x n x d
x = torch.randn([1,10,100])
m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int)
LayerNorm = DeBERTa.deberta.LayerNorm(100)
y = MaskedLayerNorm(LayerNorm, x, m)
"""
output = layerNorm(input).to(input)
if mask is None:
return output
if mask.dim()!=input.dim():
if mask.dim()==4:
mask=mask.squeeze(1).squeeze(1)
mask = mask.unsqueeze(2)
mask = mask.to(output.dtype)
return output*mask
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
def linear_act(x):
return x
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish, "tanh": torch.tanh, "linear": linear_act, 'sigmoid': torch.sigmoid}
|