Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import pytest | |
| import torch | |
| from xformers.components.attention import maybe_sparsify | |
| from xformers.components.attention._sputnik_sparse import _dense_to_sparse | |
| from xformers.components.attention.core import SparseCS, _create_random_sparsity | |
| B = 2 | |
| M = 16 # not a nice round number, on purpose | |
| _devices_list = ["cpu", "cuda:0"] if torch.cuda.is_available() else ["cpu"] | |
| _devices = [torch.device(d) for d in _devices_list] | |
| def test_logical_and(device): | |
| mask = _create_random_sparsity(torch.ones(B, M, M, dtype=torch.bool), 0.1) | |
| mask_cs = SparseCS(mask, device) | |
| # Check that we cannot & two sparse matrices (for now) | |
| with pytest.raises(Exception): | |
| _ = mask_cs & mask_cs | |
| # Check that & ones returns the same values | |
| mask_ones = mask_cs & torch.ones_like(mask, dtype=torch.bool, device=device) | |
| assert torch.allclose(mask_cs.to_dense().long(), mask_ones.to_dense().long()) | |
| # Check that & the inverse returns 0 all around | |
| mask_not = ~mask.to(device) | |
| assert (mask_cs & mask_not).values.numel() == 0 | |
| def test_dense_sparse(seq, device): | |
| # Check that we can .to_dense() without crashing | |
| mask = torch.rand(seq, seq, device=device) > 0.1 | |
| mask_cs = SparseCS(mask, device) | |
| mask_back_forth = SparseCS(mask_cs.to_dense(), device) | |
| assert torch.allclose(mask_cs.to_dense().long(), mask_back_forth.to_dense().long()) | |
| def test_device(device): | |
| mask = _create_random_sparsity( | |
| torch.ones(B, M, M, dtype=torch.bool, device=device), 0.1 | |
| ) | |
| assert mask.device.type == device.type | |
| sparse_mask = maybe_sparsify(mask) | |
| assert sparse_mask.device.type == device.type | |
| def _baseline_dense_to_sparse(matrix): | |
| import numpy as np | |
| # Extract the nonzero values. | |
| values = matrix.compress((matrix != 0).flatten()) | |
| # Calculate the offset of each row. | |
| mask = (matrix != 0).astype(np.int32) | |
| row_offsets = np.concatenate(([0], np.cumsum(np.add.reduce(mask, axis=1))), axis=0) | |
| # Create the row indices and sort them. | |
| # note: use torch.argsort to make it compatible as sorting is not stable in PyTorch | |
| row_indices = torch.argsort(-1 * torch.as_tensor(np.diff(row_offsets))).numpy() | |
| # Extract the column indices for the nonzero values. | |
| x = mask * (np.arange(matrix.shape[1]) + 1) | |
| column_indices = x.compress((x != 0).flatten()) | |
| column_indices = column_indices - 1 | |
| # Cast the desired precision. | |
| values = torch.as_tensor(values.astype(np.float32)) | |
| row_indices, row_offsets, column_indices = [ | |
| torch.as_tensor(x.astype(np.int32)) | |
| for x in [row_indices, row_offsets, column_indices] | |
| ] | |
| return values, row_indices, row_offsets, column_indices | |
| def test_dense_to_sparse(seq, device): | |
| matrix = torch.rand(seq, seq, device=device) | |
| matrix[matrix > 0.9] = 0 | |
| baseline_res = _baseline_dense_to_sparse(matrix.cpu().numpy()) | |
| res = _dense_to_sparse(matrix, device=device) | |
| _idx_to_name = ["values", "row_indices", "row_offsets", "column_indices"] | |
| for idx, (bi, i) in enumerate(zip(baseline_res, res)): | |
| if idx != 1: | |
| # row_indices is the result of an argsort, which is not stable | |
| # for same number of elements | |
| assert torch.allclose(bi.to(device), i), f"error in {_idx_to_name[idx]}" | |
| assert bi.dtype == i.dtype | |
| assert i.device == device | |