Pavel Rykov commited on
Commit
a25e87a
·
1 Parent(s): 9d4a84b

SDPA added

Browse files
Files changed (2) hide show
  1. configuration_rugpt3xl.py +2 -0
  2. modeling_rugpt3xl.py +30 -9
configuration_rugpt3xl.py CHANGED
@@ -40,6 +40,7 @@ class RuGPT3XLConfig(PretrainedConfig):
40
  sparse_num_local_blocks=8,
41
  sparse_num_global_blocks=1,
42
  sparse_num_different_global_patterns=8,
 
43
  **kwargs,
44
  ):
45
  self.vocab_size = vocab_size
@@ -60,6 +61,7 @@ class RuGPT3XLConfig(PretrainedConfig):
60
  self.sparse_num_local_blocks = sparse_num_local_blocks
61
  self.sparse_num_global_blocks = sparse_num_global_blocks
62
  self.sparse_num_different_global_patterns = sparse_num_different_global_patterns
 
63
 
64
  super().__init__(
65
  bos_token_id=bos_token_id,
 
40
  sparse_num_local_blocks=8,
41
  sparse_num_global_blocks=1,
42
  sparse_num_different_global_patterns=8,
43
+ attn_implementation="sdpa",
44
  **kwargs,
45
  ):
46
  self.vocab_size = vocab_size
 
61
  self.sparse_num_local_blocks = sparse_num_local_blocks
62
  self.sparse_num_global_blocks = sparse_num_global_blocks
63
  self.sparse_num_different_global_patterns = sparse_num_different_global_patterns
64
+ self.attn_implementation = attn_implementation
65
 
66
  super().__init__(
67
  bos_token_id=bos_token_id,
modeling_rugpt3xl.py CHANGED
@@ -2,7 +2,8 @@
2
 
3
  GPT-3-style decoder-only transformer (1.3B) trained on Russian text.
4
  Architecture: absolute position embeddings, pre-norm layers, GELU activation,
5
- tied LM head.
 
6
  """
7
 
8
  import math
@@ -107,17 +108,37 @@ class RuGPT3XLAttention(nn.Module):
107
  if past_key_value is not None:
108
  key, value = past_key_value.update(key, value, self.layer_idx)
109
 
110
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * self.scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- if attention_mask is not None:
113
- attn_weights = attn_weights + attention_mask
114
 
115
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
116
- query.dtype
117
- )
118
- attn_weights = self.attn_dropout(attn_weights)
 
 
 
119
 
120
- attn_output = torch.matmul(attn_weights, value)
121
  attn_output = attn_output.transpose(1, 2).contiguous()
122
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
123
 
 
2
 
3
  GPT-3-style decoder-only transformer (1.3B) trained on Russian text.
4
  Architecture: absolute position embeddings, pre-norm layers, GELU activation,
5
+ tied LM head. Attention: config.attn_implementation "sdpa" uses
6
+ scaled_dot_product_attention (Flash/Memory-efficient/Triton backends on CUDA).
7
  """
8
 
9
  import math
 
108
  if past_key_value is not None:
109
  key, value = past_key_value.update(key, value, self.layer_idx)
110
 
111
+ attn_impl = getattr(self.config, "attn_implementation", "sdpa")
112
+ use_sdpa = attn_impl == "sdpa" and not output_attentions
113
+
114
+ if use_sdpa:
115
+ dropout_p = self.attn_dropout.p if self.training else 0.0
116
+ sdpa_mask = attention_mask
117
+ if sdpa_mask is not None:
118
+ sdpa_mask = sdpa_mask.to(dtype=query.dtype)
119
+ attn_output = F.scaled_dot_product_attention(
120
+ query,
121
+ key,
122
+ value,
123
+ attn_mask=sdpa_mask,
124
+ dropout_p=dropout_p,
125
+ is_causal=False,
126
+ )
127
+ attn_weights = None
128
+ else:
129
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * self.scale
130
 
131
+ if attention_mask is not None:
132
+ attn_weights = attn_weights + attention_mask
133
 
134
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
135
+ query.dtype
136
+ )
137
+ attn_weights = self.attn_dropout(attn_weights)
138
+
139
+ attn_output = torch.matmul(attn_weights, value)
140
+ attn_weights = attn_weights if output_attentions else None
141
 
 
142
  attn_output = attn_output.transpose(1, 2).contiguous()
143
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
144