transformers / tests /utils /test_masking_utils.py
AbdulElahGwaith's picture
Upload folder using huggingface_hub
a9bd396 verified
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.testing_utils import (
cleanup,
is_torch_available,
require_torch,
torch_device,
)
if is_torch_available():
import torch
from torch.nn.attention.flex_attention import create_block_mask
from transformers import DynamicCache, LlamaConfig
from transformers.cache_utils import DynamicSlidingWindowLayer
from transformers.masking_utils import (
create_bidirectional_mask,
create_causal_mask,
create_chunked_causal_mask,
find_packed_sequence_indices,
)
# fmt: off
EXPECTED_PACKED_MASK = torch.tensor([[[
[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[False, False, False, False, True, False, False, False, False, False],
[False, False, False, False, True, True, False, False, False, False],
[False, False, False, False, False, False, True, False, False, False],
[False, False, False, False, False, False, True, True, False, False],
[False, False, False, False, False, False, True, True, True, False],
[False, False, False, False, False, False, True, True, True, True]]],
[[[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, True, False, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[False, False, False, False, False, False, True, False, False, False],
[False, False, False, False, False, False, True, True, False, False],
[False, False, False, False, False, False, True, True, True, False],
[False, False, False, False, False, False, True, True, True, True]
]]], dtype=torch.bool)
# fmt: on
@require_torch
class MaskTest(unittest.TestCase):
def setup(self):
cleanup(torch_device, gc_collect=True)
def tearDown(self):
cleanup(torch_device, gc_collect=True)
def test_packed_sequence_mask_sdpa(self):
config = LlamaConfig()
config._attn_implementation = "sdpa"
batch_size = 2
sequence_length = 10
cache_position = torch.arange(sequence_length)
# First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
causal_mask = create_causal_mask(
config=config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
attention_mask=None,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
self.assertTrue((causal_mask == EXPECTED_PACKED_MASK).all())
def test_packed_sequence_mask_eager(self):
config = LlamaConfig()
config._attn_implementation = "eager"
batch_size = 2
sequence_length = 10
cache_position = torch.arange(sequence_length)
# First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
causal_mask = create_causal_mask(
config=config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
attention_mask=None,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
min_dtype = torch.finfo(torch.float16).min
self.assertTrue((causal_mask == torch.where(EXPECTED_PACKED_MASK, 0.0, min_dtype)).all())
def test_packed_sequence_mask_flex_attention(self):
config = LlamaConfig()
config._attn_implementation = "flex_attention"
batch_size = 2
sequence_length = 10
cache_position = torch.arange(sequence_length)
# First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
causal_mask = create_causal_mask(
config=config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
attention_mask=None,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
def dummy_mask_mod(b, h, q, kv):
return EXPECTED_PACKED_MASK[b, h, q, kv]
EXPECTED_BLOCK_MASK = create_block_mask(dummy_mask_mod, 2, None, 10, 10, device="cpu")
# We compatre the str representations, as the BlockMask objects themselves cannot easily be compared
self.assertEqual(causal_mask.to_string(), EXPECTED_BLOCK_MASK.to_string())
def test_find_packed_sequence_indices(self):
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
EXPECTED_SEQUENCE_INDICES = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
self.assertTrue((find_packed_sequence_indices(position_ids) == EXPECTED_SEQUENCE_INDICES).all())
def test_nonpacked_sequence_mask_skip(self):
config = LlamaConfig()
config._attn_implementation = "sdpa"
batch_size = 2
sequence_length = 10
cache_position = torch.arange(sequence_length)
# Non-packed sequences
position_ids = torch.arange(sequence_length)[None, :]
causal_mask = create_causal_mask(
config=config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
attention_mask=None,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
# packed sequence should be skipped
self.assertTrue(causal_mask is None)
create_causal_mask_compiled = torch.compile(create_causal_mask, mode="reduce-overhead")
causal_mask = create_causal_mask_compiled(
config=config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
attention_mask=None,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
# cannot be skipped under compile, should result into a triu mask
self.assertTrue(torch.equal(~torch.ones(*causal_mask.shape).triu(diagonal=1).bool(), causal_mask))
def test_chunked_mask_with_left_padding_and_large_prefill(self):
# Make sure we have an attention_chunk_size in the config
config = LlamaConfig(attention_chunk_size=3, attn_implementation="sdpa")
batch_size = 2
sequence_length = 8
pad_tokens = 4
input_ids = torch.randint(100, 200, (batch_size, sequence_length))
attention_mask = torch.tensor(
[[0 if i < pad_tokens else 1 for i in range(sequence_length)], [1] * sequence_length]
)
inputs_embeds = torch.empty_like(input_ids, dtype=torch.float16)
cache_position = torch.arange(sequence_length)
position_ids = torch.empty(batch_size, sequence_length, dtype=cache_position.dtype)
position_ids[0, :pad_tokens] = 1
position_ids[0, pad_tokens:] = torch.arange(sequence_length - pad_tokens)
position_ids[1, :] = cache_position
chunked_attention_mask = create_chunked_causal_mask(
config=config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
# fmt: off
EXPECTED_CHUNKED_MASK = torch.tensor(
# Here, for the padded sequence, the chunk size should start correctly at index 4 (otherwise, with 4 padding
# tokens are chunk_size=3, the first chunk is from indices 0-2, then 3-6 if we don't account for the padding correctly)
[[[[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, True, False, False, False],
[False, False, False, False, True, True, False, False],
[False, False, False, False, True, True, True, False],
[False, False, False, False, False, False, False, True]]],
[[[ True, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False],
[False, False, False, True, False, False, False, False],
[False, False, False, True, True, False, False, False],
[False, False, False, True, True, True, False, False],
[False, False, False, False, False, False, True, False],
[False, False, False, False, False, False, True, True]]]],
dtype=torch.bool)
# fmt: on
self.assertTrue((chunked_attention_mask == EXPECTED_CHUNKED_MASK).all())
def test_chunked_mask_with_left_padding_decoding(self):
# Make sure we have an attention_chunk_size in the config
config = LlamaConfig(attention_chunk_size=4, attn_implementation="sdpa", num_hidden_layers=1)
cache = DynamicCache(config=config)
# Sanity check
self.assertEqual(len(cache), 1)
self.assertTrue(isinstance(cache.layers[0], DynamicSlidingWindowLayer))
# Fill-in the Cache (sequence length is bigger than chunk size here)
batch_size = 2
prefill_size = 8
pad_tokens = 7
fake_kv = torch.rand(batch_size, 32, prefill_size, 32)
cache.update(fake_kv, fake_kv, 0, torch.arange(prefill_size))
# Create a new input after the prefill
input_ids = torch.randint(100, 200, (batch_size, 1))
attention_mask = torch.tensor(
[[0 if i < pad_tokens else 1 for i in range(prefill_size + 1)], [1] * (prefill_size + 1)]
)
inputs_embeds = torch.empty_like(input_ids, dtype=torch.float16)
cache_position = torch.tensor([prefill_size], dtype=int)
position_ids = torch.tensor([[prefill_size - pad_tokens], [prefill_size]])
chunked_attention_mask = create_chunked_causal_mask(
config=config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=cache,
position_ids=position_ids,
)
# To understand a bit more the following expected mask, here is the full 2d mask, where the "|" characters are the chunk
# separators (where the tokens should stop seeing each other)
# [0, 0, 0, 0, 0, 0, 0, | 1, 1], -> due to left padding, the first chunk only starts after the padding tokens
# [| 1, 1, 1, 1, | 1, 1, 1, 1, | 1]]) -> easy case, each 4 tokens is a new chunk
# fmt: off
EXPECTED_CHUNKED_MASK = torch.tensor(
# Here, for the padded sequence, the chunk size should start correctly at index 7 (the first unpadded
# index), and so only indices 7 and 8 should be True
[[[[False, False, True, True]]],
# Here, for the unpadded sequence, the chunks start at index 0. Since we have 9 tokens in total, the last
# token (index 8) will only see itself (we have 2 full chunks before)
[[[False, False, False, True]]]],
dtype=torch.bool)
# fmt: on
self.assertTrue((chunked_attention_mask == EXPECTED_CHUNKED_MASK).all())
@staticmethod
def _run_bidirectional_mask(mask_fn, attn_implementation):
def run_mask_creation(mask_fn, config, input_embeds, encoder_mask, cross_mask, encoder_hidden_states):
encoder_attn_mask = mask_fn(
config=config,
input_embeds=input_embeds,
attention_mask=encoder_mask,
)
cross_attn_mask = mask_fn(
config=config,
input_embeds=input_embeds,
attention_mask=cross_mask,
encoder_hidden_states=encoder_hidden_states,
)
return encoder_attn_mask, cross_attn_mask
# We use llama but could be also bert/bart --> we only need the `_attn_implementation` here
config = LlamaConfig()
config._attn_implementation = attn_implementation
# Meta data
batch_size = 2
q_length = 10
kv_length = 5
input_embeds = torch.ones((batch_size, q_length, 1), device=torch_device, dtype=torch.float16)
encoder_hidden_states = torch.ones((batch_size, kv_length, 1), device=torch_device, dtype=torch.float16)
encoder_mask = torch.ones_like(input_embeds)[..., 0]
cross_mask = torch.ones_like(encoder_hidden_states)[..., 0]
# Case 1: Full mask
full_mask_encoder_1, full_mask_cross_1 = run_mask_creation(
mask_fn=mask_fn,
config=config,
input_embeds=input_embeds,
encoder_mask=encoder_mask,
cross_mask=cross_mask,
encoder_hidden_states=encoder_hidden_states,
)
full_mask_encoder_2, full_mask_cross_2 = run_mask_creation(
mask_fn=mask_fn,
config=config,
input_embeds=input_embeds,
encoder_mask=None,
cross_mask=None,
encoder_hidden_states=encoder_hidden_states,
)
# Case 2: Padding involved
cross_mask[:, -1] = 0
encoder_mask[:, -1] = 0
padded_mask_encoder, padded_mask_cross = run_mask_creation(
mask_fn=mask_fn,
config=config,
input_embeds=input_embeds,
encoder_mask=encoder_mask,
cross_mask=cross_mask,
encoder_hidden_states=encoder_hidden_states,
)
full_masks = (full_mask_encoder_1, full_mask_encoder_2), (full_mask_cross_1, full_mask_cross_2)
padded_masks = (padded_mask_encoder, padded_mask_cross)
return full_masks, padded_masks
def test_bidirectional_mask_cudagraphs(self):
"""
Checks whether the bidirectional mask creation is compatible with cuda graphs, i.e. we do not into any error
during this test.
"""
mask_creation_function = torch.compile(create_bidirectional_mask, mode="reduce-overhead")
self._run_bidirectional_mask(mask_fn=mask_creation_function, attn_implementation="sdpa")
def test_bidirectional_mask_skip_eager(self):
"""
Checks whether the bidirectional mask creation can skip the mask creation if we have a full mask.
"""
full_masks, padded_mask = self._run_bidirectional_mask(
mask_fn=create_bidirectional_mask, attn_implementation="eager"
)
for alternative_masks in full_masks:
self.assertTrue(alternative_masks[0] is None)
self.assertTrue(alternative_masks[1] is None)
self.assertTrue(padded_mask[0] is not None)
self.assertTrue(padded_mask[1] is not None)