Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import pytest | |
| import torch | |
| from xformers.components.positional_embedding import RotaryEmbedding | |
| from xformers.components.positional_embedding.rotary import ( | |
| apply_rotary_pos_emb, | |
| rotate_half, | |
| ) | |
| DEVICES = ( | |
| [torch.device("cpu")] | |
| if not torch.cuda.is_available() | |
| else [ | |
| torch.device("cuda") | |
| ] # save a bit on CI for now, we have separate cpu and gpu jobs | |
| ) | |
| BATCH = 2 | |
| SEQ = 32 | |
| HEADS = 2 | |
| EMB = 32 | |
| def test_helper_methods(): | |
| # rotate_half | |
| tens = torch.tensor([[0, 1, 2, 3], [3, 1, 2, 0], [0, 1, 0, 1], [1, 0, 1, 0]]) | |
| tens_rotated = rotate_half(tens) | |
| assert torch.equal( | |
| tens_rotated, | |
| torch.tensor([[-2, -3, 0, 1], [-2, 0, 3, 1], [0, -1, 0, 1], [-1, 0, 1, 0]]), | |
| ) | |
| # apply_rotary_pos_emb | |
| cos_test = torch.ones((1, 1, 4, 4)) | |
| sin_test = cos_test.clone() | |
| q_test = 3 * torch.ones((2, 2, 3, 4)) | |
| q_applied = apply_rotary_pos_emb(q_test, cos_test, sin_test) | |
| assert torch.equal( | |
| q_applied, | |
| torch.concat( | |
| ( | |
| torch.zeros((2, 2, 3, 2), dtype=torch.float), | |
| 6 * torch.ones((2, 2, 3, 2), dtype=torch.float), | |
| ), | |
| dim=-1, | |
| ), | |
| ) | |
| def test_rotary_embeddings(device, dtype): | |
| rotary = RotaryEmbedding(EMB).to(device) | |
| # Generate dummy inputs | |
| q = torch.ones( | |
| (BATCH, HEADS, SEQ, EMB), device=device, dtype=dtype | |
| ) # uniform on purpose | |
| k = q.clone() | |
| q_rot, k_rot = rotary(q, k) | |
| assert q_rot.dtype == q.dtype | |
| assert k_rot.dtype == k.dtype | |
| # Check that the sequences now encode relative position information | |
| q, k = q.float(), k.float() | |
| q_rot, k_rot = q_rot.float(), k_rot.float() | |
| att = torch.einsum("bhne,bhme->bhnm", q, k) | |
| att_rot = torch.einsum("bhne,bhme->bhnm", q_rot, k_rot) | |
| # - the attention for the same positions is not meaningfully changed | |
| assert torch.allclose( | |
| torch.diag(att[0, 0, :, :]), torch.diag(att_rot[0, 0, :, :]), rtol=0.1 | |
| ) | |
| # - the post-rotary attention is more focused on the diagonal | |
| diag_max = torch.max(torch.diag(att_rot[0, 0, :, :])) | |
| att_rot -= diag_max | |
| att_rot = ( | |
| att_rot <= 1e-4 | |
| ) # all non diagonal elements had lower attention than diagonal (+ float tolerance) | |
| assert torch.all(att_rot) | |
| # Test that different sequence lengths is ok | |
| _, _ = rotary(q[:, :, :-16, :], k) | |