| |
| |
| |
| |
| @@ -40,7 +40,7 @@ |
| from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput |
| from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from ...modeling_utils import PreTrainedModel |
| -from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging |
| +from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging |
| from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig |
| |
| |
| @@ -358,7 +358,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
| - _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` |
| + _supports_static_cache = True |
| |
| def _init_weights(self, module): |
| std = self.config.get_text_config().initializer_range |
| @@ -1659,9 +1659,9 @@ def forward( |
| inputs_embeds = self.get_input_embeddings()(input_ids) |
| if pixel_values is not None: |
| image_embeds = self.get_image_features(pixel_values, image_grid_thw) |
| - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() |
| + n_image_tokens = (input_ids == self.config.image_token_id).sum() |
| n_image_features = image_embeds.shape[0] |
| - if n_image_tokens != n_image_features: |
| + if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
| @@ -1676,9 +1676,9 @@ def forward( |
| |
| if pixel_values_videos is not None: |
| video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) |
| - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() |
| + n_video_tokens = (input_ids == self.config.video_token_id).sum() |
| n_video_features = video_embeds.shape[0] |
| - if n_video_tokens != n_video_features: |
| + if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: |
| raise ValueError( |
| f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
| ) |
| @@ -1694,20 +1694,32 @@ def forward( |
| if attention_mask is not None: |
| attention_mask = attention_mask.to(inputs_embeds.device) |
| |
| - # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme |
| - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): |
| - # calculate RoPE index once per generation in the pre-fill stage only |
| - if ( |
| + if position_ids is None: |
| + attention_mask_2d = attention_mask |
| + if attention_mask is not None and attention_mask.ndim == 4: |
| + attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2) |
| + attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min |
| + attention_mask_2d = (1.0 - attention_mask_2d).int() |
| + |
| + # Calculate RoPE index once per generation in the pre-fill stage only. |
| + # When compiling, we can't check tensor values thus we check only input length |
| + # It is safe to assume that `length!=1` means we're in pre-fill because compiled |
| + # models currently cannot do asssisted decoding |
| + prefill_compiled_stage = is_torchdynamo_compiling() and ( |
| + (input_ids is not None and input_ids.shape[1] != 1) |
| + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) |
| + ) |
| + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( |
| (cache_position is not None and cache_position[0] == 0) |
| - or self.rope_deltas is None |
| or (past_key_values is None or past_key_values.get_seq_length() == 0) |
| - ): |
| + ) |
| + if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: |
| position_ids, rope_deltas = self.get_rope_index( |
| input_ids, |
| image_grid_thw, |
| video_grid_thw, |
| - second_per_grid_ts, |
| - attention_mask, |
| + second_per_grid_ts=second_per_grid_ts, |
| + attention_mask=attention_mask_2d, |
| ) |
| self.rope_deltas = rope_deltas |
| # then use the prev pre-calculated rope-deltas to get the correct position ids |
| @@ -1747,6 +1759,61 @@ def forward( |
| ) |
| return output if return_dict else output.to_tuple() |
| |
| + @staticmethod |
| + def _prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask: torch.Tensor, |
| + sequence_length: int, |
| + target_length: int, |
| + dtype: torch.dtype, |
| + cache_position: torch.Tensor, |
| + batch_size: int, |
| + **kwargs, |
| + ): |
| + """ |
| + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| + |
| + Args: |
| + attention_mask (`torch.Tensor`): |
| + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
| + `(batch_size, 1, query_length, key_value_length)`. |
| + sequence_length (`int`): |
| + The sequence length being processed. |
| + target_length (`int`): |
| + The target length: when generating with static cache, the mask should be as long as the static cache, |
| + to account for the 0 padding, the part of the cache that is not filled yet. |
| + dtype (`torch.dtype`): |
| + The dtype to use for the 4D attention mask. |
| + cache_position (`torch.Tensor`): |
| + Indices depicting the position of the input sequence tokens in the sequence. |
| + batch_size (`torch.Tensor`): |
| + Batch size. |
| + """ |
| + if attention_mask is not None and attention_mask.dim() == 4: |
| + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
| + causal_mask = attention_mask |
| + else: |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = torch.full( |
| + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device |
| + ) |
| + if sequence_length != 1: |
| + causal_mask = torch.triu(causal_mask, diagonal=1) |
| + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) |
| + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| + if attention_mask is not None: |
| + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
| + mask_length = attention_mask.shape[-1] |
| + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( |
| + causal_mask.device |
| + ) |
| + padding_mask = padding_mask == 0 |
| + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| + padding_mask, min_dtype |
| + ) |
| + |
| + return causal_mask |
| + |
| |
| @dataclass |
| class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): |
| @@ -2108,60 +2175,5 @@ def _expand_dict_for_generation(dict_to_expand): |
| |
| return input_ids, model_kwargs |
| |
| - @staticmethod |
| - def _prepare_4d_causal_attention_mask_with_cache_position( |
| - attention_mask: torch.Tensor, |
| - sequence_length: int, |
| - target_length: int, |
| - dtype: torch.dtype, |
| - cache_position: torch.Tensor, |
| - batch_size: int, |
| - **kwargs, |
| - ): |
| - """ |
| - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| - |
| - Args: |
| - attention_mask (`torch.Tensor`): |
| - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
| - `(batch_size, 1, query_length, key_value_length)`. |
| - sequence_length (`int`): |
| - The sequence length being processed. |
| - target_length (`int`): |
| - The target length: when generating with static cache, the mask should be as long as the static cache, |
| - to account for the 0 padding, the part of the cache that is not filled yet. |
| - dtype (`torch.dtype`): |
| - The dtype to use for the 4D attention mask. |
| - cache_position (`torch.Tensor`): |
| - Indices depicting the position of the input sequence tokens in the sequence. |
| - batch_size (`torch.Tensor`): |
| - Batch size. |
| - """ |
| - if attention_mask is not None and attention_mask.dim() == 4: |
| - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
| - causal_mask = attention_mask |
| - else: |
| - min_dtype = torch.finfo(dtype).min |
| - causal_mask = torch.full( |
| - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device |
| - ) |
| - if sequence_length != 1: |
| - causal_mask = torch.triu(causal_mask, diagonal=1) |
| - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) |
| - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| - if attention_mask is not None: |
| - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
| - mask_length = attention_mask.shape[-1] |
| - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( |
| - causal_mask.device |
| - ) |
| - padding_mask = padding_mask == 0 |
| - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| - padding_mask, min_dtype |
| - ) |
| - |
| - return causal_mask |
| - |
| |
| __all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLTextModel"] |
| |
| |
| |
| |
| @@ -50,7 +50,7 @@ |
| from ...modeling_flash_attention_utils import is_flash_attn_available |
| from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs |
| from ...tokenization_utils_base import PreTokenizedInput, TextInput |
| -from ...utils import logging |
| +from ...utils import is_torchdynamo_compiling, logging |
| from ...video_utils import VideoInput |
| |
| |
| @@ -647,9 +647,9 @@ def forward( |
| inputs_embeds = self.get_input_embeddings()(input_ids) |
| if pixel_values is not None: |
| image_embeds = self.get_image_features(pixel_values, image_grid_thw) |
| - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() |
| + n_image_tokens = (input_ids == self.config.image_token_id).sum() |
| n_image_features = image_embeds.shape[0] |
| - if n_image_tokens != n_image_features: |
| + if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
| @@ -664,9 +664,9 @@ def forward( |
| |
| if pixel_values_videos is not None: |
| video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) |
| - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() |
| + n_video_tokens = (input_ids == self.config.video_token_id).sum() |
| n_video_features = video_embeds.shape[0] |
| - if n_video_tokens != n_video_features: |
| + if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: |
| raise ValueError( |
| f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
| ) |
| @@ -682,20 +682,32 @@ def forward( |
| if attention_mask is not None: |
| attention_mask = attention_mask.to(inputs_embeds.device) |
| |
| - # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme |
| - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): |
| - # calculate RoPE index once per generation in the pre-fill stage only |
| - if ( |
| + if position_ids is None: |
| + attention_mask_2d = attention_mask |
| + if attention_mask is not None and attention_mask.ndim == 4: |
| + attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2) |
| + attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min |
| + attention_mask_2d = (1.0 - attention_mask_2d).int() |
| + |
| + # Calculate RoPE index once per generation in the pre-fill stage only. |
| + # When compiling, we can't check tensor values thus we check only input length |
| + # It is safe to assume that `length!=1` means we're in pre-fill because compiled |
| + # models currently cannot do asssisted decoding |
| + prefill_compiled_stage = is_torchdynamo_compiling() and ( |
| + (input_ids is not None and input_ids.shape[1] != 1) |
| + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) |
| + ) |
| + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( |
| (cache_position is not None and cache_position[0] == 0) |
| - or self.rope_deltas is None |
| or (past_key_values is None or past_key_values.get_seq_length() == 0) |
| - ): |
| + ) |
| + if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: |
| position_ids, rope_deltas = self.get_rope_index( |
| input_ids, |
| image_grid_thw, |
| video_grid_thw, |
| - second_per_grid_ts, |
| - attention_mask, |
| + second_per_grid_ts=second_per_grid_ts, |
| + attention_mask=attention_mask_2d, |
| ) |
| self.rope_deltas = rope_deltas |
| # then use the prev pre-calculated rope-deltas to get the correct position ids |
| |
| |
| |
| |
| @@ -924,7 +924,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
| - _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` |
| + _supports_static_cache = True |
| |
| def _init_weights(self, module): |
| std = self.config.get_text_config().initializer_range |
| @@ -1616,16 +1616,28 @@ def forward( |
| if attention_mask is not None: |
| attention_mask = attention_mask.to(inputs_embeds.device) |
| |
| - # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme |
| - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): |
| - # calculate RoPE index once per generation in the pre-fill stage only |
| - if ( |
| + if position_ids is None: |
| + attention_mask_2d = attention_mask |
| + if attention_mask is not None and attention_mask.ndim == 4: |
| + attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2) |
| + attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min |
| + attention_mask_2d = (1.0 - attention_mask_2d).int() |
| + |
| + # Calculate RoPE index once per generation in the pre-fill stage only. |
| + # When compiling, we can't check tensor values thus we check only input length |
| + # It is safe to assume that `length!=1` means we're in pre-fill because compiled |
| + # models currently cannot do asssisted decoding |
| + prefill_compiled_stage = is_torchdynamo_compiling() and ( |
| + (input_ids is not None and input_ids.shape[1] != 1) |
| + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) |
| + ) |
| + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( |
| (cache_position is not None and cache_position[0] == 0) |
| - or self.rope_deltas is None |
| or (past_key_values is None or past_key_values.get_seq_length() == 0) |
| - ): |
| + ) |
| + if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: |
| position_ids, rope_deltas = self.get_rope_index( |
| - input_ids, image_grid_thw, video_grid_thw, attention_mask |
| + input_ids, image_grid_thw, video_grid_thw, attention_mask_2d |
| ) |
| self.rope_deltas = rope_deltas |
| # then use the prev pre-calculated rope-deltas to get the correct position ids |
| @@ -1662,6 +1674,62 @@ def forward( |
| ) |
| return output if return_dict else output.to_tuple() |
| |
| + @staticmethod |
| + # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position |
| + def _prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask: torch.Tensor, |
| + sequence_length: int, |
| + target_length: int, |
| + dtype: torch.dtype, |
| + cache_position: torch.Tensor, |
| + batch_size: int, |
| + **kwargs, |
| + ): |
| + """ |
| + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| + |
| + Args: |
| + attention_mask (`torch.Tensor`): |
| + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
| + `(batch_size, 1, query_length, key_value_length)`. |
| + sequence_length (`int`): |
| + The sequence length being processed. |
| + target_length (`int`): |
| + The target length: when generating with static cache, the mask should be as long as the static cache, |
| + to account for the 0 padding, the part of the cache that is not filled yet. |
| + dtype (`torch.dtype`): |
| + The dtype to use for the 4D attention mask. |
| + cache_position (`torch.Tensor`): |
| + Indices depicting the position of the input sequence tokens in the sequence. |
| + batch_size (`torch.Tensor`): |
| + Batch size. |
| + """ |
| + if attention_mask is not None and attention_mask.dim() == 4: |
| + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
| + causal_mask = attention_mask |
| + else: |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = torch.full( |
| + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device |
| + ) |
| + if sequence_length != 1: |
| + causal_mask = torch.triu(causal_mask, diagonal=1) |
| + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) |
| + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| + if attention_mask is not None: |
| + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
| + mask_length = attention_mask.shape[-1] |
| + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( |
| + causal_mask.device |
| + ) |
| + padding_mask = padding_mask == 0 |
| + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| + padding_mask, min_dtype |
| + ) |
| + |
| + return causal_mask |
| + |
| |
| class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): |
| _checkpoint_conversion_mapping = { |
| @@ -1974,61 +2042,5 @@ def _expand_dict_for_generation(dict_to_expand): |
| |
| return input_ids, model_kwargs |
| |
| - @staticmethod |
| - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position |
| - def _prepare_4d_causal_attention_mask_with_cache_position( |
| - attention_mask: torch.Tensor, |
| - sequence_length: int, |
| - target_length: int, |
| - dtype: torch.dtype, |
| - cache_position: torch.Tensor, |
| - batch_size: int, |
| - **kwargs, |
| - ): |
| - """ |
| - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| - |
| - Args: |
| - attention_mask (`torch.Tensor`): |
| - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
| - `(batch_size, 1, query_length, key_value_length)`. |
| - sequence_length (`int`): |
| - The sequence length being processed. |
| - target_length (`int`): |
| - The target length: when generating with static cache, the mask should be as long as the static cache, |
| - to account for the 0 padding, the part of the cache that is not filled yet. |
| - dtype (`torch.dtype`): |
| - The dtype to use for the 4D attention mask. |
| - cache_position (`torch.Tensor`): |
| - Indices depicting the position of the input sequence tokens in the sequence. |
| - batch_size (`torch.Tensor`): |
| - Batch size. |
| - """ |
| - if attention_mask is not None and attention_mask.dim() == 4: |
| - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
| - causal_mask = attention_mask |
| - else: |
| - min_dtype = torch.finfo(dtype).min |
| - causal_mask = torch.full( |
| - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device |
| - ) |
| - if sequence_length != 1: |
| - causal_mask = torch.triu(causal_mask, diagonal=1) |
| - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) |
| - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| - if attention_mask is not None: |
| - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
| - mask_length = attention_mask.shape[-1] |
| - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( |
| - causal_mask.device |
| - ) |
| - padding_mask = padding_mask == 0 |
| - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| - padding_mask, min_dtype |
| - ) |
| - |
| - return causal_mask |
| - |
| |
| __all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel", "Qwen2VLTextModel"] |
| |
| |
| |
| |
| @@ -346,10 +346,6 @@ def test_disk_offload_safetensors(self): |
| def test_model_parallelism(self): |
| pass |
| |
| - @unittest.skip(reason="Compile not yet supported because in Qwen2_5_VL models") |
| - def test_sdpa_can_compile_dynamic(self): |
| - pass |
| - |
| @unittest.skip(reason="Compile not yet supported because in Qwen2_5_VL models") |
| def test_sdpa_can_dispatch_on_flash(self): |
| pass |
| @@ -368,10 +364,6 @@ def test_model_is_small(self): |
| def test_generate_from_inputs_embeds_with_static_cache(self): |
| pass |
| |
| - @unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`") |
| - def test_generate_compile_fullgraph(self): |
| - pass |
| - |
| @is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model |
| def test_prompt_lookup_decoding_matches_greedy_search(self): |
| super().test_prompt_lookup_decoding_matches_greedy_search() |
| |
| |
| |
| |
| @@ -300,10 +300,6 @@ def test_disk_offload_safetensors(self): |
| def test_model_parallelism(self): |
| pass |
| |
| - @unittest.skip(reason="Compile not yet supported because in Qwen2VL models") |
| - def test_sdpa_can_compile_dynamic(self): |
| - pass |
| - |
| @unittest.skip(reason="Compile not yet supported because in Qwen2VL models") |
| def test_sdpa_can_dispatch_on_flash(self): |
| pass |
|
|