Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import pytest | |
| import torch | |
| from xformers.components import MultiHeadDispatch | |
| # Automatically test all the registered attentions | |
| from xformers.components.attention import ATTENTION_REGISTRY, build_attention | |
| DEVICES = ( | |
| [torch.device("cpu")] if not torch.cuda.is_available() else [torch.device("cuda")] | |
| ) | |
| BATCH = 2 | |
| SEQ = 128 if torch.cuda.is_available() else 16 | |
| MODEL = 128 if torch.cuda.is_available() else 32 | |
| assert ATTENTION_REGISTRY.keys(), "Attention layers should have been registered" | |
| def test_build_and_run( | |
| heads: int, | |
| attn_dropout: float, | |
| causal: bool, | |
| rules: int, | |
| q_compose: bool, | |
| dim_selection: int, | |
| num_rules: int, | |
| qk_rule: bool, | |
| nonlinear: bool, | |
| device: torch.device, | |
| ): | |
| torch.manual_seed(42) | |
| test_config = { | |
| "name": "compositional", | |
| "dropout": attn_dropout, | |
| "causal": causal, | |
| "seq_len": SEQ, | |
| "dim_model": MODEL, | |
| "num_heads": heads, | |
| "num_rules": num_rules, | |
| "q_compose": q_compose, | |
| "rules": rules, | |
| "dim_selection": dim_selection, | |
| "qk_rule": qk_rule, | |
| "nonlinear": nonlinear, | |
| } | |
| attention = build_attention(test_config) | |
| # build a multi head dispatch to test this attention mechanism | |
| multi_head = MultiHeadDispatch( | |
| seq_len=SEQ, | |
| dim_model=MODEL, | |
| num_heads=heads, | |
| attention=attention, | |
| residual_dropout=0.0, | |
| ).to(device) | |
| # Check that a shuffled input produces the same results | |
| seqs = [SEQ, SEQ // 2] | |
| for seq in seqs: | |
| # Check that we can pass a smaller sequence | |
| inputs = torch.rand(BATCH, seq, MODEL, device=device) | |
| shuffle = torch.randperm(inputs.shape[1]) | |
| inputs_shuffled = inputs[:, shuffle, :].clone() | |
| results = multi_head(inputs, inputs, inputs) | |
| results_shuffled = multi_head(inputs_shuffled, inputs_shuffled, inputs_shuffled) | |
| if attn_dropout == 0.0 and num_rules == 1 and not causal: | |
| assert (results[:, shuffle, :] - results_shuffled).abs().max() < 1e-3 | |
| # Test the non-self-attention codepath | |
| att = multi_head(inputs, inputs_shuffled, inputs) | |
| # Check that dropout actually drops some values | |
| if attn_dropout > 0: | |
| att_2 = multi_head(inputs, inputs_shuffled, inputs) | |
| assert (att != att_2).any() | |