# Copyright 2025 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from transformers.testing_utils import ( cleanup, is_torch_available, require_torch, torch_device, ) if is_torch_available(): import torch from torch.nn.attention.flex_attention import create_block_mask from transformers import DynamicCache, LlamaConfig from transformers.cache_utils import DynamicSlidingWindowLayer from transformers.masking_utils import ( create_bidirectional_mask, create_causal_mask, create_chunked_causal_mask, find_packed_sequence_indices, ) # fmt: off EXPECTED_PACKED_MASK = torch.tensor([[[ [ True, False, False, False, False, False, False, False, False, False], [ True, True, False, False, False, False, False, False, False, False], [ True, True, True, False, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [False, False, False, False, True, False, False, False, False, False], [False, False, False, False, True, True, False, False, False, False], [False, False, False, False, False, False, True, False, False, False], [False, False, False, False, False, False, True, True, False, False], [False, False, False, False, False, False, True, True, True, False], [False, False, False, False, False, False, True, True, True, True]]], [[[ True, False, False, False, False, False, False, False, False, False], [ True, True, False, False, False, False, False, False, False, False], [ True, True, True, False, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, True, False, False, False, False], [False, False, False, False, False, False, True, False, False, False], [False, False, False, False, False, False, True, True, False, False], [False, False, False, False, False, False, True, True, True, False], [False, False, False, False, False, False, True, True, True, True] ]]], dtype=torch.bool) # fmt: on @require_torch class MaskTest(unittest.TestCase): def setup(self): cleanup(torch_device, gc_collect=True) def tearDown(self): cleanup(torch_device, gc_collect=True) def test_packed_sequence_mask_sdpa(self): config = LlamaConfig() config._attn_implementation = "sdpa" batch_size = 2 sequence_length = 10 cache_position = torch.arange(sequence_length) # First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]]) causal_mask = create_causal_mask( config=config, # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16), attention_mask=None, cache_position=cache_position, past_key_values=None, position_ids=position_ids, ) self.assertTrue((causal_mask == EXPECTED_PACKED_MASK).all()) def test_packed_sequence_mask_eager(self): config = LlamaConfig() config._attn_implementation = "eager" batch_size = 2 sequence_length = 10 cache_position = torch.arange(sequence_length) # First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]]) causal_mask = create_causal_mask( config=config, # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16), attention_mask=None, cache_position=cache_position, past_key_values=None, position_ids=position_ids, ) min_dtype = torch.finfo(torch.float16).min self.assertTrue((causal_mask == torch.where(EXPECTED_PACKED_MASK, 0.0, min_dtype)).all()) def test_packed_sequence_mask_flex_attention(self): config = LlamaConfig() config._attn_implementation = "flex_attention" batch_size = 2 sequence_length = 10 cache_position = torch.arange(sequence_length) # First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]]) causal_mask = create_causal_mask( config=config, # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16), attention_mask=None, cache_position=cache_position, past_key_values=None, position_ids=position_ids, ) def dummy_mask_mod(b, h, q, kv): return EXPECTED_PACKED_MASK[b, h, q, kv] EXPECTED_BLOCK_MASK = create_block_mask(dummy_mask_mod, 2, None, 10, 10, device="cpu") # We compatre the str representations, as the BlockMask objects themselves cannot easily be compared self.assertEqual(causal_mask.to_string(), EXPECTED_BLOCK_MASK.to_string()) def test_find_packed_sequence_indices(self): position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]]) EXPECTED_SEQUENCE_INDICES = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]]) self.assertTrue((find_packed_sequence_indices(position_ids) == EXPECTED_SEQUENCE_INDICES).all()) def test_nonpacked_sequence_mask_skip(self): config = LlamaConfig() config._attn_implementation = "sdpa" batch_size = 2 sequence_length = 10 cache_position = torch.arange(sequence_length) # Non-packed sequences position_ids = torch.arange(sequence_length)[None, :] causal_mask = create_causal_mask( config=config, # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16), attention_mask=None, cache_position=cache_position, past_key_values=None, position_ids=position_ids, ) # packed sequence should be skipped self.assertTrue(causal_mask is None) create_causal_mask_compiled = torch.compile(create_causal_mask, mode="reduce-overhead") causal_mask = create_causal_mask_compiled( config=config, # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16), attention_mask=None, cache_position=cache_position, past_key_values=None, position_ids=position_ids, ) # cannot be skipped under compile, should result into a triu mask self.assertTrue(torch.equal(~torch.ones(*causal_mask.shape).triu(diagonal=1).bool(), causal_mask)) def test_chunked_mask_with_left_padding_and_large_prefill(self): # Make sure we have an attention_chunk_size in the config config = LlamaConfig(attention_chunk_size=3, attn_implementation="sdpa") batch_size = 2 sequence_length = 8 pad_tokens = 4 input_ids = torch.randint(100, 200, (batch_size, sequence_length)) attention_mask = torch.tensor( [[0 if i < pad_tokens else 1 for i in range(sequence_length)], [1] * sequence_length] ) inputs_embeds = torch.empty_like(input_ids, dtype=torch.float16) cache_position = torch.arange(sequence_length) position_ids = torch.empty(batch_size, sequence_length, dtype=cache_position.dtype) position_ids[0, :pad_tokens] = 1 position_ids[0, pad_tokens:] = torch.arange(sequence_length - pad_tokens) position_ids[1, :] = cache_position chunked_attention_mask = create_chunked_causal_mask( config=config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=None, position_ids=position_ids, ) # fmt: off EXPECTED_CHUNKED_MASK = torch.tensor( # Here, for the padded sequence, the chunk size should start correctly at index 4 (otherwise, with 4 padding # tokens are chunk_size=3, the first chunk is from indices 0-2, then 3-6 if we don't account for the padding correctly) [[[[False, False, False, False, False, False, False, False], [False, False, False, False, False, False, False, False], [False, False, False, False, False, False, False, False], [False, False, False, False, False, False, False, False], [False, False, False, False, True, False, False, False], [False, False, False, False, True, True, False, False], [False, False, False, False, True, True, True, False], [False, False, False, False, False, False, False, True]]], [[[ True, False, False, False, False, False, False, False], [ True, True, False, False, False, False, False, False], [ True, True, True, False, False, False, False, False], [False, False, False, True, False, False, False, False], [False, False, False, True, True, False, False, False], [False, False, False, True, True, True, False, False], [False, False, False, False, False, False, True, False], [False, False, False, False, False, False, True, True]]]], dtype=torch.bool) # fmt: on self.assertTrue((chunked_attention_mask == EXPECTED_CHUNKED_MASK).all()) def test_chunked_mask_with_left_padding_decoding(self): # Make sure we have an attention_chunk_size in the config config = LlamaConfig(attention_chunk_size=4, attn_implementation="sdpa", num_hidden_layers=1) cache = DynamicCache(config=config) # Sanity check self.assertEqual(len(cache), 1) self.assertTrue(isinstance(cache.layers[0], DynamicSlidingWindowLayer)) # Fill-in the Cache (sequence length is bigger than chunk size here) batch_size = 2 prefill_size = 8 pad_tokens = 7 fake_kv = torch.rand(batch_size, 32, prefill_size, 32) cache.update(fake_kv, fake_kv, 0, torch.arange(prefill_size)) # Create a new input after the prefill input_ids = torch.randint(100, 200, (batch_size, 1)) attention_mask = torch.tensor( [[0 if i < pad_tokens else 1 for i in range(prefill_size + 1)], [1] * (prefill_size + 1)] ) inputs_embeds = torch.empty_like(input_ids, dtype=torch.float16) cache_position = torch.tensor([prefill_size], dtype=int) position_ids = torch.tensor([[prefill_size - pad_tokens], [prefill_size]]) chunked_attention_mask = create_chunked_causal_mask( config=config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=cache, position_ids=position_ids, ) # To understand a bit more the following expected mask, here is the full 2d mask, where the "|" characters are the chunk # separators (where the tokens should stop seeing each other) # [0, 0, 0, 0, 0, 0, 0, | 1, 1], -> due to left padding, the first chunk only starts after the padding tokens # [| 1, 1, 1, 1, | 1, 1, 1, 1, | 1]]) -> easy case, each 4 tokens is a new chunk # fmt: off EXPECTED_CHUNKED_MASK = torch.tensor( # Here, for the padded sequence, the chunk size should start correctly at index 7 (the first unpadded # index), and so only indices 7 and 8 should be True [[[[False, False, True, True]]], # Here, for the unpadded sequence, the chunks start at index 0. Since we have 9 tokens in total, the last # token (index 8) will only see itself (we have 2 full chunks before) [[[False, False, False, True]]]], dtype=torch.bool) # fmt: on self.assertTrue((chunked_attention_mask == EXPECTED_CHUNKED_MASK).all()) @staticmethod def _run_bidirectional_mask(mask_fn, attn_implementation): def run_mask_creation(mask_fn, config, input_embeds, encoder_mask, cross_mask, encoder_hidden_states): encoder_attn_mask = mask_fn( config=config, input_embeds=input_embeds, attention_mask=encoder_mask, ) cross_attn_mask = mask_fn( config=config, input_embeds=input_embeds, attention_mask=cross_mask, encoder_hidden_states=encoder_hidden_states, ) return encoder_attn_mask, cross_attn_mask # We use llama but could be also bert/bart --> we only need the `_attn_implementation` here config = LlamaConfig() config._attn_implementation = attn_implementation # Meta data batch_size = 2 q_length = 10 kv_length = 5 input_embeds = torch.ones((batch_size, q_length, 1), device=torch_device, dtype=torch.float16) encoder_hidden_states = torch.ones((batch_size, kv_length, 1), device=torch_device, dtype=torch.float16) encoder_mask = torch.ones_like(input_embeds)[..., 0] cross_mask = torch.ones_like(encoder_hidden_states)[..., 0] # Case 1: Full mask full_mask_encoder_1, full_mask_cross_1 = run_mask_creation( mask_fn=mask_fn, config=config, input_embeds=input_embeds, encoder_mask=encoder_mask, cross_mask=cross_mask, encoder_hidden_states=encoder_hidden_states, ) full_mask_encoder_2, full_mask_cross_2 = run_mask_creation( mask_fn=mask_fn, config=config, input_embeds=input_embeds, encoder_mask=None, cross_mask=None, encoder_hidden_states=encoder_hidden_states, ) # Case 2: Padding involved cross_mask[:, -1] = 0 encoder_mask[:, -1] = 0 padded_mask_encoder, padded_mask_cross = run_mask_creation( mask_fn=mask_fn, config=config, input_embeds=input_embeds, encoder_mask=encoder_mask, cross_mask=cross_mask, encoder_hidden_states=encoder_hidden_states, ) full_masks = (full_mask_encoder_1, full_mask_encoder_2), (full_mask_cross_1, full_mask_cross_2) padded_masks = (padded_mask_encoder, padded_mask_cross) return full_masks, padded_masks def test_bidirectional_mask_cudagraphs(self): """ Checks whether the bidirectional mask creation is compatible with cuda graphs, i.e. we do not into any error during this test. """ mask_creation_function = torch.compile(create_bidirectional_mask, mode="reduce-overhead") self._run_bidirectional_mask(mask_fn=mask_creation_function, attn_implementation="sdpa") def test_bidirectional_mask_skip_eager(self): """ Checks whether the bidirectional mask creation can skip the mask creation if we have a full mask. """ full_masks, padded_mask = self._run_bidirectional_mask( mask_fn=create_bidirectional_mask, attn_implementation="eager" ) for alternative_masks in full_masks: self.assertTrue(alternative_masks[0] is None) self.assertTrue(alternative_masks[1] is None) self.assertTrue(padded_mask[0] is not None) self.assertTrue(padded_mask[1] is not None)