jiang-cc commited on
Commit
572d078
·
verified ·
1 Parent(s): fc450cf

Upload processor

Browse files
Files changed (1) hide show
  1. modeling_yangjian.py +187 -12
modeling_yangjian.py CHANGED
@@ -116,6 +116,7 @@ class YangJianProcessor(Qwen2_5_VLProcessor):
116
  for i in range(len(text)):
117
  while self.image_token in text[i]:
118
  num_image_tokens = image_grid_thw[index].prod() // merge_length
 
119
  text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens + self.compare_token_size), 1)
120
  index += 1
121
  text[i] = text[i].replace("<|placeholder|>", self.image_token)
@@ -486,7 +487,7 @@ class YangJianVisionTransformerPretrainedModel(Qwen2_5_VisionTransformerPretrain
486
  def __init__(self, config, *inputs, **kwargs) -> None:
487
  super().__init__(config, *inputs, **kwargs)
488
  self.compare_visual_encoder = YangJianCompareVisualEncoder(config)
489
-
490
  def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
491
  """
492
  Args:
@@ -570,6 +571,7 @@ class YangJianVLModel(Qwen2_5_VLModel):
570
  def __init__(self, config):
571
  super().__init__(config)
572
  self.visual = YangJianVisionTransformerPretrainedModel._from_config(config.vision_config)
 
573
  # self.learnable_image_embeddings = nn.Parameter(
574
  # torch.randn(100, config.hidden_size) * 0.02 # 使用小的初始化值
575
  # )
@@ -644,19 +646,11 @@ class YangJianVLModel(Qwen2_5_VLModel):
644
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
645
 
646
  if inputs_embeds is None:
 
647
  inputs_embeds = self.get_input_embeddings()(input_ids)
648
  if pixel_values is not None:
649
  image_embeds = self.get_image_features(pixel_values, image_grid_thw)
650
-
651
- # # 为每个图像添加 100 个可学习的 embedding
652
- # learnable_embeddings = self.learnable_image_embeddings.to(image_embeds[0].device, image_embeds[0].dtype)
653
- # enhanced_image_embeds = []
654
-
655
- # for i, embeds in enumerate(image_embeds):
656
- # # 为每个图像添加 100 个可学习的 embedding
657
- # enhanced_embeds = torch.cat([embeds, learnable_embeddings], dim=0)
658
- # enhanced_image_embeds.append(enhanced_embeds)
659
-
660
  image_embeds = torch.cat(image_embeds, dim=0)
661
  n_image_tokens = (input_ids == self.config.image_token_id).sum()
662
  n_image_features = image_embeds.shape[0]
@@ -713,7 +707,7 @@ class YangJianVLModel(Qwen2_5_VLModel):
713
  or (past_key_values is None or past_key_values.get_seq_length() == 0)
714
  )
715
  if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
716
- position_ids, rope_deltas = self.get_rope_index(
717
  input_ids,
718
  image_grid_thw,
719
  video_grid_thw,
@@ -758,6 +752,142 @@ class YangJianVLModel(Qwen2_5_VLModel):
758
  rope_deltas=self.rope_deltas,
759
  )
760
  return output if return_dict else output.to_tuple()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
761
 
762
  class YangJianVLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
763
  config_class = YangJianConfig
@@ -765,3 +895,48 @@ class YangJianVLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
765
  def __init__(self, config):
766
  super().__init__(config)
767
  self.model = YangJianVLModel(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  for i in range(len(text)):
117
  while self.image_token in text[i]:
118
  num_image_tokens = image_grid_thw[index].prod() // merge_length
119
+ # text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens), 1)
120
  text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens + self.compare_token_size), 1)
121
  index += 1
122
  text[i] = text[i].replace("<|placeholder|>", self.image_token)
 
487
  def __init__(self, config, *inputs, **kwargs) -> None:
488
  super().__init__(config, *inputs, **kwargs)
489
  self.compare_visual_encoder = YangJianCompareVisualEncoder(config)
490
+
491
  def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
492
  """
493
  Args:
 
571
  def __init__(self, config):
572
  super().__init__(config)
573
  self.visual = YangJianVisionTransformerPretrainedModel._from_config(config.vision_config)
574
+ self.compare_token_size = config.vision_config.compare_token_size
575
  # self.learnable_image_embeddings = nn.Parameter(
576
  # torch.randn(100, config.hidden_size) * 0.02 # 使用小的初始化值
577
  # )
 
646
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
647
 
648
  if inputs_embeds is None:
649
+
650
  inputs_embeds = self.get_input_embeddings()(input_ids)
651
  if pixel_values is not None:
652
  image_embeds = self.get_image_features(pixel_values, image_grid_thw)
653
+
 
 
 
 
 
 
 
 
 
654
  image_embeds = torch.cat(image_embeds, dim=0)
655
  n_image_tokens = (input_ids == self.config.image_token_id).sum()
656
  n_image_features = image_embeds.shape[0]
 
707
  or (past_key_values is None or past_key_values.get_seq_length() == 0)
708
  )
