Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import itertools | |
| import pytest | |
| import torch | |
| import xformers.components.attention.attention_patterns as AP | |
| from xformers.components.attention.sparsity_config import ( | |
| BigBirdSparsityConfig, | |
| BSLongformerSparsityConfig, | |
| DenseSparsityConfig, | |
| FixedSparsityConfig, | |
| VariableSparsityConfig, | |
| ) | |
| # baseline implementations | |
| def _local_1d_pattern(attn_size: int, window_size: int) -> torch.Tensor: | |
| assert ( | |
| window_size % 2 == 1 | |
| ), "The window size is assumed to be odd (counts self-attention + 2 wings)" | |
| h_win_size = window_size // 2 | |
| attn_shape = (attn_size, attn_size) | |
| full_attn = torch.ones(attn_shape, dtype=torch.bool) | |
| mask = torch.tril(full_attn, diagonal=h_win_size) | |
| mask &= ~torch.tril(full_attn, diagonal=-(h_win_size + 1)) | |
| return mask | |
| def _generate_2d_grid(H, W): | |
| i = torch.arange(H) | |
| j = torch.arange(W) | |
| i, j = torch.meshgrid(i, j) | |
| return i, j | |
| def _horizontal_axial_2d_distance(H, W, p=2.0): | |
| i, _ = _generate_2d_grid(H, W) | |
| ij = i.reshape(-1, 1).float() | |
| d = torch.cdist(ij, ij, p=p) | |
| return d | |
| def _vertical_axial_2d_distance(H, W, p=2.0): | |
| _, j = _generate_2d_grid(H, W) | |
| ij = j.reshape(-1, 1).float() | |
| d = torch.cdist(ij, ij, p=p) | |
| return d | |
| def _local_2d_distance(H, W, p=2.0): | |
| # axial is a special case with p=0 and distance=2 | |
| i, j = _generate_2d_grid(H, W) | |
| ij = torch.stack([i.flatten(), j.flatten()], 1).float() | |
| d = torch.cdist(ij, ij, p=p) | |
| return d | |
| def _local_2d_gaussian_distribution(H, W, sigma=1.0): | |
| d = _local_2d_distance(H, W, p=2.0) ** 2 | |
| d = torch.exp(-0.5 * sigma ** (-2.0) * d) | |
| return d | |
| def test_local_1d_pattern(attn_size, window_size): | |
| mask = AP.local_1d_pattern(attn_size, window_size).float() | |
| mask_ref = _local_1d_pattern(attn_size, window_size).float() | |
| assert torch.allclose(mask, mask_ref) | |
| def test_horizontal_axial_2d_distance(H, W, p): | |
| d = AP.horizontal_axial_2d_distance(H, W, p=p) | |
| d_ref = _horizontal_axial_2d_distance(H, W, p=p) | |
| assert torch.allclose(d, d_ref) | |
| def test_vertical_axial_2d_distance(H, W, p): | |
| d = AP.vertical_axial_2d_distance(H, W, p=p) | |
| d_ref = _vertical_axial_2d_distance(H, W, p=p) | |
| assert torch.allclose(d, d_ref) | |
| def test_local_2d_distance(H, W, p): | |
| d = AP.local_2d_distance(H, W, p=p) | |
| d_ref = _local_2d_distance(H, W, p=p) | |
| assert torch.allclose(d, d_ref) | |
| def test_local_2d_gaussian_distribution(H, W, sigma): | |
| d = AP.local_2d_gausian_distribution(H, W, sigma=sigma) | |
| d_ref = _local_2d_gaussian_distribution(H, W, sigma=sigma) | |
| assert torch.allclose(d, d_ref) | |
| def test_swin_attention_pattern(H, W, window_size): | |
| # test non-shifted case | |
| d = AP.swin_attention_pattern(H, W, window_size, shift_size=0) | |
| # partition the self-attention into regions of window_size | |
| # similar to the window_partition function from the original paper | |
| h = H // window_size | |
| w = W // window_size | |
| d = d.reshape(h, window_size, w, window_size, h, window_size, w, window_size) | |
| product = itertools.product(range(h), range(w)) | |
| for y, x in product: | |
| # every region should fully attend to itself | |
| assert torch.all(d[y, :, x, :, y, :, x, :]) | |
| for y2, x2 in product: | |
| if y == y2 or x == x2: | |
| continue | |
| # different regions shouldn't attend between each other | |
| assert torch.all(~d[y, :, x, :, y2, :, x2, :]) | |
| # test shifted case | |
| # in the shifted case, the self-attention should be the same | |
| # as in the non-shifted case, when we pad the inputs, apply the operations and then | |
| # remove the padding from the result | |
| d_shifted = AP.swin_attention_pattern( | |
| H, W, window_size, shift_size=window_size // 2 | |
| ) | |
| # add padding and remove shift | |
| h = H + window_size | |
| w = W + window_size | |
| d_padded = AP.swin_attention_pattern(h, w, window_size, shift_size=0) | |
| d_padded = d_padded.reshape(h, w, h, w) | |
| # remove padding elements | |
| half_size = window_size // 2 | |
| s = slice(half_size, -half_size) | |
| d_padded = d_padded[s, s, s, s].reshape(H * W, H * W) | |
| assert torch.all(d_padded == d_shifted) | |
| def test_dilated_2d_pattern(H, W, k): | |
| d = AP.dilated_2d_pattern(H, W, k) | |
| d = d.reshape(H, W, H, W) | |
| product_HW = itertools.product(range(H), range(W)) | |
| product_kk = itertools.product(range(k), range(k)) | |
| for h, w in product_HW: | |
| i = h % k | |
| j = w % k | |
| # every kth element is taken | |
| assert torch.all(d[h, w][i::k, j::k]) | |
| for ii, jj in product_kk: | |
| if ii == i and jj == j: | |
| continue | |
| # and the other elements are discarded | |
| assert torch.all(~d[h, w][ii::k, jj::k]) | |
| def test_pattern_to_layout(): | |
| BLOCK = 16 | |
| SIZE = 128 | |
| LAYOUT_SIZE = SIZE // BLOCK | |
| # All ones | |
| mask1 = torch.ones((SIZE, SIZE), dtype=torch.bool) | |
| layout1 = AP.pattern_to_layout(mask1, BLOCK) | |
| ref1 = torch.ones((LAYOUT_SIZE, LAYOUT_SIZE), dtype=torch.long) | |
| assert torch.allclose(layout1, ref1) | |
| # Diagonal -> expect block diagonal | |
| mask2 = torch.eye(SIZE, dtype=torch.bool) | |
| layout2 = AP.pattern_to_layout(mask2, BLOCK) | |
| ref2 = torch.eye(LAYOUT_SIZE, dtype=torch.long) | |
| assert torch.allclose(layout2, ref2) | |
| # Lower triangular, without the diagonal | |
| # note that the layout will need to have the diagonal, else the coefficients close enough would not be computed | |
| mask3 = torch.tril(torch.ones((SIZE, SIZE)), diagonal=-1).to(torch.bool) | |
| layout3 = AP.pattern_to_layout(mask3, BLOCK) | |
| ref3 = torch.tril(torch.ones((LAYOUT_SIZE, LAYOUT_SIZE)), diagonal=0).to(torch.long) | |
| assert torch.allclose(layout3, ref3) | |
| # Handle heads properly | |
| mask = torch.cat((mask1, mask2, mask3)) | |
| layout = AP.pattern_to_layout(mask, BLOCK) | |
| assert torch.allclose(layout, torch.cat((ref1, ref2, ref3))) | |
| # Catch problematic dimensions | |
| mask_off = torch.ones((SIZE + 3, SIZE), dtype=torch.bool) | |
| with pytest.raises(AssertionError): | |
| AP.pattern_to_layout(mask_off, BLOCK) | |
| def test_alibi_pattern(): | |
| mask = AP.alibi_pattern(1e-3, (16, 128, 128)) | |
| # Minor, check that all the top left corners are True | |
| assert torch.sum(mask[:, 0, 0]) == 16 | |
| def test_quick_layouts(): | |
| seq_size = 128 | |
| block_size = 16 | |
| num_heads = 2 | |
| # Fixed | |
| assert torch.allclose( | |
| AP.quick_fixed_layout(num_heads, block_size, seq_size), | |
| torch.Tensor( | |
| [ | |
| [ | |
| [1, 1, 1, 1, 0, 0, 0, 1], | |
| [1, 1, 1, 1, 0, 0, 0, 1], | |
| [1, 1, 1, 1, 0, 0, 0, 1], | |
| [1, 1, 1, 1, 0, 0, 0, 1], | |
| [0, 0, 0, 1, 1, 1, 1, 1], | |
| [0, 0, 0, 1, 1, 1, 1, 1], | |
| [0, 0, 0, 1, 1, 1, 1, 1], | |
| [0, 0, 0, 1, 1, 1, 1, 1], | |
| ], | |
| [ | |
| [1, 1, 1, 1, 0, 0, 0, 1], | |
| [1, 1, 1, 1, 0, 0, 0, 1], | |
| [1, 1, 1, 1, 0, 0, 0, 1], | |
| [1, 1, 1, 1, 0, 0, 0, 1], | |
| [0, 0, 0, 1, 1, 1, 1, 1], | |
| [0, 0, 0, 1, 1, 1, 1, 1], | |
| [0, 0, 0, 1, 1, 1, 1, 1], | |
| [0, 0, 0, 1, 1, 1, 1, 1], | |
| ], | |
| ] | |
| ).long(), | |
| ) | |
| # BSLongformer | |
| assert torch.allclose( | |
| AP.quick_bslongformer_layout(num_heads, block_size, seq_size), | |
| torch.Tensor( | |
| [ | |
| [ | |
| [1, 1, 1, 1, 1, 1, 1, 1], | |
| [1, 1, 1, 0, 0, 0, 0, 0], | |
| [1, 1, 1, 1, 0, 0, 0, 0], | |
| [1, 0, 1, 1, 1, 0, 0, 0], | |
| [1, 0, 0, 1, 1, 1, 0, 0], | |
| [1, 0, 0, 0, 1, 1, 1, 0], | |
| [1, 0, 0, 0, 0, 1, 1, 1], | |
| [1, 0, 0, 0, 0, 0, 1, 1], | |
| ], | |
| [ | |
| [1, 1, 1, 1, 1, 1, 1, 1], | |
| [1, 1, 1, 0, 0, 0, 0, 0], | |
| [1, 1, 1, 1, 0, 0, 0, 0], | |
| [1, 0, 1, 1, 1, 0, 0, 0], | |
| [1, 0, 0, 1, 1, 1, 0, 0], | |
| [1, 0, 0, 0, 1, 1, 1, 0], | |
| [1, 0, 0, 0, 0, 1, 1, 1], | |
| [1, 0, 0, 0, 0, 0, 1, 1], | |
| ], | |
| ] | |
| ).long(), | |
| ) | |
| # Variable | |
| assert torch.allclose( | |
| AP.quick_variable_layout(num_heads, block_size, seq_size), | |
| torch.Tensor( | |
| [ | |
| [ | |
| [1, 1, 1, 1, 0, 0, 0, 0], | |
| [1, 1, 1, 1, 0, 0, 0, 0], | |
| [1, 1, 1, 1, 0, 0, 0, 0], | |
| [1, 1, 1, 1, 0, 0, 0, 0], | |
| [1, 0, 0, 0, 1, 1, 1, 1], | |
| [1, 0, 0, 0, 1, 1, 1, 1], | |
| [1, 0, 0, 0, 1, 1, 1, 1], | |
| [1, 0, 0, 0, 1, 1, 1, 1], | |
| ], | |
| [ | |
| [1, 1, 1, 1, 0, 0, 0, 0], | |
| [1, 1, 1, 1, 0, 0, 0, 0], | |
| [1, 1, 1, 1, 0, 0, 0, 0], | |
| [1, 1, 1, 1, 0, 0, 0, 0], | |
| [1, 0, 0, 0, 1, 1, 1, 1], | |
| [1, 0, 0, 0, 1, 1, 1, 1], | |
| [1, 0, 0, 0, 1, 1, 1, 1], | |
| [1, 0, 0, 0, 1, 1, 1, 1], | |
| ], | |
| ] | |
| ).long(), | |
| ) | |
| # BigBird (just the shape) | |
| assert AP.quick_bigbird_layout(num_heads, block_size, seq_size).shape == torch.Size( | |
| [num_heads, seq_size // block_size, seq_size // block_size] | |
| ) | |
| def test_layout_to_pattern(): | |
| torch.allclose( | |
| AP.layout_to_pattern( | |
| layout=torch.Tensor([[[0, 1], [1, 0]], [[1, 0], [0, 1]]]), block_size=2 | |
| ), | |
| torch.Tensor( | |
| [ | |
| [ | |
| [0.0, 0.0, 1.0, 1.0], | |
| [0.0, 0.0, 1.0, 1.0], | |
| [1.0, 1.0, 0.0, 0.0], | |
| [1.0, 1.0, 0.0, 0.0], | |
| ], | |
| [ | |
| [1.0, 1.0, 0.0, 0.0], | |
| [1.0, 1.0, 0.0, 0.0], | |
| [0.0, 0.0, 1.0, 1.0], | |
| [0.0, 0.0, 1.0, 1.0], | |
| ], | |
| ] | |
| ), | |
| ) | |
| def test_dense_sparsity_config(): | |
| sc = DenseSparsityConfig(num_heads=1, block_size=16) | |
| with pytest.raises(expected_exception=ValueError): | |
| sc.setup_layout(seq_len=17) | |
| assert torch.allclose( | |
| sc.make_layout(seq_len=32), torch.Tensor([[[1, 1], [1, 1]]]).long() | |
| ) | |
| def test_big_bird_sparsity_config(): | |
| sc = BigBirdSparsityConfig( | |
| num_heads=1, | |
| block_size=16, | |
| num_random_blocks=2, | |
| num_sliding_window_blocks=1, | |
| num_global_blocks=1, | |
| ) | |
| with pytest.raises(expected_exception=ValueError): | |
| sc.make_layout(seq_len=16) | |
| sc = BigBirdSparsityConfig( | |
| num_heads=1, | |
| block_size=16, | |
| num_random_blocks=1, | |
| num_sliding_window_blocks=2, | |
| num_global_blocks=1, | |
| ) | |
| with pytest.raises(expected_exception=ValueError): | |
| sc.make_layout(seq_len=16) | |
| sc = BigBirdSparsityConfig( | |
| num_heads=1, | |
| block_size=16, | |
| num_random_blocks=1, | |
| num_sliding_window_blocks=1, | |
| num_global_blocks=2, | |
| ) | |
| with pytest.raises(expected_exception=ValueError): | |
| sc.make_layout(seq_len=16) | |
| with pytest.raises(expected_exception=NotImplementedError): | |
| BigBirdSparsityConfig(num_heads=1, attention="directional") | |
| def test_bslongformer_sparsity_config(): | |
| sc = BSLongformerSparsityConfig(num_heads=1, global_block_end_indices=[1]) | |
| assert torch.allclose( | |
| sc.make_layout(128), | |
| torch.Tensor( | |
| [ | |
| [ | |
| [1, 1, 1, 1, 1, 1, 1, 1], | |
| [1, 1, 1, 0, 0, 0, 0, 0], | |
| [1, 1, 1, 1, 0, 0, 0, 0], | |
| [1, 0, 1, 1, 1, 0, 0, 0], | |
| [1, 0, 0, 1, 1, 1, 0, 0], | |
| [1, 0, 0, 0, 1, 1, 1, 0], | |
| [1, 0, 0, 0, 0, 1, 1, 1], | |
| [1, 0, 0, 0, 0, 0, 1, 1], | |
| ] | |
| ] | |
| ).long(), | |
| ) | |
| with pytest.raises(expected_exception=ValueError): | |
| BSLongformerSparsityConfig(num_heads=1, global_block_end_indices=[]) | |
| with pytest.raises(expected_exception=ValueError): | |
| BSLongformerSparsityConfig(num_heads=1, global_block_end_indices=[-1]) | |
| def test_fixed_sparsity_config(): | |
| # chech that the case end < num_blocks is correct | |
| sc = FixedSparsityConfig(num_heads=1, horizontal_global_attention=True) | |
| assert torch.allclose( | |
| sc.make_layout(112), | |
| torch.Tensor( | |
| [ | |
| [ | |
| [1, 1, 1, 1, 0, 0, 1], | |
| [1, 1, 1, 1, 0, 0, 1], | |
| [1, 1, 1, 1, 0, 0, 1], | |
| [1, 1, 1, 1, 1, 1, 1], | |
| [0, 0, 0, 1, 1, 1, 1], | |
| [0, 0, 0, 1, 1, 1, 1], | |
| [1, 1, 1, 1, 1, 1, 1], | |
| ] | |
| ] | |
| ).long(), | |
| ) | |
| with pytest.raises(expected_exception=ValueError): | |
| FixedSparsityConfig(num_heads=1, num_local_blocks=3, num_global_blocks=2) | |
| with pytest.raises(expected_exception=NotImplementedError): | |
| FixedSparsityConfig(num_heads=1, attention="directional") | |
| with pytest.raises(expected_exception=ValueError): | |
| FixedSparsityConfig( | |
| num_heads=1, attention="unidirectional", horizontal_global_attention=True | |
| ) | |
| with pytest.raises(expected_exception=ValueError): | |
| FixedSparsityConfig( | |
| num_heads=1, | |
| num_different_global_patterns=2, | |
| different_layout_per_head=False, | |
| ) | |
| with pytest.raises(expected_exception=ValueError): | |
| FixedSparsityConfig( | |
| num_heads=1, | |
| num_different_global_patterns=10, | |
| num_local_blocks=4, | |
| num_global_blocks=1, | |
| ) | |
| def test_variable_sparsity_config(): | |
| sc = VariableSparsityConfig(num_heads=1, global_block_end_indices=[1]) | |
| assert torch.allclose( | |
| sc.make_layout(128), | |
| torch.Tensor( | |
| [ | |
| [ | |
| [1, 1, 1, 1, 0, 0, 0, 0], | |
| [1, 1, 1, 1, 0, 0, 0, 0], | |
| [1, 1, 1, 1, 0, 0, 0, 0], | |
| [1, 1, 1, 1, 0, 0, 0, 0], | |
| [1, 0, 0, 0, 1, 1, 1, 1], | |
| [1, 0, 0, 0, 1, 1, 1, 1], | |
| [1, 0, 0, 0, 1, 1, 1, 1], | |
| [1, 0, 0, 0, 1, 1, 1, 1], | |
| ] | |
| ] | |
| ).long(), | |
| ) | |
| with pytest.raises(expected_exception=ValueError): | |
| VariableSparsityConfig(num_heads=1, global_block_end_indices=[]) | |
| with pytest.raises(expected_exception=ValueError): | |
| VariableSparsityConfig(num_heads=1, global_block_end_indices=[-1]) | |