| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import unittest |
| |
|
| | from transformers.testing_utils import ( |
| | cleanup, |
| | is_torch_available, |
| | require_torch, |
| | torch_device, |
| | ) |
| |
|
| |
|
| | if is_torch_available(): |
| | import torch |
| | from torch.nn.attention.flex_attention import create_block_mask |
| |
|
| | from transformers import DynamicCache, LlamaConfig |
| | from transformers.cache_utils import DynamicSlidingWindowLayer |
| | from transformers.masking_utils import ( |
| | create_bidirectional_mask, |
| | create_causal_mask, |
| | create_chunked_causal_mask, |
| | find_packed_sequence_indices, |
| | ) |
| |
|
| |
|
| | |
| | EXPECTED_PACKED_MASK = torch.tensor([[[ |
| | [ True, False, False, False, False, False, False, False, False, False], |
| | [ True, True, False, False, False, False, False, False, False, False], |
| | [ True, True, True, False, False, False, False, False, False, False], |
| | [ True, True, True, True, False, False, False, False, False, False], |
| | [False, False, False, False, True, False, False, False, False, False], |
| | [False, False, False, False, True, True, False, False, False, False], |
| | [False, False, False, False, False, False, True, False, False, False], |
| | [False, False, False, False, False, False, True, True, False, False], |
| | [False, False, False, False, False, False, True, True, True, False], |
| | [False, False, False, False, False, False, True, True, True, True]]], |
| |
|
| |
|
| | [[[ True, False, False, False, False, False, False, False, False, False], |
| | [ True, True, False, False, False, False, False, False, False, False], |
| | [ True, True, True, False, False, False, False, False, False, False], |
| | [ True, True, True, True, False, False, False, False, False, False], |
| | [ True, True, True, True, True, False, False, False, False, False], |
| | [ True, True, True, True, True, True, False, False, False, False], |
| | [False, False, False, False, False, False, True, False, False, False], |
| | [False, False, False, False, False, False, True, True, False, False], |
| | [False, False, False, False, False, False, True, True, True, False], |
| | [False, False, False, False, False, False, True, True, True, True] |
| | ]]], dtype=torch.bool) |
| | |
| |
|
| |
|
| | @require_torch |
| | class MaskTest(unittest.TestCase): |
| | def setup(self): |
| | cleanup(torch_device, gc_collect=True) |
| |
|
| | def tearDown(self): |
| | cleanup(torch_device, gc_collect=True) |
| |
|
| | def test_packed_sequence_mask_sdpa(self): |
| | config = LlamaConfig() |
| | config._attn_implementation = "sdpa" |
| |
|
| | batch_size = 2 |
| | sequence_length = 10 |
| | cache_position = torch.arange(sequence_length) |
| |
|
| | |
| | position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]]) |
| |
|
| | causal_mask = create_causal_mask( |
| | config=config, |
| | |
| | input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16), |
| | attention_mask=None, |
| | cache_position=cache_position, |
| | past_key_values=None, |
| | position_ids=position_ids, |
| | ) |
| |
|
| | self.assertTrue((causal_mask == EXPECTED_PACKED_MASK).all()) |
| |
|
| | def test_packed_sequence_mask_eager(self): |
| | config = LlamaConfig() |
| | config._attn_implementation = "eager" |
| |
|
| | batch_size = 2 |
| | sequence_length = 10 |
| | cache_position = torch.arange(sequence_length) |
| |
|
| | |
| | position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]]) |
| |
|
| | causal_mask = create_causal_mask( |
| | config=config, |
| | |
| | input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16), |
| | attention_mask=None, |
| | cache_position=cache_position, |
| | past_key_values=None, |
| | position_ids=position_ids, |
| | ) |
| |
|
| | min_dtype = torch.finfo(torch.float16).min |
| | self.assertTrue((causal_mask == torch.where(EXPECTED_PACKED_MASK, 0.0, min_dtype)).all()) |
| |
|
| | def test_packed_sequence_mask_flex_attention(self): |
| | config = LlamaConfig() |
| | config._attn_implementation = "flex_attention" |
| |
|
| | batch_size = 2 |
| | sequence_length = 10 |
| | cache_position = torch.arange(sequence_length) |
| |
|
| | |
| | position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]]) |
| |
|
| | causal_mask = create_causal_mask( |
| | config=config, |
| | |
| | input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16), |
| | attention_mask=None, |
| | cache_position=cache_position, |
| | past_key_values=None, |
| | position_ids=position_ids, |
| | ) |
| |
|
| | def dummy_mask_mod(b, h, q, kv): |
| | return EXPECTED_PACKED_MASK[b, h, q, kv] |
| |
|
| | EXPECTED_BLOCK_MASK = create_block_mask(dummy_mask_mod, 2, None, 10, 10, device="cpu") |
| |
|
| | |
| | self.assertEqual(causal_mask.to_string(), EXPECTED_BLOCK_MASK.to_string()) |
| |
|
| | def test_find_packed_sequence_indices(self): |
| | position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]]) |
| | EXPECTED_SEQUENCE_INDICES = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]]) |
| | self.assertTrue((find_packed_sequence_indices(position_ids) == EXPECTED_SEQUENCE_INDICES).all()) |
| |
|
| | def test_nonpacked_sequence_mask_skip(self): |
| | config = LlamaConfig() |
| | config._attn_implementation = "sdpa" |
| |
|
| | batch_size = 2 |
| | sequence_length = 10 |
| | cache_position = torch.arange(sequence_length) |
| |
|
| | |
| | position_ids = torch.arange(sequence_length)[None, :] |
| |
|
| | causal_mask = create_causal_mask( |
| | config=config, |
| | |
| | input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16), |
| | attention_mask=None, |
| | cache_position=cache_position, |
| | past_key_values=None, |
| | position_ids=position_ids, |
| | ) |
| | |
| | self.assertTrue(causal_mask is None) |
| |
|
| | create_causal_mask_compiled = torch.compile(create_causal_mask, mode="reduce-overhead") |
| | causal_mask = create_causal_mask_compiled( |
| | config=config, |
| | |
| | input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16), |
| | attention_mask=None, |
| | cache_position=cache_position, |
| | past_key_values=None, |
| | position_ids=position_ids, |
| | ) |
| | |
| | self.assertTrue(torch.equal(~torch.ones(*causal_mask.shape).triu(diagonal=1).bool(), causal_mask)) |
| |
|
| | def test_chunked_mask_with_left_padding_and_large_prefill(self): |
| | |
| | config = LlamaConfig(attention_chunk_size=3, attn_implementation="sdpa") |
| |
|
| | batch_size = 2 |
| | sequence_length = 8 |
| | pad_tokens = 4 |
| |
|
| | input_ids = torch.randint(100, 200, (batch_size, sequence_length)) |
| | attention_mask = torch.tensor( |
| | [[0 if i < pad_tokens else 1 for i in range(sequence_length)], [1] * sequence_length] |
| | ) |
| | inputs_embeds = torch.empty_like(input_ids, dtype=torch.float16) |
| | cache_position = torch.arange(sequence_length) |
| | position_ids = torch.empty(batch_size, sequence_length, dtype=cache_position.dtype) |
| | position_ids[0, :pad_tokens] = 1 |
| | position_ids[0, pad_tokens:] = torch.arange(sequence_length - pad_tokens) |
| | position_ids[1, :] = cache_position |
| |
|
| | chunked_attention_mask = create_chunked_causal_mask( |
| | config=config, |
| | input_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | cache_position=cache_position, |
| | past_key_values=None, |
| | position_ids=position_ids, |
| | ) |
| |
|
| | |
| | EXPECTED_CHUNKED_MASK = torch.tensor( |
| | |
| | |
| | [[[[False, False, False, False, False, False, False, False], |
| | [False, False, False, False, False, False, False, False], |
| | [False, False, False, False, False, False, False, False], |
| | [False, False, False, False, False, False, False, False], |
| | [False, False, False, False, True, False, False, False], |
| | [False, False, False, False, True, True, False, False], |
| | [False, False, False, False, True, True, True, False], |
| | [False, False, False, False, False, False, False, True]]], |
| |
|
| |
|
| | [[[ True, False, False, False, False, False, False, False], |
| | [ True, True, False, False, False, False, False, False], |
| | [ True, True, True, False, False, False, False, False], |
| | [False, False, False, True, False, False, False, False], |
| | [False, False, False, True, True, False, False, False], |
| | [False, False, False, True, True, True, False, False], |
| | [False, False, False, False, False, False, True, False], |
| | [False, False, False, False, False, False, True, True]]]], |
| | dtype=torch.bool) |
| | |
| |
|
| | self.assertTrue((chunked_attention_mask == EXPECTED_CHUNKED_MASK).all()) |
| |
|
| | def test_chunked_mask_with_left_padding_decoding(self): |
| | |
| | config = LlamaConfig(attention_chunk_size=4, attn_implementation="sdpa", num_hidden_layers=1) |
| |
|
| | cache = DynamicCache(config=config) |
| | |
| | self.assertEqual(len(cache), 1) |
| | self.assertTrue(isinstance(cache.layers[0], DynamicSlidingWindowLayer)) |
| |
|
| | |
| | batch_size = 2 |
| | prefill_size = 8 |
| | pad_tokens = 7 |
| | fake_kv = torch.rand(batch_size, 32, prefill_size, 32) |
| | cache.update(fake_kv, fake_kv, 0, torch.arange(prefill_size)) |
| |
|
| | |
| | input_ids = torch.randint(100, 200, (batch_size, 1)) |
| | attention_mask = torch.tensor( |
| | [[0 if i < pad_tokens else 1 for i in range(prefill_size + 1)], [1] * (prefill_size + 1)] |
| | ) |
| | inputs_embeds = torch.empty_like(input_ids, dtype=torch.float16) |
| | cache_position = torch.tensor([prefill_size], dtype=int) |
| | position_ids = torch.tensor([[prefill_size - pad_tokens], [prefill_size]]) |
| |
|
| | chunked_attention_mask = create_chunked_causal_mask( |
| | config=config, |
| | input_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | cache_position=cache_position, |
| | past_key_values=cache, |
| | position_ids=position_ids, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | EXPECTED_CHUNKED_MASK = torch.tensor( |
| | |
| | |
| | [[[[False, False, True, True]]], |
| |
|
| | |
| | |
| | [[[False, False, False, True]]]], |
| | dtype=torch.bool) |
| | |
| |
|
| | self.assertTrue((chunked_attention_mask == EXPECTED_CHUNKED_MASK).all()) |
| |
|
| | @staticmethod |
| | def _run_bidirectional_mask(mask_fn, attn_implementation): |
| | def run_mask_creation(mask_fn, config, input_embeds, encoder_mask, cross_mask, encoder_hidden_states): |
| | encoder_attn_mask = mask_fn( |
| | config=config, |
| | input_embeds=input_embeds, |
| | attention_mask=encoder_mask, |
| | ) |
| | cross_attn_mask = mask_fn( |
| | config=config, |
| | input_embeds=input_embeds, |
| | attention_mask=cross_mask, |
| | encoder_hidden_states=encoder_hidden_states, |
| | ) |
| | return encoder_attn_mask, cross_attn_mask |
| |
|
| | |
| | config = LlamaConfig() |
| | config._attn_implementation = attn_implementation |
| |
|
| | |
| | batch_size = 2 |
| | q_length = 10 |
| | kv_length = 5 |
| |
|
| | input_embeds = torch.ones((batch_size, q_length, 1), device=torch_device, dtype=torch.float16) |
| | encoder_hidden_states = torch.ones((batch_size, kv_length, 1), device=torch_device, dtype=torch.float16) |
| |
|
| | encoder_mask = torch.ones_like(input_embeds)[..., 0] |
| | cross_mask = torch.ones_like(encoder_hidden_states)[..., 0] |
| |
|
| | |
| | full_mask_encoder_1, full_mask_cross_1 = run_mask_creation( |
| | mask_fn=mask_fn, |
| | config=config, |
| | input_embeds=input_embeds, |
| | encoder_mask=encoder_mask, |
| | cross_mask=cross_mask, |
| | encoder_hidden_states=encoder_hidden_states, |
| | ) |
| | full_mask_encoder_2, full_mask_cross_2 = run_mask_creation( |
| | mask_fn=mask_fn, |
| | config=config, |
| | input_embeds=input_embeds, |
| | encoder_mask=None, |
| | cross_mask=None, |
| | encoder_hidden_states=encoder_hidden_states, |
| | ) |
| |
|
| | |
| | cross_mask[:, -1] = 0 |
| | encoder_mask[:, -1] = 0 |
| |
|
| | padded_mask_encoder, padded_mask_cross = run_mask_creation( |
| | mask_fn=mask_fn, |
| | config=config, |
| | input_embeds=input_embeds, |
| | encoder_mask=encoder_mask, |
| | cross_mask=cross_mask, |
| | encoder_hidden_states=encoder_hidden_states, |
| | ) |
| |
|
| | full_masks = (full_mask_encoder_1, full_mask_encoder_2), (full_mask_cross_1, full_mask_cross_2) |
| | padded_masks = (padded_mask_encoder, padded_mask_cross) |
| | return full_masks, padded_masks |
| |
|
| | def test_bidirectional_mask_cudagraphs(self): |
| | """ |
| | Checks whether the bidirectional mask creation is compatible with cuda graphs, i.e. we do not into any error |
| | during this test. |
| | """ |
| | mask_creation_function = torch.compile(create_bidirectional_mask, mode="reduce-overhead") |
| | self._run_bidirectional_mask(mask_fn=mask_creation_function, attn_implementation="sdpa") |
| |
|
| | def test_bidirectional_mask_skip_eager(self): |
| | """ |
| | Checks whether the bidirectional mask creation can skip the mask creation if we have a full mask. |
| | """ |
| | full_masks, padded_mask = self._run_bidirectional_mask( |
| | mask_fn=create_bidirectional_mask, attn_implementation="eager" |
| | ) |
| |
|
| | for alternative_masks in full_masks: |
| | self.assertTrue(alternative_masks[0] is None) |
| | self.assertTrue(alternative_masks[1] is None) |
| |
|
| | self.assertTrue(padded_mask[0] is not None) |
| | self.assertTrue(padded_mask[1] is not None) |
| |
|