709
  if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
710
+ position_ids, rope_deltas = self.get_rope_index_with_compare_token(
711
  input_ids,
712
  image_grid_thw,
713
  video_grid_thw,
 
752
  rope_deltas=self.rope_deltas,
753
  )
754
  return output if return_dict else output.to_tuple()
755
+
756
+ def get_rope_index_with_compare_token(
757
+ self,
758
+ input_ids: Optional[torch.LongTensor] = None,
759
+ image_grid_thw: Optional[torch.LongTensor] = None,
760
+ video_grid_thw: Optional[torch.LongTensor] = None,
761
+ second_per_grid_ts: Optional[torch.Tensor] = None,
762
+ attention_mask: Optional[torch.Tensor] = None,
763
+ ) -> tuple[torch.Tensor, torch.Tensor]:
764
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
765
+ image_token_id = self.config.image_token_id
766
+ video_token_id = self.config.video_token_id
767
+ vision_start_token_id = self.config.vision_start_token_id
768
+ mrope_position_deltas = []
769
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
770
+ total_input_ids = input_ids
771
+ if attention_mask is None:
772
+ attention_mask = torch.ones_like(total_input_ids)
773
+ position_ids = torch.ones(
774
+ 3,
775
+ input_ids.shape[0],
776
+ input_ids.shape[1],
777
+ dtype=input_ids.dtype,
778
+ device=input_ids.device,
779
+ )
780
+ image_index, video_index = 0, 0
781
+ attention_mask = attention_mask.to(total_input_ids.device)
782
+ for i, input_ids in enumerate(total_input_ids):
783
+ input_ids = input_ids[attention_mask[i] == 1]
784
+ image_nums, video_nums = 0, 0
785
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
786
+ vision_tokens = input_ids[vision_start_indices + 1]
787
+ image_nums = (vision_tokens == image_token_id).sum()
788
+ video_nums = (vision_tokens == video_token_id).sum()
789
+ input_tokens = input_ids.tolist()
790
+ llm_pos_ids_list: list = []
791
+ st = 0
792
+ remain_images, remain_videos = image_nums, video_nums
793
+ for vision_index in range(image_nums + video_nums):
794
+ if image_token_id in input_tokens and remain_images > 0:
795
+ ed_image = input_tokens.index(image_token_id, st)
796
+ else:
797
+ ed_image = len(input_tokens) + 1
798
+ if video_token_id in input_tokens and remain_videos > 0:
799
+ ed_video = input_tokens.index(video_token_id, st)
800
+ else:
801
+ ed_video = len(input_tokens) + 1
802
+ if ed_image < ed_video:
803
+ t, h, w = (
804
+ image_grid_thw[image_index][0],
805
+ image_grid_thw[image_index][1],
806
+ image_grid_thw[image_index][2],
807
+ )
808
+ second_per_grid_t = 0
809
+ image_index += 1
810
+ remain_images -= 1
811
+ ed = ed_image
812
+
813
+ else:
814
+ t, h, w = (
815
+ video_grid_thw[video_index][0],
816
+ video_grid_thw[video_index][1],
817
+ video_grid_thw[video_index][2],
818
+ )
819
+ if second_per_grid_ts is not None:
820
+ second_per_grid_t = second_per_grid_ts[video_index]
821
+ else:
822
+ second_per_grid_t = 1.0
823
+ video_index += 1
824
+ remain_videos -= 1
825
+ ed = ed_video
826
+ llm_grid_t, llm_grid_h, llm_grid_w = (
827
+ t.item(),
828
+ h.item() // spatial_merge_size,
829
+ w.item() // spatial_merge_size,
830
+ )
831
+ text_len = ed - st
832
+
833
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
834
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
835
+
836
+ range_tensor = torch.arange(llm_grid_t).view(-1, 1)
837
+ expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
838
+
839
+ ## normalize type, send to device.
840
+ second_per_grid_t = torch.as_tensor(
841
+ second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device
842
+ )
843
+
844
+ time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
845
+
846
+ time_tensor_long = time_tensor.long()
847
+ t_index = time_tensor_long.flatten()
848
+
849
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
850
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
851
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
852
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
853
+ if ed_image < ed_video:
854
+ # 如果当前是图片,则需要插入 compare_token_size 个图像对比的token的position
855
+ compare_t_index = t_index[-1].repeat(self.compare_token_size)
856
+ compare_h_index = torch.arange(self.compare_token_size)
857
+ compare_w_index = torch.arange(self.compare_token_size)
858
+ llm_pos_ids_list.append(torch.stack([compare_t_index, compare_h_index, compare_w_index]) + text_len + st_idx)
859
+ st = st + self.compare_token_size
860
+
861
+ if st < len(input_tokens):
862
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
863
+ text_len = len(input_tokens) - st
864
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
865
+
866
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
867
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
868
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
869
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
870
+ return position_ids, mrope_position_deltas
871
+ else:
872
+ if attention_mask is not None:
873
+ position_ids = attention_mask.long().cumsum(-1) - 1
874
+ position_ids.masked_fill_(attention_mask == 0, 1)
875
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
876
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
877
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
878
+ else:
879
+ position_ids = (
880
+ torch.arange(input_ids.shape[1], device=input_ids.device)
881
+ .view(1, 1, -1)
882
+ .expand(3, input_ids.shape[0], -1)
883
+ )
884
+ mrope_position_deltas = torch.zeros(
885
+ [input_ids.shape[0], 1],
886
+ device=input_ids.device,
887
+ dtype=input_ids.dtype,
888
+ )
889
+
890
+ return position_ids, mrope_position_deltas
891
 
