jiang-cc commited on
Commit
9d08fc8
·
verified ·
1 Parent(s): 61bcc35

Upload processor

Browse files
Files changed (1) hide show
  1. modeling_yangjian.py +25 -20
modeling_yangjian.py CHANGED
@@ -262,14 +262,14 @@ class YangJianCompareVisualEncoder(nn.Module):
262
  super().__init__()
263
  self.config = config
264
  self.hidden_size = config.hidden_size
265
- self.token_size = 100 * (config.spatial_merge_size**2) if "compare_token_size" not in config else config.compare_token_size * (config.spatial_merge_size**2)
266
-
267
  # Encoder 部分:双向图像特征交互
268
  # 第一个cross attention: previous attend to current
269
  self.encoder_cross_attn1 = OptimizedCrossAttention(config, is_cross_attention=True)
270
  # 第二个cross attention: current attend to previous
271
  self.encoder_cross_attn2 = OptimizedCrossAttention(config, is_cross_attention=True)
272
-
273
  self.encoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
274
  self.encoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
275
  self.encoder_norm3 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
@@ -290,6 +290,8 @@ class YangJianCompareVisualEncoder(nn.Module):
290
  self.decoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
291
  self.decoder_mlp = Qwen2_5_VLMLP(config)
292
 
 
 
293
  def _ensure_device_dtype_consistency(self, target_tensor):
294
  """
295
  确保所有模块组件都在目标张量的设备上并使用相同的数据类型
@@ -391,8 +393,19 @@ class YangJianCompareVisualEncoder(nn.Module):
391
  torch.ones(batch_size, self.token_size, dtype=torch.bool, device=device), # query掩码
392
  attention_masks # encoded特征的掩码
393
  )
 
 
 
 
 
 
 
 
 
 
 
394
 
395
- return compare_visual_embeds # [batch_size, token_size, hidden_size]
396
 
397
  def _encoder_forward(self, current_features, previous_features, current_mask=None, previous_mask=None):
398
  """
@@ -447,8 +460,9 @@ class YangJianCompareVisualEncoder(nn.Module):
447
  residual = current_features
448
  mlp_input2 = self.encoder_norm4(current_features)
449
  mlp_output2 = self.encoder_mlp2(mlp_input2)
450
- current_features = residual + mlp_output2
451
-
 
452
  return current_features
453
 
454
  def _decoder_forward(self, queries, encoded_features, query_mask=None, encoded_mask=None):
