Pavel Rykov commited on
Commit ·
a25e87a
1
Parent(s): 9d4a84b
SDPA added
Browse files- configuration_rugpt3xl.py +2 -0
- modeling_rugpt3xl.py +30 -9
configuration_rugpt3xl.py
CHANGED
|
@@ -40,6 +40,7 @@ class RuGPT3XLConfig(PretrainedConfig):
|
|
| 40 |
sparse_num_local_blocks=8,
|
| 41 |
sparse_num_global_blocks=1,
|
| 42 |
sparse_num_different_global_patterns=8,
|
|
|
|
| 43 |
**kwargs,
|
| 44 |
):
|
| 45 |
self.vocab_size = vocab_size
|
|
@@ -60,6 +61,7 @@ class RuGPT3XLConfig(PretrainedConfig):
|
|
| 60 |
self.sparse_num_local_blocks = sparse_num_local_blocks
|
| 61 |
self.sparse_num_global_blocks = sparse_num_global_blocks
|
| 62 |
self.sparse_num_different_global_patterns = sparse_num_different_global_patterns
|
|
|
|
| 63 |
|
| 64 |
super().__init__(
|
| 65 |
bos_token_id=bos_token_id,
|
|
|
|
| 40 |
sparse_num_local_blocks=8,
|
| 41 |
sparse_num_global_blocks=1,
|
| 42 |
sparse_num_different_global_patterns=8,
|
| 43 |
+
attn_implementation="sdpa",
|
| 44 |
**kwargs,
|
| 45 |
):
|
| 46 |
self.vocab_size = vocab_size
|
|
|
|
| 61 |
self.sparse_num_local_blocks = sparse_num_local_blocks
|
| 62 |
self.sparse_num_global_blocks = sparse_num_global_blocks
|
| 63 |
self.sparse_num_different_global_patterns = sparse_num_different_global_patterns
|
| 64 |
+
self.attn_implementation = attn_implementation
|
| 65 |
|
| 66 |
super().__init__(
|
| 67 |
bos_token_id=bos_token_id,
|
modeling_rugpt3xl.py
CHANGED
|
@@ -2,7 +2,8 @@
|
|
| 2 |
|
| 3 |
GPT-3-style decoder-only transformer (1.3B) trained on Russian text.
|
| 4 |
Architecture: absolute position embeddings, pre-norm layers, GELU activation,
|
| 5 |
-
tied LM head.
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import math
|
|
@@ -107,17 +108,37 @@ class RuGPT3XLAttention(nn.Module):
|
|
| 107 |
if past_key_value is not None:
|
| 108 |
key, value = past_key_value.update(key, value, self.layer_idx)
|
| 109 |
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
attn_output = torch.matmul(attn_weights, value)
|
| 121 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 122 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 123 |
|
|
|
|
| 2 |
|
| 3 |
GPT-3-style decoder-only transformer (1.3B) trained on Russian text.
|
| 4 |
Architecture: absolute position embeddings, pre-norm layers, GELU activation,
|
| 5 |
+
tied LM head. Attention: config.attn_implementation "sdpa" uses
|
| 6 |
+
scaled_dot_product_attention (Flash/Memory-efficient/Triton backends on CUDA).
|
| 7 |
"""
|
| 8 |
|
| 9 |
import math
|
|
|
|
| 108 |
if past_key_value is not None:
|
| 109 |
key, value = past_key_value.update(key, value, self.layer_idx)
|
| 110 |
|
| 111 |
+
attn_impl = getattr(self.config, "attn_implementation", "sdpa")
|
| 112 |
+
use_sdpa = attn_impl == "sdpa" and not output_attentions
|
| 113 |
+
|
| 114 |
+
if use_sdpa:
|
| 115 |
+
dropout_p = self.attn_dropout.p if self.training else 0.0
|
| 116 |
+
sdpa_mask = attention_mask
|
| 117 |
+
if sdpa_mask is not None:
|
| 118 |
+
sdpa_mask = sdpa_mask.to(dtype=query.dtype)
|
| 119 |
+
attn_output = F.scaled_dot_product_attention(
|
| 120 |
+
query,
|
| 121 |
+
key,
|
| 122 |
+
value,
|
| 123 |
+
attn_mask=sdpa_mask,
|
| 124 |
+
dropout_p=dropout_p,
|
| 125 |
+
is_causal=False,
|
| 126 |
+
)
|
| 127 |
+
attn_weights = None
|
| 128 |
+
else:
|
| 129 |
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * self.scale
|
| 130 |
|
| 131 |
+
if attention_mask is not None:
|
| 132 |
+
attn_weights = attn_weights + attention_mask
|
| 133 |
|
| 134 |
+
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
| 135 |
+
query.dtype
|
| 136 |
+
)
|
| 137 |
+
attn_weights = self.attn_dropout(attn_weights)
|
| 138 |
+
|
| 139 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 140 |
+
attn_weights = attn_weights if output_attentions else None
|
| 141 |
|
|
|
|
| 142 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 143 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 144 |
|