892
  class YangJianVLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
893
  config_class = YangJianConfig
 
895
  def __init__(self, config):
896
  super().__init__(config)
897
  self.model = YangJianVLModel(config)
898
+
899
+ # def _prepare_generation_config(self, generation_config, use_model_defaults, **kwargs: dict):
900
+ # model_kwargs = super()._prepare_generation_config(generation_config, use_model_defaults, **kwargs)
901
+ # compare_token_size = self.config.vision_config.compare_token_size
902
+ # input_dict = model_kwargs[1]
903
+ # input_ids = model_kwargs[1]["input_ids"]
904
+ # attention_mask = model_kwargs[1]["attention_mask"]
905
+ # if "pixel_values" in input_dict and input_dict["pixel_values"] is not None:
906
+ # image_grid_thw = input_dict["image_grid_thw"]
907
+
908
+ # # 计算每张图片的token数量
909
+ # image_token_counts = (image_grid_thw.prod(-1) // self.config.vision_config.spatial_merge_size**2).tolist()
910
+
911
+ # image_token_positions = (input_ids == self.config.image_token_id).nonzero(as_tuple=True)[1]
912
+ # # 倒序遍历图片,这样插入时不会影响前面图片的位置
913
+ # current_end = len(image_token_positions) # 最后一个图片token的结束位置
914
+ # for i in range(len(image_token_counts) - 1, -1, -1):
915
+ # count = image_token_counts[i]
916
+ # # 计算当前图片的结束位置
917
+ # start = current_end - count # 当前图片的起始位置
918
+ # end_index = image_token_positions[current_end - 1] # 当前图片的最后一个token位置
919
+
920
+ # # 在第i张图片的末尾插入 self.compare_token_size 个图像对比的token
921
+ # # 获取插入位置的token的值
922
+ # prev_token = input_ids[:, end_index]
923
+ # input_ids = torch.cat([
924
+ # input_ids[:, :end_index + 1],
925
+ # prev_token.repeat(input_ids.shape[0], compare_token_size),
926
+ # input_ids[:, end_index + 1:]
927
+ # ], dim=1)
928
+
929
+ # # 同步更新attention_mask和position_ids
930
+ # if attention_mask is not None:
931
+ # prev_mask = attention_mask[:, end_index]
932
+ # attention_mask = torch.cat([
933
+ # attention_mask[:, :end_index + 1],
934
+ # prev_mask.repeat(input_ids.shape[0], compare_token_size),
935
+ # attention_mask[:, end_index + 1:]
936
+ # ], dim=1)
937
+
938
+ # current_end = start # 更新结束位置为当前图片的起始位置
939
+
940
+ # model_kwargs[1]["input_ids"] = input_ids
941
+ # model_kwargs[1]["attention_mask"] = attention_mask
942
+ # return model_kwargs