@@ -548,18 +562,7 @@ class YangJianVisionTransformerPretrainedModel(Qwen2_5_VisionTransformerPretrain
548
  splited_hidden_states_before_merger = torch.split(hidden_states, split_sizes)
549
  # [total_images, token_size, hidden_size]
550
  compare_visual_embeds = self.compare_visual_encoder(splited_hidden_states_before_merger)
551
- # 记录每个batch的token数量
552
- batch_size = compare_visual_embeds.size(0)
553
- token_size = compare_visual_embeds.size(1)
554
- # 将所有batch的数据拼接在一起
555
- # [batch_size * token_size, hidden_size]
556
- flattened_embeds = compare_visual_embeds.view(-1, compare_visual_embeds.size(-1))
557
- # 一次性进行merger操作
558
- # 假设merger会将token数量变为原来的1/4
559
- merged = self.merger(flattened_embeds) # [(batch_size * token_size)/4, merged_hidden_size]
560
- merged_token_size = token_size // self.spatial_merge_size**2
561
- # [batch_size, merged_token_size, merged_hidden_size]
562
- compare_visual_embeds = merged.view(batch_size, merged_token_size, -1)
563
 
564
  hidden_states = self.merger(hidden_states)
565
  reverse_indices = torch.argsort(window_index)
@@ -853,8 +856,10 @@ class YangJianVLModel(Qwen2_5_VLModel):
853
  if ed_image < ed_video:
854
  # 如果当前是图片,则需要插入 compare_token_size 个图像对比的token的position
855
  compare_t_index = t_index[-1].repeat(self.compare_token_size)
856
- compare_h_index = torch.arange(self.compare_token_size)
857
- compare_w_index = torch.arange(self.compare_token_size)
 
 
858
  llm_pos_ids_list.append(torch.stack([compare_t_index, compare_h_index, compare_w_index]) + text_len + st_idx)
859
  st = st + self.compare_token_size
860
 
 
262
  super().__init__()
263
  self.config = config
264
  self.hidden_size = config.hidden_size
265
+ # self.token_size = 100 * (config.spatial_merge_size**2) if "compare_token_size" not in config else config.compare_token_size * (config.spatial_merge_size**2)
266
+ self.token_size = 100 if "compare_token_size" not in config else config.compare_token_size
267
  # Encoder 部分:双向图像特征交互
268
  # 第一个cross attention: previous attend to current
269
  self.encoder_cross_attn1 = OptimizedCrossAttention(config, is_cross_attention=True)
270
  # 第二个cross attention: current attend to previous
271
  self.encoder_cross_attn2 = OptimizedCrossAttention(config, is_cross_attention=True)
272
+
273
  self.encoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
274
  self.encoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
275
  self.encoder_norm3 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
 
290
  self.decoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
291
  self.decoder_mlp = Qwen2_5_VLMLP(config)
292
 
293
+ self.compare_projector = nn.Linear(config.hidden_size, config.out_hidden_size)
294
+
295
  def _ensure_device_dtype_consistency(self, target_tensor):
296
  """
297
  确保所有模块组件都在目标张量的设备上并使用相同的数据类型
 
393
  torch.ones(batch_size, self.token_size, dtype=torch.bool, device=device), # query掩码
394
  attention_masks # encoded特征的掩码
395
  )
396
+
397
+ # 记录每个batch的token数量
398
+ batch_size = compare_visual_embeds.size(0)
399
+ token_size = compare_visual_embeds.size(1)
400
+ # 将所有batch的数据拼接在一起
401
+ # [batch_size * token_size, hidden_size]
402
+ flattened_embeds = compare_visual_embeds.view(-1, compare_visual_embeds.size(-1))
403
+ merged = self.compare_projector(flattened_embeds) # [batch_size * token_size, merged_hidden_size]
404
+ merged_token_size = token_size
405
+ # [batch_size, merged_token_size, merged_hidden_size]
406
+ compare_visual_embeds = merged.view(batch_size, merged_token_size, -1)
407
 
408
+ return compare_visual_embeds # [batch_size, token_size, out_hidden_size]
409
 
410
  def _encoder_forward(self, current_features, previous_features, current_mask=None, previous_mask=None):
411
  """
 
460
  residual = current_features
461
  mlp_input2 = self.encoder_norm4(current_features)
462
  mlp_output2 = self.encoder_mlp2(mlp_input2)
463
+ # current_features = residual + mlp_output2
464
+ # 修改为减法
465
+ current_features = residual - mlp_output2
466
  return current_features
467
 
468
  def _decoder_forward(self, queries, encoded_features, query_mask=None, encoded_mask=None):
 
562
  splited_hidden_states_before_merger = torch.split(hidden_states, split_sizes)
563
  # [total_images, token_size, hidden_size]
564
  compare_visual_embeds = self.compare_visual_encoder(splited_hidden_states_before_merger)
565
+
 
 
 
 
 
 
 
 
 
 
 
566
 
567
  hidden_states = self.merger(hidden_states)
568
  reverse_indices = torch.argsort(window_index)
 
856
  if ed_image < ed_video:
857
  # 如果当前是图片,则需要插入 compare_token_size 个图像对比的token的position
858
  compare_t_index = t_index[-1].repeat(self.compare_token_size)
859
+ # compare_h_index = torch.arange(self.compare_token_size)
860
+ # compare_w_index = torch.arange(self.compare_token_size)
861
+ compare_h_index = compare_t_index
862
+ compare_w_index = compare_t_index
863
  llm_pos_ids_list.append(torch.stack([compare_t_index, compare_h_index, compare_w_index]) + text_len + st_idx)
864
  st = st + self.compare_token_size
865