import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from typing import Any, Callable, Optional, Union from transformers import Qwen2_5_VLForConditionalGeneration, AutoModelForImageTextToText from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLModel, Qwen2RMSNorm, Qwen2_5_VLMLP, ALL_ATTENTION_FUNCTIONS ) from transformers.image_utils import ImageInput from transformers.tokenization_utils import TextInput, PreTokenizedInput from transformers.video_utils import VideoInput from transformers.feature_extraction_utils import BatchFeature from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLConfig from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessorKwargs class ADCopilotConfig(Qwen2_5_VLConfig): model_type = "ad_copilot" def __init__(self, **kwargs): super().__init__(**kwargs) self.vision_config.compare_token_size = 100 self.architectures = ["ADCopilotVLForConditionalGeneration"] self.sequence_compare = True class ADCopilotProcessor(Qwen2_5_VLProcessor): config_class = ADCopilotConfig def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs): super().__init__(image_processor, tokenizer, video_processor, chat_template, **kwargs) self.compare_token_size = 100 if "compare_token_size" not in kwargs else kwargs["compare_token_size"] def __call__( self, images: ImageInput = None, text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, videos: VideoInput = None, **kwargs, ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. Args: images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. text (`str`, `list[str]`, `list[list[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. """ output_kwargs = self._merge_kwargs( Qwen2_5_VLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) image_inputs = videos_inputs = {} if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] if videos is not None: fps = output_kwargs["videos_kwargs"].get("fps", 2.0) videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] if isinstance(fps, (int, float)): second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw) elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps] else: raise ValueError( f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." ) videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) if not isinstance(text, list): text = [text] text = text.copy() # below lines change text in-place if images is not None: merge_length = self.image_processor.merge_size**2 index = 0 for i in range(len(text)): while self.image_token in text[i]: num_image_tokens = image_grid_thw[index].prod() // merge_length # text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens), 1) text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens + self.compare_token_size), 1) index += 1 text[i] = text[i].replace("<|placeholder|>", self.image_token) if videos is not None: merge_length = self.video_processor.merge_size**2 index = 0 for i in range(len(text)): while self.video_token in text[i]: num_video_tokens = video_grid_thw[index].prod() // merge_length text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1) index += 1 text[i] = text[i].replace("<|placeholder|>", self.video_token) return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) if return_mm_token_type_ids: array_ids = np.array(text_inputs["input_ids"]) mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) mm_token_type_ids[array_ids == self.image_token_id] = 1 text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) class OptimizedCrossAttention(nn.Module): """ 仿照 Qwen2_5_VLVisionAttention 结构的优化 Cross Attention """ def __init__(self, config, is_cross_attention=True): super().__init__() self.config = config self.dim = config.hidden_size self.num_heads = config.num_heads self.head_dim = self.dim // self.num_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = 0.0 self.is_causal = False # cross attention 不需要因果掩码 self.is_cross_attention = is_cross_attention if is_cross_attention: # Cross attention: Q 来自一个序列,K、V 来自另一个序列 self.q_proj = nn.Linear(self.dim, self.dim, bias=True) self.kv = nn.Linear(self.dim, self.dim * 2, bias=True) # 融合 K、V else: # Self attention: Q、K、V 来自同一个序列 self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) # 融合 Q、K、V self.proj = nn.Linear(self.dim, self.dim, bias=True) def forward( self, query_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, # 只FA2用 kv_cu_seqlens: Optional[torch.Tensor] = None,# 只FA2用 **kwargs, ) -> torch.Tensor: # 允许 query_states [B,T,d] 或 [T,d],自动扩展 batch 维 orig_2d = False if query_states.dim() == 2: query_states = query_states.unsqueeze(0) orig_2d = True batch_size, seq_len_q, _ = query_states.shape # Q/K/V投影 if self.is_cross_attention and key_value_states is not None: if key_value_states.dim() == 2: key_value_states = key_value_states.unsqueeze(0) q = self.q_proj(query_states) kv = self.kv(key_value_states) seq_len_kv = kv.shape[1] k, v = kv.reshape(batch_size, seq_len_kv, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0) q = q.reshape(batch_size, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2) else: if key_value_states is None: key_value_states = query_states qkv = self.qkv(query_states) q, k, v = qkv.reshape(batch_size, seq_len_q, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0) # 选用哪个 attention kernel attn_impl = getattr(self.config, '_attn_implementation', 'sdpa') attn_impl = 'sdpa' attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[attn_impl] # ========= 支持 FA2 ========== if attn_impl == "flash_attention_2": # Qwen2_5 之所以能支持 FA2,是因为准备了 flatten+cu_seqlens # 这里假设 query_states/key_value_states 按 batch 维是变长的 # 检查 cu_seqlens,有就用,否则尝试自动生成 if cu_seqlens is None: # 默认把每个batch都视为长度=seq_len_q cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len_q, step=seq_len_q, dtype=torch.int32, device=q.device) if kv_cu_seqlens is None: cu_seqlens_k = torch.arange(0, (batch_size + 1) * k.shape[2], step=k.shape[2], dtype=torch.int32, device=k.device) else: cu_seqlens_k = kv_cu_seqlens # flatten [B, nH, T, d] -> [total_T, nH, d] # 注意!FlashAttn2是 (total, nH, d),不是 (nH, total, d),和普通实现不一样 # 更安全的 flatten 方式 # [B, nH, T, d] -> [B, T, nH, d] -> [total_T, nH, d] q_ = q.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim) k_ = k.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim) v_ = v.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim) max_seqlen_q = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() attn_output, _ = attention_interface( self, q_, k_, v_, attention_mask=None, scaling=self.scaling, dropout=0.0 if not self.training else self.attention_dropout, cu_seq_lens_q=cu_seqlens, cu_seq_lens_k=cu_seqlens_k, max_length_q=max_seqlen_q, max_length_k=max_seqlen_k, is_causal=self.is_causal, **kwargs, ) # 更简洁的输出重构 # [total_q, nH, d] -> [B, seq_len_q, nH, d] attn_output = attn_output.view(batch_size, seq_len_q, self.num_heads, self.head_dim).contiguous() else: # 普通实现,下游实现就是 [B, nH, T, d] attn_output, _ = attention_interface( self, q, k, v, attention_mask=attention_mask, scaling=self.scaling, dropout=0.0 if not self.training else self.attention_dropout, is_causal=self.is_causal, **kwargs, ) # attn_output: [B, nH, seq_q, d] attn_output = attn_output.transpose(1, 2).contiguous() # [B, seq_q, nH, d] attn_output = attn_output.reshape(batch_size, seq_len_q, self.dim) # [B, seq_q, D] attn_output = self.proj(attn_output) if orig_2d: attn_output = attn_output.squeeze(0) return attn_output.contiguous() class ADCopilotCompareVisualEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.sequence_compare = getattr(config, "sequence_compare", True) self.hidden_size = config.hidden_size # self.token_size = 100 * (config.spatial_merge_size**2) if "compare_token_size" not in config else config.compare_token_size * (config.spatial_merge_size**2) self.token_size = 100 if "compare_token_size" not in config else config.compare_token_size # Encoder 部分:双向图像特征交互 # 第一个cross attention: previous attend to current self.encoder_cross_attn1 = OptimizedCrossAttention(config, is_cross_attention=True) # 第二个cross attention: current attend to previous self.encoder_cross_attn2 = OptimizedCrossAttention(config, is_cross_attention=True) self.encoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) self.encoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) self.encoder_norm3 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) self.encoder_norm4 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) self.encoder_mlp1 = Qwen2_5_VLMLP(config) self.encoder_mlp2 = Qwen2_5_VLMLP(config) # Decoder 部分:Query 与编码特征交互 # 可学习的 Query Embeddings self.query_embeddings = nn.Parameter( torch.empty(self.token_size, self.hidden_size) ) # 只保留 Cross Attention for queries to attend to encoded features self.decoder_cross_attn = OptimizedCrossAttention(config, is_cross_attention=True) self.decoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) self.decoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) self.decoder_mlp = Qwen2_5_VLMLP(config) self.compare_projector = nn.Linear(config.hidden_size, config.out_hidden_size) def init_query_embeddings(self): nn.init.normal_(self.query_embeddings, mean=0.0, std=0.02) def forward(self, images_hidden_states: list) -> torch.Tensor: """ Args: images_hidden_states: List of tensor, each tensor has shape [seq_len, hidden_size] Returns: Tensor of shape [total_images, token_size, hidden_size] """ if not images_hidden_states: return torch.empty(0, self.token_size, self.hidden_size) # 检查 query_embeddings 是否包含 NaN if torch.isnan(self.query_embeddings).any(): print("警告:query_embeddings 包含 NaN 值") # nn.init.normal_(self.query_embeddings, mean=0.0, std=0.02) # 获取每个图像的序列长度 seq_lengths = [state.size(0) for state in images_hidden_states] max_seq_len = max(seq_lengths) batch_size = len(images_hidden_states) device = images_hidden_states[0].device dtype = images_hidden_states[0].dtype # 将所有图像填充到相同长度并堆叠 padded_states = [] attention_masks = [] for state in images_hidden_states: pad_len = max_seq_len - state.size(0) if pad_len > 0: # 填充序列 padded_state = F.pad(state, (0, 0, 0, pad_len), mode='constant', value=0) # 创建注意力掩码 attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device) attention_mask[state.size(0):] = False else: padded_state = state attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device) padded_states.append(padded_state) attention_masks.append(attention_mask) # [batch_size, max_seq_len, hidden_size] batched_states = torch.stack(padded_states) # [batch_size, max_seq_len] attention_masks = torch.stack(attention_masks) # 创建循环移位的状态用于对比 # 对于第一个图像,使用自身作为previous previous_states = torch.roll(batched_states, shifts=1, dims=0) previous_masks = torch.roll(attention_masks, shifts=1, dims=0) if previous_states.size(0) > 1 and self.sequence_compare: previous_states[0] = previous_states[1] previous_masks[0] = previous_masks[1] # Encoder: 批量处理所有图像 encoded_features = self._encoder_forward( batched_states, # [batch_size, max_seq_len, hidden_size] previous_states, # [batch_size, max_seq_len, hidden_size] attention_masks, # [batch_size, max_seq_len] previous_masks # [batch_size, max_seq_len] ) # Decoder: 批量处理所有图像 # 扩展query_embeddings到batch维度 batch_queries = self.query_embeddings.unsqueeze(0).expand(batch_size, -1, -1) # [batch_size, token_size, hidden_size] compare_visual_embeds = self._decoder_forward( batch_queries, encoded_features, torch.ones(batch_size, self.token_size, dtype=torch.bool, device=device), # query掩码 attention_masks # encoded特征的掩码 ) # 记录每个batch的token数量 batch_size = compare_visual_embeds.size(0) token_size = compare_visual_embeds.size(1) # 将所有batch的数据拼接在一起 # [batch_size * token_size, hidden_size] flattened_embeds = compare_visual_embeds.view(-1, compare_visual_embeds.size(-1)) merged = self.compare_projector(flattened_embeds) # [batch_size * token_size, merged_hidden_size] merged_token_size = token_size # [batch_size, merged_token_size, merged_hidden_size] compare_visual_embeds = merged.view(batch_size, merged_token_size, -1) return compare_visual_embeds # [batch_size, token_size, out_hidden_size] def _encoder_forward(self, current_features, previous_features, current_mask=None, previous_mask=None): """ Encoder: 双向图像特征交互 Args: current_features: [batch_size, seq_len, hidden_size] previous_features: [batch_size, seq_len, hidden_size] current_mask: [batch_size, seq_len] previous_mask: [batch_size, seq_len] """ # 第一步:previous attend to current residual = previous_features # Layer norm previous_normed = self.encoder_norm1(previous_features) current_normed1 = self.encoder_norm1(current_features) # Cross attention: previous attend to current cross_attn_output1 = self.encoder_cross_attn1( query_states=previous_normed, key_value_states=current_normed1, attention_mask=current_mask.unsqueeze(1).unsqueeze(2) if current_mask is not None else None ) # Residual connection previous_features = residual + cross_attn_output1 # MLP for previous features residual = previous_features mlp_input1 = self.encoder_norm2(previous_features) mlp_output1 = self.encoder_mlp1(mlp_input1) previous_features = residual + mlp_output1 # 第二步:current attend to previous (enhanced) residual = current_features # Layer norm current_normed2 = self.encoder_norm3(current_features) previous_normed2 = self.encoder_norm3(previous_features) # Cross attention: current attend to previous cross_attn_output2 = self.encoder_cross_attn2( query_states=current_normed2, key_value_states=previous_normed2, attention_mask=previous_mask.unsqueeze(1).unsqueeze(2) if previous_mask is not None else None ) # Residual connection current_features = residual + cross_attn_output2 # MLP for current features residual = current_features mlp_input2 = self.encoder_norm4(current_features) mlp_output2 = self.encoder_mlp2(mlp_input2) # current_features = residual + mlp_output2 # 修改为减法 current_features = residual - mlp_output2 return current_features def _decoder_forward(self, queries, encoded_features, query_mask=None, encoded_mask=None): """ Decoder: Query 与编码特征交互 Args: queries: [batch_size, token_size, hidden_size] encoded_features: [batch_size, seq_len, hidden_size] query_mask: [batch_size, token_size] encoded_mask: [batch_size, seq_len] """ # Cross attention: queries attend to encoded features residual = queries queries_normed = self.decoder_norm1(queries) encoded_normed = self.decoder_norm1(encoded_features) cross_attn_output = self.decoder_cross_attn( query_states=queries_normed, key_value_states=encoded_normed, attention_mask=encoded_mask.unsqueeze(1).unsqueeze(2) if encoded_mask is not None else None ) queries = residual + cross_attn_output # MLP residual = queries mlp_input = self.decoder_norm2(queries) mlp_output = self.decoder_mlp(mlp_input) queries = residual + mlp_output return queries # [batch_size, token_size, hidden_size] # 先把组件继承出来方便修改 class ADCopilotVisionTransformerPretrainedModel(Qwen2_5_VisionTransformerPretrainedModel): def __init__(self, config, *inputs, **kwargs) -> None: super().__init__(config, *inputs, **kwargs) self.compare_visual_encoder = ADCopilotCompareVisualEncoder(config) def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. Returns: `torch.Tensor`: hidden_states, compare_visual_embeds. """ hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) window_index, cu_window_seqlens = self.get_window_index(grid_thw) cu_window_seqlens = torch.tensor( cu_window_seqlens, device=hidden_states.device, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, # Select dtype based on the following factors: # - FA2 requires that cu_seqlens_q must have dtype int32 # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw # See https://github.com/huggingface/transformers/pull/34852 for more information dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs, ) split_sizes = grid_thw.prod(-1).tolist() splited_hidden_states_before_merger = torch.split(hidden_states, split_sizes) # [total_images, token_size, hidden_size] compare_visual_embeds = self.compare_visual_encoder(splited_hidden_states_before_merger) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) hidden_states = hidden_states[reverse_indices, :] return hidden_states, compare_visual_embeds class ADCopilotVLModel(Qwen2_5_VLModel): def __init__(self, config): super().__init__(config) self.visual = ADCopilotVisionTransformerPretrainedModel._from_config(config.vision_config) self.compare_token_size = config.vision_config.compare_token_size # self.learnable_image_embeddings = nn.Parameter( # torch.randn(100, config.hidden_size) * 0.02 # 使用小的初始化值 # ) def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): """ Encodes images into continuous embeddings that can be forwarded to the language model. Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) image_embeds, compare_visual_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) # 每个图像添加了对比感知token split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) # 将图像嵌入和对比视觉嵌入拼接 enhanced_image_embeds = [] for i, embeds in enumerate(image_embeds): # 确保 compare_visual_embeds[i] 与 embeds 在相同设备和数据类型 compare_embed = compare_visual_embeds[i].to(device=embeds.device, dtype=embeds.dtype) enhanced_embeds = torch.cat([embeds, compare_embed], dim=0) enhanced_image_embeds.append(enhanced_embeds) # image_embeds = torch.cat(enhanced_image_embeds, dim=0) return enhanced_image_embeds def get_rope_index(self, input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]: return self.get_rope_index_with_compare_token(input_ids, image_grid_thw, video_grid_thw, second_per_grid_ts, attention_mask) def get_rope_index_with_compare_token( self, input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: spatial_merge_size = self.config.vision_config.spatial_merge_size image_token_id = self.config.image_token_id video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id mrope_position_deltas = [] if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): total_input_ids = input_ids if attention_mask is None: attention_mask = torch.ones_like(total_input_ids) position_ids = torch.ones( 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device, ) image_index, video_index = 0, 0 attention_mask = attention_mask.to(total_input_ids.device) for i, input_ids in enumerate(total_input_ids): input_ids = input_ids[attention_mask[i] == 1] image_nums, video_nums = 0, 0 vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] image_nums = (vision_tokens == image_token_id).sum() video_nums = (vision_tokens == video_token_id).sum() input_tokens = input_ids.tolist() llm_pos_ids_list: list = [] st = 0 remain_images, remain_videos = image_nums, video_nums for vision_index in range(image_nums + video_nums): if image_token_id in input_tokens and remain_images > 0: ed_image = input_tokens.index(image_token_id, st) else: ed_image = len(input_tokens) + 1 if video_token_id in input_tokens and remain_videos > 0: ed_video = input_tokens.index(video_token_id, st) else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: t, h, w = ( image_grid_thw[image_index][0], image_grid_thw[image_index][1], image_grid_thw[image_index][2], ) second_per_grid_t = 0 image_index += 1 remain_images -= 1 ed = ed_image else: t, h, w = ( video_grid_thw[video_index][0], video_grid_thw[video_index][1], video_grid_thw[video_index][2], ) if second_per_grid_ts is not None: second_per_grid_t = second_per_grid_ts[video_index] else: second_per_grid_t = 1.0 video_index += 1 remain_videos -= 1 ed = ed_video llm_grid_t, llm_grid_h, llm_grid_w = ( t.item(), h.item() // spatial_merge_size, w.item() // spatial_merge_size, ) text_len = ed - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) range_tensor = torch.arange(llm_grid_t).view(-1, 1) expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) ## normalize type, send to device. second_per_grid_t = torch.as_tensor( second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device ) time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second time_tensor_long = time_tensor.long() t_index = time_tensor_long.flatten() h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if ed_image < ed_video: # 如果当前是图片,则需要插入 compare_token_size 个图像对比的token的position compare_t_index = t_index[-1].repeat(self.compare_token_size) # compare_h_index = torch.arange(self.compare_token_size) # compare_w_index = torch.arange(self.compare_token_size) compare_h_index = compare_t_index compare_w_index = compare_t_index llm_pos_ids_list.append(torch.stack([compare_t_index, compare_h_index, compare_w_index]) + text_len + st_idx) st = st + self.compare_token_size if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas else: if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] else: position_ids = ( torch.arange(input_ids.shape[1], device=input_ids.device) .view(1, 1, -1) .expand(3, input_ids.shape[0], -1) ) mrope_position_deltas = torch.zeros( [input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype, ) return position_ids, mrope_position_deltas class ADCopilotVLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): config_class = ADCopilotConfig def __init__(self, config): super().__init__(config) self.model = ADCopilotVLModel(config)