Pavel Rykov commited on
Commit ·
a942c12
1
Parent(s): 511ef2e
Sparse attention fixed
Browse files- README.md +47 -0
- config.json +5 -0
- configuration_rugpt3xl.py +10 -0
- modeling_rugpt3xl.py +133 -4
- tokenizer.json +0 -0
README.md
CHANGED
|
@@ -38,6 +38,7 @@ Details in "[A family of pretrained transformer language models for Russian](htt
|
|
| 38 |
| Activation | GELU |
|
| 39 |
| Normalization | Pre-LayerNorm |
|
| 40 |
| Position encoding | Learned absolute |
|
|
|
|
| 41 |
| Precision | float16 |
|
| 42 |
| Training data | 80B tokens of Russian text (4 epochs) |
|
| 43 |
| Test perplexity | 12.05 |
|
|
@@ -225,6 +226,46 @@ trainer.train()
|
|
| 225 |
|
| 226 |
**LoRA target modules:** `q_proj`, `k_proj`, `v_proj`, `o_proj`, `up_proj`, `down_proj`
|
| 227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
## Architecture Details
|
| 229 |
|
| 230 |
The model implements a custom `RuGPT3XLForCausalLM` class (loaded via `trust_remote_code=True`):
|
|
@@ -254,6 +295,10 @@ RuGPT3XLForCausalLM
|
|
| 254 |
└── lm_head (Linear: 2048 -> 50264, no bias)
|
| 255 |
```
|
| 256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
## Conversion
|
| 258 |
|
| 259 |
This model was converted from the original Megatron-LM checkpoint using a custom script.
|
|
@@ -291,5 +336,7 @@ For full conversion details and the script, see the
|
|
| 291 |
## Links
|
| 292 |
|
| 293 |
- [A family of pretrained transformer language models for Russian](https://scholar.google.com/citations?view_op=view_citation&hl=en&user=yPayeJIAAAAJ&citation_for_view=yPayeJIAAAAJ:Se3iqnhoufwC) - paper on Google Scholar
|
|
|
|
| 294 |
- [ai-forever/rugpt3xl](https://huggingface.co/ai-forever/rugpt3xl) - original model
|
| 295 |
- [ai-forever/ru-gpts](https://github.com/ai-forever/ru-gpts) - original training codebase
|
|
|
|
|
|
| 38 |
| Activation | GELU |
|
| 39 |
| Normalization | Pre-LayerNorm |
|
| 40 |
| Position encoding | Learned absolute |
|
| 41 |
+
| Attention | Alternating sparse/dense (see [Sparse Attention](#sparse-attention)) |
|
| 42 |
| Precision | float16 |
|
| 43 |
| Training data | 80B tokens of Russian text (4 epochs) |
|
| 44 |
| Test perplexity | 12.05 |
|
|
|
|
| 226 |
|
| 227 |
**LoRA target modules:** `q_proj`, `k_proj`, `v_proj`, `o_proj`, `up_proj`, `down_proj`
|
| 228 |
|
| 229 |
+
## Sparse Attention
|
| 230 |
+
|
| 231 |
+
This model was originally trained with DeepSpeed's
|
| 232 |
+
[SparseSelfAttention](https://www.deepspeed.ai/tutorials/sparse-attention/) using the
|
| 233 |
+
**alternating** pattern: even layers (0, 2, 4, ...) use block-sparse attention, odd layers
|
| 234 |
+
(1, 3, 5, ...) use standard dense causal attention. The sparse layers use a
|
| 235 |
+
`FixedSparsityConfig` derived from the "Generating Long Sequences with Sparse Transformers"
|
| 236 |
+
paper (Child et al., 2019).
|
| 237 |
+
|
| 238 |
+
This is a **critical** architectural detail. The model weights were optimized for this specific
|
| 239 |
+
attention pattern during training. Running the model with all-dense attention degrades
|
| 240 |
+
perplexity from ~12 to ~50.
|
| 241 |
+
|
| 242 |
+
The converted model **fully replicates** this sparse attention pattern without any DeepSpeed
|
| 243 |
+
dependency, using a precomputed block-sparse mask applied to standard dense attention.
|
| 244 |
+
|
| 245 |
+
| Attention mode | Test PPL (Gazeta) |
|
| 246 |
+
|---|---|
|
| 247 |
+
| Sparse alternating (original training regime) | **11.68** |
|
| 248 |
+
| All-dense (no sparse mask) | ~50.1 |
|
| 249 |
+
|
| 250 |
+
### Sparse attention parameters
|
| 251 |
+
|
| 252 |
+
The sparse pattern is controlled by `config.json` fields:
|
| 253 |
+
|
| 254 |
+
| Parameter | Value | Description |
|
| 255 |
+
|---|---|---|
|
| 256 |
+
| `sparse_mode` | `"alternating"` | Even layers sparse, odd layers dense |
|
| 257 |
+
| `sparse_block_size` | `16` | Token block size for sparse layout |
|
| 258 |
+
| `sparse_num_local_blocks` | `8` | Local attention window (8 blocks = 128 tokens) |
|
| 259 |
+
| `sparse_num_global_blocks` | `1` | Global blocks per window |
|
| 260 |
+
| `sparse_num_different_global_patterns` | `8` | Different heads use different global positions |
|
| 261 |
+
|
| 262 |
+
Each sparse layer applies a per-head block-sparse mask. Within each window of 128 tokens,
|
| 263 |
+
attention is causal (lower-triangular). Across windows, only designated "global" blocks are
|
| 264 |
+
visible, with each attention head using a different global block position within the window.
|
| 265 |
+
|
| 266 |
+
To disable sparse attention (e.g. for experiments), set `sparse_mode` to `"none"` in
|
| 267 |
+
`config.json`. This will make all layers use standard dense causal attention.
|
| 268 |
+
|
| 269 |
## Architecture Details
|
| 270 |
|
| 271 |
The model implements a custom `RuGPT3XLForCausalLM` class (loaded via `trust_remote_code=True`):
|
|
|
|
| 295 |
└── lm_head (Linear: 2048 -> 50264, no bias)
|
| 296 |
```
|
| 297 |
|
| 298 |
+
Even-numbered decoder layers (0, 2, 4, ...) apply block-sparse attention masks. Odd-numbered
|
| 299 |
+
layers use full causal attention. The sparse layout is precomputed at model initialization from
|
| 300 |
+
the config parameters and stored as a non-persistent buffer.
|
| 301 |
+
|
| 302 |
## Conversion
|
| 303 |
|
| 304 |
This model was converted from the original Megatron-LM checkpoint using a custom script.
|
|
|
|
| 336 |
## Links
|
| 337 |
|
| 338 |
- [A family of pretrained transformer language models for Russian](https://scholar.google.com/citations?view_op=view_citation&hl=en&user=yPayeJIAAAAJ&citation_for_view=yPayeJIAAAAJ:Se3iqnhoufwC) - paper on Google Scholar
|
| 339 |
+
- [Generating Long Sequences with Sparse Transformers](https://arxiv.org/abs/1904.10509) - sparse attention paper (Child et al., 2019)
|
| 340 |
- [ai-forever/rugpt3xl](https://huggingface.co/ai-forever/rugpt3xl) - original model
|
| 341 |
- [ai-forever/ru-gpts](https://github.com/ai-forever/ru-gpts) - original training codebase
|
| 342 |
+
- [DeepSpeed Sparse Attention](https://www.deepspeed.ai/tutorials/sparse-attention/) - original sparse attention implementation
|
config.json
CHANGED
|
@@ -25,6 +25,11 @@
|
|
| 25 |
"eos_token_id": 1,
|
| 26 |
"pad_token_id": 0,
|
| 27 |
"tie_word_embeddings": false,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
"torch_dtype": "float16",
|
| 29 |
"transformers_version": "5.3.0"
|
| 30 |
}
|
|
|
|
| 25 |
"eos_token_id": 1,
|
| 26 |
"pad_token_id": 0,
|
| 27 |
"tie_word_embeddings": false,
|
| 28 |
+
"sparse_mode": "alternating",
|
| 29 |
+
"sparse_block_size": 16,
|
| 30 |
+
"sparse_num_local_blocks": 8,
|
| 31 |
+
"sparse_num_global_blocks": 1,
|
| 32 |
+
"sparse_num_different_global_patterns": 8,
|
| 33 |
"torch_dtype": "float16",
|
| 34 |
"transformers_version": "5.3.0"
|
| 35 |
}
|
configuration_rugpt3xl.py
CHANGED
|
@@ -35,6 +35,11 @@ class RuGPT3XLConfig(PretrainedConfig):
|
|
| 35 |
eos_token_id=1,
|
| 36 |
pad_token_id=0,
|
| 37 |
tie_word_embeddings=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
**kwargs,
|
| 39 |
):
|
| 40 |
self.vocab_size = vocab_size
|
|
@@ -50,6 +55,11 @@ class RuGPT3XLConfig(PretrainedConfig):
|
|
| 50 |
self.attention_dropout = attention_dropout
|
| 51 |
self.output_dropout = output_dropout
|
| 52 |
self.use_cache = use_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
super().__init__(
|
| 55 |
bos_token_id=bos_token_id,
|
|
|
|
| 35 |
eos_token_id=1,
|
| 36 |
pad_token_id=0,
|
| 37 |
tie_word_embeddings=False,
|
| 38 |
+
sparse_mode="none",
|
| 39 |
+
sparse_block_size=16,
|
| 40 |
+
sparse_num_local_blocks=8,
|
| 41 |
+
sparse_num_global_blocks=1,
|
| 42 |
+
sparse_num_different_global_patterns=8,
|
| 43 |
**kwargs,
|
| 44 |
):
|
| 45 |
self.vocab_size = vocab_size
|
|
|
|
| 55 |
self.attention_dropout = attention_dropout
|
| 56 |
self.output_dropout = output_dropout
|
| 57 |
self.use_cache = use_cache
|
| 58 |
+
self.sparse_mode = sparse_mode
|
| 59 |
+
self.sparse_block_size = sparse_block_size
|
| 60 |
+
self.sparse_num_local_blocks = sparse_num_local_blocks
|
| 61 |
+
self.sparse_num_global_blocks = sparse_num_global_blocks
|
| 62 |
+
self.sparse_num_different_global_patterns = sparse_num_different_global_patterns
|
| 63 |
|
| 64 |
super().__init__(
|
| 65 |
bos_token_id=bos_token_id,
|
modeling_rugpt3xl.py
CHANGED
|
@@ -27,6 +27,47 @@ from .configuration_rugpt3xl import RuGPT3XLConfig
|
|
| 27 |
logger = logging.get_logger(__name__)
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
class RuGPT3XLAttention(nn.Module):
|
| 31 |
def __init__(self, config: RuGPT3XLConfig, layer_idx: int):
|
| 32 |
super().__init__()
|
|
@@ -201,6 +242,27 @@ class RuGPT3XLModel(RuGPT3XLPreTrainedModel):
|
|
| 201 |
)
|
| 202 |
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
self.gradient_checkpointing = False
|
| 205 |
self.post_init()
|
| 206 |
|
|
@@ -278,7 +340,7 @@ class RuGPT3XLModel(RuGPT3XLPreTrainedModel):
|
|
| 278 |
position_embeds = self.embed_positions(position_ids)
|
| 279 |
hidden_states = self.embed_dropout(inputs_embeds + position_embeds)
|
| 280 |
|
| 281 |
-
# Build causal 4D attention mask
|
| 282 |
causal_mask = self._build_causal_mask(
|
| 283 |
batch_size,
|
| 284 |
seq_length,
|
|
@@ -288,19 +350,38 @@ class RuGPT3XLModel(RuGPT3XLPreTrainedModel):
|
|
| 288 |
attention_mask,
|
| 289 |
)
|
| 290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
all_hidden_states = () if output_hidden_states else None
|
| 292 |
all_self_attns = () if output_attentions else None
|
| 293 |
next_decoder_cache = None
|
| 294 |
|
| 295 |
-
for decoder_layer in self.layers:
|
| 296 |
if output_hidden_states:
|
| 297 |
all_hidden_states += (hidden_states,)
|
| 298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
if self.gradient_checkpointing and self.training:
|
| 300 |
layer_outputs = self._gradient_checkpointing_func(
|
| 301 |
decoder_layer.__call__,
|
| 302 |
hidden_states,
|
| 303 |
-
|
| 304 |
position_ids,
|
| 305 |
past_key_values,
|
| 306 |
output_attentions,
|
|
@@ -309,7 +390,7 @@ class RuGPT3XLModel(RuGPT3XLPreTrainedModel):
|
|
| 309 |
else:
|
| 310 |
layer_outputs = decoder_layer(
|
| 311 |
hidden_states,
|
| 312 |
-
attention_mask=
|
| 313 |
position_ids=position_ids,
|
| 314 |
past_key_value=past_key_values,
|
| 315 |
output_attentions=output_attentions,
|
|
@@ -372,6 +453,54 @@ class RuGPT3XLModel(RuGPT3XLPreTrainedModel):
|
|
| 372 |
|
| 373 |
return causal
|
| 374 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
|
| 376 |
class RuGPT3XLForCausalLM(RuGPT3XLPreTrainedModel):
|
| 377 |
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
|
|
|
| 27 |
logger = logging.get_logger(__name__)
|
| 28 |
|
| 29 |
|
| 30 |
+
def build_fixed_sparse_layout(
|
| 31 |
+
num_heads: int,
|
| 32 |
+
num_blocks: int,
|
| 33 |
+
num_local_blocks: int,
|
| 34 |
+
num_global_blocks: int,
|
| 35 |
+
num_different_global_patterns: int,
|
| 36 |
+
) -> torch.Tensor:
|
| 37 |
+
"""Replicate DeepSpeed FixedSparsityConfig.make_layout() for unidirectional attention.
|
| 38 |
+
|
| 39 |
+
Returns a boolean tensor of shape [num_heads, num_blocks, num_blocks] where
|
| 40 |
+
True means the query block can attend to the key block.
|
| 41 |
+
"""
|
| 42 |
+
layout = torch.zeros((num_heads, num_blocks, num_blocks), dtype=torch.bool)
|
| 43 |
+
|
| 44 |
+
# Local attention within fixed-size windows (identical for all heads)
|
| 45 |
+
for window_start in range(0, num_blocks, num_local_blocks):
|
| 46 |
+
window_end = min(window_start + num_local_blocks, num_blocks)
|
| 47 |
+
window_size = window_end - window_start
|
| 48 |
+
layout[:, window_start:window_end, window_start:window_end] = torch.tril(
|
| 49 |
+
torch.ones(window_size, window_size, dtype=torch.bool)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Global attention (per-head: different heads use different global block positions)
|
| 53 |
+
for h in range(num_heads):
|
| 54 |
+
first_global = num_local_blocks - (
|
| 55 |
+
1 + h % num_different_global_patterns
|
| 56 |
+
) * num_global_blocks
|
| 57 |
+
regular_end = num_blocks - (num_blocks % num_local_blocks)
|
| 58 |
+
|
| 59 |
+
for gi in range(first_global, regular_end, num_local_blocks):
|
| 60 |
+
layout[h, gi:, gi : gi + num_global_blocks] = True
|
| 61 |
+
|
| 62 |
+
if regular_end < num_blocks:
|
| 63 |
+
start = min(
|
| 64 |
+
regular_end + first_global, num_blocks - num_global_blocks
|
| 65 |
+
)
|
| 66 |
+
layout[h, start:, start : start + num_global_blocks] = True
|
| 67 |
+
|
| 68 |
+
return layout
|
| 69 |
+
|
| 70 |
+
|
| 71 |
class RuGPT3XLAttention(nn.Module):
|
| 72 |
def __init__(self, config: RuGPT3XLConfig, layer_idx: int):
|
| 73 |
super().__init__()
|
|
|
|
| 242 |
)
|
| 243 |
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 244 |
|
| 245 |
+
self._sparse_layers: set = set()
|
| 246 |
+
if getattr(config, "sparse_mode", "none") == "alternating":
|
| 247 |
+
self._sparse_layers = {
|
| 248 |
+
i for i in range(config.num_hidden_layers) if i % 2 == 0
|
| 249 |
+
}
|
| 250 |
+
elif getattr(config, "sparse_mode", "none") == "all":
|
| 251 |
+
self._sparse_layers = set(range(config.num_hidden_layers))
|
| 252 |
+
|
| 253 |
+
if self._sparse_layers:
|
| 254 |
+
num_blocks = config.max_position_embeddings // config.sparse_block_size
|
| 255 |
+
layout = build_fixed_sparse_layout(
|
| 256 |
+
num_heads=config.num_attention_heads,
|
| 257 |
+
num_blocks=num_blocks,
|
| 258 |
+
num_local_blocks=config.sparse_num_local_blocks,
|
| 259 |
+
num_global_blocks=config.sparse_num_global_blocks,
|
| 260 |
+
num_different_global_patterns=config.sparse_num_different_global_patterns,
|
| 261 |
+
)
|
| 262 |
+
self.register_buffer("_sparse_layout", layout, persistent=False)
|
| 263 |
+
else:
|
| 264 |
+
self._sparse_layout = None
|
| 265 |
+
|
| 266 |
self.gradient_checkpointing = False
|
| 267 |
self.post_init()
|
| 268 |
|
|
|
|
| 340 |
position_embeds = self.embed_positions(position_ids)
|
| 341 |
hidden_states = self.embed_dropout(inputs_embeds + position_embeds)
|
| 342 |
|
| 343 |
+
# Build causal 4D attention mask (dense layers)
|
| 344 |
causal_mask = self._build_causal_mask(
|
| 345 |
batch_size,
|
| 346 |
seq_length,
|
|
|
|
| 350 |
attention_mask,
|
| 351 |
)
|
| 352 |
|
| 353 |
+
# Build sparse causal mask (sparse layers)
|
| 354 |
+
sparse_mask = None
|
| 355 |
+
if self._sparse_layout is not None and self._sparse_layers:
|
| 356 |
+
sparse_mask = self._build_sparse_causal_mask(
|
| 357 |
+
seq_length,
|
| 358 |
+
past_key_values_length,
|
| 359 |
+
hidden_states.dtype,
|
| 360 |
+
hidden_states.device,
|
| 361 |
+
self._sparse_layout,
|
| 362 |
+
self.config.sparse_block_size,
|
| 363 |
+
attention_mask,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
all_hidden_states = () if output_hidden_states else None
|
| 367 |
all_self_attns = () if output_attentions else None
|
| 368 |
next_decoder_cache = None
|
| 369 |
|
| 370 |
+
for layer_idx, decoder_layer in enumerate(self.layers):
|
| 371 |
if output_hidden_states:
|
| 372 |
all_hidden_states += (hidden_states,)
|
| 373 |
|
| 374 |
+
layer_mask = (
|
| 375 |
+
sparse_mask
|
| 376 |
+
if (layer_idx in self._sparse_layers and sparse_mask is not None)
|
| 377 |
+
else causal_mask
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
if self.gradient_checkpointing and self.training:
|
| 381 |
layer_outputs = self._gradient_checkpointing_func(
|
| 382 |
decoder_layer.__call__,
|
| 383 |
hidden_states,
|
| 384 |
+
layer_mask,
|
| 385 |
position_ids,
|
| 386 |
past_key_values,
|
| 387 |
output_attentions,
|
|
|
|
| 390 |
else:
|
| 391 |
layer_outputs = decoder_layer(
|
| 392 |
hidden_states,
|
| 393 |
+
attention_mask=layer_mask,
|
| 394 |
position_ids=position_ids,
|
| 395 |
past_key_value=past_key_values,
|
| 396 |
output_attentions=output_attentions,
|
|
|
|
| 453 |
|
| 454 |
return causal
|
| 455 |
|
| 456 |
+
@staticmethod
|
| 457 |
+
def _build_sparse_causal_mask(
|
| 458 |
+
seq_length: int,
|
| 459 |
+
past_length: int,
|
| 460 |
+
dtype: torch.dtype,
|
| 461 |
+
device: torch.device,
|
| 462 |
+
sparse_layout: torch.Tensor,
|
| 463 |
+
block_size: int,
|
| 464 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 465 |
+
) -> torch.Tensor:
|
| 466 |
+
"""Build block-sparse causal mask from precomputed layout.
|
| 467 |
+
|
| 468 |
+
Returns additive mask of shape [1, num_heads, seq_length, total_length].
|
| 469 |
+
"""
|
| 470 |
+
total_length = past_length + seq_length
|
| 471 |
+
num_blocks = sparse_layout.shape[1]
|
| 472 |
+
|
| 473 |
+
q_block = (
|
| 474 |
+
torch.arange(past_length, past_length + seq_length, device=device)
|
| 475 |
+
// block_size
|
| 476 |
+
).clamp(max=num_blocks - 1)
|
| 477 |
+
k_block = (
|
| 478 |
+
torch.arange(total_length, device=device) // block_size
|
| 479 |
+
).clamp(max=num_blocks - 1)
|
| 480 |
+
|
| 481 |
+
layout_dev = sparse_layout.to(device)
|
| 482 |
+
block_ok = layout_dev[:, q_block][:, :, k_block]
|
| 483 |
+
|
| 484 |
+
q_pos = torch.arange(
|
| 485 |
+
past_length, past_length + seq_length, device=device
|
| 486 |
+
).unsqueeze(1)
|
| 487 |
+
k_pos = torch.arange(total_length, device=device).unsqueeze(0)
|
| 488 |
+
causal_ok = k_pos <= q_pos
|
| 489 |
+
|
| 490 |
+
allowed = block_ok & causal_ok.unsqueeze(0)
|
| 491 |
+
|
| 492 |
+
min_val = torch.finfo(dtype).min
|
| 493 |
+
mask = torch.where(allowed, 0.0, min_val).to(dtype)
|
| 494 |
+
mask = mask.unsqueeze(0)
|
| 495 |
+
|
| 496 |
+
if attention_mask is not None:
|
| 497 |
+
pad_mask = (
|
| 498 |
+
(1 - attention_mask[:, None, None, :].to(dtype)) * min_val
|
| 499 |
+
)
|
| 500 |
+
mask = mask + pad_mask
|
| 501 |
+
|
| 502 |
+
return mask
|
| 503 |
+
|
| 504 |
|
| 505 |
class RuGPT3XLForCausalLM(RuGPT3XLPreTrainedModel):
|
| 506 |
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|