Maxtimer97 commited on
Commit
bb18c9e
·
verified ·
1 Parent(s): ce81c56

Hardwired nsa attention

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +3 -1
modeling_chatglm.py CHANGED
@@ -642,7 +642,7 @@ class SelfAttention(torch.nn.Module):
642
  self.gate.weight.zero_()
643
  self.gate.bias.fill_(-math.log(2)) # sigmoid ≈ 1/3
644
 
645
- self.core_attention = CORE_ATTENTION_CLASSES[config.attn_implementation](config, self.layer_number)
646
 
647
  # Output.
648
  self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
@@ -734,6 +734,7 @@ class SelfAttention(torch.nn.Module):
734
  # core attention computation
735
  # ==================================
736
 
 
737
  if self.attn_implementation != "nsa":
738
 
739
  if self.multi_query_attention:
@@ -1007,6 +1008,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
1007
  return
1008
 
1009
  def get_masks(self, input_ids, past_key_values, padding_mask=None):
 
1010
  if self.config.attn_implementation == "flash_attention_2" or self.config.attn_implementation == "nsa":
1011
  if padding_mask is not None and not padding_mask.all():
1012
  return padding_mask
 
642
  self.gate.weight.zero_()
643
  self.gate.bias.fill_(-math.log(2)) # sigmoid ≈ 1/3
644
 
645
+ self.core_attention = CORE_ATTENTION_CLASSES["nsa"](config, self.layer_number) #config.attn_implementation
646
 
647
  # Output.
648
  self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
 
734
  # core attention computation
735
  # ==================================
736
 
737
+ self.attn_implementation = "nsa"
738
  if self.attn_implementation != "nsa":
739
 
740
  if self.multi_query_attention:
 
1008
  return
1009
 
1010
  def get_masks(self, input_ids, past_key_values, padding_mask=None):
1011
+ self.config.attn_implementation = "nsa"
1012
  if self.config.attn_implementation == "flash_attention_2" or self.config.attn_implementation == "nsa":
1013
  if padding_mask is not None and not padding_mask.all():
1014
  return padding_mask