WuChengyue commited on
Commit
cb43b83
·
verified ·
1 Parent(s): c79aa4d

Clean up inference config: remove training-only flags, set bd_size=32 default, dtype=bfloat16

Browse files
Files changed (3) hide show
  1. config.json +1 -16
  2. configuration.py +1 -13
  3. modeling.py +162 -474
config.json CHANGED
@@ -1,6 +1,4 @@
1
  {
2
- "always_mask_im_end": true,
3
- "anneal_block_size": true,
4
  "architectures": [
5
  "Fast_dVLMForConditionalGeneration"
6
  ],
@@ -11,14 +9,8 @@
11
  "AutoModelForCausalLM": "modeling.Fast_dVLMForConditionalGeneration"
12
  },
13
  "bd_size": 32,
14
- "block_causal_no_dynamic": false,
15
- "complementary_mask": true,
16
  "dtype": "bfloat16",
17
- "enable_efficient_vision_embed": false,
18
- "entropy_loss": false,
19
- "entropy_loss_weight": 1.0,
20
  "eos_token_id": 151645,
21
- "flexible_bd_size": false,
22
  "hidden_act": "silu",
23
  "hidden_size": 2048,
24
  "image_token_id": 151655,
@@ -55,12 +47,7 @@
55
  "AutoModelForCausalLM": "modeling.Fast_dVLMForConditionalGeneration"
56
  },
57
  "bd_size": 8,
58
- "block_causal_no_dynamic": false,
59
  "bos_token_id": 151643,
60
- "complementary_mask": true,
61
- "dtype": "float32",
62
- "entropy_loss": false,
63
- "entropy_loss_weight": 1.0,
64
  "eos_token_id": 151645,
65
  "hidden_act": "silu",
66
  "hidden_size": 2048,
@@ -125,7 +112,6 @@
125
  "rope_theta": 1000000.0,
126
  "sliding_window": null,
127
  "tie_word_embeddings": true,
128
- "use_block_causal_mask": false,
129
  "use_cache": true,
130
  "use_sliding_window": false,
131
  "video_token_id": null,
@@ -135,13 +121,12 @@
135
  "vocab_size": 151936
136
  },
137
  "transformers_version": "4.57.1",
138
- "use_block_causal_mask": true,
139
  "use_cache": true,
140
  "use_sliding_window": false,
141
  "video_token_id": 151656,
142
  "vision_config": {
143
  "depth": 32,
144
- "dtype": "float32",
145
  "fullatt_block_indexes": [
146
  7,
147
  15,
 
1
  {
 
 
2
  "architectures": [
3
  "Fast_dVLMForConditionalGeneration"
4
  ],
 
9
  "AutoModelForCausalLM": "modeling.Fast_dVLMForConditionalGeneration"
10
  },
11
  "bd_size": 32,
 
 
12
  "dtype": "bfloat16",
 
 
 
13
  "eos_token_id": 151645,
 
14
  "hidden_act": "silu",
15
  "hidden_size": 2048,
16
  "image_token_id": 151655,
 
47
  "AutoModelForCausalLM": "modeling.Fast_dVLMForConditionalGeneration"
48
  },
49
  "bd_size": 8,
 
50
  "bos_token_id": 151643,
 
 
 
 
51
  "eos_token_id": 151645,
52
  "hidden_act": "silu",
53
  "hidden_size": 2048,
 
112
  "rope_theta": 1000000.0,
113
  "sliding_window": null,
114
  "tie_word_embeddings": true,
 
115
  "use_cache": true,
116
  "use_sliding_window": false,
117
  "video_token_id": null,
 
121
  "vocab_size": 151936
122
  },
123
  "transformers_version": "4.57.1",
 
124
  "use_cache": true,
125
  "use_sliding_window": false,
126
  "video_token_id": 151656,
127
  "vision_config": {
128
  "depth": 32,
129
+ "dtype": "bfloat16",
130
  "fullatt_block_indexes": [
131
  7,
132
  15,
configuration.py CHANGED
@@ -87,15 +87,10 @@ class Fast_dVLMTextConfig(PretrainedConfig):
87
  rope_scaling=None,
88
  image_token_id=None,
89
  video_token_id=None,
90
- bd_size=8,
91
  self_spec_inference_mode=None,
92
  block_length=None,
93
- use_block_causal_mask=False,
94
- complementary_mask=True,
95
  minimum_noise_level=1e-3,
96
- entropy_loss=False,
97
- entropy_loss_weight=1.0,
98
- block_causal_no_dynamic=False,
99
  **kwargs,
100
  ):
101
  self.vocab_size = vocab_size
@@ -122,12 +117,7 @@ class Fast_dVLMTextConfig(PretrainedConfig):
122
  self.rope_scaling = rope_scaling
123
  self.bd_size = bd_size
124
  self.layer_types = layer_types
125
- self.use_block_causal_mask = use_block_causal_mask
126
- self.complementary_mask = complementary_mask
127
  self.minimum_noise_level = minimum_noise_level
128
- self.entropy_loss = entropy_loss
129
- self.entropy_loss_weight = entropy_loss_weight
130
- self.block_causal_no_dynamic = block_causal_no_dynamic
131
  self.self_spec_inference_mode = self_spec_inference_mode
132
  self.block_length = block_length
133
  if self.layer_types is None:
@@ -166,7 +156,6 @@ class Fast_dVLMConfig(PretrainedConfig):
166
  vision_config=None,
167
  image_token_id=151655,
168
  video_token_id=151656,
169
- enable_efficient_vision_embed=False,
170
  **kwargs,
171
  ):
172
  if isinstance(vision_config, dict):
@@ -182,7 +171,6 @@ class Fast_dVLMConfig(PretrainedConfig):
182
 
183
  self.image_token_id = image_token_id
184
  self.video_token_id = video_token_id
185
- self.enable_efficient_vision_embed = enable_efficient_vision_embed
186
 
187
  super().__init__(**kwargs)
188
 
 
87
  rope_scaling=None,
88
  image_token_id=None,
89
  video_token_id=None,
90
+ bd_size=32,
91
  self_spec_inference_mode=None,
92
  block_length=None,
 
 
93
  minimum_noise_level=1e-3,
 
 
 
94
  **kwargs,
95
  ):
96
  self.vocab_size = vocab_size
 
117
  self.rope_scaling = rope_scaling
118
  self.bd_size = bd_size
119
  self.layer_types = layer_types
 
 
120
  self.minimum_noise_level = minimum_noise_level
 
 
 
121
  self.self_spec_inference_mode = self_spec_inference_mode
122
  self.block_length = block_length
123
  if self.layer_types is None:
 
156
  vision_config=None,
157
  image_token_id=151655,
158
  video_token_id=151656,
 
159
  **kwargs,
160
  ):
161
  if isinstance(vision_config, dict):
 
171
 
172
  self.image_token_id = image_token_id
173
  self.video_token_id = video_token_id
 
174
 
175
  super().__init__(**kwargs)
176
 
modeling.py CHANGED
@@ -23,94 +23,14 @@ from .configuration import Fast_dVLMConfig, Fast_dVLMTextConfig, Fast_dVLMVision
23
  from torch.nn.attention.flex_attention import flex_attention, create_block_mask
24
 
25
  from functools import partial
