Clean up inference config: remove training-only flags, set bd_size=32 default, dtype=bfloat16
Browse files- config.json +1 -16
- configuration.py +1 -13
- modeling.py +162 -474
config.json
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
{
|
| 2 |
-
"always_mask_im_end": true,
|
| 3 |
-
"anneal_block_size": true,
|
| 4 |
"architectures": [
|
| 5 |
"Fast_dVLMForConditionalGeneration"
|
| 6 |
],
|
|
@@ -11,14 +9,8 @@
|
|
| 11 |
"AutoModelForCausalLM": "modeling.Fast_dVLMForConditionalGeneration"
|
| 12 |
},
|
| 13 |
"bd_size": 32,
|
| 14 |
-
"block_causal_no_dynamic": false,
|
| 15 |
-
"complementary_mask": true,
|
| 16 |
"dtype": "bfloat16",
|
| 17 |
-
"enable_efficient_vision_embed": false,
|
| 18 |
-
"entropy_loss": false,
|
| 19 |
-
"entropy_loss_weight": 1.0,
|
| 20 |
"eos_token_id": 151645,
|
| 21 |
-
"flexible_bd_size": false,
|
| 22 |
"hidden_act": "silu",
|
| 23 |
"hidden_size": 2048,
|
| 24 |
"image_token_id": 151655,
|
|
@@ -55,12 +47,7 @@
|
|
| 55 |
"AutoModelForCausalLM": "modeling.Fast_dVLMForConditionalGeneration"
|
| 56 |
},
|
| 57 |
"bd_size": 8,
|
| 58 |
-
"block_causal_no_dynamic": false,
|
| 59 |
"bos_token_id": 151643,
|
| 60 |
-
"complementary_mask": true,
|
| 61 |
-
"dtype": "float32",
|
| 62 |
-
"entropy_loss": false,
|
| 63 |
-
"entropy_loss_weight": 1.0,
|
| 64 |
"eos_token_id": 151645,
|
| 65 |
"hidden_act": "silu",
|
| 66 |
"hidden_size": 2048,
|
|
@@ -125,7 +112,6 @@
|
|
| 125 |
"rope_theta": 1000000.0,
|
| 126 |
"sliding_window": null,
|
| 127 |
"tie_word_embeddings": true,
|
| 128 |
-
"use_block_causal_mask": false,
|
| 129 |
"use_cache": true,
|
| 130 |
"use_sliding_window": false,
|
| 131 |
"video_token_id": null,
|
|
@@ -135,13 +121,12 @@
|
|
| 135 |
"vocab_size": 151936
|
| 136 |
},
|
| 137 |
"transformers_version": "4.57.1",
|
| 138 |
-
"use_block_causal_mask": true,
|
| 139 |
"use_cache": true,
|
| 140 |
"use_sliding_window": false,
|
| 141 |
"video_token_id": 151656,
|
| 142 |
"vision_config": {
|
| 143 |
"depth": 32,
|
| 144 |
-
"dtype": "
|
| 145 |
"fullatt_block_indexes": [
|
| 146 |
7,
|
| 147 |
15,
|
|
|
|
| 1 |
{
|
|
|
|
|
|
|
| 2 |
"architectures": [
|
| 3 |
"Fast_dVLMForConditionalGeneration"
|
| 4 |
],
|
|
|
|
| 9 |
"AutoModelForCausalLM": "modeling.Fast_dVLMForConditionalGeneration"
|
| 10 |
},
|
| 11 |
"bd_size": 32,
|
|
|
|
|
|
|
| 12 |
"dtype": "bfloat16",
|
|
|
|
|
|
|
|
|
|
| 13 |
"eos_token_id": 151645,
|
|
|
|
| 14 |
"hidden_act": "silu",
|
| 15 |
"hidden_size": 2048,
|
| 16 |
"image_token_id": 151655,
|
|
|
|
| 47 |
"AutoModelForCausalLM": "modeling.Fast_dVLMForConditionalGeneration"
|
| 48 |
},
|
| 49 |
"bd_size": 8,
|
|
|
|
| 50 |
"bos_token_id": 151643,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
"eos_token_id": 151645,
|
| 52 |
"hidden_act": "silu",
|
| 53 |
"hidden_size": 2048,
|
|
|
|
| 112 |
"rope_theta": 1000000.0,
|
| 113 |
"sliding_window": null,
|
| 114 |
"tie_word_embeddings": true,
|
|
|
|
| 115 |
"use_cache": true,
|
| 116 |
"use_sliding_window": false,
|
| 117 |
"video_token_id": null,
|
|
|
|
| 121 |
"vocab_size": 151936
|
| 122 |
},
|
| 123 |
"transformers_version": "4.57.1",
|
|
|
|
| 124 |
"use_cache": true,
|
| 125 |
"use_sliding_window": false,
|
| 126 |
"video_token_id": 151656,
|
| 127 |
"vision_config": {
|
| 128 |
"depth": 32,
|
| 129 |
+
"dtype": "bfloat16",
|
| 130 |
"fullatt_block_indexes": [
|
| 131 |
7,
|
| 132 |
15,
|
configuration.py
CHANGED
|
@@ -87,15 +87,10 @@ class Fast_dVLMTextConfig(PretrainedConfig):
|
|
| 87 |
rope_scaling=None,
|
| 88 |
image_token_id=None,
|
| 89 |
video_token_id=None,
|
| 90 |
-
bd_size=
|
| 91 |
self_spec_inference_mode=None,
|
| 92 |
block_length=None,
|
| 93 |
-
use_block_causal_mask=False,
|
| 94 |
-
complementary_mask=True,
|
| 95 |
minimum_noise_level=1e-3,
|
| 96 |
-
entropy_loss=False,
|
| 97 |
-
entropy_loss_weight=1.0,
|
| 98 |
-
block_causal_no_dynamic=False,
|
| 99 |
**kwargs,
|
| 100 |
):
|
| 101 |
self.vocab_size = vocab_size
|
|
@@ -122,12 +117,7 @@ class Fast_dVLMTextConfig(PretrainedConfig):
|
|
| 122 |
self.rope_scaling = rope_scaling
|
| 123 |
self.bd_size = bd_size
|
| 124 |
self.layer_types = layer_types
|
| 125 |
-
self.use_block_causal_mask = use_block_causal_mask
|
| 126 |
-
self.complementary_mask = complementary_mask
|
| 127 |
self.minimum_noise_level = minimum_noise_level
|
| 128 |
-
self.entropy_loss = entropy_loss
|
| 129 |
-
self.entropy_loss_weight = entropy_loss_weight
|
| 130 |
-
self.block_causal_no_dynamic = block_causal_no_dynamic
|
| 131 |
self.self_spec_inference_mode = self_spec_inference_mode
|
| 132 |
self.block_length = block_length
|
| 133 |
if self.layer_types is None:
|
|
@@ -166,7 +156,6 @@ class Fast_dVLMConfig(PretrainedConfig):
|
|
| 166 |
vision_config=None,
|
| 167 |
image_token_id=151655,
|
| 168 |
video_token_id=151656,
|
| 169 |
-
enable_efficient_vision_embed=False,
|
| 170 |
**kwargs,
|
| 171 |
):
|
| 172 |
if isinstance(vision_config, dict):
|
|
@@ -182,7 +171,6 @@ class Fast_dVLMConfig(PretrainedConfig):
|
|
| 182 |
|
| 183 |
self.image_token_id = image_token_id
|
| 184 |
self.video_token_id = video_token_id
|
| 185 |
-
self.enable_efficient_vision_embed = enable_efficient_vision_embed
|
| 186 |
|
| 187 |
super().__init__(**kwargs)
|
| 188 |
|
|
|
|
| 87 |
rope_scaling=None,
|
| 88 |
image_token_id=None,
|
| 89 |
video_token_id=None,
|
| 90 |
+
bd_size=32,
|
| 91 |
self_spec_inference_mode=None,
|
| 92 |
block_length=None,
|
|
|
|
|
|
|
| 93 |
minimum_noise_level=1e-3,
|
|
|
|
|
|
|
|
|
|
| 94 |
**kwargs,
|
| 95 |
):
|
| 96 |
self.vocab_size = vocab_size
|
|
|
|
| 117 |
self.rope_scaling = rope_scaling
|
| 118 |
self.bd_size = bd_size
|
| 119 |
self.layer_types = layer_types
|
|
|
|
|
|
|
| 120 |
self.minimum_noise_level = minimum_noise_level
|
|
|
|
|
|
|
|
|
|
| 121 |
self.self_spec_inference_mode = self_spec_inference_mode
|
| 122 |
self.block_length = block_length
|
| 123 |
if self.layer_types is None:
|
|
|
|
| 156 |
vision_config=None,
|
| 157 |
image_token_id=151655,
|
| 158 |
video_token_id=151656,
|
|
|
|
| 159 |
**kwargs,
|
| 160 |
):
|
| 161 |
if isinstance(vision_config, dict):
|
|
|
|
| 171 |
|
| 172 |
self.image_token_id = image_token_id
|
| 173 |
self.video_token_id = video_token_id
|
|
|
|
| 174 |
|
| 175 |
super().__init__(**kwargs)
|
| 176 |
|
modeling.py
CHANGED
|
@@ -23,94 +23,14 @@ from .configuration import Fast_dVLMConfig, Fast_dVLMTextConfig, Fast_dVLMVision
|
|
| 23 |
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
| 24 |
|
| 25 |
from functools import partial
|
| 26 |
-
import random
|
| 27 |
import math
|
| 28 |
|
| 29 |
logger = logging.get_logger(__name__)
|
| 30 |
|
| 31 |
|
| 32 |
-
# @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
|
| 33 |
-
# @torch.compile()
|
| 34 |
def fused_flex_attention(q, k, v, mask=None):
|
| 35 |
return flex_attention(q, k, v, block_mask=mask, enable_gqa=True)
|
| 36 |
|
| 37 |
-
def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
|
| 38 |
-
"""
|
| 39 |
-
Constructs the specialized block diffusion attention mask for training
|
| 40 |
-
composed of three masks:
|
| 41 |
-
- **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
|
| 42 |
-
- **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
|
| 43 |
-
- **Block Causal Mask (M_BC)**: Attention to update x0
|
| 44 |
-
|
| 45 |
-
Args:
|
| 46 |
-
b, h: Batch and head indices (ignored for mask logic).
|
| 47 |
-
q_idx, kv_idx: Query and Key indices.
|
| 48 |
-
seq_len: Total sequence length.
|
| 49 |
-
block_size: Defines the block structure.
|
| 50 |
-
|
| 51 |
-
Returns:
|
| 52 |
-
A boolean attention mask.
|
| 53 |
-
"""
|
| 54 |
-
# Indicate whether token belongs to xt or x0
|
| 55 |
-
x0_flag_q = (q_idx >= n)
|
| 56 |
-
x0_flag_kv = (kv_idx >= n)
|
| 57 |
-
|
| 58 |
-
# Compute block indices
|
| 59 |
-
block_q = torch.where(x0_flag_q == 1,
|
| 60 |
-
(q_idx - n) // block_size,
|
| 61 |
-
q_idx // block_size)
|
| 62 |
-
block_kv = torch.where(x0_flag_kv == 1,
|
| 63 |
-
(kv_idx - n) // block_size,
|
| 64 |
-
kv_idx // block_size)
|
| 65 |
-
|
| 66 |
-
# **1. Block Diagonal Mask (M_BD) **
|
| 67 |
-
block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
|
| 68 |
-
|
| 69 |
-
# **2. Offset Block-Causal Mask (M_OBC) **
|
| 70 |
-
offset_block_causal = (
|
| 71 |
-
(block_q > block_kv)
|
| 72 |
-
& (x0_flag_kv == 1)
|
| 73 |
-
& (x0_flag_q == 0)
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
# **3. Block-Causal Mask (M_BC) **
|
| 77 |
-
block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
|
| 78 |
-
|
| 79 |
-
# **4. Combine Masks **
|
| 80 |
-
return block_diagonal | offset_block_causal | block_causal
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def block_causal_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
|
| 84 |
-
|
| 85 |
-
# Indicate whether token belongs to xt or x0
|
| 86 |
-
x0_flag_q = (q_idx >= n)
|
| 87 |
-
x0_flag_kv = (kv_idx >= n)
|
| 88 |
-
|
| 89 |
-
# Compute block indices
|
| 90 |
-
block_q = torch.where(x0_flag_q == 1,
|
| 91 |
-
(q_idx - n) // block_size,
|
| 92 |
-
q_idx // block_size)
|
| 93 |
-
block_kv = torch.where(x0_flag_kv == 1,
|
| 94 |
-
(kv_idx - n) // block_size,
|
| 95 |
-
kv_idx // block_size)
|
| 96 |
-
|
| 97 |
-
# **1. Block Diagonal Mask (M_BD) **
|
| 98 |
-
block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
|
| 99 |
-
|
| 100 |
-
# **2. Offset Block-Causal Mask (M_OBC) **
|
| 101 |
-
offset_block_causal = (
|
| 102 |
-
(block_q > block_kv)
|
| 103 |
-
& (x0_flag_kv == 1)
|
| 104 |
-
& (x0_flag_q == 0)
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
# **3. Block-Causal Mask (M_BC) **
|
| 108 |
-
block_causal = (q_idx >= kv_idx) & (x0_flag_kv == 1) & (x0_flag_q == 1)
|
| 109 |
-
|
| 110 |
-
# **4. Combine Masks **
|
| 111 |
-
return block_diagonal | offset_block_causal | block_causal
|
| 112 |
-
|
| 113 |
-
|
| 114 |
def hybrid_block_causal_mask_multiturn(b, h, q_idx, kv_idx, response_block_idx=None, turn_idx=None, n=None):
|
| 115 |
"""
|
| 116 |
Multi-turn hybrid mask: Prompt uses causal, Response uses block causal.
|
|
@@ -145,29 +65,20 @@ def hybrid_block_causal_mask_multiturn(b, h, q_idx, kv_idx, response_block_idx=N
|
|
| 145 |
|
| 146 |
is_prompt_q = (block_q < 0)
|
| 147 |
is_prompt_kv = (block_kv < 0)
|
| 148 |
-
|
| 149 |
-
# x_t region
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
#
|
| 153 |
-
# xt_same_turn_prompt_causal = ~x0_flag_q & ~x0_flag_kv & (turn_q == turn_kv) & is_prompt_q & (pos_q >= pos_kv)
|
| 154 |
-
# xt_same_turn_response = ~x0_flag_q & ~x0_flag_kv & (turn_q == turn_kv) & ~is_prompt_q & (
|
| 155 |
-
# ~is_prompt_kv
|
| 156 |
-
# )
|
| 157 |
-
block_diagonal = ~x0_flag_q & ~x0_flag_kv & (turn_q == turn_kv)
|
| 158 |
-
|
| 159 |
-
# **2. Offset Block-Causal Mask (M_OBC) **
|
| 160 |
offset_block_causal = (
|
| 161 |
-
(turn_q > turn_kv)
|
| 162 |
& (x0_flag_kv == 1)
|
| 163 |
& (x0_flag_q == 0)
|
| 164 |
)
|
| 165 |
-
# x_0 region
|
| 166 |
x0_causal = x0_flag_q & x0_flag_kv & (pos_q >= pos_kv)
|
| 167 |
-
|
| 168 |
-
return
|
| 169 |
-
offset_block_causal |
|
| 170 |
-
x0_causal)
|
| 171 |
|
| 172 |
|
| 173 |
def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
|
|
@@ -820,7 +731,6 @@ class Fast_dVLMAttention(nn.Module):
|
|
| 820 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
| 821 |
if update_kv_cache:
|
| 822 |
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 823 |
-
# elif len(past_key_values) > self.layer_idx:
|
| 824 |
elif len(past_key_values) > self.layer_idx and past_key_values[self.layer_idx][0] is not None:
|
| 825 |
key_states = torch.cat((past_key_values[self.layer_idx][0], key_states), dim=-2)
|
| 826 |
value_states = torch.cat((past_key_values[self.layer_idx][1], value_states), dim=-2)
|
|
@@ -964,7 +874,12 @@ class Fast_dVLMTextModel(Fast_dVLMPreTrainedModel):
|
|
| 964 |
# Initialize weights and apply final processing
|
| 965 |
self.post_init()
|
| 966 |
|
| 967 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 968 |
@auto_docstring
|
| 969 |
def forward(
|
| 970 |
self,
|
|
@@ -1036,34 +951,11 @@ class Fast_dVLMTextModel(Fast_dVLMPreTrainedModel):
|
|
| 1036 |
text_position_ids = position_ids[0]
|
| 1037 |
position_ids = position_ids[1:]
|
| 1038 |
else:
|
| 1039 |
-
# If inputs are not packed (usual 3D positions), do not prepare mask from position_ids
|
| 1040 |
text_position_ids = None
|
| 1041 |
|
| 1042 |
-
# It may already have been prepared by e.g. `generate`
|
| 1043 |
-
# if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 1044 |
-
# # Prepare mask arguments
|
| 1045 |
-
# mask_kwargs = {
|
| 1046 |
-
# "config": self.config,
|
| 1047 |
-
# "input_embeds": inputs_embeds,
|
| 1048 |
-
# "attention_mask": attention_mask,
|
| 1049 |
-
# "cache_position": cache_position,
|
| 1050 |
-
# "past_key_values": past_key_values,
|
| 1051 |
-
# "position_ids": text_position_ids,
|
| 1052 |
-
# }
|
| 1053 |
-
# # Create the masks
|
| 1054 |
-
# causal_mask_mapping = {
|
| 1055 |
-
# "full_attention": create_causal_mask(**mask_kwargs),
|
| 1056 |
-
# }
|
| 1057 |
-
# # The sliding window alternating layers are not always activated depending on the config
|
| 1058 |
-
# if self.has_sliding_layers:
|
| 1059 |
-
# causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
| 1060 |
-
|
| 1061 |
hidden_states = inputs_embeds
|
| 1062 |
-
|
| 1063 |
-
# create position embeddings to be shared across the decoder layers
|
| 1064 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 1065 |
|
| 1066 |
-
# decoder layers
|
| 1067 |
all_hidden_states = () if output_hidden_states else None
|
| 1068 |
all_self_attns = () if output_attentions else None
|
| 1069 |
|
|
@@ -1091,7 +983,6 @@ class Fast_dVLMTextModel(Fast_dVLMPreTrainedModel):
|
|
| 1091 |
|
| 1092 |
hidden_states = self.norm(hidden_states)
|
| 1093 |
|
| 1094 |
-
# add hidden states from the last decoder layer
|
| 1095 |
if output_hidden_states:
|
| 1096 |
all_hidden_states += (hidden_states,)
|
| 1097 |
|
|
@@ -1121,7 +1012,6 @@ class Fast_dVLMModel(Fast_dVLMPreTrainedModel):
|
|
| 1121 |
self.visual = Fast_dVLMVisionTransformerPretrainedModel._from_config(config.vision_config)
|
| 1122 |
self.language_model = Fast_dVLMTextModel._from_config(config.text_config)
|
| 1123 |
self.rope_deltas = None # cache rope_deltas here
|
| 1124 |
-
self.use_block_causal_mask = config.use_block_causal_mask
|
| 1125 |
|
| 1126 |
# Initialize weights and apply final processing
|
| 1127 |
self.post_init()
|
|
@@ -1307,13 +1197,6 @@ class Fast_dVLMModel(Fast_dVLMPreTrainedModel):
|
|
| 1307 |
mrope_position_deltas = torch.tensor(mrope_position_deltas).unsqueeze(1).to(device=input_ids.device)
|
| 1308 |
return position_ids, mrope_position_deltas
|
| 1309 |
else:
|
| 1310 |
-
# if attention_mask is not None:
|
| 1311 |
-
# position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1312 |
-
# position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1313 |
-
# position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
| 1314 |
-
# max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
| 1315 |
-
# mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
| 1316 |
-
# else:
|
| 1317 |
if self.training:
|
| 1318 |
position_ids = (
|
| 1319 |
torch.arange(input_ids.shape[1] // 2, device=input_ids.device)
|
|
@@ -1415,16 +1298,16 @@ class Fast_dVLMModel(Fast_dVLMPreTrainedModel):
|
|
| 1415 |
return special_image_mask, special_video_mask
|
| 1416 |
|
| 1417 |
|
| 1418 |
-
def eval_mask(self, seqlen, block_size, cache_seq_len, update_kv_cache=False
|
| 1419 |
q_indices = torch.arange(seqlen, device=self.device) + cache_seq_len
|
| 1420 |
k_indices = torch.arange(seqlen + cache_seq_len, device=self.device)
|
| 1421 |
-
if
|
| 1422 |
mask = eval_causal_mask(q_indices[:, None], k_indices[None, :])
|
| 1423 |
else:
|
| 1424 |
mask = eval_block_diff_mask(
|
| 1425 |
-
q_idx=q_indices[:, None],
|
| 1426 |
-
kv_idx=k_indices[None, :],
|
| 1427 |
-
block_size=block_size
|
| 1428 |
)
|
| 1429 |
return mask
|
| 1430 |
|
|
@@ -1536,8 +1419,6 @@ class Fast_dVLMModel(Fast_dVLMPreTrainedModel):
|
|
| 1536 |
|
| 1537 |
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
| 1538 |
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
| 1539 |
-
# if cache_position is not None:
|
| 1540 |
-
# delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
| 1541 |
if past_key_values is not None:
|
| 1542 |
delta = (past_key_values.get_seq_length() + self.rope_deltas).to(inputs_embeds.device)
|
| 1543 |
else:
|
|
@@ -1547,7 +1428,7 @@ class Fast_dVLMModel(Fast_dVLMPreTrainedModel):
|
|
| 1547 |
|
| 1548 |
position_ids = position_ids.to(inputs_embeds.device)
|
| 1549 |
if not self.training:
|
| 1550 |
-
attention_mask = self.eval_mask(inputs_embeds.shape[1], self.bd_size if bd_size is None else bd_size, 0 if past_key_values is None else past_key_values.get_seq_length(), update_kv_cache=update_kv_cache
|
| 1551 |
|
| 1552 |
outputs = self.language_model(
|
| 1553 |
input_ids=None,
|
|
@@ -1620,19 +1501,9 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1620 |
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
| 1621 |
self.bd_size = config.bd_size
|
| 1622 |
self.model.bd_size = self.bd_size
|
| 1623 |
-
self.complementary_mask = getattr(config, 'complementary_mask', False)
|
| 1624 |
-
self.always_mask_im_end = getattr(config, 'always_mask_im_end', False)
|
| 1625 |
-
self.flexible_bd_size = getattr(config, 'flexible_bd_size', False)
|
| 1626 |
-
self.use_block_causal_mask = getattr(config, 'use_block_causal_mask', False)
|
| 1627 |
-
self.anneal_block_size = getattr(config, 'anneal_block_size', False)
|
| 1628 |
-
self.enable_efficient_vision_embed = getattr(config, 'enable_efficient_vision_embed', False)
|
| 1629 |
self.minimum_noise_level = getattr(config, 'minimum_noise_level', 0.0)
|
| 1630 |
-
self.entropy_loss = getattr(config, 'entropy_loss', False)
|
| 1631 |
-
self.entropy_loss_weight = getattr(config, 'entropy_loss_weight', 1.0)
|
| 1632 |
-
self.block_causal_no_dynamic = getattr(config, 'block_causal_no_dynamic', False)
|
| 1633 |
self.im_end_token_id = 151645 # <|im_end|> token id
|
| 1634 |
-
|
| 1635 |
-
|
| 1636 |
# Vision-to-text aligner (if vision output dim != text hidden dim)
|
| 1637 |
vision_out_dim = config.vision_config.out_hidden_size
|
| 1638 |
text_hidden = config.text_config.hidden_size
|
|
@@ -1675,30 +1546,6 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1675 |
def visual(self):
|
| 1676 |
return self.model.visual
|
| 1677 |
|
| 1678 |
-
def gen_mask(self, seqlen, block_size, B, H):
|
| 1679 |
-
# ================== 修改开始 ==================
|
| 1680 |
-
# flex_attention 要求闭包捕获的变量必须是 Tensor
|
| 1681 |
-
# 将 int 转换为 Tensor,并放在对应的设备上
|
| 1682 |
-
block_size_t = torch.tensor(block_size, device=self.device, dtype=torch.int32)
|
| 1683 |
-
n_t = torch.tensor(seqlen, device=self.device, dtype=torch.int32)
|
| 1684 |
-
|
| 1685 |
-
mask = create_block_mask(
|
| 1686 |
-
# 这里将原来的 block_size=block_size 改为传入 Tensor
|
| 1687 |
-
partial(block_diff_mask, block_size=block_size_t, n=n_t),
|
| 1688 |
-
B=B, H=H, Q_LEN=seqlen*2, KV_LEN=seqlen*2
|
| 1689 |
-
)
|
| 1690 |
-
# ================== 修改结束 ==================
|
| 1691 |
-
return mask
|
| 1692 |
-
|
| 1693 |
-
def gen_block_causal_mask(self, seqlen, block_size, B, H):
|
| 1694 |
-
block_size_t = torch.tensor(block_size, device=self.device, dtype=torch.int32)
|
| 1695 |
-
n_t = torch.tensor(seqlen, device=self.device, dtype=torch.int32)
|
| 1696 |
-
mask = create_block_mask(
|
| 1697 |
-
partial(block_causal_mask, block_size=block_size_t, n=n_t),
|
| 1698 |
-
B=B, H=H, Q_LEN=seqlen*2, KV_LEN=seqlen*2
|
| 1699 |
-
)
|
| 1700 |
-
return mask
|
| 1701 |
-
|
| 1702 |
def compute_response_block_idx(self, labels, block_size):
|
| 1703 |
"""
|
| 1704 |
Compute block index and turn index for each position.
|
|
@@ -1767,36 +1614,6 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1767 |
)
|
| 1768 |
return mask
|
| 1769 |
|
| 1770 |
-
def compute_entropy_loss(self, logits, labels, num_items_in_batch=None):
|
| 1771 |
-
"""Compute entropy loss with optional global normalization.
|
| 1772 |
-
|
| 1773 |
-
Args:
|
| 1774 |
-
logits: Model logits
|
| 1775 |
-
labels: Ground truth labels (-100 for ignored tokens)
|
| 1776 |
-
num_items_in_batch: Global number of non-ignored tokens for normalization.
|
| 1777 |
-
If provided, uses sum/num_items_in_batch for global norm.
|
| 1778 |
-
If None, uses mean() for micro-batch norm.
|
| 1779 |
-
"""
|
| 1780 |
-
non_ignore_mask = labels != -100
|
| 1781 |
-
logits = logits[non_ignore_mask]
|
| 1782 |
-
labels = labels[non_ignore_mask]
|
| 1783 |
-
correct_mask = logits.argmax(dim=-1) == labels
|
| 1784 |
-
|
| 1785 |
-
compute_logits = logits[correct_mask]
|
| 1786 |
-
|
| 1787 |
-
if correct_mask.sum() == 0:
|
| 1788 |
-
return torch.tensor(0.0, device=logits.device)
|
| 1789 |
-
|
| 1790 |
-
p = F.softmax(compute_logits, dim=-1)
|
| 1791 |
-
log_p = F.log_softmax(compute_logits, dim=-1)
|
| 1792 |
-
entropy = -torch.sum(p * log_p, dim=-1)
|
| 1793 |
-
|
| 1794 |
-
if num_items_in_batch is not None:
|
| 1795 |
-
# Global normalization: use same denominator as cross entropy loss
|
| 1796 |
-
return entropy.sum() / num_items_in_batch
|
| 1797 |
-
else:
|
| 1798 |
-
return entropy.mean()
|
| 1799 |
-
|
| 1800 |
@can_return_tuple
|
| 1801 |
@auto_docstring
|
| 1802 |
def forward(
|
|
@@ -1839,34 +1656,22 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1839 |
eval_bd_size (`int`, *optional*):
|
| 1840 |
Block diffusion size to use during evaluation. Overrides the model default when set.
|
| 1841 |
"""
|
| 1842 |
-
# input_ids = torch.tensor([[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]]).to(input_ids.device, dtype=input_ids.dtype)
|
| 1843 |
-
# labels = torch.tensor([[-100,-100,3,4,5,6,-100,-100,-100,-100,11,12,13,14,15]]).to(labels.device, dtype=labels.dtype)
|
| 1844 |
-
# pixel_values = None
|
| 1845 |
-
# pixel_values_videos = None
|
| 1846 |
-
# self.bd_size = 2
|
| 1847 |
-
|
| 1848 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1849 |
output_hidden_states = (
|
| 1850 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1851 |
)
|
| 1852 |
if self.training:
|
| 1853 |
-
|
| 1854 |
-
|
| 1855 |
-
|
| 1856 |
-
|
| 1857 |
-
|
| 1858 |
-
|
| 1859 |
-
|
| 1860 |
-
|
| 1861 |
-
|
| 1862 |
-
|
| 1863 |
-
|
| 1864 |
-
max_power = int(math.log2(self.bd_size))
|
| 1865 |
-
possible_bd_sizes = [2**i for i in range(max_power + 1)]
|
| 1866 |
-
bd_size = random.choice(possible_bd_sizes)
|
| 1867 |
-
else:
|
| 1868 |
-
bd_size = self.bd_size
|
| 1869 |
-
if pixel_values is None and pixel_values_videos is None: # only train on text
|
| 1870 |
|
| 1871 |
batch_size, seq_len = input_ids.shape
|
| 1872 |
original_labels = labels.clone()
|
|
@@ -1877,79 +1682,57 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1877 |
response_mask = (labels != -100) # [B, seq_len]
|
| 1878 |
eps = self.minimum_noise_level
|
| 1879 |
|
| 1880 |
-
|
| 1881 |
-
|
| 1882 |
-
|
| 1883 |
-
|
| 1884 |
-
|
| 1885 |
-
|
| 1886 |
-
|
| 1887 |
-
|
| 1888 |
-
|
| 1889 |
-
|
| 1890 |
-
|
| 1891 |
-
|
| 1892 |
-
|
| 1893 |
-
|
| 1894 |
-
|
| 1895 |
-
|
| 1896 |
-
|
| 1897 |
-
|
| 1898 |
-
t = torch.rand((b,), device=input_ids.device)
|
| 1899 |
-
p_mask = (1 - eps) * t + eps
|
| 1900 |
-
p_mask = p_mask[:, None].repeat(1, l)
|
| 1901 |
-
|
| 1902 |
-
mask_indices = torch.rand((b, l), device=input_ids.device) < p_mask
|
| 1903 |
-
mask_indices = mask_indices.reshape(labels.shape) & response_mask
|
| 1904 |
-
input_ids = input_ids.reshape(labels.shape)
|
| 1905 |
-
|
| 1906 |
-
# Always mask <|im_end|> in response
|
| 1907 |
-
if self.always_mask_im_end:
|
| 1908 |
-
im_end_mask = (input_ids == self.im_end_token_id) & response_mask
|
| 1909 |
-
mask_indices = mask_indices | im_end_mask
|
| 1910 |
-
|
| 1911 |
-
# Apply mask only to response
|
| 1912 |
noisy_input_ids = input_ids.clone()
|
| 1913 |
noisy_input_ids[mask_indices] = mask_id
|
| 1914 |
-
|
| 1915 |
-
#
|
| 1916 |
labels = labels.clone()
|
| 1917 |
labels[~mask_indices] = -100
|
| 1918 |
-
|
| 1919 |
-
# Concatenate [noisy | clean]
|
| 1920 |
input_ids = torch.cat([noisy_input_ids, original_input_ids], dim=1)
|
| 1921 |
|
| 1922 |
-
# Complementary
|
| 1923 |
-
|
| 1924 |
-
|
| 1925 |
-
|
| 1926 |
-
|
| 1927 |
-
|
| 1928 |
-
|
| 1929 |
-
|
| 1930 |
-
|
| 1931 |
-
|
| 1932 |
-
|
| 1933 |
-
|
| 1934 |
-
|
| 1935 |
-
|
| 1936 |
-
input_ids = torch.cat([input_ids, complementary_input_ids], dim=0)
|
| 1937 |
-
labels = torch.cat([labels, complementary_labels], dim=0)
|
| 1938 |
-
|
| 1939 |
-
if self.use_block_causal_mask:
|
| 1940 |
-
if self.block_causal_no_dynamic:
|
| 1941 |
-
attention_mask = self.gen_block_causal_mask(seq_len, bd_size, input_ids.shape[0], self.config.num_attention_heads)
|
| 1942 |
-
else:
|
| 1943 |
-
attention_mask = self.gen_hybrid_block_causal_mask(seq_len, response_block_idx, turn_idx, input_ids.shape[0], self.config.num_attention_heads)
|
| 1944 |
-
else:
|
| 1945 |
-
attention_mask = self.gen_mask(seq_len, bd_size, input_ids.shape[0], self.config.num_attention_heads)
|
| 1946 |
-
|
| 1947 |
-
else: # 多模态 block diffusion
|
| 1948 |
-
# Phase A: Embed + masked scatter vision
|
| 1949 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1950 |
if inputs_embeds is None:
|
| 1951 |
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
| 1952 |
-
|
| 1953 |
if pixel_values is not None:
|
| 1954 |
image_embeds = self.model.get_image_features(pixel_values, image_grid_thw)
|
| 1955 |
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
|
|
@@ -1959,7 +1742,7 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1959 |
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
|
| 1960 |
)
|
| 1961 |
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
| 1962 |
-
|
| 1963 |
if pixel_values_videos is not None:
|
| 1964 |
video_embeds = self.model.get_video_features(pixel_values_videos, video_grid_thw)
|
| 1965 |
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
|
|
@@ -1969,8 +1752,8 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1969 |
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
|
| 1970 |
)
|
| 1971 |
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
| 1972 |
-
|
| 1973 |
-
# Phase B:
|
| 1974 |
if position_ids is None:
|
| 1975 |
position_ids, rope_deltas = self.model.get_rope_index(
|
| 1976 |
input_ids=input_ids,
|
|
@@ -1979,189 +1762,107 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1979 |
second_per_grid_ts=second_per_grid_ts,
|
| 1980 |
attention_mask=attention_mask,
|
| 1981 |
)
|
| 1982 |
-
|
| 1983 |
-
# Phase C:
|
| 1984 |
batch_size = input_ids.shape[0]
|
| 1985 |
L = input_ids.shape[1]
|
| 1986 |
seq_len = L
|
| 1987 |
-
|
| 1988 |
-
# if L > self.max_context_length:
|
| 1989 |
-
# L = self.max_context_length
|
| 1990 |
-
# input_ids = input_ids[:, :self.max_context_length]
|
| 1991 |
-
# labels = labels[:, :self.max_context_length]
|
| 1992 |
-
# position_ids = position_ids[:, :self.max_context_length]
|
| 1993 |
-
# attention_mask = attention_mask[:, :self.max_context_length]
|
| 1994 |
-
# inputs_embeds = inputs_embeds[:, :self.max_context_length]
|
| 1995 |
-
|
| 1996 |
hidden_size = inputs_embeds.shape[-1]
|
| 1997 |
-
|
| 1998 |
original_labels = labels.clone()
|
| 1999 |
original_input_ids = input_ids.clone()
|
| 2000 |
original_embeds = inputs_embeds.clone()
|
| 2001 |
-
original_position_ids = position_ids.clone()
|
| 2002 |
-
|
| 2003 |
-
#
|
| 2004 |
image_token_id = self.config.image_token_id
|
| 2005 |
video_token_id = self.config.video_token_id
|
| 2006 |
vision_start_token_id = self.config.vision_start_token_id
|
| 2007 |
vision_token_mask = (input_ids == image_token_id) | (input_ids == video_token_id) | (input_ids == vision_start_token_id)
|
| 2008 |
vision_mask_3d = vision_token_mask.unsqueeze(-1).expand(-1, -1, hidden_size)
|
| 2009 |
-
|
| 2010 |
-
# Block diffusion with multi-turn support
|
| 2011 |
-
# Each response segment has independent blocks
|
| 2012 |
response_block_idx, turn_idx, n_blocks = self.compute_response_block_idx(labels, bd_size)
|
| 2013 |
-
|
| 2014 |
-
# Each response segment has independent blocks
|
| 2015 |
-
response_mask = (labels != -100) # [B, seq_len]
|
| 2016 |
eps = self.minimum_noise_level
|
| 2017 |
|
| 2018 |
-
|
| 2019 |
-
|
| 2020 |
-
|
| 2021 |
-
|
| 2022 |
-
|
| 2023 |
-
|
| 2024 |
-
|
| 2025 |
-
|
| 2026 |
-
|
| 2027 |
-
|
| 2028 |
-
|
| 2029 |
-
|
| 2030 |
-
block_i = response_block_idx[i].item()
|
| 2031 |
-
if block_i >= 0: # response token
|
| 2032 |
-
mask_indices[:, i] = torch.rand((batch_size,), device=input_ids.device) < p_mask_per_block[block_i]
|
| 2033 |
-
else:
|
| 2034 |
-
input_ids = input_ids.reshape(input_ids.shape[0] * input_ids.shape[1] // bd_size, bd_size)
|
| 2035 |
-
b, l = input_ids.shape
|
| 2036 |
-
t = torch.rand((b,), device=input_ids.device)
|
| 2037 |
-
p_mask = (1 - eps) * t + eps
|
| 2038 |
-
p_mask = p_mask[:, None].repeat(1, l)
|
| 2039 |
-
|
| 2040 |
-
mask_indices = torch.rand((b, l), device=input_ids.device) < p_mask
|
| 2041 |
-
mask_indices = mask_indices.reshape(labels.shape) & response_mask
|
| 2042 |
-
input_ids = input_ids.reshape(labels.shape)
|
| 2043 |
-
|
| 2044 |
-
if self.always_mask_im_end:
|
| 2045 |
-
im_end_mask = (input_ids == self.im_end_token_id) & response_mask
|
| 2046 |
-
mask_indices = mask_indices | im_end_mask
|
| 2047 |
-
|
| 2048 |
noisy_input_ids = input_ids.clone()
|
| 2049 |
noisy_input_ids[mask_indices] = mask_id
|
| 2050 |
-
|
| 2051 |
-
#
|
| 2052 |
-
|
| 2053 |
-
|
| 2054 |
-
|
| 2055 |
-
mask_embeds = self.model.language_model.embed_tokens(
|
| 2056 |
-
torch.full_like(input_ids, mask_id)
|
| 2057 |
-
)
|
| 2058 |
-
noisy_embeds = torch.where(text_mask_3d, mask_embeds, noisy_embeds)
|
| 2059 |
-
else:
|
| 2060 |
-
noisy_embeds_raw = self.model.language_model.embed_tokens(noisy_input_ids)
|
| 2061 |
-
noisy_embeds = torch.where(vision_mask_3d, original_embeds, noisy_embeds_raw)
|
| 2062 |
-
|
| 2063 |
-
# 更新 labels
|
| 2064 |
labels_noisy = labels.clone()
|
| 2065 |
labels_noisy[~mask_indices] = -100
|
| 2066 |
-
|
| 2067 |
-
#
|
| 2068 |
input_ids_pair1 = torch.cat([noisy_input_ids, original_input_ids], dim=1)
|
| 2069 |
embeds_pair1 = torch.cat([noisy_embeds, original_embeds], dim=1)
|
| 2070 |
labels_pair1 = labels_noisy
|
| 2071 |
-
position_ids_pair1 = original_position_ids
|
| 2072 |
|
| 2073 |
-
|
| 2074 |
-
|
| 2075 |
-
|
| 2076 |
-
|
| 2077 |
-
|
| 2078 |
-
# Complementary
|
| 2079 |
-
if self.complementary_mask:
|
| 2080 |
-
complementary_mask_indices = response_mask & ~mask_indices
|
| 2081 |
-
if self.always_mask_im_end:
|
| 2082 |
-
im_end_mask = (original_input_ids == self.im_end_token_id) & response_mask
|
| 2083 |
-
complementary_mask_indices = complementary_mask_indices | im_end_mask
|
| 2084 |
-
|
| 2085 |
-
complementary_noisy_input_ids = original_input_ids.clone()
|
| 2086 |
-
complementary_noisy_input_ids[complementary_mask_indices] = mask_id
|
| 2087 |
-
|
| 2088 |
-
if self.enable_efficient_vision_embed:
|
| 2089 |
-
complementary_noisy_embeds = original_embeds.clone()
|
| 2090 |
-
text_mask_3d = complementary_mask_indices.unsqueeze(-1).expand(-1, -1, hidden_size)
|
| 2091 |
-
mask_embeds = self.model.language_model.embed_tokens(
|
| 2092 |
-
torch.full_like(original_input_ids, mask_id)
|
| 2093 |
-
)
|
| 2094 |
-
complementary_noisy_embeds = torch.where(text_mask_3d, mask_embeds, complementary_noisy_embeds)
|
| 2095 |
-
else:
|
| 2096 |
-
complementary_noisy_embeds_raw = self.model.language_model.embed_tokens(complementary_noisy_input_ids)
|
| 2097 |
-
complementary_noisy_embeds = torch.where(vision_mask_3d, original_embeds, complementary_noisy_embeds_raw)
|
| 2098 |
-
|
| 2099 |
-
complementary_labels = original_labels.clone()
|
| 2100 |
-
complementary_labels[~complementary_mask_indices] = -100
|
| 2101 |
-
|
| 2102 |
-
input_ids_pair2 = torch.cat([complementary_noisy_input_ids, original_input_ids], dim=1)
|
| 2103 |
-
embeds_pair2 = torch.cat([complementary_noisy_embeds, original_embeds], dim=1)
|
| 2104 |
-
labels_pair2 = complementary_labels
|
| 2105 |
-
position_ids_pair2 = original_position_ids
|
| 2106 |
-
|
| 2107 |
-
# Batch 拼接
|
| 2108 |
-
input_ids = torch.cat([input_ids_pair1, input_ids_pair2], dim=0)
|
| 2109 |
-
inputs_embeds = torch.cat([embeds_pair1, embeds_pair2], dim=0)
|
| 2110 |
-
labels = torch.cat([labels_pair1, labels_pair2], dim=0)
|
| 2111 |
-
position_ids = torch.cat([position_ids_pair1, position_ids_pair2], dim=1)
|
| 2112 |
-
|
| 2113 |
-
if self.use_block_causal_mask:
|
| 2114 |
-
if self.block_causal_no_dynamic:
|
| 2115 |
-
attention_mask = self.gen_block_causal_mask(L, bd_size, input_ids.shape[0], self.config.num_attention_heads)
|
| 2116 |
-
else:
|
| 2117 |
-
attention_mask = self.gen_hybrid_block_causal_mask(L, response_block_idx, turn_idx, input_ids.shape[0], self.config.num_attention_heads)
|
| 2118 |
-
else:
|
| 2119 |
-
attention_mask = self.gen_mask(L, bd_size, input_ids.shape[0], self.config.num_attention_heads)
|
| 2120 |
-
|
| 2121 |
-
# 清空 pixel_values(已替换)
|
| 2122 |
-
pixel_values = None
|
| 2123 |
-
pixel_values_videos = None
|
| 2124 |
|
|
|
|
|
|
|
| 2125 |
|
| 2126 |
-
|
| 2127 |
-
|
| 2128 |
-
|
| 2129 |
-
|
| 2130 |
-
|
| 2131 |
-
|
| 2132 |
-
|
| 2133 |
-
|
| 2134 |
-
|
| 2135 |
-
|
| 2136 |
-
|
| 2137 |
-
|
| 2138 |
-
|
| 2139 |
-
|
| 2140 |
-
|
| 2141 |
-
|
| 2142 |
-
|
| 2143 |
-
|
| 2144 |
-
|
| 2145 |
-
|
| 2146 |
-
|
| 2147 |
-
|
| 2148 |
-
|
| 2149 |
-
|
| 2150 |
-
|
| 2151 |
-
|
| 2152 |
-
|
| 2153 |
-
|
| 2154 |
-
|
| 2155 |
-
|
| 2156 |
-
|
| 2157 |
-
|
| 2158 |
-
|
| 2159 |
-
|
| 2160 |
-
|
| 2161 |
-
|
| 2162 |
-
|
| 2163 |
-
|
| 2164 |
-
|
|
|
|
|
|
|
| 2165 |
|
| 2166 |
else:
|
| 2167 |
outputs = self.model(
|
|
@@ -2193,31 +1894,18 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 2193 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 2194 |
logits = self.lm_head(mdm_hidden_states[:, slice_indices, :])
|
| 2195 |
|
| 2196 |
-
|
| 2197 |
-
|
| 2198 |
-
|
| 2199 |
-
}
|
| 2200 |
-
else:
|
| 2201 |
-
new_kwargs = kwargs
|
| 2202 |
if labels is not None:
|
| 2203 |
loss = self.loss_function(
|
| 2204 |
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs
|
| 2205 |
) * 0.5
|
| 2206 |
-
|
| 2207 |
-
|
| 2208 |
-
|
| 2209 |
-
|
| 2210 |
-
|
| 2211 |
-
causal_logits = self.lm_head(causal_hidden_states[:, slice_indices, :])
|
| 2212 |
-
loss += self.loss_function(
|
| 2213 |
-
logits=causal_logits, labels=original_labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs
|
| 2214 |
-
)
|
| 2215 |
-
|
| 2216 |
-
if self.entropy_loss:
|
| 2217 |
-
# Use num_items_in_batch for global normalization (consistent with cross entropy)
|
| 2218 |
-
num_items = kwargs.get('num_items_in_batch', None)
|
| 2219 |
-
entropy_loss = self.compute_entropy_loss(logits, labels, num_items_in_batch=num_items)
|
| 2220 |
-
loss += self.entropy_loss_weight * entropy_loss
|
| 2221 |
else:
|
| 2222 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 2223 |
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
|
|
| 23 |
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
| 24 |
|
| 25 |
from functools import partial
|
|
|
|
| 26 |
import math
|
| 27 |
|
| 28 |
logger = logging.get_logger(__name__)
|
| 29 |
|
| 30 |
|
|
|
|
|
|
|
| 31 |
def fused_flex_attention(q, k, v, mask=None):
|
| 32 |
return flex_attention(q, k, v, block_mask=mask, enable_gqa=True)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
def hybrid_block_causal_mask_multiturn(b, h, q_idx, kv_idx, response_block_idx=None, turn_idx=None, n=None):
|
| 35 |
"""
|
| 36 |
Multi-turn hybrid mask: Prompt uses causal, Response uses block causal.
|
|
|
|
| 65 |
|
| 66 |
is_prompt_q = (block_q < 0)
|
| 67 |
is_prompt_kv = (block_kv < 0)
|
| 68 |
+
|
| 69 |
+
# Block diagonal: same turn, both in x_t region.
|
| 70 |
+
block_diagonal = ~x0_flag_q & ~x0_flag_kv & (turn_q == turn_kv)
|
| 71 |
+
|
| 72 |
+
# Offset block-causal: x_t can attend to x_0 of strictly earlier turns.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
offset_block_causal = (
|
| 74 |
+
(turn_q > turn_kv)
|
| 75 |
& (x0_flag_kv == 1)
|
| 76 |
& (x0_flag_q == 0)
|
| 77 |
)
|
| 78 |
+
# x_0 region uses standard causal masking.
|
| 79 |
x0_causal = x0_flag_q & x0_flag_kv & (pos_q >= pos_kv)
|
| 80 |
+
|
| 81 |
+
return block_diagonal | offset_block_causal | x0_causal
|
|
|
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
|
|
|
|
| 731 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
| 732 |
if update_kv_cache:
|
| 733 |
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
| 734 |
elif len(past_key_values) > self.layer_idx and past_key_values[self.layer_idx][0] is not None:
|
| 735 |
key_states = torch.cat((past_key_values[self.layer_idx][0], key_states), dim=-2)
|
| 736 |
value_states = torch.cat((past_key_values[self.layer_idx][1], value_states), dim=-2)
|
|
|
|
| 874 |
# Initialize weights and apply final processing
|
| 875 |
self.post_init()
|
| 876 |
|
| 877 |
+
def get_input_embeddings(self):
|
| 878 |
+
return self.embed_tokens
|
| 879 |
+
|
| 880 |
+
def set_input_embeddings(self, value):
|
| 881 |
+
self.embed_tokens = value
|
| 882 |
+
|
| 883 |
@auto_docstring
|
| 884 |
def forward(
|
| 885 |
self,
|
|
|
|
| 951 |
text_position_ids = position_ids[0]
|
| 952 |
position_ids = position_ids[1:]
|
| 953 |
else:
|
|
|
|
| 954 |
text_position_ids = None
|
| 955 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 956 |
hidden_states = inputs_embeds
|
|
|
|
|
|
|
| 957 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 958 |
|
|
|
|
| 959 |
all_hidden_states = () if output_hidden_states else None
|
| 960 |
all_self_attns = () if output_attentions else None
|
| 961 |
|
|
|
|
| 983 |
|
| 984 |
hidden_states = self.norm(hidden_states)
|
| 985 |
|
|
|
|
| 986 |
if output_hidden_states:
|
| 987 |
all_hidden_states += (hidden_states,)
|
| 988 |
|
|
|
|
| 1012 |
self.visual = Fast_dVLMVisionTransformerPretrainedModel._from_config(config.vision_config)
|
| 1013 |
self.language_model = Fast_dVLMTextModel._from_config(config.text_config)
|
| 1014 |
self.rope_deltas = None # cache rope_deltas here
|
|
|
|
| 1015 |
|
| 1016 |
# Initialize weights and apply final processing
|
| 1017 |
self.post_init()
|
|
|
|
| 1197 |
mrope_position_deltas = torch.tensor(mrope_position_deltas).unsqueeze(1).to(device=input_ids.device)
|
| 1198 |
return position_ids, mrope_position_deltas
|
| 1199 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1200 |
if self.training:
|
| 1201 |
position_ids = (
|
| 1202 |
torch.arange(input_ids.shape[1] // 2, device=input_ids.device)
|
|
|
|
| 1298 |
return special_image_mask, special_video_mask
|
| 1299 |
|
| 1300 |
|
| 1301 |
+
def eval_mask(self, seqlen, block_size, cache_seq_len, update_kv_cache=False):
|
| 1302 |
q_indices = torch.arange(seqlen, device=self.device) + cache_seq_len
|
| 1303 |
k_indices = torch.arange(seqlen + cache_seq_len, device=self.device)
|
| 1304 |
+
if update_kv_cache:
|
| 1305 |
mask = eval_causal_mask(q_indices[:, None], k_indices[None, :])
|
| 1306 |
else:
|
| 1307 |
mask = eval_block_diff_mask(
|
| 1308 |
+
q_idx=q_indices[:, None],
|
| 1309 |
+
kv_idx=k_indices[None, :],
|
| 1310 |
+
block_size=block_size,
|
| 1311 |
)
|
| 1312 |
return mask
|
| 1313 |
|
|
|
|
| 1419 |
|
| 1420 |
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
| 1421 |
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
|
|
|
|
|
|
| 1422 |
if past_key_values is not None:
|
| 1423 |
delta = (past_key_values.get_seq_length() + self.rope_deltas).to(inputs_embeds.device)
|
| 1424 |
else:
|
|
|
|
| 1428 |
|
| 1429 |
position_ids = position_ids.to(inputs_embeds.device)
|
| 1430 |
if not self.training:
|
| 1431 |
+
attention_mask = self.eval_mask(inputs_embeds.shape[1], self.bd_size if bd_size is None else bd_size, 0 if past_key_values is None else past_key_values.get_seq_length(), update_kv_cache=update_kv_cache).to(inputs_embeds.device)
|
| 1432 |
|
| 1433 |
outputs = self.language_model(
|
| 1434 |
input_ids=None,
|
|
|
|
| 1501 |
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
| 1502 |
self.bd_size = config.bd_size
|
| 1503 |
self.model.bd_size = self.bd_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1504 |
self.minimum_noise_level = getattr(config, 'minimum_noise_level', 0.0)
|
|
|
|
|
|
|
|
|
|
| 1505 |
self.im_end_token_id = 151645 # <|im_end|> token id
|
| 1506 |
+
|
|
|
|
| 1507 |
# Vision-to-text aligner (if vision output dim != text hidden dim)
|
| 1508 |
vision_out_dim = config.vision_config.out_hidden_size
|
| 1509 |
text_hidden = config.text_config.hidden_size
|
|
|
|
| 1546 |
def visual(self):
|
| 1547 |
return self.model.visual
|
| 1548 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1549 |
def compute_response_block_idx(self, labels, block_size):
|
| 1550 |
"""
|
| 1551 |
Compute block index and turn index for each position.
|
|
|
|
| 1614 |
)
|
| 1615 |
return mask
|
| 1616 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1617 |
@can_return_tuple
|
| 1618 |
@auto_docstring
|
| 1619 |
def forward(
|
|
|
|
| 1656 |
eval_bd_size (`int`, *optional*):
|
| 1657 |
Block diffusion size to use during evaluation. Overrides the model default when set.
|
| 1658 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1659 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1660 |
output_hidden_states = (
|
| 1661 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1662 |
)
|
| 1663 |
if self.training:
|
| 1664 |
+
# Anneal block size: pick from [4, 8, ..., target_bd_size] based on
|
| 1665 |
+
# training progress. update_ratio is passed by the trainer (default 1.0
|
| 1666 |
+
# corresponds to using the target block size).
|
| 1667 |
+
update_ratio = kwargs.get('update_ratio', 1.0)
|
| 1668 |
+
max_power = int(math.log2(self.bd_size))
|
| 1669 |
+
possible_bd_sizes = [2**i for i in range(2, max_power + 1)]
|
| 1670 |
+
scaled_ratio = math.sqrt(update_ratio)
|
| 1671 |
+
idx = min(int(scaled_ratio * len(possible_bd_sizes)), len(possible_bd_sizes) - 1)
|
| 1672 |
+
bd_size = possible_bd_sizes[idx]
|
| 1673 |
+
|
| 1674 |
+
if pixel_values is None and pixel_values_videos is None: # text-only batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1675 |
|
| 1676 |
batch_size, seq_len = input_ids.shape
|
| 1677 |
original_labels = labels.clone()
|
|
|
|
| 1682 |
response_mask = (labels != -100) # [B, seq_len]
|
| 1683 |
eps = self.minimum_noise_level
|
| 1684 |
|
| 1685 |
+
response_block_idx, turn_idx, n_blocks = self.compute_response_block_idx(labels, bd_size)
|
| 1686 |
+
|
| 1687 |
+
# Per-block noise level sampled from [minimum_noise_level, 1].
|
| 1688 |
+
t = torch.rand((n_blocks,), device=input_ids.device)
|
| 1689 |
+
p_mask_per_block = (1 - eps) * t + eps
|
| 1690 |
+
|
| 1691 |
+
# Build [B, seq_len] mask: prompt tokens stay clean, response tokens
|
| 1692 |
+
# are masked block-wise according to p_mask_per_block.
|
| 1693 |
+
mask_indices = torch.zeros_like(labels, dtype=torch.bool)
|
| 1694 |
+
for i in range(seq_len):
|
| 1695 |
+
block_i = response_block_idx[i].item()
|
| 1696 |
+
if block_i >= 0: # response token
|
| 1697 |
+
mask_indices[:, i] = torch.rand((batch_size,), device=input_ids.device) < p_mask_per_block[block_i]
|
| 1698 |
+
|
| 1699 |
+
# Always mask <|im_end|> tokens that fall inside the response.
|
| 1700 |
+
im_end_mask = (input_ids == self.im_end_token_id) & response_mask
|
| 1701 |
+
mask_indices = mask_indices | im_end_mask
|
| 1702 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1703 |
noisy_input_ids = input_ids.clone()
|
| 1704 |
noisy_input_ids[mask_indices] = mask_id
|
| 1705 |
+
|
| 1706 |
+
# Restrict the loss to masked response tokens only.
|
| 1707 |
labels = labels.clone()
|
| 1708 |
labels[~mask_indices] = -100
|
| 1709 |
+
|
| 1710 |
+
# Concatenate [noisy | clean] along the sequence dimension.
|
| 1711 |
input_ids = torch.cat([noisy_input_ids, original_input_ids], dim=1)
|
| 1712 |
|
| 1713 |
+
# Complementary pair: mask the positions that were left clean above.
|
| 1714 |
+
complementary_mask_indices = response_mask & ~mask_indices
|
| 1715 |
+
im_end_mask = (original_input_ids == self.im_end_token_id) & response_mask
|
| 1716 |
+
complementary_mask_indices = complementary_mask_indices | im_end_mask
|
| 1717 |
+
|
| 1718 |
+
complementary_noisy_input_ids = original_input_ids.clone()
|
| 1719 |
+
complementary_noisy_input_ids[complementary_mask_indices] = mask_id
|
| 1720 |
+
|
| 1721 |
+
complementary_labels = original_labels.clone()
|
| 1722 |
+
complementary_labels[~complementary_mask_indices] = -100
|
| 1723 |
+
complementary_input_ids = torch.cat([complementary_noisy_input_ids, original_input_ids], dim=1)
|
| 1724 |
+
|
| 1725 |
+
input_ids = torch.cat([input_ids, complementary_input_ids], dim=0)
|
| 1726 |
+
labels = torch.cat([labels, complementary_labels], dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1727 |
|
| 1728 |
+
attention_mask = self.gen_hybrid_block_causal_mask(seq_len, response_block_idx, turn_idx, input_ids.shape[0], self.config.num_attention_heads)
|
| 1729 |
+
|
| 1730 |
+
else:
|
| 1731 |
+
# Multimodal block diffusion path.
|
| 1732 |
+
# Phase A: embed input_ids and scatter vision features into placeholder positions.
|
| 1733 |
if inputs_embeds is None:
|
| 1734 |
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
| 1735 |
+
|
| 1736 |
if pixel_values is not None:
|
| 1737 |
image_embeds = self.model.get_image_features(pixel_values, image_grid_thw)
|
| 1738 |
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
|
|
|
|
| 1742 |
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
|
| 1743 |
)
|
| 1744 |
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
| 1745 |
+
|
| 1746 |
if pixel_values_videos is not None:
|
| 1747 |
video_embeds = self.model.get_video_features(pixel_values_videos, video_grid_thw)
|
| 1748 |
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
|
|
|
|
| 1752 |
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
|
| 1753 |
)
|
| 1754 |
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
| 1755 |
+
|
| 1756 |
+
# Phase B: build 3D position_ids on the original (pre-doubled) length.
|
| 1757 |
if position_ids is None:
|
| 1758 |
position_ids, rope_deltas = self.model.get_rope_index(
|
| 1759 |
input_ids=input_ids,
|
|
|
|
| 1762 |
second_per_grid_ts=second_per_grid_ts,
|
| 1763 |
attention_mask=attention_mask,
|
| 1764 |
)
|
| 1765 |
+
|
| 1766 |
+
# Phase C: block diffusion that preserves vision token positions.
|
| 1767 |
batch_size = input_ids.shape[0]
|
| 1768 |
L = input_ids.shape[1]
|
| 1769 |
seq_len = L
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1770 |
hidden_size = inputs_embeds.shape[-1]
|
| 1771 |
+
|
| 1772 |
original_labels = labels.clone()
|
| 1773 |
original_input_ids = input_ids.clone()
|
| 1774 |
original_embeds = inputs_embeds.clone()
|
| 1775 |
+
original_position_ids = position_ids.clone()
|
| 1776 |
+
|
| 1777 |
+
# Identify vision tokens so noise is never applied to them.
|
| 1778 |
image_token_id = self.config.image_token_id
|
| 1779 |
video_token_id = self.config.video_token_id
|
| 1780 |
vision_start_token_id = self.config.vision_start_token_id
|
| 1781 |
vision_token_mask = (input_ids == image_token_id) | (input_ids == video_token_id) | (input_ids == vision_start_token_id)
|
| 1782 |
vision_mask_3d = vision_token_mask.unsqueeze(-1).expand(-1, -1, hidden_size)
|
| 1783 |
+
|
| 1784 |
+
# Block diffusion with multi-turn support: each response segment has its own blocks.
|
|
|
|
| 1785 |
response_block_idx, turn_idx, n_blocks = self.compute_response_block_idx(labels, bd_size)
|
| 1786 |
+
response_mask = (labels != -100)
|
|
|
|
|
|
|
| 1787 |
eps = self.minimum_noise_level
|
| 1788 |
|
| 1789 |
+
t = torch.rand((n_blocks,), device=input_ids.device)
|
| 1790 |
+
p_mask_per_block = (1 - eps) * t + eps
|
| 1791 |
+
|
| 1792 |
+
mask_indices = torch.zeros_like(labels, dtype=torch.bool)
|
| 1793 |
+
for i in range(seq_len):
|
| 1794 |
+
block_i = response_block_idx[i].item()
|
| 1795 |
+
if block_i >= 0:
|
| 1796 |
+
mask_indices[:, i] = torch.rand((batch_size,), device=input_ids.device) < p_mask_per_block[block_i]
|
| 1797 |
+
|
| 1798 |
+
im_end_mask = (input_ids == self.im_end_token_id) & response_mask
|
| 1799 |
+
mask_indices = mask_indices | im_end_mask
|
| 1800 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1801 |
noisy_input_ids = input_ids.clone()
|
| 1802 |
noisy_input_ids[mask_indices] = mask_id
|
| 1803 |
+
|
| 1804 |
+
# Build noisy embeddings while keeping vision embeddings intact.
|
| 1805 |
+
noisy_embeds_raw = self.model.language_model.embed_tokens(noisy_input_ids)
|
| 1806 |
+
noisy_embeds = torch.where(vision_mask_3d, original_embeds, noisy_embeds_raw)
|
| 1807 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1808 |
labels_noisy = labels.clone()
|
| 1809 |
labels_noisy[~mask_indices] = -100
|
| 1810 |
+
|
| 1811 |
+
# Concatenate [noisy | clean] along the sequence dimension.
|
| 1812 |
input_ids_pair1 = torch.cat([noisy_input_ids, original_input_ids], dim=1)
|
| 1813 |
embeds_pair1 = torch.cat([noisy_embeds, original_embeds], dim=1)
|
| 1814 |
labels_pair1 = labels_noisy
|
| 1815 |
+
position_ids_pair1 = original_position_ids
|
| 1816 |
|
| 1817 |
+
# Complementary pair: mask the positions that were left clean above.
|
| 1818 |
+
complementary_mask_indices = response_mask & ~mask_indices
|
| 1819 |
+
im_end_mask = (original_input_ids == self.im_end_token_id) & response_mask
|
| 1820 |
+
complementary_mask_indices = complementary_mask_indices | im_end_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1821 |
|
| 1822 |
+
complementary_noisy_input_ids = original_input_ids.clone()
|
| 1823 |
+
complementary_noisy_input_ids[complementary_mask_indices] = mask_id
|
| 1824 |
|
| 1825 |
+
complementary_noisy_embeds_raw = self.model.language_model.embed_tokens(complementary_noisy_input_ids)
|
| 1826 |
+
complementary_noisy_embeds = torch.where(vision_mask_3d, original_embeds, complementary_noisy_embeds_raw)
|
| 1827 |
+
|
| 1828 |
+
complementary_labels = original_labels.clone()
|
| 1829 |
+
complementary_labels[~complementary_mask_indices] = -100
|
| 1830 |
+
|
| 1831 |
+
input_ids_pair2 = torch.cat([complementary_noisy_input_ids, original_input_ids], dim=1)
|
| 1832 |
+
embeds_pair2 = torch.cat([complementary_noisy_embeds, original_embeds], dim=1)
|
| 1833 |
+
labels_pair2 = complementary_labels
|
| 1834 |
+
position_ids_pair2 = original_position_ids
|
| 1835 |
+
|
| 1836 |
+
# Stack the complementary pair along the batch dimension.
|
| 1837 |
+
input_ids = torch.cat([input_ids_pair1, input_ids_pair2], dim=0)
|
| 1838 |
+
inputs_embeds = torch.cat([embeds_pair1, embeds_pair2], dim=0)
|
| 1839 |
+
labels = torch.cat([labels_pair1, labels_pair2], dim=0)
|
| 1840 |
+
position_ids = torch.cat([position_ids_pair1, position_ids_pair2], dim=1)
|
| 1841 |
+
|
| 1842 |
+
attention_mask = self.gen_hybrid_block_causal_mask(L, response_block_idx, turn_idx, input_ids.shape[0], self.config.num_attention_heads)
|
| 1843 |
+
|
| 1844 |
+
# Phase D: forward through the inner model. Vision features (if any)
|
| 1845 |
+
# have already been scattered into inputs_embeds, so pixel_values are
|
| 1846 |
+
# cleared to skip re-processing inside `Fast_dVLMModel`.
|
| 1847 |
+
outputs = self.model(
|
| 1848 |
+
input_ids=input_ids,
|
| 1849 |
+
pixel_values=None,
|
| 1850 |
+
pixel_values_videos=None,
|
| 1851 |
+
image_grid_thw=None,
|
| 1852 |
+
video_grid_thw=None,
|
| 1853 |
+
position_ids=position_ids,
|
| 1854 |
+
attention_mask=attention_mask,
|
| 1855 |
+
past_key_values=past_key_values,
|
| 1856 |
+
inputs_embeds=inputs_embeds,
|
| 1857 |
+
use_cache=use_cache,
|
| 1858 |
+
output_attentions=output_attentions,
|
| 1859 |
+
output_hidden_states=output_hidden_states,
|
| 1860 |
+
return_dict=True,
|
| 1861 |
+
cache_position=cache_position,
|
| 1862 |
+
update_kv_cache=update_kv_cache,
|
| 1863 |
+
bd_size=bd_size,
|
| 1864 |
+
**kwargs,
|
| 1865 |
+
)
|
| 1866 |
|
| 1867 |
else:
|
| 1868 |
outputs = self.model(
|
|
|
|
| 1894 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1895 |
logits = self.lm_head(mdm_hidden_states[:, slice_indices, :])
|
| 1896 |
|
| 1897 |
+
new_kwargs = {
|
| 1898 |
+
'num_items_in_batch': 2 * kwargs['num_items_in_batch'],
|
| 1899 |
+
}
|
|
|
|
|
|
|
|
|
|
| 1900 |
if labels is not None:
|
| 1901 |
loss = self.loss_function(
|
| 1902 |
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs
|
| 1903 |
) * 0.5
|
| 1904 |
+
causal_hidden_states = hidden_states[:hidden_states.shape[0]//2, hidden_states.shape[1]//2:, :]
|
| 1905 |
+
causal_logits = self.lm_head(causal_hidden_states[:, slice_indices, :])
|
| 1906 |
+
loss += self.loss_function(
|
| 1907 |
+
logits=causal_logits, labels=original_labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs
|
| 1908 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1909 |
else:
|
| 1910 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1911 |
logits = self.lm_head(hidden_states[:, slice_indices, :])
|