kilianhaefeli commited on
Commit
e5351ca
·
1 Parent(s): c277c56
Files changed (2) hide show
  1. modeling.py +34 -15
  2. modeling_f.py +949 -0
modeling.py CHANGED
@@ -479,11 +479,15 @@ class Fast_dLLM_QwenModel(Fast_dLLM_QwenPreTrainedModel):
479
  block_start_position, block_start_position + inputs_embeds.shape[1], device=inputs_embeds.device
480
  )
481
  else:
 
482
  cache_position = torch.arange(
483
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] if not self.training else inputs_embeds.shape[1]//2, device=inputs_embeds.device
484
  )
485
 
486
  # --- keep the user/tokenizer padding mask BEFORE you overwrite attention_mask ---
 
 
 
487
  padding_mask_2d = attention_mask # shape [B, KV_LEN], 1=token, 0=pad
488
 
489
  # -------------------------
@@ -492,17 +496,21 @@ class Fast_dLLM_QwenModel(Fast_dLLM_QwenPreTrainedModel):
492
  if position_ids is None:
493
  if (padding_mask_2d is not None) and (not self.training):
494
  # full, per-sample positions over KV_LEN
 
495
  pos_full = padding_mask_2d.long().cumsum(-1) - 1 # pads => -1
496
  pos_full = pos_full.clamp_min(0) # pads => 0
497
 
 
498
  q_len = inputs_embeds.shape[1]
499
  kv_len = pos_full.shape[1]
 
500
  if kv_len < q_len:
501
  raise ValueError(f"attention_mask KV_LEN={kv_len} < input_len={q_len}. "
502
  "When using cache, pass the FULL mask (past+current).")
503
 
504
- q_start = kv_len - q_len # assumes current tokens are the last q_len positions
505
- position_ids = pos_full[:, q_start:]
 
506
  else:
507
  # no padding mask: same positions for all batch elements
508
  position_ids = cache_position.unsqueeze(0)
@@ -527,19 +535,23 @@ class Fast_dLLM_QwenModel(Fast_dLLM_QwenPreTrainedModel):
527
  attention_mask = structural[None, None, :, :] # [1,1,Q,KV]
528
  else:
529
  pad = padding_mask_2d.to(torch.bool) # [B, KV]
530
- B, kv_len = pad.shape
531
- q_len = inputs_embeds.shape[1]
532
- q_start = kv_len - q_len
533
 
534
  # Per-sample block ids computed from *non-pad* positions
535
- pos_full = pad.long().cumsum(-1) - 1
536
- pos_full = pos_full.clamp_min(0)
537
- block_full = pos_full // block_size # [B, KV]
 
538
 
539
- block_q = block_full[:, q_start:] # [B, Q]
540
- block_k = block_full # [B, KV]
 
 
 
541
 
542
- structural = block_q.unsqueeze(-1) >= block_k.unsqueeze(-2) # [B, Q, KV]
543
 
544
  # Mask keys AND queries (only valid tokens participate)
545
  key_ok = pad[:, None, None, :] # [B,1,1,KV]
@@ -630,7 +642,7 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
630
  mask_id: Optional[int] = 151665,
631
  **kwargs
632
  ) -> CausalLMOutputWithPastAndBlockCache:
633
-
634
  if self.training:
635
  original_labels = labels.clone()
636
  original_input_ids = input_ids.clone()
@@ -727,11 +739,13 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
727
  assert attention_mask is not None, "attention_mask must be provided for this generate() implementation."
728
 
729
  # pad the initial input_ids and attention_mask to be multiple of block_size
730
- if input_ids.shape[1] % block_size != 0:
731
  pad_len = block_size - (input_ids.shape[1] % block_size)
732
  input_ids = torch.cat([torch.full((input_ids.shape[0], pad_len), self.config.pad_token_id, device=input_ids.device), input_ids], dim=1)
733
  attention_mask = torch.cat([torch.zeros((attention_mask.shape[0], pad_len), device=attention_mask.device), attention_mask], dim=1)
734
 
 
 
735
  num_blocks = max_new_tokens // block_size
736
  device = input_ids.device
737
  batch_size = input_ids.size(0)
@@ -747,7 +761,9 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
747
 
748
  # Handle prefix processing (Context Encoding)
749
  if input_ids.shape[1] >= block_size:
 
750
  output = self.forward(input_ids=input_ids[:, :(input_ids.shape[1] // block_size * block_size)], attention_mask=attention_mask[:, :(input_ids.shape[1] // block_size * block_size)], use_cache=True, update_past_key_values=True, block_size=block_size)
 
751
  logits, past_key_values = output.logits, output.past_key_values
752
  if input_ids.shape[1] % block_size == 0:
753
  next_token = logits[:, -1:, :].argmax(dim=-1)
@@ -780,13 +796,16 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
780
  prompt_length = input_ids.shape[1]
781
 
782
  # Initialize x_init with mask_id with all mask tokens for the new block
783
- x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long)
784
 
785
  # Concatenate input_ids with x_init to form the new input_ids (we added a block-1 of masks to our current generation)
786
  x_init = torch.cat([input_ids, x_init], dim=1)
787
 
788
  # mask extension is extending the current mask by the number of new tokens we are generating in this block by adding ones.
 
789
  mask_extension = unfinished_sequences.unsqueeze(1).repeat(1, block_size - prompt_length % block_size).to(dtype=attention_mask.dtype)
 
 
790
  # mask is the current attention mask extended by the new tokens we are generating in this block by adding ones.
791
  curr_attention_mask = torch.cat([attention_mask, mask_extension], dim=1)
792
 
@@ -795,7 +814,7 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
795
  while True:
796
  # mask_idx indicates where the mask tokens are in the current block
797
  mask_idx = (x_t[:, -block_size:] == mask_id)
798
- # TODOL assert that first element is always not a mask
799
 
800
  if mask_idx.sum() == 0:
801
  # If no mask tokens left in the current block, then we generate the next token autoregressively
 
479
  block_start_position, block_start_position + inputs_embeds.shape[1], device=inputs_embeds.device
480
  )
481
  else:
482
+ # from past_seen_tokens to past_seen_tokens + current_input_length (for us this is always the last blocks + the current block)
483
  cache_position = torch.arange(
484
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] if not self.training else inputs_embeds.shape[1]//2, device=inputs_embeds.device
485
  )
486
 
487
  # --- keep the user/tokenizer padding mask BEFORE you overwrite attention_mask ---
488
+ # kv mask contains however many tokens are in the
489
+
490
+ # kv len is always the previous processed blocks.
491
  padding_mask_2d = attention_mask # shape [B, KV_LEN], 1=token, 0=pad
492
 
493
  # -------------------------
 