26
- import random
27
  import math
28
 
29
  logger = logging.get_logger(__name__)
30
 
31
 
32
- # @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
33
- # @torch.compile()
34
  def fused_flex_attention(q, k, v, mask=None):
35
  return flex_attention(q, k, v, block_mask=mask, enable_gqa=True)
36
 
37
- def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
38
- """
39
- Constructs the specialized block diffusion attention mask for training
40
- composed of three masks:
41
- - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
42
- - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
43
- - **Block Causal Mask (M_BC)**: Attention to update x0
44
-
45
- Args:
46
- b, h: Batch and head indices (ignored for mask logic).
47
- q_idx, kv_idx: Query and Key indices.
48
- seq_len: Total sequence length.
49
- block_size: Defines the block structure.
50
-
51
- Returns:
52
- A boolean attention mask.
53
- """
54
- # Indicate whether token belongs to xt or x0
55
- x0_flag_q = (q_idx >= n)
56
- x0_flag_kv = (kv_idx >= n)
57
-
58
- # Compute block indices
59
- block_q = torch.where(x0_flag_q == 1,
60
- (q_idx - n) // block_size,
61
- q_idx // block_size)
62
- block_kv = torch.where(x0_flag_kv == 1,
63
- (kv_idx - n) // block_size,
64
- kv_idx // block_size)
65
-
66
- # **1. Block Diagonal Mask (M_BD) **
67
- block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
68
-
69
- # **2. Offset Block-Causal Mask (M_OBC) **
70
- offset_block_causal = (
71
- (block_q > block_kv)
72
- & (x0_flag_kv == 1)
73
- & (x0_flag_q == 0)
74
- )
75
-
76
- # **3. Block-Causal Mask (M_BC) **
77
- block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
78
-
79
- # **4. Combine Masks **
80
- return block_diagonal | offset_block_causal | block_causal
81
-
82
-
83
- def block_causal_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
84
-
85
- # Indicate whether token belongs to xt or x0
86
- x0_flag_q = (q_idx >= n)
87
- x0_flag_kv = (kv_idx >= n)
88
-
89
- # Compute block indices
90
- block_q = torch.where(x0_flag_q == 1,
91
- (q_idx - n) // block_size,
92
- q_idx // block_size)
93
- block_kv = torch.where(x0_flag_kv == 1,
94
- (kv_idx - n) // block_size,
95
- kv_idx // block_size)
96
-
97
- # **1. Block Diagonal Mask (M_BD) **
98
- block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
99
-
100
- # **2. Offset Block-Causal Mask (M_OBC) **
101
- offset_block_causal = (
102
- (block_q > block_kv)
103
- & (x0_flag_kv == 1)
104
- & (x0_flag_q == 0)
105
- )
106
-
107
- # **3. Block-Causal Mask (M_BC) **
108
- block_causal = (q_idx >= kv_idx) & (x0_flag_kv == 1) & (x0_flag_q == 1)
109
-
110
- # **4. Combine Masks **
111
- return block_diagonal | offset_block_causal | block_causal
112
-
113
-
114
  def hybrid_block_causal_mask_multiturn(b, h, q_idx, kv_idx, response_block_idx=None, turn_idx=None, n=None):
115
  """
116
  Multi-turn hybrid mask: Prompt uses causal, Response uses block causal.
@@ -145,29 +65,20 @@ def hybrid_block_causal_mask_multiturn(b, h, q_idx, kv_idx, response_block_idx=N
145
 
146
  is_prompt_q = (block_q < 0)
147
  is_prompt_kv = (block_kv < 0)
148
-
149
- # x_t region rules:
150
- # 1. Can see all previous turns: turn_q > turn_kv
151
- # 2. Within same turn, prompt: causal (turn same + is prompt + pos satisfies causal)
152
- # 3. Within same turn, response: sees all prompt in same turn + block causal for response
153
- # xt_same_turn_prompt_causal = ~x0_flag_q & ~x0_flag_kv & (turn_q == turn_kv) & is_prompt_q & (pos_q >= pos_kv)
154
- # xt_same_turn_response = ~x0_flag_q & ~x0_flag_kv & (turn_q == turn_kv) & ~is_prompt_q & (
155
- # ~is_prompt_kv
156
- # )
157
- block_diagonal = ~x0_flag_q & ~x0_flag_kv & (turn_q == turn_kv)
158
-
159
- # **2. Offset Block-Causal Mask (M_OBC) **
160
  offset_block_causal = (
161
- (turn_q > turn_kv)
162
  & (x0_flag_kv == 1)
163
  & (x0_flag_q == 0)
164
  )
165
- # x_0 region: standard causal
166
  x0_causal = x0_flag_q & x0_flag_kv & (pos_q >= pos_kv)
167
-
168
- return (block_diagonal |
169
- offset_block_causal |
170
- x0_causal)
171
 
172
 
173
  def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
@@ -820,7 +731,6 @@ class Fast_dVLMAttention(nn.Module):
820
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
821
  if update_kv_cache:
822
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
823
- # elif len(past_key_values) > self.layer_idx:
824
  elif len(past_key_values) > self.layer_idx and past_key_values[self.layer_idx][0] is not None:
825
  key_states = torch.cat((past_key_values[self.layer_idx][0], key_states), dim=-2)
826
  value_states = torch.cat((past_key_values[self.layer_idx][1], value_states), dim=-2)
@@ -964,7 +874,12 @@ class Fast_dVLMTextModel(Fast_dVLMPreTrainedModel):
964
  # Initialize weights and apply final processing
965
  self.post_init()
966
 
967
-
 
 
 
 
 
968
  @auto_docstring
969
  def forward(
970
  self,
@@ -1036,34 +951,11 @@ class Fast_dVLMTextModel(Fast_dVLMPreTrainedModel):
1036
  text_position_ids = position_ids[0]
1037
  position_ids = position_ids[1:]
1038
  else:
1039
- # If inputs are not packed (usual 3D positions), do not prepare mask from position_ids
1040
  text_position_ids = None
1041
 
1042
- # It may already have been prepared by e.g. `generate`
1043
- # if not isinstance(causal_mask_mapping := attention_mask, dict):
1044
- # # Prepare mask arguments
1045
- # mask_kwargs = {
1046
- # "config": self.config,
1047
- # "input_embeds": inputs_embeds,
1048
- # "attention_mask": attention_mask,
1049
- # "cache_position": cache_position,
1050
- # "past_key_values": past_key_values,
1051
- # "position_ids": text_position_ids,
1052
- # }
1053
- # # Create the masks
1054
- # causal_mask_mapping = {
1055
- # "full_attention": create_causal_mask(**mask_kwargs),
1056
- # }
1057
- # # The sliding window alternating layers are not always activated depending on the config
1058
- # if self.has_sliding_layers:
1059
- # causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
1060
-
1061
  hidden_states = inputs_embeds
1062
-
1063
- # create position embeddings to be shared across the decoder layers
1064
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
1065
 
1066
- # decoder layers
1067
  all_hidden_states = () if output_hidden_states else None
1068
  all_self_attns = () if output_attentions else None
1069
 
@@ -1091,7 +983,6 @@ class Fast_dVLMTextModel(Fast_dVLMPreTrainedModel):
1091
 
1092
  hidden_states = self.norm(hidden_states)
1093
 
1094
- # add hidden states from the last decoder layer
1095
  if output_hidden_states:
1096
  all_hidden_states += (hidden_states,)
1097
 
@@ -1121,7 +1012,6 @@ class Fast_dVLMModel(Fast_dVLMPreTrainedModel):
1121
  self.visual = Fast_dVLMVisionTransformerPretrainedModel._from_config(config.vision_config)
1122
  self.language_model = Fast_dVLMTextModel._from_config(config.text_config)
1123
  self.rope_deltas = None # cache rope_deltas here
1124
- self.use_block_causal_mask = config.use_block_causal_mask
1125
 
1126
  # Initialize weights and apply final processing
1127
  self.post_init()
@@ -1307,13 +1197,6 @@ class Fast_dVLMModel(Fast_dVLMPreTrainedModel):
1307
  mrope_position_deltas = torch.tensor(mrope_position_deltas).unsqueeze(1).to(device=input_ids.device)
1308
  return position_ids, mrope_position_deltas
1309
  else:
1310
- # if attention_mask is not None:
1311
- # position_ids = attention_mask.long().cumsum(-1) - 1
1312
- # position_ids.masked_fill_(attention_mask == 0, 1)
1313
- # position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
1314
- # max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
1315
- # mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
1316
- # else:
1317
  if self.training:
1318
  position_ids = (
1319
  torch.arange(input_ids.shape[1] // 2, device=input_ids.device)
@@ -1415,16 +1298,16 @@ class Fast_dVLMModel(Fast_dVLMPreTrainedModel):
1415
  return special_image_mask, special_video_mask
1416
 
1417
 
1418
- def eval_mask(self, seqlen, block_size, cache_seq_len, update_kv_cache=False, use_block_causal_mask=False):
1419
  q_indices = torch.arange(seqlen, device=self.device) + cache_seq_len
1420
  k_indices = torch.arange(seqlen + cache_seq_len, device=self.device)
1421
- if use_block_causal_mask and update_kv_cache:
1422
  mask = eval_causal_mask(q_indices[:, None], k_indices[None, :])
1423
  else:
1424
  mask = eval_block_diff_mask(
1425
- q_idx=q_indices[:, None],
1426
- kv_idx=k_indices[None, :],
1427
- block_size=block_size
1428
  )
1429
  return mask
1430
 
@@ -1536,8 +1419,6 @@ class Fast_dVLMModel(Fast_dVLMPreTrainedModel):
1536
 
1537
  position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1538
  position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
1539
- # if cache_position is not None:
1540
- # delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
1541
  if past_key_values is not None:
1542
  delta = (past_key_values.get_seq_length() + self.rope_deltas).to(inputs_embeds.device)
1543
  else:
@@ -1547,7 +1428,7 @@ class Fast_dVLMModel(Fast_dVLMPreTrainedModel):
1547
 
1548
  position_ids = position_ids.to(inputs_embeds.device)
1549
  if not self.training:
1550
- attention_mask = self.eval_mask(inputs_embeds.shape[1], self.bd_size if bd_size is None else bd_size, 0 if past_key_values is None else past_key_values.get_seq_length(), update_kv_cache=update_kv_cache, use_block_causal_mask=self.use_block_causal_mask).to(inputs_embeds.device)
1551
 
1552
  outputs = self.language_model(
1553
  input_ids=None,
@@ -1620,19 +1501,9 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1620
  self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1621
  self.bd_size = config.bd_size
1622
  self.model.bd_size = self.bd_size
1623
- self.complementary_mask = getattr(config, 'complementary_mask', False)
1624
- self.always_mask_im_end = getattr(config, 'always_mask_im_end', False)
1625
- self.flexible_bd_size = getattr(config, 'flexible_bd_size', False)
1626
- self.use_block_causal_mask = getattr(config, 'use_block_causal_mask', False)
1627
- self.anneal_block_size = getattr(config, 'anneal_block_size', False)
1628
- self.enable_efficient_vision_embed = getattr(config, 'enable_efficient_vision_embed', False)
1629
  self.minimum_noise_level = getattr(config, 'minimum_noise_level', 0.0)
1630
- self.entropy_loss = getattr(config, 'entropy_loss', False)
1631
- self.entropy_loss_weight = getattr(config, 'entropy_loss_weight', 1.0)
1632
- self.block_causal_no_dynamic = getattr(config, 'block_causal_no_dynamic', False)
1633
  self.im_end_token_id = 151645 # <|im_end|> token id
1634
- # self.max_context_length = 4096
1635
-
1636
  # Vision-to-text aligner (if vision output dim != text hidden dim)
1637
  vision_out_dim = config.vision_config.out_hidden_size
1638
  text_hidden = config.text_config.hidden_size
@@ -1675,30 +1546,6 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1675
  def visual(self):
1676
  return self.model.visual
1677
 
1678
- def gen_mask(self, seqlen, block_size, B, H):
1679
- # ================== 修改开始 ==================
1680
- # flex_attention 要求闭包捕获的变量必须是 Tensor
1681
- # 将 int 转换为 Tensor,并放在对应的设备上
1682
- block_size_t = torch.tensor(block_size, device=self.device, dtype=torch.int32)
1683
- n_t = torch.tensor(seqlen, device=self.device, dtype=torch.int32)
1684
-
1685
- mask = create_block_mask(
1686
- # 这里将原来的 block_size=block_size 改为传入 Tensor
1687
- partial(block_diff_mask, block_size=block_size_t, n=n_t),
1688
- B=B, H=H, Q_LEN=seqlen*2, KV_LEN=seqlen*2
1689
- )
1690
- # ================== 修改结束 ==================
1691
- return mask
1692
-
1693
- def gen_block_causal_mask(self, seqlen, block_size, B, H):
1694
- block_size_t = torch.tensor(block_size, device=self.device, dtype=torch.int32)
1695
- n_t = torch.tensor(seqlen, device=self.device, dtype=torch.int32)
1696
- mask = create_block_mask(
1697
- partial(block_causal_mask, block_size=block_size_t, n=n_t),
1698
- B=B, H=H, Q_LEN=seqlen*2, KV_LEN=seqlen*2
1699
- )
1700
- return mask
1701
-
1702
  def compute_response_block_idx(self, labels, block_size):
1703
  """
1704
  Compute block index and turn index for each position.
@@ -1767,36 +1614,6 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1767
  )
1768
  return mask
1769
 
1770
- def compute_entropy_loss(self, logits, labels, num_items_in_batch=None):
1771
- """Compute entropy loss with optional global normalization.
1772
-
1773
- Args:
1774
- logits: Model logits
1775
- labels: Ground truth labels (-100 for ignored tokens)
1776
- num_items_in_batch: Global number of non-ignored tokens for normalization.
1777
- If provided, uses sum/num_items_in_batch for global norm.
1778
- If None, uses mean() for micro-batch norm.
1779
- """
1780
- non_ignore_mask = labels != -100
1781
- logits = logits[non_ignore_mask]
1782
- labels = labels[non_ignore_mask]
1783
- correct_mask = logits.argmax(dim=-1) == labels
1784
-
1785
- compute_logits = logits[correct_mask]
1786
-
1787
- if correct_mask.sum() == 0:
1788
- return torch.tensor(0.0, device=logits.device)
1789
-
1790
- p = F.softmax(compute_logits, dim=-1)
1791
- log_p = F.log_softmax(compute_logits, dim=-1)
1792
- entropy = -torch.sum(p * log_p, dim=-1)
1793
-
1794
- if num_items_in_batch is not None:
1795
- # Global normalization: use same denominator as cross entropy loss
1796
- return entropy.sum() / num_items_in_batch
1797
- else:
1798
- return entropy.mean()
1799
-
1800
  @can_return_tuple
1801
  @auto_docstring
1802
  def forward(
@@ -1839,34 +1656,22 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1839
  eval_bd_size (`int`, *optional*):
1840
  Block diffusion size to use during evaluation. Overrides the model default when set.
1841
  """
1842
- # input_ids = torch.tensor([[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]]).to(input_ids.device, dtype=input_ids.dtype)
1843
- # labels = torch.tensor([[-100,-100,3,4,5,6,-100,-100,-100,-100,11,12,13,14,15]]).to(labels.device, dtype=labels.dtype)
1844
- # pixel_values = None
1845
- # pixel_values_videos = None
1846
- # self.bd_size = 2
1847
-
1848
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1849
  output_hidden_states = (
1850
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1851
  )
1852
  if self.training:
1853
- if self.anneal_block_size:
1854
- # Get update_ratio from kwargs (passed by trainer)
1855
- update_ratio = kwargs.get('update_ratio', 1.0)
1856
- # Compute possible bd_sizes: [2, 4, 8, ..., target_bd_size]
1857
- max_power = int(math.log2(self.bd_size))
1858
- possible_bd_sizes = [2**i for i in range(2, max_power + 1)] # Start from 4
1859
- # sqrt mapping: larger block sizes get more training time
1860
- scaled_ratio = math.sqrt(update_ratio)
1861
- idx = min(int(scaled_ratio * len(possible_bd_sizes)), len(possible_bd_sizes) - 1)
1862
- bd_size = possible_bd_sizes[idx]
1863
- elif self.flexible_bd_size:
1864
- max_power = int(math.log2(self.bd_size))
1865
- possible_bd_sizes = [2**i for i in range(max_power + 1)]
1866
- bd_size = random.choice(possible_bd_sizes)
1867
- else:
1868
- bd_size = self.bd_size
1869
- if pixel_values is None and pixel_values_videos is None: # only train on text
1870
 
1871
  batch_size, seq_len = input_ids.shape
1872
  original_labels = labels.clone()
@@ -1877,79 +1682,57 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1877
  response_mask = (labels != -100) # [B, seq_len]
1878
  eps = self.minimum_noise_level
1879
 
1880
- if self.use_block_causal_mask and not self.block_causal_no_dynamic:
1881
- response_block_idx, turn_idx, n_blocks = self.compute_response_block_idx(labels, bd_size)
1882
-
1883
- # Sample t for each block: [n_blocks]
1884
-
1885
- # random sample t for each block from [self.minimum_noise_level, 1]
1886
- t = torch.rand((n_blocks,), device=input_ids.device)
1887
- p_mask_per_block = (1 - eps) * t + eps
1888
-
1889
- # Create mask_indices: [B, seq_len]
1890
- mask_indices = torch.zeros_like(labels, dtype=torch.bool)
1891
- for i in range(seq_len):
1892
- block_i = response_block_idx[i].item()
1893
- if block_i >= 0: # response token
1894
- mask_indices[:, i] = torch.rand((batch_size,), device=input_ids.device) < p_mask_per_block[block_i]
1895
- else:
1896
- input_ids = input_ids.reshape(input_ids.shape[0] * input_ids.shape[1] // bd_size, bd_size)
1897
- b, l = input_ids.shape
1898
- t = torch.rand((b,), device=input_ids.device)
1899
- p_mask = (1 - eps) * t + eps
1900
- p_mask = p_mask[:, None].repeat(1, l)
1901
-
1902
- mask_indices = torch.rand((b, l), device=input_ids.device) < p_mask
1903
- mask_indices = mask_indices.reshape(labels.shape) & response_mask
1904
- input_ids = input_ids.reshape(labels.shape)
1905
-
1906
- # Always mask <|im_end|> in response
1907
- if self.always_mask_im_end:
1908
- im_end_mask = (input_ids == self.im_end_token_id) & response_mask
1909
- mask_indices = mask_indices | im_end_mask
1910
-
1911
- # Apply mask only to response
1912
  noisy_input_ids = input_ids.clone()
1913
  noisy_input_ids[mask_indices] = mask_id
1914
-
1915
- # Update labels: only predict masked response tokens
1916
  labels = labels.clone()
1917
  labels[~mask_indices] = -100
1918
-
1919
- # Concatenate [noisy | clean]
1920
  input_ids = torch.cat([noisy_input_ids, original_input_ids], dim=1)
1921
 
1922
- # Complementary version
1923
- if self.complementary_mask:
1924
- complementary_mask_indices = response_mask & ~mask_indices
1925
- if self.always_mask_im_end:
1926
- im_end_mask = (original_input_ids == self.im_end_token_id) & response_mask
1927
- complementary_mask_indices = complementary_mask_indices | im_end_mask
1928
-
1929
- complementary_noisy_input_ids = original_input_ids.clone()
1930
- complementary_noisy_input_ids[complementary_mask_indices] = mask_id
1931
-
1932
- complementary_labels = original_labels.clone()
1933
- complementary_labels[~complementary_mask_indices] = -100
1934
- complementary_input_ids = torch.cat([complementary_noisy_input_ids, original_input_ids], dim=1)
1935
-
1936
- input_ids = torch.cat([input_ids, complementary_input_ids], dim=0)
1937
- labels = torch.cat([labels, complementary_labels], dim=0)
1938
-
1939
- if self.use_block_causal_mask:
1940
- if self.block_causal_no_dynamic:
1941
- attention_mask = self.gen_block_causal_mask(seq_len, bd_size, input_ids.shape[0], self.config.num_attention_heads)
1942
- else:
1943
- attention_mask = self.gen_hybrid_block_causal_mask(seq_len, response_block_idx, turn_idx, input_ids.shape[0], self.config.num_attention_heads)
1944
- else:
1945
- attention_mask = self.gen_mask(seq_len, bd_size, input_ids.shape[0], self.config.num_attention_heads)
1946
-
1947
- else: # 多模态 block diffusion
1948
- # Phase A: Embed + masked scatter vision
1949
 
 
 
 
 
 
1950
  if inputs_embeds is None:
1951
  inputs_embeds = self.model.get_input_embeddings()(input_ids)
1952
-
1953
  if pixel_values is not None:
1954
  image_embeds = self.model.get_image_features(pixel_values, image_grid_thw)
1955
  image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
@@ -1959,7 +1742,7 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1959
  input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
1960
  )
1961
  inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1962
-
1963
  if pixel_values_videos is not None:
1964
  video_embeds = self.model.get_video_features(pixel_values_videos, video_grid_thw)
1965
  video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
@@ -1969,8 +1752,8 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1969
  input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
1970
  )
1971
  inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
1972
-
1973
- # Phase B: 生成 3D position_ids(在扩倍前,基于原长)
1974
  if position_ids is None:
1975
  position_ids, rope_deltas = self.model.get_rope_index(
1976
  input_ids=input_ids,
@@ -1979,189 +1762,107 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1979
  second_per_grid_ts=second_per_grid_ts,
1980
  attention_mask=attention_mask,
1981
  )
1982
-
1983
- # Phase C: Block diffusion (保护 vision token 位置)
1984
  batch_size = input_ids.shape[0]
1985
  L = input_ids.shape[1]
1986
  seq_len = L
1987
-
1988
- # if L > self.max_context_length:
1989
- # L = self.max_context_length
1990
- # input_ids = input_ids[:, :self.max_context_length]
1991
- # labels = labels[:, :self.max_context_length]
1992
- # position_ids = position_ids[:, :self.max_context_length]
1993
- # attention_mask = attention_mask[:, :self.max_context_length]
1994
- # inputs_embeds = inputs_embeds[:, :self.max_context_length]
1995
-
1996
  hidden_size = inputs_embeds.shape[-1]
1997
-
1998
  original_labels = labels.clone()
1999
  original_input_ids = input_ids.clone()
2000
  original_embeds = inputs_embeds.clone()
2001
- original_position_ids = position_ids.clone() # 保存原长 position [3, B, L]
2002
-
2003
- # 识别 vision tokens(不加噪声)
2004
  image_token_id = self.config.image_token_id
2005
  video_token_id = self.config.video_token_id
2006
  vision_start_token_id = self.config.vision_start_token_id
2007
  vision_token_mask = (input_ids == image_token_id) | (input_ids == video_token_id) | (input_ids == vision_start_token_id)
2008
  vision_mask_3d = vision_token_mask.unsqueeze(-1).expand(-1, -1, hidden_size)
2009
-
2010
- # Block diffusion with multi-turn support
2011
- # Each response segment has independent blocks
2012
  response_block_idx, turn_idx, n_blocks = self.compute_response_block_idx(labels, bd_size)
2013
- # Compute response block index: -1 for prompt, >=0 for response
2014
- # Each response segment has independent blocks
2015
- response_mask = (labels != -100) # [B, seq_len]
2016
  eps = self.minimum_noise_level
2017
 
2018
- if self.use_block_causal_mask and not self.block_causal_no_dynamic:
2019
- response_block_idx, turn_idx, n_blocks = self.compute_response_block_idx(labels, bd_size)
2020
-
2021
- # Sample t for each block: [n_blocks]
2022
-
2023
- # random sample t for each block from [self.minimum_noise_level, 1]
2024
- t = torch.rand((n_blocks,), device=input_ids.device)
2025
- p_mask_per_block = (1 - eps) * t + eps
2026
-
2027
- # Create mask_indices: [B, seq_len]
2028
- mask_indices = torch.zeros_like(labels, dtype=torch.bool)
2029
- for i in range(seq_len):
2030
- block_i = response_block_idx[i].item()
2031
- if block_i >= 0: # response token
2032
- mask_indices[:, i] = torch.rand((batch_size,), device=input_ids.device) < p_mask_per_block[block_i]
2033
- else:
2034
- input_ids = input_ids.reshape(input_ids.shape[0] * input_ids.shape[1] // bd_size, bd_size)
2035
- b, l = input_ids.shape
2036
- t = torch.rand((b,), device=input_ids.device)
2037
- p_mask = (1 - eps) * t + eps
2038
- p_mask = p_mask[:, None].repeat(1, l)
2039
-
2040
- mask_indices = torch.rand((b, l), device=input_ids.device) < p_mask
2041
- mask_indices = mask_indices.reshape(labels.shape) & response_mask
2042
- input_ids = input_ids.reshape(labels.shape)
2043
-
2044
- if self.always_mask_im_end:
2045
- im_end_mask = (input_ids == self.im_end_token_id) & response_mask
2046
- mask_indices = mask_indices | im_end_mask
2047
-
2048
  noisy_input_ids = input_ids.clone()
2049
  noisy_input_ids[mask_indices] = mask_id
2050
-
2051
- # Noisy embeds(保护 vision
2052
- if self.enable_efficient_vision_embed:
2053
- noisy_embeds = original_embeds.clone()
2054
- text_mask_3d = mask_indices.unsqueeze(-1).expand(-1, -1, hidden_size)
2055
- mask_embeds = self.model.language_model.embed_tokens(
2056
- torch.full_like(input_ids, mask_id)
2057
- )
2058
- noisy_embeds = torch.where(text_mask_3d, mask_embeds, noisy_embeds)
2059
- else:
2060
- noisy_embeds_raw = self.model.language_model.embed_tokens(noisy_input_ids)
2061
- noisy_embeds = torch.where(vision_mask_3d, original_embeds, noisy_embeds_raw)
2062
-
2063
- # 更新 labels
2064
  labels_noisy = labels.clone()
2065
  labels_noisy[~mask_indices] = -100
2066
-
2067
- # 拼接 [noisy | clean]
2068
  input_ids_pair1 = torch.cat([noisy_input_ids, original_input_ids], dim=1)
2069
  embeds_pair1 = torch.cat([noisy_embeds, original_embeds], dim=1)
2070
  labels_pair1 = labels_noisy
2071
- position_ids_pair1 = original_position_ids # [3, B, L]
2072
 
2073
- input_ids = input_ids_pair1
2074
- inputs_embeds = embeds_pair1
2075
- labels = labels_pair1
2076
- position_ids = position_ids_pair1
2077
-
2078
- # Complementary
2079
- if self.complementary_mask:
2080
- complementary_mask_indices = response_mask & ~mask_indices
2081
- if self.always_mask_im_end:
2082
- im_end_mask = (original_input_ids == self.im_end_token_id) & response_mask
2083
- complementary_mask_indices = complementary_mask_indices | im_end_mask
2084
-
2085
- complementary_noisy_input_ids = original_input_ids.clone()
2086
- complementary_noisy_input_ids[complementary_mask_indices] = mask_id
2087
-
2088
- if self.enable_efficient_vision_embed:
2089
- complementary_noisy_embeds = original_embeds.clone()
2090
- text_mask_3d = complementary_mask_indices.unsqueeze(-1).expand(-1, -1, hidden_size)
2091
- mask_embeds = self.model.language_model.embed_tokens(
2092
- torch.full_like(original_input_ids, mask_id)
2093
- )
2094
- complementary_noisy_embeds = torch.where(text_mask_3d, mask_embeds, complementary_noisy_embeds)
2095
- else:
2096
- complementary_noisy_embeds_raw = self.model.language_model.embed_tokens(complementary_noisy_input_ids)
2097
- complementary_noisy_embeds = torch.where(vision_mask_3d, original_embeds, complementary_noisy_embeds_raw)
2098
-
2099
- complementary_labels = original_labels.clone()
2100
- complementary_labels[~complementary_mask_indices] = -100
2101
-
2102
- input_ids_pair2 = torch.cat([complementary_noisy_input_ids, original_input_ids], dim=1)
2103
- embeds_pair2 = torch.cat([complementary_noisy_embeds, original_embeds], dim=1)
2104
- labels_pair2 = complementary_labels
2105
- position_ids_pair2 = original_position_ids
2106
-
2107
- # Batch 拼接
2108
- input_ids = torch.cat([input_ids_pair1, input_ids_pair2], dim=0)
2109
- inputs_embeds = torch.cat([embeds_pair1, embeds_pair2], dim=0)
2110
- labels = torch.cat([labels_pair1, labels_pair2], dim=0)
2111
- position_ids = torch.cat([position_ids_pair1, position_ids_pair2], dim=1)
2112
-
2113
- if self.use_block_causal_mask:
2114
- if self.block_causal_no_dynamic:
2115
- attention_mask = self.gen_block_causal_mask(L, bd_size, input_ids.shape[0], self.config.num_attention_heads)
2116
- else:
2117
- attention_mask = self.gen_hybrid_block_causal_mask(L, response_block_idx, turn_idx, input_ids.shape[0], self.config.num_attention_heads)
2118
- else:
2119
- attention_mask = self.gen_mask(L, bd_size, input_ids.shape[0], self.config.num_attention_heads)
2120
-
2121
- # 清空 pixel_values(已替换)
2122
- pixel_values = None
2123
- pixel_values_videos = None
2124
 
 
 
2125
 
2126
- # Phase D: 调用内层(多模态时传 inputs_embeds,纯文本时传 input_ids)
2127
- if pixel_values is None and pixel_values_videos is None:
2128
- # 纯文本:传 input_ids(内层会 embed)
2129
- outputs = self.model(
2130
- input_ids=input_ids,
2131
- pixel_values=None,
2132
- pixel_values_videos=None,
2133
- image_grid_thw=None,
2134
- video_grid_thw=None,
2135
- position_ids=position_ids,
2136
- attention_mask=attention_mask,
2137
- past_key_values=past_key_values,
2138
- inputs_embeds=inputs_embeds,
2139
- use_cache=use_cache,
2140
- output_attentions=output_attentions,
2141
- output_hidden_states=output_hidden_states,
2142
- return_dict=True,
2143
- cache_position=cache_position,
2144
- update_kv_cache=update_kv_cache,
2145
- bd_size=bd_size,
2146
- **kwargs,
2147
- )
2148
- else:
2149
- # 多模态:传 inputs_embeds(已 masked_scatter)
2150
- outputs = self.model.language_model(
2151
- input_ids=None,
2152
- position_ids=position_ids,
2153
- attention_mask=attention_mask,
2154
- past_key_values=past_key_values,
2155
- inputs_embeds=inputs_embeds,
2156
- use_cache=use_cache,
2157
- output_attentions=output_attentions,
2158
- output_hidden_states=output_hidden_states,
2159
- return_dict=True,
2160
- cache_position=cache_position,
2161
- update_kv_cache=update_kv_cache,
2162
- bd_size=bd_size,
2163
- **kwargs,
2164
- )
 
 
2165
 
2166
  else:
2167
  outputs = self.model(
@@ -2193,31 +1894,18 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
2193
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
2194
  logits = self.lm_head(mdm_hidden_states[:, slice_indices, :])
2195
 
2196
- if self.use_block_causal_mask:
2197
- new_kwargs = {
2198
- 'num_items_in_batch': 2*kwargs['num_items_in_batch'],
2199
- }
2200
- else:
2201
- new_kwargs = kwargs
2202
  if labels is not None:
2203
  loss = self.loss_function(
2204
  logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs
2205
  ) * 0.5
2206
- if self.use_block_causal_mask:
2207
- if self.complementary_mask:
2208
- causal_hidden_states = hidden_states[:hidden_states.shape[0]//2, hidden_states.shape[1]//2:, :]
2209
- else:
2210
- causal_hidden_states = hidden_states[:, :hidden_states.shape[1]//2, :]
2211
- causal_logits = self.lm_head(causal_hidden_states[:, slice_indices, :])
2212
- loss += self.loss_function(
2213
- logits=causal_logits, labels=original_labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs
2214
- )
2215
-
2216
- if self.entropy_loss:
2217
- # Use num_items_in_batch for global normalization (consistent with cross entropy)
2218
- num_items = kwargs.get('num_items_in_batch', None)
2219
- entropy_loss = self.compute_entropy_loss(logits, labels, num_items_in_batch=num_items)
2220
- loss += self.entropy_loss_weight * entropy_loss
2221
  else:
2222
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
2223
  logits = self.lm_head(hidden_states[:, slice_indices, :])
 
23
  from torch.nn.attention.flex_attention import flex_attention, create_block_mask
24
 
25
  from functools import partial
 
26
  import math
27
 
28
  logger = logging.get_logger(__name__)
29
 
30
 
 
 
31
  def fused_flex_attention(q, k, v, mask=None):
32
  return flex_attention(q, k, v, block_mask=mask, enable_gqa=True)
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def hybrid_block_causal_mask_multiturn(b, h, q_idx, kv_idx, response_block_idx=None, turn_idx=None, n=None):
35
  """
36
  Multi-turn hybrid mask: Prompt uses causal, Response uses block causal.
 
65
 
66
  is_prompt_q = (block_q < 0)
67
  is_prompt_kv = (block_kv < 0)
68
+
69
+ # Block diagonal: same turn, both in x_t region.
70
+ block_diagonal = ~x0_flag_q & ~x0_flag_kv & (turn_q == turn_kv)
71
+
72
+ # Offset block-causal: x_t can attend to x_0 of strictly earlier turns.
 
 
 
 
 
 
 
73
  offset_block_causal = (
74
+ (turn_q > turn_kv)
75
  & (x0_flag_kv == 1)
76
  & (x0_flag_q == 0)
77
  )
78
+ # x_0 region uses standard causal masking.
79
  x0_causal = x0_flag_q & x0_flag_kv & (pos_q >= pos_kv)
80
+
81
+ return block_diagonal | offset_block_causal | x0_causal
 
 
82
 
83
 
84
  def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
 
731
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
732
  if update_kv_cache:
733
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
734
  elif len(past_key_values) > self.layer_idx and past_key_values[self.layer_idx][0] is not None:
735
  key_states = torch.cat((past_key_values[self.layer_idx][0], key_states), dim=-2)
736
  value_states = torch.cat((past_key_values[self.layer_idx][1], value_states), dim=-2)
 
874
  # Initialize weights and apply final processing
875
  self.post_init()
876
 
877
+ def get_input_embeddings(self):
878
+ return self.embed_tokens
879
+
880
+ def set_input_embeddings(self, value):
881
+ self.embed_tokens = value
882
+
883
  @auto_docstring
884
  def forward(
885
  self,
 
951
  text_position_ids = position_ids[0]
952
  position_ids = position_ids[1:]
953
  else:
 
954
  text_position_ids = None
955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956
  hidden_states = inputs_embeds
 
 
957
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
958
 
 
959
  all_hidden_states = () if output_hidden_states else None
960
  all_self_attns = () if output_attentions else None
961
 
 
983
 
984
  hidden_states = self.norm(hidden_states)
985
 
 
986
  if output_hidden_states:
987
  all_hidden_states += (hidden_states,)
988
 
 
1012
  self.visual = Fast_dVLMVisionTransformerPretrainedModel._from_config(config.vision_config)
1013
  self.language_model = Fast_dVLMTextModel._from_config(config.text_config)
1014
  self.rope_deltas = None # cache rope_deltas here
 
1015
 
1016
  # Initialize weights and apply final processing
1017
  self.post_init()
 
1197
  mrope_position_deltas = torch.tensor(mrope_position_deltas).unsqueeze(1).to(device=input_ids.device)
1198
  return position_ids, mrope_position_deltas
1199
  else:
 
 
 
 
 
 
 
1200
  if self.training:
1201
  position_ids = (
1202
  torch.arange(input_ids.shape[1] // 2, device=input_ids.device)
 
1298
  return special_image_mask, special_video_mask
1299
 
1300
 
1301
+ def eval_mask(self, seqlen, block_size, cache_seq_len, update_kv_cache=False):
1302
  q_indices = torch.arange(seqlen, device=self.device) + cache_seq_len
1303
  k_indices = torch.arange(seqlen + cache_seq_len, device=self.device)
1304
+ if update_kv_cache:
1305
  mask = eval_causal_mask(q_indices[:, None], k_indices[None, :])
1306
  else:
1307
  mask = eval_block_diff_mask(
1308
+ q_idx=q_indices[:, None],
1309
+ kv_idx=k_indices[None, :],
1310
+ block_size=block_size,
1311
  )
1312
  return mask
1313
 
 
1419
 
1420
  position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1421
  position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
 
 
1422
  if past_key_values is not None:
1423
  delta = (past_key_values.get_seq_length() + self.rope_deltas).to(inputs_embeds.device)
1424
  else:
 
1428
 
1429
  position_ids = position_ids.to(inputs_embeds.device)
1430
  if not self.training:
1431
+ attention_mask = self.eval_mask(inputs_embeds.shape[1], self.bd_size if bd_size is None else bd_size, 0 if past_key_values is None else past_key_values.get_seq_length(), update_kv_cache=update_kv_cache).to(inputs_embeds.device)
1432
 
1433
  outputs = self.language_model(
1434
  input_ids=None,
 
1501
  self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1502
  self.bd_size = config.bd_size
1503
  self.model.bd_size = self.bd_size
 
 
 
 
 
 
1504
  self.minimum_noise_level = getattr(config, 'minimum_noise_level', 0.0)
 
 
 
1505
  self.im_end_token_id = 151645 # <|im_end|> token id
1506
+
 
1507
  # Vision-to-text aligner (if vision output dim != text hidden dim)
1508
  vision_out_dim = config.vision_config.out_hidden_size
1509
  text_hidden = config.text_config.hidden_size
 
1546
  def visual(self):
1547
  return self.model.visual
1548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1549
  def compute_response_block_idx(self, labels, block_size):
1550
  """
1551
  Compute block index and turn index for each position.
 
1614
  )
1615
  return mask
1616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1617
  @can_return_tuple
1618
  @auto_docstring
1619
  def forward(
 
1656
  eval_bd_size (`int`, *optional*):
1657
  Block diffusion size to use during evaluation. Overrides the model default when set.
1658
  """
 
 
 
 
 
 
1659
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1660
  output_hidden_states = (
1661
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1662
  )
1663
  if self.training:
1664
+ # Anneal block size: pick from [4, 8, ..., target_bd_size] based on
1665
+ # training progress. update_ratio is passed by the trainer (default 1.0
1666
+ # corresponds to using the target block size).
1667
+ update_ratio = kwargs.get('update_ratio', 1.0)
1668
+ max_power = int(math.log2(self.bd_size))
1669
+ possible_bd_sizes = [2**i for i in range(2, max_power + 1)]
1670
+ scaled_ratio = math.sqrt(update_ratio)
1671
+ idx = min(int(scaled_ratio * len(possible_bd_sizes)), len(possible_bd_sizes) - 1)
1672
+ bd_size = possible_bd_sizes[idx]
1673
+
1674
+ if pixel_values is None and pixel_values_videos is None: # text-only batch
 
 
 
 
 
 
1675
 
1676
  batch_size, seq_len = input_ids.shape
1677
  original_labels = labels.clone()
 
1682
  response_mask = (labels != -100) # [B, seq_len]
1683
  eps = self.minimum_noise_level
1684
 
1685
+ response_block_idx, turn_idx, n_blocks = self.compute_response_block_idx(labels, bd_size)
1686
+
1687
+ # Per-block noise level sampled from [minimum_noise_level, 1].
1688
+ t = torch.rand((n_blocks,), device=input_ids.device)
1689
+ p_mask_per_block = (1 - eps) * t + eps
1690
+
1691
+ # Build [B, seq_len] mask: prompt tokens stay clean, response tokens
1692
+ # are masked block-wise according to p_mask_per_block.
1693
+ mask_indices = torch.zeros_like(labels, dtype=torch.bool)
1694
+ for i in range(seq_len):
1695
+ block_i = response_block_idx[i].item()
1696
+ if block_i >= 0: # response token
1697
+ mask_indices[:, i] = torch.rand((batch_size,), device=input_ids.device) < p_mask_per_block[block_i]
1698
+
1699
+ # Always mask <|im_end|> tokens that fall inside the response.
1700
+ im_end_mask = (input_ids == self.im_end_token_id) & response_mask
1701
+ mask_indices = mask_indices | im_end_mask
1702
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1703
  noisy_input_ids = input_ids.clone()
1704
  noisy_input_ids[mask_indices] = mask_id
1705
+
1706
+ # Restrict the loss to masked response tokens only.
1707
  labels = labels.clone()
1708
  labels[~mask_indices] = -100
1709
+
1710
+ # Concatenate [noisy | clean] along the sequence dimension.
1711
  input_ids = torch.cat([noisy_input_ids, original_input_ids], dim=1)
1712
 
1713
+ # Complementary pair: mask the positions that were left clean above.
1714
+ complementary_mask_indices = response_mask & ~mask_indices
1715
+ im_end_mask = (original_input_ids == self.im_end_token_id) & response_mask
1716
+ complementary_mask_indices = complementary_mask_indices | im_end_mask
1717
+
1718
+ complementary_noisy_input_ids = original_input_ids.clone()
1719
+ complementary_noisy_input_ids[complementary_mask_indices] = mask_id
1720
+
1721
+ complementary_labels = original_labels.clone()
1722
+ complementary_labels[~complementary_mask_indices] = -100
1723
+ complementary_input_ids = torch.cat([complementary_noisy_input_ids, original_input_ids], dim=1)
1724
+
1725
+ input_ids = torch.cat([input_ids, complementary_input_ids], dim=0)
1726
+ labels = torch.cat([labels, complementary_labels], dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
1727
 
1728
+ attention_mask = self.gen_hybrid_block_causal_mask(seq_len, response_block_idx, turn_idx, input_ids.shape[0], self.config.num_attention_heads)
1729
+
1730
+ else:
1731
+ # Multimodal block diffusion path.
1732
+ # Phase A: embed input_ids and scatter vision features into placeholder positions.
1733
  if inputs_embeds is None:
1734
  inputs_embeds = self.model.get_input_embeddings()(input_ids)
1735
+
1736
  if pixel_values is not None:
1737
  image_embeds = self.model.get_image_features(pixel_values, image_grid_thw)
1738
  image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
 
1742
  input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
1743
  )
1744
  inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1745
+
1746
  if pixel_values_videos is not None:
1747
  video_embeds = self.model.get_video_features(pixel_values_videos, video_grid_thw)
1748
  video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
 
1752
  input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
1753
  )
1754
  inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
1755
+
1756
+ # Phase B: build 3D position_ids on the original (pre-doubled) length.
1757
  if position_ids is None:
1758
  position_ids, rope_deltas = self.model.get_rope_index(
1759
  input_ids=input_ids,
 
1762
  second_per_grid_ts=second_per_grid_ts,
1763
  attention_mask=attention_mask,
1764
  )
1765
+
1766
+ # Phase C: block diffusion that preserves vision token positions.
1767
  batch_size = input_ids.shape[0]
1768
  L = input_ids.shape[1]
1769
  seq_len = L
 
 
 
 
 
 
 
 
 
1770
  hidden_size = inputs_embeds.shape[-1]
1771
+
1772
  original_labels = labels.clone()
1773
  original_input_ids = input_ids.clone()
1774
  original_embeds = inputs_embeds.clone()
1775
+ original_position_ids = position_ids.clone()
1776
+
1777
+ # Identify vision tokens so noise is never applied to them.
1778
  image_token_id = self.config.image_token_id
1779
  video_token_id = self.config.video_token_id
1780
  vision_start_token_id = self.config.vision_start_token_id
1781
  vision_token_mask = (input_ids == image_token_id) | (input_ids == video_token_id) | (input_ids == vision_start_token_id)
1782
  vision_mask_3d = vision_token_mask.unsqueeze(-1).expand(-1, -1, hidden_size)
1783
+
1784
+ # Block diffusion with multi-turn support: each response segment has its own blocks.
 
1785
  response_block_idx, turn_idx, n_blocks = self.compute_response_block_idx(labels, bd_size)
1786
+ response_mask = (labels != -100)
 
 
1787
  eps = self.minimum_noise_level
1788
 
1789
+ t = torch.rand((n_blocks,), device=input_ids.device)
1790
+ p_mask_per_block = (1 - eps) * t + eps
1791
+
1792
+ mask_indices = torch.zeros_like(labels, dtype=torch.bool)
1793
+ for i in range(seq_len):
1794
+ block_i = response_block_idx[i].item()
1795
+ if block_i >= 0:
1796
+ mask_indices[:, i] = torch.rand((batch_size,), device=input_ids.device) < p_mask_per_block[block_i]
1797
+
1798
+ im_end_mask = (input_ids == self.im_end_token_id) & response_mask
1799
+ mask_indices = mask_indices | im_end_mask
1800
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1801
  noisy_input_ids = input_ids.clone()
1802
  noisy_input_ids[mask_indices] = mask_id
1803
+
1804
+ # Build noisy embeddings while keeping vision embeddings intact.
1805
+ noisy_embeds_raw = self.model.language_model.embed_tokens(noisy_input_ids)
1806
+ noisy_embeds = torch.where(vision_mask_3d, original_embeds, noisy_embeds_raw)
1807
+
 
 
 
 
 
 
 
 
 
1808
  labels_noisy = labels.clone()
1809
  labels_noisy[~mask_indices] = -100
1810
+
1811
+ # Concatenate [noisy | clean] along the sequence dimension.
1812
  input_ids_pair1 = torch.cat([noisy_input_ids, original_input_ids], dim=1)
1813
  embeds_pair1 = torch.cat([noisy_embeds, original_embeds], dim=1)
1814
  labels_pair1 = labels_noisy
1815
+ position_ids_pair1 = original_position_ids
1816
 
1817
+ # Complementary pair: mask the positions that were left clean above.
1818
+ complementary_mask_indices = response_mask & ~mask_indices
1819
+ im_end_mask = (original_input_ids == self.im_end_token_id) & response_mask
1820
+ complementary_mask_indices = complementary_mask_indices | im_end_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1821
 
1822
+ complementary_noisy_input_ids = original_input_ids.clone()
1823
+ complementary_noisy_input_ids[complementary_mask_indices] = mask_id
1824
 
1825
+ complementary_noisy_embeds_raw = self.model.language_model.embed_tokens(complementary_noisy_input_ids)
1826
+ complementary_noisy_embeds = torch.where(vision_mask_3d, original_embeds, complementary_noisy_embeds_raw)
1827
+
1828
+ complementary_labels = original_labels.clone()
1829
+ complementary_labels[~complementary_mask_indices] = -100
1830
+
1831
+ input_ids_pair2 = torch.cat([complementary_noisy_input_ids, original_input_ids], dim=1)
1832
+ embeds_pair2 = torch.cat([complementary_noisy_embeds, original_embeds], dim=1)
1833
+ labels_pair2 = complementary_labels
1834
+ position_ids_pair2 = original_position_ids
1835
+
1836
+ # Stack the complementary pair along the batch dimension.
1837
+ input_ids = torch.cat([input_ids_pair1, input_ids_pair2], dim=0)
1838
+ inputs_embeds = torch.cat([embeds_pair1, embeds_pair2], dim=0)
1839
+ labels = torch.cat([labels_pair1, labels_pair2], dim=0)
1840
+ position_ids = torch.cat([position_ids_pair1, position_ids_pair2], dim=1)
1841
+
1842
+ attention_mask = self.gen_hybrid_block_causal_mask(L, response_block_idx, turn_idx, input_ids.shape[0], self.config.num_attention_heads)
1843
+
1844
+ # Phase D: forward through the inner model. Vision features (if any)
1845
+ # have already been scattered into inputs_embeds, so pixel_values are
1846
+ # cleared to skip re-processing inside `Fast_dVLMModel`.
1847
+ outputs = self.model(
1848
+ input_ids=input_ids,
1849
+ pixel_values=None,
1850
+ pixel_values_videos=None,
1851
+ image_grid_thw=None,
1852
+ video_grid_thw=None,
1853
+ position_ids=position_ids,
1854
+ attention_mask=attention_mask,
1855
+ past_key_values=past_key_values,
1856
+ inputs_embeds=inputs_embeds,
1857
+ use_cache=use_cache,
1858
+ output_attentions=output_attentions,
1859
+ output_hidden_states=output_hidden_states,
1860
+ return_dict=True,
1861
+ cache_position=cache_position,
1862
+ update_kv_cache=update_kv_cache,
1863
+ bd_size=bd_size,
1864
+ **kwargs,
1865
+ )
1866
 
1867
  else:
1868
  outputs = self.model(
 
1894
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1895
  logits = self.lm_head(mdm_hidden_states[:, slice_indices, :])
1896
 
1897
+ new_kwargs = {
1898
+ 'num_items_in_batch': 2 * kwargs['num_items_in_batch'],
1899
+ }
 
 
 
1900
  if labels is not None:
1901
  loss = self.loss_function(
1902
  logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs
1903
  ) * 0.5
1904
+ causal_hidden_states = hidden_states[:hidden_states.shape[0]//2, hidden_states.shape[1]//2:, :]
1905
+ causal_logits = self.lm_head(causal_hidden_states[:, slice_indices, :])
1906
+ loss += self.loss_function(
1907
+ logits=causal_logits, labels=original_labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs
1908
+ )
 
 
 
 
 
 
 
 
 
 
1909
  else:
1910
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1911
  logits = self.lm_head(hidden_states[:, slice_indices, :])