simplified softmax (to allow torch.compile)
Browse files- modeling_norbert.py +4 -25
modeling_norbert.py
CHANGED
|
@@ -101,23 +101,6 @@ class FeedForward(nn.Module):
|
|
| 101 |
return self.mlp(x)
|
| 102 |
|
| 103 |
|
| 104 |
-
class MaskedSoftmax(torch.autograd.Function):
|
| 105 |
-
@staticmethod
|
| 106 |
-
def forward(self, x, mask, dim):
|
| 107 |
-
self.dim = dim
|
| 108 |
-
x.masked_fill_(mask, float('-inf'))
|
| 109 |
-
x = torch.softmax(x, self.dim)
|
| 110 |
-
x.masked_fill_(mask, 0.0)
|
| 111 |
-
self.save_for_backward(x)
|
| 112 |
-
return x
|
| 113 |
-
|
| 114 |
-
@staticmethod
|
| 115 |
-
def backward(self, grad_output):
|
| 116 |
-
output, = self.saved_tensors
|
| 117 |
-
input_grad = softmax_backward_data(self, grad_output, output, self.dim, output)
|
| 118 |
-
return input_grad, None, None
|
| 119 |
-
|
| 120 |
-
|
| 121 |
class Attention(nn.Module):
|
| 122 |
def __init__(self, config):
|
| 123 |
super().__init__()
|
|
@@ -155,7 +138,7 @@ class Attention(nn.Module):
|
|
| 155 |
bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
|
| 156 |
return bucket_pos
|
| 157 |
|
| 158 |
-
def
|
| 159 |
key_len, batch_size, _ = hidden_states.size()
|
| 160 |
query_len = key_len
|
| 161 |
|
|
@@ -193,21 +176,17 @@ class Attention(nn.Module):
|
|
| 193 |
attention_scores.add_(attention_c_p)
|
| 194 |
attention_scores.add_(attention_p_c)
|
| 195 |
|
| 196 |
-
|
|
|
|
| 197 |
|
| 198 |
-
def compute_output(self, attention_probs, value):
|
| 199 |
attention_probs = self.dropout(attention_probs)
|
| 200 |
context = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D]
|
| 201 |
context = context.transpose(0, 1).reshape(context.size(1), -1, self.hidden_size) # shape: [Q, B, H*D]
|
| 202 |
context = self.out_proj(context)
|
| 203 |
context = self.post_layer_norm(context)
|
| 204 |
context = self.dropout(context)
|
| 205 |
-
return context
|
| 206 |
|
| 207 |
-
|
| 208 |
-
attention_scores, value = self.compute_attention_scores(hidden_states, relative_embedding)
|
| 209 |
-
attention_probs = MaskedSoftmax.apply(attention_scores, attention_mask, -1)
|
| 210 |
-
return self.compute_output(attention_probs, value), attention_probs.detach()
|
| 211 |
|
| 212 |
|
| 213 |
class Embedding(nn.Module):
|
|
|
|
| 101 |
return self.mlp(x)
|
| 102 |
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
class Attention(nn.Module):
|
| 105 |
def __init__(self, config):
|
| 106 |
super().__init__()
|
|
|
|
| 138 |
bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
|
| 139 |
return bucket_pos
|
| 140 |
|
| 141 |
+
def forward(self, hidden_states, attention_mask, relative_embedding):
|
| 142 |
key_len, batch_size, _ = hidden_states.size()
|
| 143 |
query_len = key_len
|
| 144 |
|
|
|
|
| 176 |
attention_scores.add_(attention_c_p)
|
| 177 |
attention_scores.add_(attention_p_c)
|
| 178 |
|
| 179 |
+
attention_scores = attention_scores.masked_fill(attention_mask, float('-inf'))
|
| 180 |
+
attention_probs = F.softmax(attention_scores, dim=-1)
|
| 181 |
|
|
|
|
| 182 |
attention_probs = self.dropout(attention_probs)
|
| 183 |
context = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D]
|
| 184 |
context = context.transpose(0, 1).reshape(context.size(1), -1, self.hidden_size) # shape: [Q, B, H*D]
|
| 185 |
context = self.out_proj(context)
|
| 186 |
context = self.post_layer_norm(context)
|
| 187 |
context = self.dropout(context)
|
|
|
|
| 188 |
|
| 189 |
+
return context, attention_probs.detach()
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
|
| 192 |
class Embedding(nn.Module):
|