replace 1e4 mask
Browse files- README.md +1 -0
- modeling_lsg_roberta.py +11 -7
README.md
CHANGED
|
@@ -45,6 +45,7 @@ You can change various parameters like :
|
|
| 45 |
* local block size (block_size=128)
|
| 46 |
* sparse block size (sparse_block_size=128)
|
| 47 |
* sparsity factor (sparsity_factor=2)
|
|
|
|
| 48 |
* see config.json file
|
| 49 |
|
| 50 |
Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
|
|
|
|
| 45 |
* local block size (block_size=128)
|
| 46 |
* sparse block size (sparse_block_size=128)
|
| 47 |
* sparsity factor (sparsity_factor=2)
|
| 48 |
+
* mask_first_token (mask first token since it is redundant with the first global token)
|
| 49 |
* see config.json file
|
| 50 |
|
| 51 |
Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
|
modeling_lsg_roberta.py
CHANGED
|
@@ -182,7 +182,11 @@ class CausalAttentionProduct(nn.Module):
|
|
| 182 |
|
| 183 |
# Add causal mask
|
| 184 |
causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
|
| 185 |
-
causal_mask = torch.tril(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
|
| 187 |
|
| 188 |
del attention_mask
|
|
@@ -300,7 +304,7 @@ class LSGAttentionProduct(nn.Module):
|
|
| 300 |
|
| 301 |
# Pad before block reshaping
|
| 302 |
if is_attn_mask:
|
| 303 |
-
pad_value =
|
| 304 |
hidden_states = hidden_states.transpose(-1, -2)
|
| 305 |
else:
|
| 306 |
pad_value = 0
|
|
@@ -333,7 +337,7 @@ class LSGAttentionProduct(nn.Module):
|
|
| 333 |
|
| 334 |
# Pad before block reshaping
|
| 335 |
if is_attn_mask:
|
| 336 |
-
pad_value =
|
| 337 |
hidden_states = hidden_states.transpose(-1, -2)
|
| 338 |
else:
|
| 339 |
pad_value = 0
|
|
@@ -557,7 +561,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
| 557 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
| 558 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
| 559 |
|
| 560 |
-
mask =
|
| 561 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
| 562 |
|
| 563 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
@@ -622,7 +626,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
| 622 |
keys /= mask + 1e-8
|
| 623 |
values /= mask + 1e-8
|
| 624 |
|
| 625 |
-
mask =
|
| 626 |
|
| 627 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
| 628 |
|
|
@@ -988,7 +992,7 @@ class LSGRobertaModel(LSGRobertaPreTrainedModel, RobertaModel):
|
|
| 988 |
n, t = inputs_.size()[:2]
|
| 989 |
|
| 990 |
if attention_mask is None:
|
| 991 |
-
attention_mask = torch.ones(n, t, device=inputs_.device)
|
| 992 |
if self.mask_first_token:
|
| 993 |
attention_mask[:,0] = 0
|
| 994 |
|
|
@@ -1069,7 +1073,7 @@ class LSGRobertaModel(LSGRobertaPreTrainedModel, RobertaModel):
|
|
| 1069 |
)
|
| 1070 |
|
| 1071 |
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 1072 |
-
extended_attention_mask = (1.0 - extended_attention_mask) *
|
| 1073 |
|
| 1074 |
return extended_attention_mask
|
| 1075 |
|
|
|
|
| 182 |
|
| 183 |
# Add causal mask
|
| 184 |
causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
|
| 185 |
+
causal_mask = torch.tril(
|
| 186 |
+
torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
|
| 187 |
+
diagonal=-1
|
| 188 |
+
)
|
| 189 |
+
causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
|
| 190 |
attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
|
| 191 |
|
| 192 |
del attention_mask
|
|
|
|
| 304 |
|
| 305 |
# Pad before block reshaping
|
| 306 |
if is_attn_mask:
|
| 307 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
| 308 |
hidden_states = hidden_states.transpose(-1, -2)
|
| 309 |
else:
|
| 310 |
pad_value = 0
|
|
|
|
| 337 |
|
| 338 |
# Pad before block reshaping
|
| 339 |
if is_attn_mask:
|
| 340 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
| 341 |
hidden_states = hidden_states.transpose(-1, -2)
|
| 342 |
else:
|
| 343 |
pad_value = 0
|
|
|
|
| 561 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
| 562 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
| 563 |
|
| 564 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
| 565 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
| 566 |
|
| 567 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
|
|
| 626 |
keys /= mask + 1e-8
|
| 627 |
values /= mask + 1e-8
|
| 628 |
|
| 629 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
| 630 |
|
| 631 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
| 632 |
|
|
|
|
| 992 |
n, t = inputs_.size()[:2]
|
| 993 |
|
| 994 |
if attention_mask is None:
|
| 995 |
+
attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
|
| 996 |
if self.mask_first_token:
|
| 997 |
attention_mask[:,0] = 0
|
| 998 |
|
|
|
|
| 1073 |
)
|
| 1074 |
|
| 1075 |
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 1076 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(extended_attention_mask.dtype).min
|
| 1077 |
|
| 1078 |
return extended_attention_mask
|
| 1079 |
|