Pavel Rykov commited on
Commit
a942c12
·
1 Parent(s): 511ef2e

Sparse attention fixed

Browse files
Files changed (5) hide show
  1. README.md +47 -0
  2. config.json +5 -0
  3. configuration_rugpt3xl.py +10 -0
  4. modeling_rugpt3xl.py +133 -4
  5. tokenizer.json +0 -0
README.md CHANGED
@@ -38,6 +38,7 @@ Details in "[A family of pretrained transformer language models for Russian](htt
38
  | Activation | GELU |
39
  | Normalization | Pre-LayerNorm |
40
  | Position encoding | Learned absolute |
 
41
  | Precision | float16 |
42
  | Training data | 80B tokens of Russian text (4 epochs) |
43
  | Test perplexity | 12.05 |
@@ -225,6 +226,46 @@ trainer.train()
225
 
226
  **LoRA target modules:** `q_proj`, `k_proj`, `v_proj`, `o_proj`, `up_proj`, `down_proj`
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  ## Architecture Details
229
 
230
  The model implements a custom `RuGPT3XLForCausalLM` class (loaded via `trust_remote_code=True`):
@@ -254,6 +295,10 @@ RuGPT3XLForCausalLM
254
  └── lm_head (Linear: 2048 -> 50264, no bias)
255
  ```
256
 
 
 
 
 
257
  ## Conversion
258
 
259
  This model was converted from the original Megatron-LM checkpoint using a custom script.
@@ -291,5 +336,7 @@ For full conversion details and the script, see the
291
  ## Links
292
 
293
  - [A family of pretrained transformer language models for Russian](https://scholar.google.com/citations?view_op=view_citation&hl=en&user=yPayeJIAAAAJ&citation_for_view=yPayeJIAAAAJ:Se3iqnhoufwC) - paper on Google Scholar
 
294
  - [ai-forever/rugpt3xl](https://huggingface.co/ai-forever/rugpt3xl) - original model
295
  - [ai-forever/ru-gpts](https://github.com/ai-forever/ru-gpts) - original training codebase
 
 
38
  | Activation | GELU |
39
  | Normalization | Pre-LayerNorm |
40
  | Position encoding | Learned absolute |
41
+ | Attention | Alternating sparse/dense (see [Sparse Attention](#sparse-attention)) |
42
  | Precision | float16 |
43
  | Training data | 80B tokens of Russian text (4 epochs) |
44
  | Test perplexity | 12.05 |
 
226
 
227
  **LoRA target modules:** `q_proj`, `k_proj`, `v_proj`, `o_proj`, `up_proj`, `down_proj`
228
 
229
+ ## Sparse Attention
230
+
231
+ This model was originally trained with DeepSpeed's
232
+ [SparseSelfAttention](https://www.deepspeed.ai/tutorials/sparse-attention/) using the
233
+ **alternating** pattern: even layers (0, 2, 4, ...) use block-sparse attention, odd layers
234
+ (1, 3, 5, ...) use standard dense causal attention. The sparse layers use a
235
+ `FixedSparsityConfig` derived from the "Generating Long Sequences with Sparse Transformers"
236
+ paper (Child et al., 2019).
237
+
238
+ This is a **critical** architectural detail. The model weights were optimized for this specific
239
+ attention pattern during training. Running the model with all-dense attention degrades
240
+ perplexity from ~12 to ~50.
241
+
242
+ The converted model **fully replicates** this sparse attention pattern without any DeepSpeed
243
+ dependency, using a precomputed block-sparse mask applied to standard dense attention.
244
+
245
+ | Attention mode | Test PPL (Gazeta) |
246
+ |---|---|
247
+ | Sparse alternating (original training regime) | **11.68** |
248
+ | All-dense (no sparse mask) | ~50.1 |
249
+
250
+ ### Sparse attention parameters
251
+
252
+ The sparse pattern is controlled by `config.json` fields:
253
+
254
+ | Parameter | Value | Description |
255
+ |---|---|---|
256
+ | `sparse_mode` | `"alternating"` | Even layers sparse, odd layers dense |
257
+ | `sparse_block_size` | `16` | Token block size for sparse layout |
258
+ | `sparse_num_local_blocks` | `8` | Local attention window (8 blocks = 128 tokens) |
259
+ | `sparse_num_global_blocks` | `1` | Global blocks per window |
260
+ | `sparse_num_different_global_patterns` | `8` | Different heads use different global positions |
261
+
262
+ Each sparse layer applies a per-head block-sparse mask. Within each window of 128 tokens,
263
+ attention is causal (lower-triangular). Across windows, only designated "global" blocks are
264
+ visible, with each attention head using a different global block position within the window.
265
+
266
+ To disable sparse attention (e.g. for experiments), set `sparse_mode` to `"none"` in
267
+ `config.json`. This will make all layers use standard dense causal attention.
268
+
269
  ## Architecture Details
270
 
271
  The model implements a custom `RuGPT3XLForCausalLM` class (loaded via `trust_remote_code=True`):
 
295
  └── lm_head (Linear: 2048 -> 50264, no bias)
296
  ```
