Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| from typing import Tuple | |
| import pytest | |
| import torch | |
| from xformers.components import ( | |
| InputProjection, | |
| InputProjectionConfig, | |
| MultiHeadDispatch, | |
| ) | |
| # Automatically test all the registered attentions | |
| from xformers.components.attention import ( | |
| _DENSITY_THRESHOLD, | |
| ATTENTION_REGISTRY, | |
| build_attention, | |
| ) | |
| DEVICES = ( | |
| [torch.device("cpu")] if not torch.cuda.is_available() else [torch.device("cuda")] | |
| ) | |
| BATCH = 2 | |
| SEQ = 128 if torch.cuda.is_available() else 36 | |
| MODEL = 128 if torch.cuda.is_available() else 16 | |
| GLOBAL_ATTENTION_RATIO = ( | |
| _DENSITY_THRESHOLD * 0.9 | |
| ) # Make sure that we test the sparse implementation, no matter the threshold | |
| assert ATTENTION_REGISTRY.keys(), "Attention layers should have been registered" | |
| _non_order_invariant_attentions = ["visual", "pooling"] | |
| def _get_multihead( | |
| attention_name, | |
| attn_dropout, | |
| res_dropout, | |
| causal, | |
| heads, | |
| device, | |
| skip_output_projection=False, | |
| use_separate_proj_weights=True, | |
| ): | |
| test_config = { | |
| "name": attention_name, | |
| "dropout": attn_dropout, | |
| "causal": causal, | |
| "seq_len": SEQ, | |
| "window_size": SEQ // 8 + 1, # local attention | |
| "attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO, | |
| "dim_model": MODEL, | |
| "num_heads": heads, | |
| "dim_head": MODEL / heads, | |
| "num_rules": 2, # Compositional Attention | |
| "r": 0.5, # random attention, ratio of tokens that the attention can attend to | |
| } | |
| if skip_output_projection: | |
| def noop(x): | |
| return x | |
| test_config["out_proj"] = noop | |
| # Add some blocksparse layout to test the corresponding attention | |
| block_size = 16 | |
| test_config["layout"] = torch.eye( | |
| SEQ // block_size, SEQ // block_size, dtype=torch.long | |
| ) | |
| test_config["block_size"] = block_size | |
| attention = build_attention(test_config) | |
| # build a multi head dispatch to test this attention mechanism | |
| multi_head = MultiHeadDispatch( | |
| seq_len=SEQ, | |
| dim_model=MODEL, | |
| residual_dropout=res_dropout, | |
| num_heads=heads, | |
| attention=attention, | |
| use_separate_proj_weight=use_separate_proj_weights, | |
| ).to(device) | |
| return multi_head | |
| def test_order_invariance( | |
| attention_name: str, | |
| heads: int, | |
| attn_dropout: float, | |
| residual_dropout: float, | |
| causal: bool, | |
| device: torch.device, | |
| ): | |
| if ( | |
| torch.version.hip | |
| and device == torch.device("cuda") | |
| and attention_name == "local" | |
| ): | |
| # Backend calls into Sputnik library which isn't built on ROCm | |
| device = torch.device("cpu") | |
| torch.manual_seed(42) | |
| torch.cuda.manual_seed_all(42) | |
| multi_head = _get_multihead( | |
| attention_name, | |
| attn_dropout, | |
| residual_dropout, | |
| causal, | |
| heads, | |
| device, | |
| use_separate_proj_weights=False, | |
| ) | |
| if ( | |
| int(math.sqrt(SEQ)) ** 2 != SEQ | |
| and multi_head.attention.requires_squared_context | |
| ): | |
| pytest.skip(f"{attention_name} requires squared sequence lengths") | |
| # Check that we can pass a smaller sequence | |
| seqs = ( | |
| [SEQ, SEQ // 2] | |
| if not multi_head.attention.requires_same_k_q_dimensions | |
| else [SEQ] | |
| ) | |
| for seq in seqs: | |
| # Check that the attention is invariant to a permutation of K, V | |
| inputs = torch.rand(BATCH, seq, MODEL, device=device) | |
| shuffle = torch.randperm(inputs.shape[1]) | |
| inputs_shuffled = inputs[:, shuffle, :].clone() | |
| results = multi_head(inputs, inputs, inputs) | |
| results_shuffled = multi_head(inputs, inputs_shuffled, inputs_shuffled) | |
| torch.allclose(results, results_shuffled) | |
| # Check that the attention is equivariant to a permutation of Q, | |
| # meaning that the result is permuted in the same way | |
| results_shuffled = multi_head(inputs_shuffled, inputs, inputs) | |
| torch.allclose(results[:, shuffle, :], results_shuffled) | |
| # Check that dropout actually drops some values | |
| if attn_dropout > 0: | |
| att_1 = multi_head(inputs, inputs_shuffled, inputs) | |
| att_2 = multi_head(inputs, inputs_shuffled, inputs) | |
| assert (att_1 != att_2).any() | |
| # Test AMP, if available | |
| if device.type == "cuda": | |
| with torch.cuda.amp.autocast(enabled=True): | |
| _ = multi_head(inputs, inputs_shuffled, inputs) | |
| def test_kqv_ordering( | |
| attention_name: str, | |
| heads: int, | |
| device: torch.device, | |
| ): | |
| multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) | |
| # Check kqv are not flipped | |
| # this will not catch all issues, but would catch a V being misplaced | |
| # make k and q complimentary, so that QKt is all zero and attention is uniform | |
| q = torch.cat( | |
| ( | |
| torch.rand((1, MODEL // 2), device=device), | |
| torch.zeros((1, MODEL // 2), device=device), | |
| ), | |
| dim=1, | |
| ).expand((BATCH, SEQ, MODEL)) | |
| k = torch.cat( | |
| ( | |
| torch.zeros((1, MODEL // 2), device=device), | |
| torch.rand((1, MODEL // 2), device=device), | |
| ), | |
| dim=1, | |
| ).expand((BATCH, SEQ, MODEL)) | |
| v = torch.rand(BATCH, SEQ, MODEL, device=device) | |
| # Normal call | |
| res = multi_head(query=q, key=k, value=v) | |
| for i in range(BATCH): | |
| assert torch.allclose(res[i, :, :], res[i, 0, :].unsqueeze(-2)) | |
| assert not torch.allclose(res[0, :, :], res[1, :, :]) | |
| # Flip qkv, and check that we invert the above check properly | |
| res_false = multi_head(query=v, key=k, value=q) | |
| assert torch.allclose(res_false[0, :, :], res_false[1, :, :]) | |
| def test_different_seqlen( | |
| attention_name: str, | |
| heads: int, | |
| device: torch.device, | |
| ): | |
| multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) | |
| # Check kqv are not flipped | |
| # this will not catch all issues, but would catch a V being misplaced | |
| # make k and q complimentary, so that QKt is all zero and attention is uniform | |
| q = torch.cat( | |
| ( | |
| torch.rand((1, MODEL // 2), device=device), | |
| torch.zeros((1, MODEL // 2), device=device), | |
| ), | |
| dim=1, | |
| ).expand((BATCH, SEQ, MODEL)) | |
| k = torch.cat( | |
| ( | |
| torch.zeros((1, MODEL // 2), device=device), | |
| torch.rand((1, MODEL // 2), device=device), | |
| ), | |
| dim=1, | |
| ).expand((BATCH, SEQ, MODEL)) | |
| v = torch.rand(BATCH, SEQ, MODEL, device=device) | |
| # Normal call | |
| res = multi_head(query=q, key=k, value=v) | |
| # Changing sequence length by dividing by two to simulate differing sequence length | |
| q2 = torch.cat( | |
| ( | |
| torch.rand((1, MODEL // 2), device=device), | |
| torch.zeros((1, MODEL // 2), device=device), | |
| ), | |
| dim=1, | |
| ).expand((BATCH, SEQ // 2, MODEL)) | |
| k2 = torch.cat( | |
| ( | |
| torch.zeros((1, MODEL // 2), device=device), | |
| torch.rand((1, MODEL // 2), device=device), | |
| ), | |
| dim=1, | |
| ).expand((BATCH, SEQ // 2, MODEL)) | |
| v2 = torch.rand(BATCH, SEQ // 2, MODEL, device=device) | |
| res2 = multi_head(query=q2, key=k2, value=v2) | |
| assert res.shape != res2.shape | |
| def test_inproj(proj_bias: bool, same_sizes: bool, same_settings: bool): | |
| test_config = { | |
| "name": "scaled_dot_product", | |
| "dropout": 0.1, | |
| "causal": False, | |
| "seq_len": SEQ, | |
| "window_size": SEQ // 8 + 1, | |
| "num_heads": 1, | |
| "dim_head": MODEL, | |
| } | |
| attention = build_attention(test_config) | |
| # Construct the initial projection, test different options | |
| in_params = InputProjectionConfig(MODEL, MODEL, proj_bias) | |
| if same_settings: | |
| in_proj = InputProjection(in_params, None, None) | |
| out_features = MODEL | |
| else: | |
| out_features = MODEL if same_sizes else MODEL // 2 | |
| in_params_flip = InputProjectionConfig(MODEL, out_features, proj_bias) | |
| in_proj = InputProjection( | |
| in_params_flip, # Q proj | |
| in_params_flip, # K proj | |
| in_params, # V proj | |
| ) | |
| # build a multi head dispatch to test this attention mechanism | |
| multi_head = MultiHeadDispatch( | |
| seq_len=SEQ, | |
| dim_model=MODEL, | |
| residual_dropout=0.1, | |
| num_heads=1, | |
| attention=attention, | |
| in_proj_container=in_proj, | |
| dim_key=out_features, | |
| dim_value=MODEL, | |
| ) | |
| # Check kqv are not flipped | |
| # this will not catch all issues, but would catch a V being misplaced | |
| # make k and q complimentary, so that QKt is all zero and attention is uniform | |
| q = torch.cat( | |
| ( | |
| torch.rand((1, MODEL // 2)), | |
| torch.zeros((1, MODEL // 2)), | |
| ), | |
| dim=1, | |
| ).expand((BATCH, SEQ, MODEL)) | |
| k = torch.cat( | |
| ( | |
| torch.zeros((1, MODEL // 2)), | |
| torch.rand((1, MODEL // 2)), | |
| ), | |
| dim=1, | |
| ).expand((BATCH, SEQ, MODEL)) | |
| v = torch.rand(BATCH, SEQ, MODEL) | |
| # just check that a FW does not assert out | |
| _ = multi_head(query=q, key=k, value=v) | |
| def test_different_kq_dimensions( | |
| attention_name: str, | |
| heads: int, | |
| device: torch.device, | |
| ): | |
| multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) | |
| if multi_head.attention.requires_same_k_q_dimensions: | |
| # pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre. | |
| pytest.skip(f"{attention_name} does not support different k, q dimensions yet.") | |
| seq_q = SEQ // 2 | |
| q = torch.rand((BATCH, seq_q, MODEL), device=device) | |
| k = torch.rand((BATCH, SEQ, MODEL), device=device) | |
| v = torch.rand((BATCH, SEQ, MODEL), device=device) | |
| res = multi_head(query=q, key=k, value=v) | |
| assert res.shape == torch.Size([BATCH, seq_q, MODEL]) | |
| def test_broadcast_batch_dimension( | |
| attention_name: str, | |
| heads: int, | |
| device: torch.device, | |
| batch_sizes: Tuple[int, int, int], | |
| ): | |
| Q_BATCH, K_BATCH, V_BATCH = batch_sizes | |
| multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) | |
| if ( | |
| int(math.sqrt(SEQ)) ** 2 != SEQ | |
| and multi_head.attention.requires_squared_context | |
| ): | |
| pytest.skip(f"{attention_name} requires squared sequence lengths") | |
| if multi_head.attention.requires_same_k_q_dimensions: | |
| # pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre. | |
| pytest.skip(f"{attention_name} does not support different k, q dimensions yet.") | |
| q = torch.rand((Q_BATCH, SEQ, MODEL), device=device) | |
| k = torch.rand((K_BATCH, SEQ, MODEL), device=device) | |
| v = torch.rand((V_BATCH, SEQ, MODEL), device=device) | |
| res = multi_head(query=q, key=k, value=v) | |
| assert res.shape == torch.Size([BATCH, SEQ, MODEL]) | |
| def test_causal( | |
| attention_name: str, | |
| heads: int, | |
| ): | |
| """ | |
| Make sure that the causal flag is respected. | |
| The input data is orthogonal by design if causal is respected, but if the attention looks ahead this will fail | |
| """ | |
| torch.random.manual_seed(42) | |
| device = torch.device("cuda") | |
| multi_head = _get_multihead( | |
| attention_name, | |
| 0.0, | |
| 0.0, | |
| causal=True, | |
| heads=heads, | |
| device=device, | |
| skip_output_projection=True, | |
| ) | |
| k = ( | |
| torch.tril(torch.ones((SEQ, SEQ), device=device), diagonal=0) | |
| .unsqueeze(0) | |
| .expand(1, -1, -1) | |
| ) | |
| q = ( | |
| torch.triu(torch.ones((SEQ, SEQ), device=device), diagonal=0) | |
| .unsqueeze(0) | |
| .expand(1, -1, -1) | |
| ) | |
| v = ( | |
| torch.arange(SEQ, device=device) | |
| .float() | |
| .unsqueeze(0) | |
| .unsqueeze(-1) | |
| .expand(1, -1, SEQ) | |
| ) | |
| # Make sure that we don´t project, to keep the embeddings orthogonal | |
| multi_head.attention.requires_input_projection = False | |
| res = multi_head(query=q, key=k, value=v).squeeze(0) | |
| # Consolidate along the embedding, if causal was respected the amplitude should be sorted already | |
| res_sum = torch.sum(res, dim=1).cpu() | |
| assert torch.allclose(torch.sort(res_sum)[1], torch.arange(SEQ)) or torch.allclose( | |
| torch.sort(res_sum, descending=True)[1], torch.arange(SEQ) | |
| ), res_sum | |
| def test_torch_script_ability( | |
| attention_name: str, | |
| heads: int, | |
| attn_dropout: float, | |
| ): | |
| if attention_name in {"favor", "global", "local", "random"}: | |
| # pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre. | |
| pytest.skip(f"{attention_name} does not support scripting yet.") | |
| device = torch.device("cpu") | |
| multi_head = _get_multihead(attention_name, attn_dropout, 0.0, False, heads, device) | |
| if ( | |
| int(math.sqrt(SEQ)) ** 2 != SEQ | |
| and multi_head.attention.requires_squared_context | |
| ): | |
| pytest.skip(f"{attention_name} requires squared sequence lengths") | |
| # input for tracing the function | |
| q = torch.rand((BATCH, SEQ, MODEL), device=device) | |
| k = torch.rand((BATCH, SEQ, MODEL), device=device) | |
| v = torch.rand((BATCH, SEQ, MODEL), device=device) | |
| # to make sure dropout behaves deterministically | |
| torch.random.manual_seed(42) | |
| # tracing the attention module | |
| traced_multi_head = torch.jit.trace(multi_head, (q, k, v)) | |
| # create new random inputs for testing the eager model and traced model | |
| q = torch.rand((BATCH, SEQ, MODEL), device=device) | |
| k = torch.rand((BATCH, SEQ, MODEL), device=device) | |
| v = torch.rand((BATCH, SEQ, MODEL), device=device) | |
| # to make sure dropout behaves deterministically need to set the seed again | |
| torch.random.manual_seed(42) | |
| res = multi_head(query=q, key=k, value=v) | |
| # to make sure dropout behaves deterministically need to set the seed again | |
| torch.random.manual_seed(42) | |
| res_traced = traced_multi_head(query=q, key=k, value=v) | |
| assert torch.allclose(res, res_traced) | |
| # TODO: way more unit tests.. | |