Hardwired nsa attention
Browse files- modeling_chatglm.py +3 -1
modeling_chatglm.py
CHANGED
|
@@ -642,7 +642,7 @@ class SelfAttention(torch.nn.Module):
|
|
| 642 |
self.gate.weight.zero_()
|
| 643 |
self.gate.bias.fill_(-math.log(2)) # sigmoid ≈ 1/3
|
| 644 |
|
| 645 |
-
self.core_attention = CORE_ATTENTION_CLASSES[
|
| 646 |
|
| 647 |
# Output.
|
| 648 |
self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
|
|
@@ -734,6 +734,7 @@ class SelfAttention(torch.nn.Module):
|
|
| 734 |
# core attention computation
|
| 735 |
# ==================================
|
| 736 |
|
|
|
|
| 737 |
if self.attn_implementation != "nsa":
|
| 738 |
|
| 739 |
if self.multi_query_attention:
|
|
@@ -1007,6 +1008,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
| 1007 |
return
|
| 1008 |
|
| 1009 |
def get_masks(self, input_ids, past_key_values, padding_mask=None):
|
|
|
|
| 1010 |
if self.config.attn_implementation == "flash_attention_2" or self.config.attn_implementation == "nsa":
|
| 1011 |
if padding_mask is not None and not padding_mask.all():
|
| 1012 |
return padding_mask
|
|
|
|
| 642 |
self.gate.weight.zero_()
|
| 643 |
self.gate.bias.fill_(-math.log(2)) # sigmoid ≈ 1/3
|
| 644 |
|
| 645 |
+
self.core_attention = CORE_ATTENTION_CLASSES["nsa"](config, self.layer_number) #config.attn_implementation
|
| 646 |
|
| 647 |
# Output.
|
| 648 |
self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
|
|
|
|
| 734 |
# core attention computation
|
| 735 |
# ==================================
|
| 736 |
|
| 737 |
+
self.attn_implementation = "nsa"
|
| 738 |
if self.attn_implementation != "nsa":
|
| 739 |
|
| 740 |
if self.multi_query_attention:
|
|
|
|
| 1008 |
return
|
| 1009 |
|
| 1010 |
def get_masks(self, input_ids, past_key_values, padding_mask=None):
|
| 1011 |
+
self.config.attn_implementation = "nsa"
|
| 1012 |
if self.config.attn_implementation == "flash_attention_2" or self.config.attn_implementation == "nsa":
|
| 1013 |
if padding_mask is not None and not padding_mask.all():
|
| 1014 |
return padding_mask
|