Upload processor
Browse files- modeling_yangjian.py +25 -20
modeling_yangjian.py
CHANGED
|
@@ -262,14 +262,14 @@ class YangJianCompareVisualEncoder(nn.Module):
|
|
| 262 |
super().__init__()
|
| 263 |
self.config = config
|
| 264 |
self.hidden_size = config.hidden_size
|
| 265 |
-
self.token_size = 100 * (config.spatial_merge_size**2) if "compare_token_size" not in config else config.compare_token_size * (config.spatial_merge_size**2)
|
| 266 |
-
|
| 267 |
# Encoder 部分:双向图像特征交互
|
| 268 |
# 第一个cross attention: previous attend to current
|
| 269 |
self.encoder_cross_attn1 = OptimizedCrossAttention(config, is_cross_attention=True)
|
| 270 |
# 第二个cross attention: current attend to previous
|
| 271 |
self.encoder_cross_attn2 = OptimizedCrossAttention(config, is_cross_attention=True)
|
| 272 |
-
|
| 273 |
self.encoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
|
| 274 |
self.encoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
|
| 275 |
self.encoder_norm3 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
|
|
@@ -290,6 +290,8 @@ class YangJianCompareVisualEncoder(nn.Module):
|
|
| 290 |
self.decoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
|
| 291 |
self.decoder_mlp = Qwen2_5_VLMLP(config)
|
| 292 |
|
|
|
|
|
|
|
| 293 |
def _ensure_device_dtype_consistency(self, target_tensor):
|
| 294 |
"""
|
| 295 |
确保所有模块组件都在目标张量的设备上并使用相同的数据类型
|
|
@@ -391,8 +393,19 @@ class YangJianCompareVisualEncoder(nn.Module):
|
|
| 391 |
torch.ones(batch_size, self.token_size, dtype=torch.bool, device=device), # query掩码
|
| 392 |
attention_masks # encoded特征的掩码
|
| 393 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
-
return compare_visual_embeds # [batch_size, token_size,
|
| 396 |
|
| 397 |
def _encoder_forward(self, current_features, previous_features, current_mask=None, previous_mask=None):
|
| 398 |
"""
|
|
@@ -447,8 +460,9 @@ class YangJianCompareVisualEncoder(nn.Module):
|
|
| 447 |
residual = current_features
|
| 448 |
mlp_input2 = self.encoder_norm4(current_features)
|
| 449 |
mlp_output2 = self.encoder_mlp2(mlp_input2)
|
| 450 |
-
current_features = residual + mlp_output2
|
| 451 |
-
|
|
|
|
| 452 |
return current_features
|
| 453 |
|
| 454 |
def _decoder_forward(self, queries, encoded_features, query_mask=None, encoded_mask=None):
|
|
@@ -548,18 +562,7 @@ class YangJianVisionTransformerPretrainedModel(Qwen2_5_VisionTransformerPretrain
|
|
| 548 |
splited_hidden_states_before_merger = torch.split(hidden_states, split_sizes)
|
| 549 |
# [total_images, token_size, hidden_size]
|
| 550 |
compare_visual_embeds = self.compare_visual_encoder(splited_hidden_states_before_merger)
|
| 551 |
-
|
| 552 |
-
batch_size = compare_visual_embeds.size(0)
|
| 553 |
-
token_size = compare_visual_embeds.size(1)
|
| 554 |
-
# 将所有batch的数据拼接在一起
|
| 555 |
-
# [batch_size * token_size, hidden_size]
|
| 556 |
-
flattened_embeds = compare_visual_embeds.view(-1, compare_visual_embeds.size(-1))
|
| 557 |
-
# 一次性进行merger操作
|
| 558 |
-
# 假设merger会将token数量变为原来的1/4
|
| 559 |
-
merged = self.merger(flattened_embeds) # [(batch_size * token_size)/4, merged_hidden_size]
|
| 560 |
-
merged_token_size = token_size // self.spatial_merge_size**2
|
| 561 |
-
# [batch_size, merged_token_size, merged_hidden_size]
|
| 562 |
-
compare_visual_embeds = merged.view(batch_size, merged_token_size, -1)
|
| 563 |
|
| 564 |
hidden_states = self.merger(hidden_states)
|
| 565 |
reverse_indices = torch.argsort(window_index)
|
|
@@ -853,8 +856,10 @@ class YangJianVLModel(Qwen2_5_VLModel):
|
|
| 853 |
if ed_image < ed_video:
|
| 854 |
# 如果当前是图片,则需要插入 compare_token_size 个图像对比的token的position
|
| 855 |
compare_t_index = t_index[-1].repeat(self.compare_token_size)
|
| 856 |
-
compare_h_index = torch.arange(self.compare_token_size)
|
| 857 |
-
compare_w_index = torch.arange(self.compare_token_size)
|
|
|
|
|
|
|
| 858 |
llm_pos_ids_list.append(torch.stack([compare_t_index, compare_h_index, compare_w_index]) + text_len + st_idx)
|
| 859 |
st = st + self.compare_token_size
|
| 860 |
|
|
|
|
| 262 |
super().__init__()
|
| 263 |
self.config = config
|
| 264 |
self.hidden_size = config.hidden_size
|
| 265 |
+
# self.token_size = 100 * (config.spatial_merge_size**2) if "compare_token_size" not in config else config.compare_token_size * (config.spatial_merge_size**2)
|
| 266 |
+
self.token_size = 100 if "compare_token_size" not in config else config.compare_token_size
|
| 267 |
# Encoder 部分:双向图像特征交互
|
| 268 |
# 第一个cross attention: previous attend to current
|
| 269 |
self.encoder_cross_attn1 = OptimizedCrossAttention(config, is_cross_attention=True)
|
| 270 |
# 第二个cross attention: current attend to previous
|
| 271 |
self.encoder_cross_attn2 = OptimizedCrossAttention(config, is_cross_attention=True)
|
| 272 |
+
|
| 273 |
self.encoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
|
| 274 |
self.encoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
|
| 275 |
self.encoder_norm3 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
|
|
|
|
| 290 |
self.decoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
|
| 291 |
self.decoder_mlp = Qwen2_5_VLMLP(config)
|
| 292 |
|
| 293 |
+
self.compare_projector = nn.Linear(config.hidden_size, config.out_hidden_size)
|
| 294 |
+
|
| 295 |
def _ensure_device_dtype_consistency(self, target_tensor):
|
| 296 |
"""
|
| 297 |
确保所有模块组件都在目标张量的设备上并使用相同的数据类型
|
|
|
|
| 393 |
torch.ones(batch_size, self.token_size, dtype=torch.bool, device=device), # query掩码
|
| 394 |
attention_masks # encoded特征的掩码
|
| 395 |
)
|
| 396 |
+
|
| 397 |
+
# 记录每个batch的token数量
|
| 398 |
+
batch_size = compare_visual_embeds.size(0)
|
| 399 |
+
token_size = compare_visual_embeds.size(1)
|
| 400 |
+
# 将所有batch的数据拼接在一起
|
| 401 |
+
# [batch_size * token_size, hidden_size]
|
| 402 |
+
flattened_embeds = compare_visual_embeds.view(-1, compare_visual_embeds.size(-1))
|
| 403 |
+
merged = self.compare_projector(flattened_embeds) # [batch_size * token_size, merged_hidden_size]
|
| 404 |
+
merged_token_size = token_size
|
| 405 |
+
# [batch_size, merged_token_size, merged_hidden_size]
|
| 406 |
+
compare_visual_embeds = merged.view(batch_size, merged_token_size, -1)
|
| 407 |
|
| 408 |
+
return compare_visual_embeds # [batch_size, token_size, out_hidden_size]
|
| 409 |
|
| 410 |
def _encoder_forward(self, current_features, previous_features, current_mask=None, previous_mask=None):
|
| 411 |
"""
|
|
|
|
| 460 |
residual = current_features
|
| 461 |
mlp_input2 = self.encoder_norm4(current_features)
|
| 462 |
mlp_output2 = self.encoder_mlp2(mlp_input2)
|
| 463 |
+
# current_features = residual + mlp_output2
|
| 464 |
+
# 修改为减法
|
| 465 |
+
current_features = residual - mlp_output2
|
| 466 |
return current_features
|
| 467 |
|
| 468 |
def _decoder_forward(self, queries, encoded_features, query_mask=None, encoded_mask=None):
|
|
|
|
| 562 |
splited_hidden_states_before_merger = torch.split(hidden_states, split_sizes)
|
| 563 |
# [total_images, token_size, hidden_size]
|
| 564 |
compare_visual_embeds = self.compare_visual_encoder(splited_hidden_states_before_merger)
|
| 565 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
|
| 567 |
hidden_states = self.merger(hidden_states)
|
| 568 |
reverse_indices = torch.argsort(window_index)
|
|
|
|
| 856 |
if ed_image < ed_video:
|
| 857 |
# 如果当前是图片,则需要插入 compare_token_size 个图像对比的token的position
|
| 858 |
compare_t_index = t_index[-1].repeat(self.compare_token_size)
|
| 859 |
+
# compare_h_index = torch.arange(self.compare_token_size)
|
| 860 |
+
# compare_w_index = torch.arange(self.compare_token_size)
|
| 861 |
+
compare_h_index = compare_t_index
|
| 862 |
+
compare_w_index = compare_t_index
|
| 863 |
llm_pos_ids_list.append(torch.stack([compare_t_index, compare_h_index, compare_w_index]) + text_len + st_idx)
|
| 864 |
st = st + self.compare_token_size
|
| 865 |
|