Upload processor
Browse files- modeling_yangjian.py +187 -12
modeling_yangjian.py
CHANGED
|
@@ -116,6 +116,7 @@ class YangJianProcessor(Qwen2_5_VLProcessor):
|
|
| 116 |
for i in range(len(text)):
|
| 117 |
while self.image_token in text[i]:
|
| 118 |
num_image_tokens = image_grid_thw[index].prod() // merge_length
|
|
|
|
| 119 |
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens + self.compare_token_size), 1)
|
| 120 |
index += 1
|
| 121 |
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
|
@@ -486,7 +487,7 @@ class YangJianVisionTransformerPretrainedModel(Qwen2_5_VisionTransformerPretrain
|
|
| 486 |
def __init__(self, config, *inputs, **kwargs) -> None:
|
| 487 |
super().__init__(config, *inputs, **kwargs)
|
| 488 |
self.compare_visual_encoder = YangJianCompareVisualEncoder(config)
|
| 489 |
-
|
| 490 |
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 491 |
"""
|
| 492 |
Args:
|
|
@@ -570,6 +571,7 @@ class YangJianVLModel(Qwen2_5_VLModel):
|
|
| 570 |
def __init__(self, config):
|
| 571 |
super().__init__(config)
|
| 572 |
self.visual = YangJianVisionTransformerPretrainedModel._from_config(config.vision_config)
|
|
|
|
| 573 |
# self.learnable_image_embeddings = nn.Parameter(
|
| 574 |
# torch.randn(100, config.hidden_size) * 0.02 # 使用小的初始化值
|
| 575 |
# )
|
|
@@ -644,19 +646,11 @@ class YangJianVLModel(Qwen2_5_VLModel):
|
|
| 644 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 645 |
|
| 646 |
if inputs_embeds is None:
|
|
|
|
| 647 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 648 |
if pixel_values is not None:
|
| 649 |
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
| 650 |
-
|
| 651 |
-
# # 为每个图像添加 100 个可学习的 embedding
|
| 652 |
-
# learnable_embeddings = self.learnable_image_embeddings.to(image_embeds[0].device, image_embeds[0].dtype)
|
| 653 |
-
# enhanced_image_embeds = []
|
| 654 |
-
|
| 655 |
-
# for i, embeds in enumerate(image_embeds):
|
| 656 |
-
# # 为每个图像添加 100 个可学习的 embedding
|
| 657 |
-
# enhanced_embeds = torch.cat([embeds, learnable_embeddings], dim=0)
|
| 658 |
-
# enhanced_image_embeds.append(enhanced_embeds)
|
| 659 |
-
|
| 660 |
image_embeds = torch.cat(image_embeds, dim=0)
|
| 661 |
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
| 662 |
n_image_features = image_embeds.shape[0]
|
|
@@ -713,7 +707,7 @@ class YangJianVLModel(Qwen2_5_VLModel):
|
|
| 713 |
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
| 714 |
)
|
| 715 |
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
| 716 |
-
position_ids, rope_deltas = self.
|
| 717 |
input_ids,
|
| 718 |
image_grid_thw,
|
| 719 |
video_grid_thw,
|
|
@@ -758,6 +752,142 @@ class YangJianVLModel(Qwen2_5_VLModel):
|
|
| 758 |
rope_deltas=self.rope_deltas,
|
| 759 |
)
|
| 760 |
return output if return_dict else output.to_tuple()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 761 |
|
| 762 |
class YangJianVLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
| 763 |
config_class = YangJianConfig
|
|
@@ -765,3 +895,48 @@ class YangJianVLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
|
| 765 |
def __init__(self, config):
|
| 766 |
super().__init__(config)
|
| 767 |
self.model = YangJianVLModel(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
for i in range(len(text)):
|
| 117 |
while self.image_token in text[i]:
|
| 118 |
num_image_tokens = image_grid_thw[index].prod() // merge_length
|
| 119 |
+
# text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens), 1)
|
| 120 |
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens + self.compare_token_size), 1)
|
| 121 |
index += 1
|
| 122 |
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
|
|
|
| 487 |
def __init__(self, config, *inputs, **kwargs) -> None:
|
| 488 |
super().__init__(config, *inputs, **kwargs)
|
| 489 |
self.compare_visual_encoder = YangJianCompareVisualEncoder(config)
|
| 490 |
+
|
| 491 |
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 492 |
"""
|
| 493 |
Args:
|
|
|
|
| 571 |
def __init__(self, config):
|
| 572 |
super().__init__(config)
|
| 573 |
self.visual = YangJianVisionTransformerPretrainedModel._from_config(config.vision_config)
|
| 574 |
+
self.compare_token_size = config.vision_config.compare_token_size
|
| 575 |
# self.learnable_image_embeddings = nn.Parameter(
|
| 576 |
# torch.randn(100, config.hidden_size) * 0.02 # 使用小的初始化值
|
| 577 |
# )
|
|
|
|
| 646 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 647 |
|
| 648 |
if inputs_embeds is None:
|
| 649 |
+
|
| 650 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 651 |
if pixel_values is not None:
|
| 652 |
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
| 653 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
image_embeds = torch.cat(image_embeds, dim=0)
|
| 655 |
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
| 656 |
n_image_features = image_embeds.shape[0]
|
|
|
|
| 707 |
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
| 708 |
)
|
| 709 |
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
| 710 |
+
position_ids, rope_deltas = self.get_rope_index_with_compare_token(
|
| 711 |
input_ids,
|
| 712 |
image_grid_thw,
|
| 713 |
video_grid_thw,
|
|
|
|
| 752 |
rope_deltas=self.rope_deltas,
|
| 753 |
)
|
| 754 |
return output if return_dict else output.to_tuple()
|
| 755 |
+
|
| 756 |
+
def get_rope_index_with_compare_token(
|
| 757 |
+
self,
|
| 758 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 759 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 760 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 761 |
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
| 762 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 763 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 764 |
+
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
| 765 |
+
image_token_id = self.config.image_token_id
|
| 766 |
+
video_token_id = self.config.video_token_id
|
| 767 |
+
vision_start_token_id = self.config.vision_start_token_id
|
| 768 |
+
mrope_position_deltas = []
|
| 769 |
+
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
| 770 |
+
total_input_ids = input_ids
|
| 771 |
+
if attention_mask is None:
|
| 772 |
+
attention_mask = torch.ones_like(total_input_ids)
|
| 773 |
+
position_ids = torch.ones(
|
| 774 |
+
3,
|
| 775 |
+
input_ids.shape[0],
|
| 776 |
+
input_ids.shape[1],
|
| 777 |
+
dtype=input_ids.dtype,
|
| 778 |
+
device=input_ids.device,
|
| 779 |
+
)
|
| 780 |
+
image_index, video_index = 0, 0
|
| 781 |
+
attention_mask = attention_mask.to(total_input_ids.device)
|
| 782 |
+
for i, input_ids in enumerate(total_input_ids):
|
| 783 |
+
input_ids = input_ids[attention_mask[i] == 1]
|
| 784 |
+
image_nums, video_nums = 0, 0
|
| 785 |
+
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
|
| 786 |
+
vision_tokens = input_ids[vision_start_indices + 1]
|
| 787 |
+
image_nums = (vision_tokens == image_token_id).sum()
|
| 788 |
+
video_nums = (vision_tokens == video_token_id).sum()
|
| 789 |
+
input_tokens = input_ids.tolist()
|
| 790 |
+
llm_pos_ids_list: list = []
|
| 791 |
+
st = 0
|
| 792 |
+
remain_images, remain_videos = image_nums, video_nums
|
| 793 |
+
for vision_index in range(image_nums + video_nums):
|
| 794 |
+
if image_token_id in input_tokens and remain_images > 0:
|
| 795 |
+
ed_image = input_tokens.index(image_token_id, st)
|
| 796 |
+
else:
|
| 797 |
+
ed_image = len(input_tokens) + 1
|
| 798 |
+
if video_token_id in input_tokens and remain_videos > 0:
|
| 799 |
+
ed_video = input_tokens.index(video_token_id, st)
|
| 800 |
+
else:
|
| 801 |
+
ed_video = len(input_tokens) + 1
|
| 802 |
+
if ed_image < ed_video:
|
| 803 |
+
t, h, w = (
|
| 804 |
+
image_grid_thw[image_index][0],
|
| 805 |
+
image_grid_thw[image_index][1],
|
| 806 |
+
image_grid_thw[image_index][2],
|
| 807 |
+
)
|
| 808 |
+
second_per_grid_t = 0
|
| 809 |
+
image_index += 1
|
| 810 |
+
remain_images -= 1
|
| 811 |
+
ed = ed_image
|
| 812 |
+
|
| 813 |
+
else:
|
| 814 |
+
t, h, w = (
|
| 815 |
+
video_grid_thw[video_index][0],
|
| 816 |
+
video_grid_thw[video_index][1],
|
| 817 |
+
video_grid_thw[video_index][2],
|
| 818 |
+
)
|
| 819 |
+
if second_per_grid_ts is not None:
|
| 820 |
+
second_per_grid_t = second_per_grid_ts[video_index]
|
| 821 |
+
else:
|
| 822 |
+
second_per_grid_t = 1.0
|
| 823 |
+
video_index += 1
|
| 824 |
+
remain_videos -= 1
|
| 825 |
+
ed = ed_video
|
| 826 |
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
| 827 |
+
t.item(),
|
| 828 |
+
h.item() // spatial_merge_size,
|
| 829 |
+
w.item() // spatial_merge_size,
|
| 830 |
+
)
|
| 831 |
+
text_len = ed - st
|
| 832 |
+
|
| 833 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
| 834 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
| 835 |
+
|
| 836 |
+
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
|
| 837 |
+
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
|
| 838 |
+
|
| 839 |
+
## normalize type, send to device.
|
| 840 |
+
second_per_grid_t = torch.as_tensor(
|
| 841 |
+
second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device
|
| 842 |
+
)
|
| 843 |
+
|
| 844 |
+
time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
|
| 845 |
+
|
| 846 |
+
time_tensor_long = time_tensor.long()
|
| 847 |
+
t_index = time_tensor_long.flatten()
|
| 848 |
+
|
| 849 |
+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
| 850 |
+
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
| 851 |
+
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
| 852 |
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
| 853 |
+
if ed_image < ed_video:
|
| 854 |
+
# 如果当前是图片,则需要插入 compare_token_size 个图像对比的token的position
|
| 855 |
+
compare_t_index = t_index[-1].repeat(self.compare_token_size)
|
| 856 |
+
compare_h_index = torch.arange(self.compare_token_size)
|
| 857 |
+
compare_w_index = torch.arange(self.compare_token_size)
|
| 858 |
+
llm_pos_ids_list.append(torch.stack([compare_t_index, compare_h_index, compare_w_index]) + text_len + st_idx)
|
| 859 |
+
st = st + self.compare_token_size
|
| 860 |
+
|
| 861 |
+
if st < len(input_tokens):
|
| 862 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
| 863 |
+
text_len = len(input_tokens) - st
|
| 864 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
| 865 |
+
|
| 866 |
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
| 867 |
+
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
| 868 |
+
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
|
| 869 |
+
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
| 870 |
+
return position_ids, mrope_position_deltas
|
| 871 |
+
else:
|
| 872 |
+
if attention_mask is not None:
|
| 873 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 874 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 875 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
| 876 |
+
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
| 877 |
+
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
| 878 |
+
else:
|
| 879 |
+
position_ids = (
|
| 880 |
+
torch.arange(input_ids.shape[1], device=input_ids.device)
|
| 881 |
+
.view(1, 1, -1)
|
| 882 |
+
.expand(3, input_ids.shape[0], -1)
|
| 883 |
+
)
|
| 884 |
+
mrope_position_deltas = torch.zeros(
|
| 885 |
+
[input_ids.shape[0], 1],
|
| 886 |
+
device=input_ids.device,
|
| 887 |
+
dtype=input_ids.dtype,
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
return position_ids, mrope_position_deltas
|
| 891 |
|
| 892 |
class YangJianVLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
| 893 |
config_class = YangJianConfig
|
|
|
|
| 895 |
def __init__(self, config):
|
| 896 |
super().__init__(config)
|
| 897 |
self.model = YangJianVLModel(config)
|
| 898 |
+
|
| 899 |
+
# def _prepare_generation_config(self, generation_config, use_model_defaults, **kwargs: dict):
|
| 900 |
+
# model_kwargs = super()._prepare_generation_config(generation_config, use_model_defaults, **kwargs)
|
| 901 |
+
# compare_token_size = self.config.vision_config.compare_token_size
|
| 902 |
+
# input_dict = model_kwargs[1]
|
| 903 |
+
# input_ids = model_kwargs[1]["input_ids"]
|
| 904 |
+
# attention_mask = model_kwargs[1]["attention_mask"]
|
| 905 |
+
# if "pixel_values" in input_dict and input_dict["pixel_values"] is not None:
|
| 906 |
+
# image_grid_thw = input_dict["image_grid_thw"]
|
| 907 |
+
|
| 908 |
+
# # 计算每张图片的token数量
|
| 909 |
+
# image_token_counts = (image_grid_thw.prod(-1) // self.config.vision_config.spatial_merge_size**2).tolist()
|
| 910 |
+
|
| 911 |
+
# image_token_positions = (input_ids == self.config.image_token_id).nonzero(as_tuple=True)[1]
|
| 912 |
+
# # 倒序遍历图片,这样插入时不会影响前面图片的位置
|
| 913 |
+
# current_end = len(image_token_positions) # 最后一个图片token的结束位置
|
| 914 |
+
# for i in range(len(image_token_counts) - 1, -1, -1):
|
| 915 |
+
# count = image_token_counts[i]
|
| 916 |
+
# # 计算当前图片的结束位置
|
| 917 |
+
# start = current_end - count # 当前图片的起始位置
|
| 918 |
+
# end_index = image_token_positions[current_end - 1] # 当前图片的最后一个token位置
|
| 919 |
+
|
| 920 |
+
# # 在第i张图片的末尾插入 self.compare_token_size 个图像对比的token
|
| 921 |
+
# # 获取插入位置的token的值
|
| 922 |
+
# prev_token = input_ids[:, end_index]
|
| 923 |
+
# input_ids = torch.cat([
|
| 924 |
+
# input_ids[:, :end_index + 1],
|
| 925 |
+
# prev_token.repeat(input_ids.shape[0], compare_token_size),
|
| 926 |
+
# input_ids[:, end_index + 1:]
|
| 927 |
+
# ], dim=1)
|
| 928 |
+
|
| 929 |
+
# # 同步更新attention_mask和position_ids
|
| 930 |
+
# if attention_mask is not None:
|
| 931 |
+
# prev_mask = attention_mask[:, end_index]
|
| 932 |
+
# attention_mask = torch.cat([
|
| 933 |
+
# attention_mask[:, :end_index + 1],
|
| 934 |
+
# prev_mask.repeat(input_ids.shape[0], compare_token_size),
|
| 935 |
+
# attention_mask[:, end_index + 1:]
|
| 936 |
+
# ], dim=1)
|
| 937 |
+
|
| 938 |
+
# current_end = start # 更新结束位置为当前图片的起始位置
|
| 939 |
+
|
| 940 |
+
# model_kwargs[1]["input_ids"] = input_ids
|
| 941 |
+
# model_kwargs[1]["attention_mask"] = attention_mask
|
| 942 |
+
# return model_kwargs
|