496
  if position_ids is None:
497
  if (padding_mask_2d is not None) and (not self.training):
498
  # full, per-sample positions over KV_LEN
499
+ # first real token gets 0 and then araneg up all masks
500
  pos_full = padding_mask_2d.long().cumsum(-1) - 1 # pads => -1
501
  pos_full = pos_full.clamp_min(0) # pads => 0
502
 
503
+
504
  q_len = inputs_embeds.shape[1]
505
  kv_len = pos_full.shape[1]
506
+
507
  if kv_len < q_len:
508
  raise ValueError(f"attention_mask KV_LEN={kv_len} < input_len={q_len}. "
509
  "When using cache, pass the FULL mask (past+current).")
510
 
511
+ # position ids are the arange but only taking the last block of values!
512
+ q_start = kv_len - q_len # assumes current tokens are the last q_len positions (assuming q length is one block which it always is).
513
+ position_ids = pos_full[:, q_start:] # TODO assert same as just taking last block
514
  else:
515
  # no padding mask: same positions for all batch elements
516
  position_ids = cache_position.unsqueeze(0)
 
535
  attention_mask = structural[None, None, :, :] # [1,1,Q,KV]
536
  else:
537
  pad = padding_mask_2d.to(torch.bool) # [B, KV]
538
+ B, kv_len = pad.shape # kv len is here the length of the mask so kvlen + 32
539
+ q_len = inputs_embeds.shape[1] # inputs_embeds = 32 TODO check
540
+ q_start = kv_len - q_len
541
 
542
  # Per-sample block ids computed from *non-pad* positions
543
+ # TODO fix!
544
+ # pos_full = pad.long().cumsum(-1) - 1 # again basically arange on the
545
+ # pos_full = pos_full.clamp_min(0)
546
+ # block_full = pos_full // block_size # [B, KV] # this makes it so that block transitions might be in wrong place! thus it will attend wrongly!
547
 
548
+ pos_full = torch.arange(0, kv_len, device=inputs_embeds.device)[None, ...]
549
+ block_full = pos_full // block_size # 0,0...,0,1...1,2...2,...
550
+
551
+ block_q = block_full[:, q_start:] # [B, Q] # get the latest block () TODO check this is of all the same values!!!
552
+ block_k = block_full # [B, KV] # everything we attent to!
553
 
554
+ structural = block_q.unsqueeze(-1) >= block_k.unsqueeze(-2) # [B, Q, KV] # on if the block is in a larger one than the other one or equal so on for same block and for all old blocks.!
555
 
556
  # Mask keys AND queries (only valid tokens participate)
557
  key_ok = pad[:, None, None, :] # [B,1,1,KV]
 
642
  mask_id: Optional[int] = 151665,
643
  **kwargs
644
  ) -> CausalLMOutputWithPastAndBlockCache:
645
+
646
  if self.training:
647
  original_labels = labels.clone()
648
  original_input_ids = input_ids.clone()
 
739
  assert attention_mask is not None, "attention_mask must be provided for this generate() implementation."
740
 
741
  # pad the initial input_ids and attention_mask to be multiple of block_size
742
+ if False: # input_ids.shape[1] % block_size != 0:
743
  pad_len = block_size - (input_ids.shape[1] % block_size)
744
  input_ids = torch.cat([torch.full((input_ids.shape[0], pad_len), self.config.pad_token_id, device=input_ids.device), input_ids], dim=1)
745
  attention_mask = torch.cat([torch.zeros((attention_mask.shape[0], pad_len), device=attention_mask.device), attention_mask], dim=1)
746
 
747
+ # attention_mask length is same as padded prompts!
748
+
749
  num_blocks = max_new_tokens // block_size
750
  device = input_ids.device
751
  batch_size = input_ids.size(0)
 
761
 
762
  # Handle prefix processing (Context Encoding)
763
  if input_ids.shape[1] >= block_size:
764
+ # pass in the entire context apart from the overlapping tokens and caches them.
765
  output = self.forward(input_ids=input_ids[:, :(input_ids.shape[1] // block_size * block_size)], attention_mask=attention_mask[:, :(input_ids.shape[1] // block_size * block_size)], use_cache=True, update_past_key_values=True, block_size=block_size)
766
+ # if we passed all of them then we need to extend by one prediction.
767
  logits, past_key_values = output.logits, output.past_key_values
768
  if input_ids.shape[1] % block_size == 0:
769
  next_token = logits[:, -1:, :].argmax(dim=-1)
 
796
  prompt_length = input_ids.shape[1]
797
 
798
  # Initialize x_init with mask_id with all mask tokens for the new block
799
+ x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long) # padd by however mnay needed to become multiple of 32
800
 
801
  # Concatenate input_ids with x_init to form the new input_ids (we added a block-1 of masks to our current generation)
802
  x_init = torch.cat([input_ids, x_init], dim=1)
803
 
804
  # mask extension is extending the current mask by the number of new tokens we are generating in this block by adding ones.
805
+ # mask is now of length of all tokens including the padded masks
806
  mask_extension = unfinished_sequences.unsqueeze(1).repeat(1, block_size - prompt_length % block_size).to(dtype=attention_mask.dtype)
807
+
808
+
809
  # mask is the current attention mask extended by the new tokens we are generating in this block by adding ones.
810
  curr_attention_mask = torch.cat([attention_mask, mask_extension], dim=1)
811
 
 
814
  while True:
815
  # mask_idx indicates where the mask tokens are in the current block
816
  mask_idx = (x_t[:, -block_size:] == mask_id)
817
+ # TODO: assert that first element is always not a mask
818
 
819
  if mask_idx.sum() == 0:
820
  # If no mask tokens left in the current block, then we generate the next token autoregressively
modeling_f.py ADDED
@@ -0,0 +1,949 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Union
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ from functools import partial
8
+
9
+ from transformers.activations import ACT2FN
10
+ from transformers.cache_utils import Cache, DynamicCache
11
+ from transformers.generation import GenerationMixin
12
+ from transformers.integrations import use_kernel_forward_from_hub
13
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
14
+ from transformers.modeling_layers import GradientCheckpointingLayer
15
+ from transformers.modeling_outputs import (
16
+ BaseModelOutputWithPast,
17
+ CausalLMOutputWithPast,
18
+ )
19
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
20
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
21
+ from transformers.processing_utils import Unpack
22
+ from transformers.utils import auto_docstring, can_return_tuple, logging
23
+ from .configuration import Fast_dLLM_QwenConfig
24
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
25
+ from einops import rearrange, repeat
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ @dataclass
31
+ class CausalLMOutputWithPastAndBlockCache(CausalLMOutputWithPast):
32
+ block_past_key_values: Optional[Cache] = None
33
+
34
+ @dataclass
35
+ class BaseModelOutputWithPastAndBlockCache(BaseModelOutputWithPast):
36
+ block_past_key_values: Optional[Cache] = None
37
+
38
+
39
+ # @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
40
+ def fused_flex_attention(q, k, v, mask=None):
41
+ return flex_attention(q, k, v, block_mask=mask, enable_gqa=True)
42
+
43
+ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
44
+ """
45
+ Constructs the specialized block diffusion attention mask for training
46
+ composed of three masks:
47
+ - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
48
+ - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
49
+ - **Block Causal Mask (M_BC)**: Attention to update x0
50
+
51
+ Args:
52
+ b, h: Batch and head indices (ignored for mask logic).
53
+ q_idx, kv_idx: Query and Key indices.
54
+ seq_len: Total sequence length.
55
+ block_size: Defines the block structure.
56
+
57
+ Returns:
58
+ A boolean attention mask.
59
+ """
60
+ # Indicate whether token belongs to xt or x0
61
+ x0_flag_q = (q_idx >= n)
62
+ x0_flag_kv = (kv_idx >= n)
63
+
64
+ # Compute block indices
65
+ block_q = torch.where(x0_flag_q == 1,
66
+ (q_idx - n) // block_size,
67
+ q_idx // block_size)
68
+ block_kv = torch.where(x0_flag_kv == 1,
69
+ (kv_idx - n) // block_size,
70
+ kv_idx // block_size)
71
+
72
+ # **1. Block Diagonal Mask (M_BD) **
73
+ block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
74
+
75
+ # **2. Offset Block-Causal Mask (M_OBC) **
76
+ offset_block_causal = (
77
+ (block_q > block_kv)
78
+ & (x0_flag_kv == 1)
79
+ & (x0_flag_q == 0)
80
+ )
81
+
82
+ # **3. Block-Causal Mask (M_BC) **
83
+ block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
84
+
85
+ # **4. Combine Masks **
86
+ return block_diagonal | offset_block_causal | block_causal
87
+
88
+ def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
89
+ # Compute block indices
90
+ block_q = q_idx // block_size
91
+ block_kv = kv_idx // block_size
92
+
93
+ return block_q >= block_kv
94
+
95
+ class Fast_dLLM_QwenMLP(nn.Module):
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ self.config = config
99
+ self.hidden_size = config.hidden_size
100
+ self.intermediate_size = config.intermediate_size
101
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
102
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
103
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
104
+ self.act_fn = ACT2FN[config.hidden_act]
105
+
106
+ def forward(self, x):
107
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
108
+ return down_proj
109
+
110
+
111
+ def rotate_half(x):
112
+ """Rotates half the hidden dims of the input."""
113
+ x1 = x[..., : x.shape[-1] // 2]
114
+ x2 = x[..., x.shape[-1] // 2 :]
115
+ return torch.cat((-x2, x1), dim=-1)
116
+
117
+
118
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
119
+ """Applies Rotary Position Embedding to the query and key tensors.
120
+
121
+ Args:
122
+ q (`torch.Tensor`): The query tensor.
123
+ k (`torch.Tensor`): The key tensor.
124
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
125
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
126
+ position_ids (`torch.Tensor`, *optional*):
127
+ Deprecated and unused.
128
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
129
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
130
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
131
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
132
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
133
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
134
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
135
+ Returns:
136
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
137
+ """
138
+ cos = cos.unsqueeze(unsqueeze_dim)
139
+ sin = sin.unsqueeze(unsqueeze_dim)
140
+ q_embed = (q * cos) + (rotate_half(q) * sin)
141
+ k_embed = (k * cos) + (rotate_half(k) * sin)
142
+ return q_embed, k_embed
143
+
144
+
145
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
146
+ """
147
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
148
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
149
+ """
150
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
151
+ if n_rep == 1:
152
+ return hidden_states
153
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
154
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
155
+
156
+
157
+ class Fast_dLLM_QwenAttention(nn.Module):
158
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
159
+
160
+ def __init__(self, config: Fast_dLLM_QwenConfig, layer_idx: int):
161
+ super().__init__()
162
+ self.config = config
163
+ self.layer_idx = layer_idx
164
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
165
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
166
+ self.scaling = self.head_dim**-0.5
167
+ self.attention_dropout = config.attention_dropout
168
+ self.is_causal = True
169
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
170
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
171
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
172
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
173
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
174
+
175
+ def forward(
176
+ self,
177
+ hidden_states: torch.Tensor,
178
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
179
+ attention_mask: Optional[torch.Tensor],
180
+ past_key_value: Optional[Cache] = None,
181
+ cache_position: Optional[torch.LongTensor] = None,
182
+ update_past_key_values: Optional[bool] = False,
183
+ block_past_key_values: Optional[Cache] = None,
184
+ replace_position: Optional[int] = None,
185
+ **kwargs: Unpack[FlashAttentionKwargs],
186
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
187
+ input_shape = hidden_states.shape[:-1]
188
+ hidden_shape = (*input_shape, -1, self.head_dim)
189
+
190
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
191
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
192
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
193
+
194
+ cos, sin = position_embeddings
195
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
196
+ if self.training:
197
+ #split q into two parts
198
+ q_1 = query_states[:,:,:query_states.shape[2]//2]
199
+ q_2 = query_states[:,:,query_states.shape[2]//2:]
200
+ #split k into two parts
201
+ k_1 = key_states[:,:,:key_states.shape[2]//2]
202
+ k_2 = key_states[:,:,key_states.shape[2]//2:]
203
+ q_1, k_1 = apply_rotary_pos_emb(q_1, k_1, cos, sin)
204
+ q_2, k_2 = apply_rotary_pos_emb(q_2, k_2, cos, sin)
205
+ query_states = torch.cat((q_1, q_2), dim=-2)
206
+ key_states = torch.cat((k_1, k_2), dim=-2)
207
+ else:
208
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
209
+
210
+ if block_past_key_values is not None:
211
+ if len(block_past_key_values) <= self.layer_idx:
212
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
213
+ key_states, value_states = block_past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
214
+ else:
215
+ block_cache_key_states = block_past_key_values[self.layer_idx][0]
216
+ block_cache_value_states = block_past_key_values[self.layer_idx][1]
217
+
218
+ block_cache_key_states[:, :, replace_position:replace_position+key_states.shape[2]] = key_states
219
+ block_cache_value_states[:, :, replace_position:replace_position+value_states.shape[2]] = value_states
220
+ key_states = block_cache_key_states
221
+ value_states = block_cache_value_states
222
+
223
+ if past_key_value is not None:
224
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
225
+ if update_past_key_values:
226
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
227
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
228
+ elif len(past_key_value) > self.layer_idx:
229
+ key_states = torch.cat((past_key_value[self.layer_idx][0], key_states), dim=-2)
230
+ value_states = torch.cat((past_key_value[self.layer_idx][1], value_states), dim=-2)
231
+
232
+ if self.training:
233
+ attn_output = fused_flex_attention(query_states, key_states, value_states, mask=attention_mask)
234
+ attn_output = attn_output.transpose(1, 2).contiguous()
235
+ else:
236
+ attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
237
+
238
+ attn_output, attn_weights = attention_interface(
239
+ self,
240
+ query_states,
241
+ key_states,
242
+ value_states,
243
+ attention_mask,
244
+ is_causal=False,
245
+ dropout=0.0 if not self.training else self.attention_dropout,
246
+ scaling=self.scaling,
247
+ sliding_window=self.sliding_window, # main diff with Llama
248
+ **kwargs,
249
+ )
250
+
251
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
252
+ attn_output = self.o_proj(attn_output)
253
+ return attn_output
254
+
255
+ @use_kernel_forward_from_hub("RMSNorm")
256
+ class Fast_dLLM_QwenRMSNorm(nn.Module):
257
+ def __init__(self, hidden_size, eps=1e-6):
258
+ """
259
+ Fast_dLLM_QwenRMSNorm is equivalent to T5LayerNorm
260
+ """
261
+ super().__init__()
262
+ self.weight = nn.Parameter(torch.ones(hidden_size))
263
+ self.variance_epsilon = eps
264
+
265
+ def forward(self, hidden_states):
266
+ input_dtype = hidden_states.dtype
267
+ hidden_states = hidden_states.to(torch.float32)
268
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
269
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
270
+ return self.weight * hidden_states.to(input_dtype)
271
+
272
+ def extra_repr(self):
273
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
274
+
275
+
276
+ class Fast_dLLM_QwenDecoderLayer(GradientCheckpointingLayer):
277
+ def __init__(self, config: Fast_dLLM_QwenConfig, layer_idx: int):
278
+ super().__init__()
279
+ self.hidden_size = config.hidden_size
280
+
281
+ self.self_attn = Fast_dLLM_QwenAttention(config=config, layer_idx=layer_idx)
282
+
283
+ self.mlp = Fast_dLLM_QwenMLP(config)
284
+ self.input_layernorm = Fast_dLLM_QwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
285
+ self.post_attention_layernorm = Fast_dLLM_QwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
286
+ self.attention_type = config.layer_types[layer_idx]
287
+
288
+ def forward(
289
+ self,
290
+ hidden_states: torch.Tensor,
291
+ attention_mask: Optional[torch.Tensor] = None,
292
+ position_ids: Optional[torch.LongTensor] = None,
293
+ past_key_value: Optional[Cache] = None,
294
+ use_cache: Optional[bool] = False,
295
+ cache_position: Optional[torch.LongTensor] = None,
296
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
297
+ update_past_key_values: Optional[bool] = False,
298
+ use_block_cache: Optional[bool] = False,
299
+ block_past_key_values: Optional[Cache] = None,
300
+ replace_position: Optional[int] = None,
301
+ **kwargs
302
+ ) -> tuple[torch.Tensor]:
303
+ residual = hidden_states
304
+ hidden_states = self.input_layernorm(hidden_states)
305
+ # Self Attention
306
+ hidden_states = self.self_attn(
307
+ hidden_states=hidden_states,
308
+ attention_mask=attention_mask,
309
+ position_ids=position_ids,
310
+ past_key_value=past_key_value,
311
+ use_cache=use_cache,
312
+ cache_position=cache_position,
313
+ position_embeddings=position_embeddings,
314
+ update_past_key_values=update_past_key_values,
315
+ use_block_cache=use_block_cache,
316
+ block_past_key_values=block_past_key_values,
317
+ replace_position=replace_position,
318
+ **kwargs,
319
+ )
320
+ hidden_states = residual + hidden_states
321
+
322
+ # Fully Connected
323
+ residual = hidden_states
324
+ hidden_states = self.post_attention_layernorm(hidden_states)
325
+ hidden_states = self.mlp(hidden_states)
326
+ hidden_states = residual + hidden_states
327
+ return hidden_states
328
+
329
+
330
+
331
+ class Fast_dLLM_QwenPreTrainedModel(PreTrainedModel):
332
+ config_class = Fast_dLLM_QwenConfig
333
+ base_model_prefix = "model"
334
+ supports_gradient_checkpointing = True
335
+ _no_split_modules = ["Fast_dLLM_QwenDecoderLayer"]
336
+ _skip_keys_device_placement = ["past_key_values"]
337
+ _supports_flash_attn_2 = True
338
+ _supports_sdpa = True
339
+ _supports_flex_attn = True
340
+ _supports_cache_class = True
341
+ _supports_quantized_cache = True
342
+ _supports_static_cache = True
343
+ _supports_attention_backend = True
344
+ _can_record_outputs = {
345
+ "hidden_states": Fast_dLLM_QwenDecoderLayer,
346
+ "attentions": Fast_dLLM_QwenAttention,
347
+ }
348
+
349
+ def _init_weights(self, module):
350
+ std = self.config.initializer_range
351
+ if isinstance(module, nn.Linear):
352
+ module.weight.data.normal_(mean=0.0, std=std)
353
+ if module.bias is not None:
354
+ module.bias.data.zero_()
355
+ elif isinstance(module, nn.Embedding):
356
+ module.weight.data.normal_(mean=0.0, std=std)
357
+ if module.padding_idx is not None:
358
+ module.weight.data[module.padding_idx].zero_()
359
+ elif isinstance(module, Fast_dLLM_QwenRMSNorm):
360
+ module.weight.data.fill_(1.0)
361
+
362
+
363
+ class Fast_dLLM_QwenRotaryEmbedding(nn.Module):
364
+ def __init__(self, config: Fast_dLLM_QwenConfig, device=None):
365
+ super().__init__()
366
+ # BC: "rope_type" was originally "type"
367
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
368
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
369
+ else:
370
+ self.rope_type = "default"
371
+ self.max_seq_len_cached = config.max_position_embeddings
372
+ self.original_max_seq_len = config.max_position_embeddings
373
+
374
+ self.config = config
375
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
376
+
377
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
378
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
379
+ self.original_inv_freq = self.inv_freq
380
+
381
+ @torch.no_grad()
382
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
383
+ def forward(self, x, position_ids):
384
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
385
+ position_ids_expanded = position_ids[:, None, :].float()
386
+
387
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
388
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
389
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
390
+ emb = torch.cat((freqs, freqs), dim=-1)
391
+ cos = emb.cos() * self.attention_scaling
392
+ sin = emb.sin() * self.attention_scaling
393
+
394
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
395
+
396
+
397
+
398
+ class Fast_dLLM_QwenModel(Fast_dLLM_QwenPreTrainedModel):
399
+ def __init__(self, config: Fast_dLLM_QwenConfig):
400
+ super().__init__(config)
401
+ self.padding_idx = config.pad_token_id
402
+ self.vocab_size = config.vocab_size
403
+ self.bd_size = config.bd_size
404
+
405
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
406
+ self.layers = nn.ModuleList(
407
+ [Fast_dLLM_QwenDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
408
+ )
409
+ self.norm = Fast_dLLM_QwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
410
+ self.rotary_emb = Fast_dLLM_QwenRotaryEmbedding(config=config)
411
+ self.gradient_checkpointing = True
412
+
413
+ # Initialize weights and apply final processing
414
+ self.post_init()
415
+
416
+ def get_input_embeddings(self):
417
+ return self.embed_tokens
418
+
419
+ def set_input_embeddings(self, value):
420
+ self.embed_tokens = value
421
+
422
+
423
+ def eval_mask(self, seqlen, block_size, cache_seq_len):
424
+ q_indices = torch.arange(seqlen) + cache_seq_len
425
+ k_indices = torch.arange(seqlen + cache_seq_len)
426
+ mask = eval_block_diff_mask(
427
+ q_idx=q_indices[:, None],
428
+ kv_idx=k_indices[None, :],
429
+ block_size=block_size
430
+ )
431
+ return mask
432
+
433
+ def gen_mask(self, seqlen, block_size, B, H):
434
+ mask = create_block_mask(
435
+ partial(block_diff_mask, block_size=block_size, n=seqlen),
436
+ B=B, H=H, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
437
+
438
+ return mask
439
+
440
+ def forward(
441
+ self,
442
+ input_ids: Optional[torch.LongTensor] = None,
443
+ labels: Optional[torch.LongTensor] = None,
444
+ attention_mask: Optional[torch.Tensor] = None,
445
+ position_ids: Optional[torch.LongTensor] = None,
446
+ past_key_values: Optional[Cache] = None,
447
+ inputs_embeds: Optional[torch.FloatTensor] = None,
448
+ use_cache: Optional[bool] = None,
449
+ cache_position: Optional[torch.LongTensor] = None,
450
+ update_past_key_values: Optional[bool] = False,
451
+ block_size: Optional[int] = 32,
452
+ use_block_cache: Optional[bool] = False,
453
+ block_past_key_values: Optional[Cache] = None,
454
+ replace_position: Optional[int] = None,
455
+ **kwargs
456
+ ) -> BaseModelOutputWithPast:
457
+ if (input_ids is None) ^ (inputs_embeds is not None):
458
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
459
+
460
+ if inputs_embeds is None:
461
+ inputs_embeds = self.embed_tokens(input_ids)
462
+
463
+ if use_cache and past_key_values is None:
464
+ past_key_values = DynamicCache()
465
+
466
+ if use_block_cache and block_past_key_values is None:
467
+ block_past_key_values = DynamicCache()
468
+
469
+ if cache_position is None:
470
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
471
+ if self.training:
472
+ cache_position = torch.arange(
473
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]//2, device=inputs_embeds.device
474
+ )
475
+ else:
476
+ if use_block_cache:
477
+ block_start_position = past_seen_tokens+replace_position if replace_position is not None else past_seen_tokens
478
+ cache_position = torch.arange(
479
+ block_start_position, block_start_position + inputs_embeds.shape[1], device=inputs_embeds.device
480
+ )
481
+ else:
482
+ cache_position = torch.arange(
483
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] if not self.training else inputs_embeds.shape[1]//2, device=inputs_embeds.device
484
+ )
485
+
486
+ # --- keep the user/tokenizer padding mask BEFORE you overwrite attention_mask ---
487
+ padding_mask_2d = attention_mask # shape [B, KV_LEN], 1=token, 0=pad
488
+
489
+ # -------------------------
490
+ # Position ids (left padding)
491
+ # -------------------------
492
+ if position_ids is None:
493
+ if (padding_mask_2d is not None) and (not self.training):
494
+ # full, per-sample positions over KV_LEN
495
+ pos_full = padding_mask_2d.long().cumsum(-1) - 1 # pads => -1
496
+ pos_full = pos_full.clamp_min(0) # pads => 0
497
+
498
+ q_len = inputs_embeds.shape[1]
499
+ kv_len = pos_full.shape[1]
500
+ if kv_len < q_len:
501
+ raise ValueError(f"attention_mask KV_LEN={kv_len} < input_len={q_len}. "
502
+ "When using cache, pass the FULL mask (past+current).")
503
+
504
+ q_start = kv_len - q_len # assumes current tokens are the last q_len positions
505
+ position_ids = pos_full[:, q_start:]
506
+ else:
507
+ # no padding mask: same positions for all batch elements
508
+ position_ids = cache_position.unsqueeze(0)
509
+
510
+ # -------------------------
511
+ # Attention mask (block-causal + padding), per sample
512
+ # -------------------------
513
+ if self.training:
514
+ attention_mask = self.gen_mask(labels.shape[1], self.bd_size, labels.shape[0], self.config.num_attention_heads).to(device=inputs_embeds.device)
515
+ else:
516
+ if use_block_cache and block_past_key_values.get_seq_length() != 0:
517
+ attention_mask = None
518
+ else:
519
+ # attention_mask = self.eval_mask(input_ids.shape[1], block_size, past_key_values.get_seq_length() if past_key_values is not None else 0).to(device=inputs_embeds.device)
520
+ if padding_mask_2d is None:
521
+ # fallback: original behavior (no padding)
522
+ structural = self.eval_mask(
523
+ seqlen=input_ids.shape[1],
524
+ block_size=block_size,
525
+ cache_seq_len=past_key_values.get_seq_length() if past_key_values is not None else 0,
526
+ ).to(device=inputs_embeds.device)
527
+ attention_mask = structural[None, None, :, :] # [1,1,Q,KV]
528
+ else:
529
+ pad = padding_mask_2d.to(torch.bool) # [B, KV]
530
+ B, kv_len = pad.shape
531
+ q_len = inputs_embeds.shape[1]
532
+ q_start = kv_len - q_len
533
+
534
+ # Per-sample block ids computed from *non-pad* positions
535
+ pos_full = pad.long().cumsum(-1) - 1
536
+ pos_full = pos_full.clamp_min(0)
537
+ block_full = pos_full // block_size # [B, KV]
538
+
539
+ block_q = block_full[:, q_start:] # [B, Q]
540
+ block_k = block_full # [B, KV]
541
+
542
+ structural = block_q.unsqueeze(-1) >= block_k.unsqueeze(-2) # [B, Q, KV]
543
+
544
+ # Mask keys AND queries (only valid tokens participate)
545
+ key_ok = pad[:, None, None, :] # [B,1,1,KV]
546
+ query_ok = pad[:, None, q_start:, None] # [B,1,Q,1]
547
+
548
+ attention_mask = structural[:, None, :, :] & key_ok & query_ok # [B,1,Q,KV]
549
+
550
+ hidden_states = inputs_embeds
551
+
552
+ # create position embeddings to be shared across the decoder layers
553
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
554
+
555
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
556
+ hidden_states = decoder_layer(
557
+ hidden_states,
558
+ attention_mask=attention_mask,
559
+ position_ids=position_ids,
560
+ past_key_value=past_key_values,
561
+ use_cache=use_cache,
562
+ cache_position=cache_position,
563
+ position_embeddings=position_embeddings,
564
+ update_past_key_values=update_past_key_values,
565
+ use_block_cache=use_block_cache,
566
+ block_past_key_values=block_past_key_values,
567
+ replace_position=replace_position,
568
+ **kwargs,
569
+ )
570
+
571
+ hidden_states = self.norm(hidden_states)
572
+ return BaseModelOutputWithPastAndBlockCache(
573
+ last_hidden_state=hidden_states,
574
+ past_key_values=past_key_values if use_cache else None,
575
+ block_past_key_values=block_past_key_values if use_block_cache else None,
576
+ )
577
+
578
+
579
+ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
580
+ _tied_weights_keys = ["lm_head.weight"]
581
+ _tp_plan = {"lm_head": "colwise_rep"}
582
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
583
+
584
+ def __init__(self, config):
585
+ super().__init__(config)
586
+ self.model = Fast_dLLM_QwenModel(config)
587
+ self.vocab_size = config.vocab_size
588
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
589
+
590
+ # Initialize weights and apply final processing
591
+ self.post_init()
592
+
593
+ self.generate_statistics = {}
594
+
595
+ def get_input_embeddings(self):
596
+ return self.model.embed_tokens
597
+
598
+ def set_input_embeddings(self, value):
599
+ self.model.embed_tokens = value
600
+
601
+ def get_output_embeddings(self):
602
+ return self.lm_head
603
+
604
+ def set_output_embeddings(self, new_embeddings):
605
+ self.lm_head = new_embeddings
606
+
607
+ def set_decoder(self, decoder):
608
+ self.model = decoder
609
+
610
+ def get_decoder(self):
611
+ return self.model
612
+
613
+ @can_return_tuple
614
+ def forward(
615
+ self,
616
+ input_ids: Optional[torch.LongTensor] = None,
617
+ attention_mask: Optional[torch.Tensor] = None,
618
+ position_ids: Optional[torch.LongTensor] = None,
619
+ past_key_values: Optional[Cache] = None,
620
+ inputs_embeds: Optional[torch.FloatTensor] = None,
621
+ labels: Optional[torch.LongTensor] = None,
622
+ use_cache: Optional[bool] = None,
623
+ cache_position: Optional[torch.LongTensor] = None,
624
+ logits_to_keep: Union[int, torch.Tensor] = 0,
625
+ update_past_key_values: Optional[bool] = False,
626
+ block_size: Optional[int] = 32,
627
+ use_block_cache: Optional[bool] = False,
628
+ block_past_key_values: Optional[Cache] = None,
629
+ replace_position: Optional[int] = None,
630
+ mask_id: Optional[int] = 151665,
631
+ **kwargs
632
+ ) -> CausalLMOutputWithPastAndBlockCache:
633
+
634
+ if self.training:
635
+ original_labels = labels.clone()
636
+ original_input_ids = input_ids.clone()
637
+
638
+ noisy_input_ids = input_ids.clone()
639
+
640
+ input_ids = input_ids.reshape(input_ids.shape[0] * input_ids.shape[1] // self.model.bd_size, self.model.bd_size)
641
+ b, l = input_ids.shape
642
+ t = torch.rand((b,), device=input_ids.device)
643
+ eps=1e-3
644
+ p_mask = (1 - eps) * t + eps
645
+ p_mask = p_mask[:, None].repeat(1, l)
646
+
647
+ mask_indices = torch.rand((b, l), device=input_ids.device) < p_mask
648
+ x_t = torch.where(mask_indices, mask_id, input_ids).reshape(labels.shape)
649
+ noisy_input_ids[labels != -100] = x_t[labels != -100]
650
+ mask = (noisy_input_ids != mask_id)
651
+ labels[mask] = -100
652
+ input_ids = torch.cat([noisy_input_ids, input_ids.reshape(labels.shape)], dim=1)
653
+
654
+ complementary_noisy_input_ids = original_input_ids.clone()
655
+ complementary_labels = original_labels.clone()
656
+
657
+ complementary_input_ids = original_input_ids.reshape(original_input_ids.shape[0] * original_input_ids.shape[1] // self.model.bd_size, self.model.bd_size)
658
+
659
+ complementary_mask_indices = ~mask_indices
660
+ complementary_x_t = torch.where(complementary_mask_indices, mask_id, complementary_input_ids).reshape(labels.shape)
661
+ complementary_noisy_input_ids[complementary_labels != -100] = complementary_x_t[complementary_labels != -100]
662
+ complementary_mask = (complementary_noisy_input_ids != mask_id)
663
+ complementary_labels[complementary_mask] = -100
664
+ complementary_input_ids = torch.cat([complementary_noisy_input_ids, complementary_input_ids.reshape(complementary_labels.shape)], dim=1)
665
+
666
+ input_ids = torch.cat([input_ids, complementary_input_ids], dim=0)
667
+ labels = torch.cat([labels, complementary_labels], dim=0)
668
+
669
+ outputs: BaseModelOutputWithPastAndBlockCache = self.model(
670
+ input_ids=input_ids,
671
+ labels=labels,
672
+ attention_mask=attention_mask,
673
+ position_ids=position_ids,
674
+ past_key_values=past_key_values,
675
+ inputs_embeds=inputs_embeds,
676
+ use_cache=use_cache,
677
+ cache_position=cache_position,
678
+ update_past_key_values=update_past_key_values,
679
+ block_size=block_size,
680
+ use_block_cache=use_block_cache,
681
+ block_past_key_values=block_past_key_values,
682
+ replace_position=replace_position,
683
+ **kwargs,
684
+ )
685
+
686
+ hidden_states = outputs.last_hidden_state
687
+ if self.training:
688
+ hidden_states = hidden_states[:, :hidden_states.shape[1]//2, :]
689
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
690
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
691
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
692
+
693
+ loss = None
694
+ if labels is not None:
695
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
696
+
697
+ return CausalLMOutputWithPastAndBlockCache(
698
+ loss=loss,
699
+ logits=logits,
700
+ past_key_values=outputs.past_key_values,
701
+ hidden_states=outputs.hidden_states,
702
+ attentions=outputs.attentions,
703
+ block_past_key_values=outputs.block_past_key_values,
704
+ )
705
+
706
+ @torch.no_grad()
707
+ def generate(
708
+ self,
709
+ input_ids,
710
+ attention_mask=None, # --- ADDED ARGUMENT ---
711
+ max_new_tokens=20, # Added default value for safety
712
+ mask_id=151665,
713
+ threshold=1,
714
+ small_block_size=8,
715
+ block_size=32,
716
+ stop_token=151645,
717
+ stopping_criteria=None,
718
+ top_p=0.95,
719
+ temperature=0,
720
+ use_block_cache=False,
721
+ log_lengths=False,
722
+ log_steps=False,
723
+ **kwargs
724
+ ):
725
+ if use_block_cache:
726
+ raise ValueError("use_block_cache=True is not supported in this generate() implementation.")
727
+ assert attention_mask is not None, "attention_mask must be provided for this generate() implementation."
728
+
729
+ # pad the initial input_ids and attention_mask to be multiple of block_size
730
+ if input_ids.shape[1] % block_size != 0:
731
+ pad_len = block_size - (input_ids.shape[1] % block_size)
732
+ input_ids = torch.cat([torch.full((input_ids.shape[0], pad_len), self.config.pad_token_id, device=input_ids.device), input_ids], dim=1)
733
+ attention_mask = torch.cat([torch.zeros((attention_mask.shape[0], pad_len), device=attention_mask.device), attention_mask], dim=1)
734
+
735
+ num_blocks = max_new_tokens // block_size
736
+ device = input_ids.device
737
+ batch_size = input_ids.size(0)
738
+ original_input_length = input_ids.shape[1]
739
+
740
+ # Track which sequences in the batch are still active
741
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
742
+ # Keep track of how many NFE each sequence uses and how many tokens are generated
743
+ iterations = torch.zeros((batch_size,), device=device)
744
+ n_generated_tokens = torch.zeros((batch_size,), device=device)
745
+ # Keep track if each sequence is finished
746
+ finished = torch.zeros((batch_size,), dtype=torch.bool, device=device)
747
+
748
+ # Handle prefix processing (Context Encoding)
749
+ if input_ids.shape[1] >= block_size:
750
+ output = self.forward(input_ids=input_ids[:, :(input_ids.shape[1] // block_size * block_size)], attention_mask=attention_mask[:, :(input_ids.shape[1] // block_size * block_size)], use_cache=True, update_past_key_values=True, block_size=block_size)
751
+ logits, past_key_values = output.logits, output.past_key_values
752
+ if input_ids.shape[1] % block_size == 0:
753
+ next_token = logits[:, -1:, :].argmax(dim=-1)
754
+ input_ids = torch.cat([input_ids, next_token], dim=1)
755
+
756
+ n_generated_tokens += (~finished).long()
757
+ iterations += (~finished).long()
758
+
759
+ # Update finished status
760
+ unfinished_sequences = unfinished_sequences & (next_token.squeeze(-1) != stop_token).long()
761
+ finished |= (next_token.squeeze(-1) == stop_token)
762
+
763
+ # Append to mask: If unfinished, append 1. If finished, append 0.
764
+ new_mask_col = unfinished_sequences.unsqueeze(1).to(dtype=attention_mask.dtype)
765
+ attention_mask = torch.cat([attention_mask, new_mask_col], dim=1)
766
+ else:
767
+ past_key_values = None
768
+
769
+ num_small_blocks = block_size // small_block_size
770
+
771
+ for block_idx in range(num_blocks):
772
+ new_tokens = input_ids[:, original_input_length:]
773
+ has_stop_now = (new_tokens == stop_token).any(dim=1) # check if any generated tokens ever are stop tokens
774
+ finished |= has_stop_now # whenever that is true we halt the sequence generation forever
775
+
776
+ if finished.all(): # whenever that is true we halt the sequence generation forever
777
+ break
778
+
779
+ # Length of current prompt
780
+ prompt_length = input_ids.shape[1]
781
+
782
+ # Initialize x_init with mask_id with all mask tokens for the new block
783
+ x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long)
784
+
785
+ # Concatenate input_ids with x_init to form the new input_ids (we added a block-1 of masks to our current generation)
786
+ x_init = torch.cat([input_ids, x_init], dim=1)
787
+
788
+ # mask extension is extending the current mask by the number of new tokens we are generating in this block by adding ones.
789
+ mask_extension = unfinished_sequences.unsqueeze(1).repeat(1, block_size - prompt_length % block_size).to(dtype=attention_mask.dtype)
790
+ # mask is the current attention mask extended by the new tokens we are generating in this block by adding ones.
791
+ curr_attention_mask = torch.cat([attention_mask, mask_extension], dim=1)
792
+
793
+ x_t = x_init.clone()
794
+ block_past_key_values = None
795
+ while True:
796
+ # mask_idx indicates where the mask tokens are in the current block
797
+ mask_idx = (x_t[:, -block_size:] == mask_id)
798
+ # TODOL assert that first element is always not a mask
799
+
800
+ if mask_idx.sum() == 0:
801
+ # If no mask tokens left in the current block, then we generate the next token autoregressively
802
+ output = self.forward(input_ids=x_t[:, -block_size:], attention_mask=curr_attention_mask, use_cache=True, past_key_values=past_key_values, update_past_key_values=True, block_size=block_size)
803
+ logits, past_key_values = output.logits, output.past_key_values
804
+ next_token = logits[:, -1:, :].argmax(dim=-1)
805
+ x_t = torch.cat([x_t, next_token], dim=1)
806
+
807
+ # generating one extra token means the mask needs to be extended by one more position 1 if not finished and 0 else
808
+ curr_attention_mask = torch.cat([curr_attention_mask, unfinished_sequences.unsqueeze(1).to(curr_attention_mask.dtype)], dim=1)
809
+
810
+ # add 1 to iterations for each unfinished sequence
811
+ iterations += (~finished).long()
812
+ n_generated_tokens += (~finished).long()
813
+
814
+ # TODO: we dont update the finished status here because we only care about tokens generated in the masked positions
815
+ break
816
+ for small_block_idx in range(num_small_blocks):
817
+ small_block_start_idx = small_block_idx * small_block_size
818
+ small_block_end_idx = small_block_start_idx + small_block_size
819
+
820
+ start = -block_size + small_block_start_idx
821
+ end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx
822
+ while True:
823
+ mask_idx = (x_t[:, -block_size:] == mask_id)
824
+ if mask_idx[:, start:end].sum() == 0:
825
+ break # loop untill all tokens are generated in this sub-block
826
+ # is it batch invariant? If one seq finishes then we loop until all seq finished
827
+ if use_block_cache:
828
+ assert False, "use_block_cache=True is not supported in this generate() implementation."
829
+ if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
830
+ output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
831
+ logits, block_past_key_values = output.logits, output.block_past_key_values
832
+ logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
833
+ logits = logits[:, start:end]
834
+ else:
835
+ logits = self.forward(input_ids=x_t[:,start:end], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True, block_past_key_values=block_past_key_values, replace_position=small_block_start_idx).logits
836
+ logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
837
+ else:
838
+ # input ids are the most recent block_size tokens, attention mask needs to
839
+ logits = self.forward(input_ids=x_t[:, -block_size:], attention_mask=curr_attention_mask, use_cache=True, past_key_values=past_key_values, update_past_key_values=False,block_size=block_size,).logits
840
+ # the logits to be sampled from are the most recent 32 tokens
841
+ # shift because of autoregressive conversion and valid by appending anything to the start since first token mask is off anyways always.
842
+ logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1) # TODO maybe prepend nan or sth
843
+ logits = logits[:, start:end]
844
+
845
+ x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
846
+ # Select tokens with probability greater than threshold from p_1t
847
+ x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
848
+ x1_p = torch.where(mask_idx[:, start:end], x1_p, -torch.inf)
849
+
850
+ unmask_idx = (x1_p > threshold)
851
+ # Ensure at least one token is unmasked in the current small block
852
+ max_prob_idx = x1_p.argmax(dim=-1)
853
+ unmask_idx[torch.arange(x_1.shape[0]), max_prob_idx] = True
854
+ unmask_idx = unmask_idx & mask_idx[:, start:end]
855
+
856
+ # Add 1 to iterations if the sequence is not stopped AND at least one token is generated in this iteration
857
+ # aka if not finished and unmask id has some True value
858
+ iterations += (~finished & unmask_idx.any(dim=1)).long()
859
+
860
+ # Count number of generated tokens in this iteration if not stopped
861
+ n_generated_iter = torch.where(finished, 0, unmask_idx.sum(dim=1)) # if not finished then count generated tokens
862
+ n_generated_tokens += n_generated_iter
863
+
864
+ # Only update the positions where unmask_idx is True AND the sequence if not finished TODO check this, otherwise
865
+ x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
866
+
867
+ # new_tokens = input_ids[:, original_input_length:]
868
+
869
+ # check if any newly generated token is stop token
870
+ # has_stop_now = (new_tokens == stop_token).any(dim=1)
871
+ # finished |= has_stop_now # TODO confirm if that is true here.
872
+
873
+ input_ids = x_t
874
+ attention_mask = curr_attention_mask
875
+
876
+ if log_lengths:
877
+ if self.generate_statistics.get("generation_lengths", None) is None:
878
+ self.generate_statistics["generation_lengths"] = []
879
+ self.generate_statistics["generation_lengths"].extend(n_generated_tokens.cpu().tolist())
880
+
881
+ if log_steps:
882
+ if self.generate_statistics.get("generation_steps", None) is None:
883
+ self.generate_statistics["generation_steps"] = []
884
+ self.generate_statistics["generation_steps"].extend(iterations.cpu().tolist())
885
+
886
+ # Final truncation: keep everything up to the *latest* first stop_token
887
+ new_tokens = input_ids[:, original_input_length:]
888
+ has_stop = (new_tokens == stop_token)
889
+
890
+ gen = input_ids[:, original_input_length:] # (B, T)
891
+
892
+ T = gen.size(1)
893
+
894
+ if T > 0:
895
+ device = input_ids.device
896
+ B = input_ids.size(0)
897
+
898
+ idx = torch.arange(T, device=device).unsqueeze(0).expand(B, T)
899
+ stop_mask = gen.eq(stop_token)
900
+
901
+ first_stop = torch.where(stop_mask, idx, torch.full_like(idx, T)).min(dim=1).values
902
+ has_stop = first_stop < T
903
+ keep = torch.where(has_stop, first_stop + 1, torch.full_like(first_stop, T))
904
+
905
+ pad_id = self.config.pad_token_id if getattr(self.config, "pad_token_id", None) is not None else stop_token
906
+ after = idx >= keep.unsqueeze(1)
907
+ gen = gen.clone()
908
+ gen[after] = pad_id
909
+
910
+ input_ids = torch.cat([input_ids[:, :original_input_length], gen], dim=1)
911
+
912
+ return input_ids
913
+
914
+ def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
915
+ # Calculate probabilities
916
+ if temperature > 0:
917
+ scaled_logits = logits / temperature
918
+ else:
919
+ p_1t = torch.softmax(logits, dim=-1)
920
+ x_1 = p_1t.argmax(dim=-1)
921
+ return x_1, p_1t
922
+ probs = torch.softmax(scaled_logits, dim=-1) # [B, seq_len, vocab_size]
923
+
924
+ sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) # [B, seq_len, sorted(vocab_size)]
925
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # [B, seq_len, cumsum(sorted(vocab_size))]
926
+
927
+ sorted_indices_to_remove = cumulative_probs > top_p # [B, seq_len, bool(sorted(vocab_size))]
928
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() # clone the tensor to avoid in-place operation error
929
+ sorted_indices_to_remove[..., 0] = 0 # always keep at least one token
930
+
931
+ indices_to_remove = torch.zeros_like(probs, dtype=torch.bool).scatter_(
932
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
933
+ ) # [B, seq_len, vocab_size], take 0 array and
934
+ # set True at the indices where sorted_indices_to_remove is True
935
+ # we index using the sorted indices in order to put the values back to their original position
936
+
937
+ # prev: probs[indices_to_remove] = 0, indices_to_remove is of the same shape as probs
938
+ # and therefore this operation just selects
939
+ probs = probs.masked_fill(indices_to_remove, 0.0)
940
+
941
+ probs_sum = probs.sum(dim=-1, keepdim=True).clamp_min(1e-12)
942
+ p_1t = probs / probs_sum
943
+
944
+ vocab = p_1t.shape[-1]
945
+ flat = p_1t.reshape(-1, vocab)
946
+ samples = torch.multinomial(flat, num_samples=1).squeeze(-1)
947
+ x_1 = samples.view(*p_1t.shape[:-1])
948
+
949
+ return x_1, p_1t