297
 
298
+ Even-numbered decoder layers (0, 2, 4, ...) apply block-sparse attention masks. Odd-numbered
299
+ layers use full causal attention. The sparse layout is precomputed at model initialization from
300
+ the config parameters and stored as a non-persistent buffer.
301
+
302
  ## Conversion
303
 
304
  This model was converted from the original Megatron-LM checkpoint using a custom script.
 
336
  ## Links
337
 
338
  - [A family of pretrained transformer language models for Russian](https://scholar.google.com/citations?view_op=view_citation&hl=en&user=yPayeJIAAAAJ&citation_for_view=yPayeJIAAAAJ:Se3iqnhoufwC) - paper on Google Scholar
339
+ - [Generating Long Sequences with Sparse Transformers](https://arxiv.org/abs/1904.10509) - sparse attention paper (Child et al., 2019)
340
  - [ai-forever/rugpt3xl](https://huggingface.co/ai-forever/rugpt3xl) - original model
341
  - [ai-forever/ru-gpts](https://github.com/ai-forever/ru-gpts) - original training codebase
342
+ - [DeepSpeed Sparse Attention](https://www.deepspeed.ai/tutorials/sparse-attention/) - original sparse attention implementation
config.json CHANGED
@@ -25,6 +25,11 @@
25
  "eos_token_id": 1,
26
  "pad_token_id": 0,
27
  "tie_word_embeddings": false,
 
 
 
 
 
28
  "torch_dtype": "float16",
29
  "transformers_version": "5.3.0"
30
  }
 
25
  "eos_token_id": 1,
26
  "pad_token_id": 0,
27
  "tie_word_embeddings": false,
28
+ "sparse_mode": "alternating",
29
+ "sparse_block_size": 16,
30
+ "sparse_num_local_blocks": 8,
31
+ "sparse_num_global_blocks": 1,
32
+ "sparse_num_different_global_patterns": 8,
33
  "torch_dtype": "float16",
34
  "transformers_version": "5.3.0"
35
  }
configuration_rugpt3xl.py CHANGED
@@ -35,6 +35,11 @@ class RuGPT3XLConfig(PretrainedConfig):
35
  eos_token_id=1,
36
  pad_token_id=0,
37
  tie_word_embeddings=False,
 
 
 
 
 
38
  **kwargs,
39
  ):
40
  self.vocab_size = vocab_size
@@ -50,6 +55,11 @@ class RuGPT3XLConfig(PretrainedConfig):
50
  self.attention_dropout = attention_dropout
51
  self.output_dropout = output_dropout
52
  self.use_cache = use_cache
 
 
 
 
 
53
 
54
  super().__init__(
55
  bos_token_id=bos_token_id,
 
35
  eos_token_id=1,
36
  pad_token_id=0,
37
  tie_word_embeddings=False,
38
+ sparse_mode="none",
39
+ sparse_block_size=16,
40
+ sparse_num_local_blocks=8,
41
+ sparse_num_global_blocks=1,
42
+ sparse_num_different_global_patterns=8,
43
  **kwargs,
44
  ):
45
  self.vocab_size = vocab_size
 
55
  self.attention_dropout = attention_dropout
56
  self.output_dropout = output_dropout
57
  self.use_cache = use_cache
58
+ self.sparse_mode = sparse_mode
59
+ self.sparse_block_size = sparse_block_size
60
+ self.sparse_num_local_blocks = sparse_num_local_blocks
61
+ self.sparse_num_global_blocks = sparse_num_global_blocks
62
+ self.sparse_num_different_global_patterns = sparse_num_different_global_patterns
63
 
64
  super().__init__(
65
  bos_token_id=bos_token_id,
modeling_rugpt3xl.py CHANGED
@@ -27,6 +27,47 @@ from .configuration_rugpt3xl import RuGPT3XLConfig
27
  logger = logging.get_logger(__name__)
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class RuGPT3XLAttention(nn.Module):
31
  def __init__(self, config: RuGPT3XLConfig, layer_idx: int):
32
  super().__init__()
@@ -201,6 +242,27 @@ class RuGPT3XLModel(RuGPT3XLPreTrainedModel):
201
  )
202
  self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  self.gradient_checkpointing = False
205
  self.post_init()
206
 
@@ -278,7 +340,7 @@ class RuGPT3XLModel(RuGPT3XLPreTrainedModel):
278
  position_embeds = self.embed_positions(position_ids)
279
  hidden_states = self.embed_dropout(inputs_embeds + position_embeds)
280
 
281
- # Build causal 4D attention mask
282
  causal_mask = self._build_causal_mask(
283
  batch_size,
284
  seq_length,
@@ -288,19 +350,38 @@ class RuGPT3XLModel(RuGPT3XLPreTrainedModel):
288
  attention_mask,
289
  )
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  all_hidden_states = () if output_hidden_states else None
292
  all_self_attns = () if output_attentions else None
293
  next_decoder_cache = None
294
 
295
- for decoder_layer in self.layers:
296
  if output_hidden_states:
297
  all_hidden_states += (hidden_states,)
298
 
 
 
 
 
 
 
299
  if self.gradient_checkpointing and self.training:
300
  layer_outputs = self._gradient_checkpointing_func(
301
  decoder_layer.__call__,
302
  hidden_states,
303
- causal_mask,
304
  position_ids,
305
  past_key_values,
306
  output_attentions,
@@ -309,7 +390,7 @@ class RuGPT3XLModel(RuGPT3XLPreTrainedModel):
309
  else:
310
  layer_outputs = decoder_layer(
311
  hidden_states,
312
- attention_mask=causal_mask,
313
  position_ids=position_ids,
314
  past_key_value=past_key_values,
315
  output_attentions=output_attentions,
@@ -372,6 +453,54 @@ class RuGPT3XLModel(RuGPT3XLPreTrainedModel):
372
 
373
  return causal
374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
  class RuGPT3XLForCausalLM(RuGPT3XLPreTrainedModel):
377
  _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
 
27
  logger = logging.get_logger(__name__)
28
 
29
 
30
+ def build_fixed_sparse_layout(
31
+ num_heads: int,
32
+ num_blocks: int,
33
+ num_local_blocks: int,
34
+ num_global_blocks: int,
35
+ num_different_global_patterns: int,
36
+ ) -> torch.Tensor:
37
+ """Replicate DeepSpeed FixedSparsityConfig.make_layout() for unidirectional attention.
38
+
39
+ Returns a boolean tensor of shape [num_heads, num_blocks, num_blocks] where
40
+ True means the query block can attend to the key block.
41
+ """
42
+ layout = torch.zeros((num_heads, num_blocks, num_blocks), dtype=torch.bool)
43
+
44
+ # Local attention within fixed-size windows (identical for all heads)
45
+ for window_start in range(0, num_blocks, num_local_blocks):
46
+ window_end = min(window_start + num_local_blocks, num_blocks)
47
+ window_size = window_end - window_start
48
+ layout[:, window_start:window_end, window_start:window_end] = torch.tril(
49
+ torch.ones(window_size, window_size, dtype=torch.bool)
50
+ )
51
+
52
+ # Global attention (per-head: different heads use different global block positions)
53
+ for h in range(num_heads):
54
+ first_global = num_local_blocks - (
55
+ 1 + h % num_different_global_patterns
56
+ ) * num_global_blocks
57
+ regular_end = num_blocks - (num_blocks % num_local_blocks)
58
+
59
+ for gi in range(first_global, regular_end, num_local_blocks):
60
+ layout[h, gi:, gi : gi + num_global_blocks] = True
61
+
62
+ if regular_end < num_blocks:
63
+ start = min(
64
+ regular_end + first_global, num_blocks - num_global_blocks
65
+ )
66
+ layout[h, start:, start : start + num_global_blocks] = True
67
+
68
+ return layout
69
+
70
+
71
  class RuGPT3XLAttention(nn.Module):
72
  def __init__(self, config: RuGPT3XLConfig, layer_idx: int):
73
  super().__init__()
 
242
  )
243
  self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
244
 
245
+ self._sparse_layers: set = set()
246
+ if getattr(config, "sparse_mode", "none") == "alternating":
247
+ self._sparse_layers = {
248
+ i for i in range(config.num_hidden_layers) if i % 2 == 0
249
+ }
250
+ elif getattr(config, "sparse_mode", "none") == "all":
251
+ self._sparse_layers = set(range(config.num_hidden_layers))
252
+
253
+ if self._sparse_layers:
254
+ num_blocks = config.max_position_embeddings // config.sparse_block_size
255
+ layout = build_fixed_sparse_layout(
256
+ num_heads=config.num_attention_heads,
257
+ num_blocks=num_blocks,
258
+ num_local_blocks=config.sparse_num_local_blocks,
259
+ num_global_blocks=config.sparse_num_global_blocks,
260
+ num_different_global_patterns=config.sparse_num_different_global_patterns,
261
+ )
262
+ self.register_buffer("_sparse_layout", layout, persistent=False)
263
+ else:
264
+ self._sparse_layout = None
265
+
266
  self.gradient_checkpointing = False
267
  self.post_init()
268
 
 
340
  position_embeds = self.embed_positions(position_ids)
341
  hidden_states = self.embed_dropout(inputs_embeds + position_embeds)
342
 
343
+ # Build causal 4D attention mask (dense layers)
344
  causal_mask = self._build_causal_mask(
345
  batch_size,
346
  seq_length,
 
350
  attention_mask,
351
  )
352
 
353
+ # Build sparse causal mask (sparse layers)
354
+ sparse_mask = None
355
+ if self._sparse_layout is not None and self._sparse_layers:
356
+ sparse_mask = self._build_sparse_causal_mask(
357
+ seq_length,
358
+ past_key_values_length,
359
+ hidden_states.dtype,
360
+ hidden_states.device,
361
+ self._sparse_layout,
362
+ self.config.sparse_block_size,
363
+ attention_mask,
364
+ )
365
+
366
  all_hidden_states = () if output_hidden_states else None
367
  all_self_attns = () if output_attentions else None
368
  next_decoder_cache = None
369
 
370
+ for layer_idx, decoder_layer in enumerate(self.layers):
371
  if output_hidden_states:
372
  all_hidden_states += (hidden_states,)
373
 
374
+ layer_mask = (
375
+ sparse_mask
376
+ if (layer_idx in self._sparse_layers and sparse_mask is not None)
377
+ else causal_mask
378
+ )
379
+
380
  if self.gradient_checkpointing and self.training:
381
  layer_outputs = self._gradient_checkpointing_func(
382
  decoder_layer.__call__,
383
  hidden_states,
384
+ layer_mask,
385
  position_ids,
386
  past_key_values,
387
  output_attentions,
 
390
  else:
391
  layer_outputs = decoder_layer(
392
  hidden_states,
393
+ attention_mask=layer_mask,
394
  position_ids=position_ids,
395
  past_key_value=past_key_values,
396
  output_attentions=output_attentions,
 
453
 
454
  return causal
455
 
456
+ @staticmethod
457
+ def _build_sparse_causal_mask(
458
+ seq_length: int,
459
+ past_length: int,
460
+ dtype: torch.dtype,
461
+ device: torch.device,
462
+ sparse_layout: torch.Tensor,
463
+ block_size: int,
464
+ attention_mask: Optional[torch.Tensor] = None,
465
+ ) -> torch.Tensor:
466
+ """Build block-sparse causal mask from precomputed layout.
467
+
468
+ Returns additive mask of shape [1, num_heads, seq_length, total_length].
469
+ """
470
+ total_length = past_length + seq_length
471
+ num_blocks = sparse_layout.shape[1]
472
+
473
+ q_block = (
474
+ torch.arange(past_length, past_length + seq_length, device=device)
475
+ // block_size
476
+ ).clamp(max=num_blocks - 1)
477
+ k_block = (
478
+ torch.arange(total_length, device=device) // block_size
479
+ ).clamp(max=num_blocks - 1)
480
+
481
+ layout_dev = sparse_layout.to(device)
482
+ block_ok = layout_dev[:, q_block][:, :, k_block]
483
+
484
+ q_pos = torch.arange(
485
+ past_length, past_length + seq_length, device=device
486
+ ).unsqueeze(1)
487
+ k_pos = torch.arange(total_length, device=device).unsqueeze(0)
488
+ causal_ok = k_pos <= q_pos
489
+
490
+ allowed = block_ok & causal_ok.unsqueeze(0)
491
+
492
+ min_val = torch.finfo(dtype).min
493
+ mask = torch.where(allowed, 0.0, min_val).to(dtype)
494
+ mask = mask.unsqueeze(0)
495
+
496
+ if attention_mask is not None:
497
+ pad_mask = (
498
+ (1 - attention_mask[:, None, None, :].to(dtype)) * min_val
499
+ )
500
+ mask = mask + pad_mask
501
+
502
+ return mask
503
+
504
 
505
  class RuGPT3XLForCausalLM(RuGPT3XLPreTrainedModel):
506
  _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff