Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import pytest | |
| import torch | |
| # needed to register custom ops | |
| import xformers # noqa: F401 | |
| from xformers.ops import masked_matmul | |
| from xformers.sparse import BlockSparseTensor, SparseCSRTensor | |
| from .utils import disable_tf32 | |
| cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") | |
| _devices = ( | |
| ["cpu", "cuda:0"] if torch.cuda.is_available() and torch.version.cuda else ["cpu"] | |
| ) | |
| _tensor_types = [BlockSparseTensor, SparseCSRTensor] | |
| def _create_blocksparse_tensor( | |
| device, block_size=32, Z=8, C=2, H=64, W=64, dtype=torch.float32 | |
| ): | |
| layout = torch.randint(2, (C, H // block_size, W // block_size), device=device) | |
| layout[:, :, 0] = 1 | |
| layout[:, 0, :] = 1 | |
| values = torch.randn(Z, layout.sum(), block_size, block_size, device=device).to( | |
| dtype | |
| ) | |
| return BlockSparseTensor(values, layout) | |
| def _create_csr_tensor(device, dtype, shape, sparsity, divisible_by=4): | |
| matrix = torch.rand(shape, dtype=torch.float32, device=device).to(dtype) | |
| assert matrix.ndim == 3 | |
| keep = torch.rand_like(matrix[0], dtype=torch.float32) > sparsity | |
| nonzero = torch.nonzero(keep) | |
| nnz = nonzero.shape[0] | |
| # NOTE: need to make it a multiple of 4 for sputnik | |
| nonzero = nonzero[: (nnz - nnz % divisible_by)] | |
| i, j = nonzero.unbind(1) | |
| output = torch.zeros_like(matrix) | |
| bdim = torch.arange(matrix.shape[0], device=matrix.device)[:, None] | |
| output[bdim, i, j] = matrix[bdim, i, j] | |
| return SparseCSRTensor.from_dense(output) | |
| def _create_tensor(tensor_type, device, dtype, shape, sparsity): | |
| if tensor_type == BlockSparseTensor: | |
| block_size = 16 | |
| return _create_blocksparse_tensor( | |
| device=device, dtype=dtype, block_size=block_size | |
| ) | |
| elif tensor_type == SparseCSRTensor: | |
| return _create_csr_tensor( | |
| device=device, dtype=dtype, shape=shape, sparsity=sparsity | |
| ) | |
| def _seed(): | |
| torch.random.manual_seed(42) | |
| torch.cuda.manual_seed_all(42) | |
| def _get_dtype_atol(tensor_type, device: str): | |
| _seed() | |
| if tensor_type == BlockSparseTensor and "cuda" in device: | |
| # Upstream GPU blocksparse (Triton op) uses TF32 by default for all internal computations | |
| # TF32 has the precision of fp16 but the range of fp32 | |
| # See https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/ | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True # type: ignore | |
| return torch.float32, 1e-1 | |
| # Force pytorch to keep its computations as float32 (will default to tf32 with recent cuda and ampere+ GPU) | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cudnn.allow_tf32 = False # type: ignore | |
| return torch.float32, 1e-5 | |
| def test_sparse_binary_ops(func, device): | |
| # TODO: add for BlockSparseTensor as well | |
| N, H, W = 8, 64, 64 | |
| sparsity = 0.5 | |
| shape = (N, H, W) | |
| a_sparse = _create_tensor( | |
| SparseCSRTensor, device, dtype=torch.float32, shape=shape, sparsity=sparsity | |
| ) | |
| a = a_sparse.to_dense() | |
| b = a | |
| b_sparse = a_sparse | |
| res = func(a_sparse, b_sparse).to_dense() | |
| res_gt = func(a, b) | |
| assert torch.allclose(res, res_gt) | |
| def test_masked_matmul(tensor_type, device): | |
| N, C, H, W, L = 8, 2, 64, 64, 32 | |
| sparsity = 0.7 | |
| dtype, atol = _get_dtype_atol(tensor_type, device) | |
| shape0 = (N, C, H, W) | |
| shape1 = (N, C, H, L) | |
| shape2 = (N, C, W, L) | |
| if tensor_type != BlockSparseTensor: | |
| shape0 = shape0[1:] | |
| shape1 = shape1[1:] | |
| shape2 = shape2[1:] | |
| mask_sparse = _create_tensor( | |
| tensor_type, device, dtype=torch.bool, shape=shape0, sparsity=sparsity | |
| ) | |
| mask = mask_sparse.to_dense() | |
| a = torch.randn(shape1, device=device, dtype=dtype) | |
| b = torch.randn(shape2, device=device, dtype=dtype) | |
| aa = a.clone() | |
| bb = b.clone() | |
| a.requires_grad_(True) | |
| b.requires_grad_(True) | |
| aa.requires_grad_(True) | |
| bb.requires_grad_(True) | |
| bt = b.transpose(-2, -1) | |
| bbt = bb.transpose(-2, -1) | |
| res_gt = masked_matmul(a, bt, mask) | |
| res = masked_matmul(aa, bbt, mask_sparse) | |
| res_dense = res.to_dense() | |
| res_dense = torch.where(mask, res_dense, torch.full_like(res_dense, float("-inf"))) | |
| assert res.dtype == res_gt.dtype | |
| assert torch.allclose(res_dense, res_gt, atol=atol) | |
| # try to workaround non-contiguous issues with triton for now | |
| res_gt.backward(torch.ones_like(res_gt)) | |
| res.values().backward(torch.ones_like(res.values())) | |
| assert torch.allclose(a.grad, aa.grad, atol=atol) | |
| assert torch.allclose(b.grad, bb.grad, atol=atol) | |
| def test_bmm(tensor_type, device): | |
| N, C, H, W, L = 8, 2, 64, 64, 32 | |
| dtype, atol = _get_dtype_atol(tensor_type, device) | |
| sparsity = 0.8 | |
| shape0 = (N, C, H, W) | |
| shape1 = (N, C, W, L) | |
| if tensor_type != BlockSparseTensor: | |
| shape0 = shape0[1:] | |
| shape1 = shape1[1:] | |
| a_sparse = _create_tensor( | |
| tensor_type, device, dtype=dtype, shape=shape0, sparsity=sparsity | |
| ) | |
| a = a_sparse.to_dense() | |
| mask = a != 0 | |
| a_sparse.requires_grad_(True) | |
| a.requires_grad_(True) | |
| b = torch.randn(shape1, device=device, dtype=dtype) | |
| b2 = b.clone() | |
| b.requires_grad_(True) | |
| b2.requires_grad_(True) | |
| res_gt = a @ b | |
| res = a_sparse @ b2 | |
| assert res.dtype == res_gt.dtype | |
| assert torch.allclose( | |
| res, res_gt, atol=atol | |
| ), f"{torch.max(torch.abs(res-res_gt))} - tolerance: {atol}" | |
| res_gt.sum().backward() | |
| res.sum().backward() | |
| a_grad = a.grad.clone().detach() | |
| a_grad[~mask] = 0 | |
| assert torch.allclose(b.grad, b2.grad, atol=atol) | |
| assert torch.allclose( | |
| a_grad, a_sparse.grad.to_dense(), atol=atol | |
| ), f"{torch.max(torch.abs(a_grad-a_sparse.grad.to_dense()))}" | |
| def test_sparse_softmax(tensor_type, device): | |
| N, C, H, W = 8, 2, 64, 64 | |
| dtype, atol = _get_dtype_atol(tensor_type, device) | |
| sparsity = 0.8 | |
| shape0 = (N, C, H, W) | |
| if tensor_type != BlockSparseTensor: | |
| shape0 = shape0[1:] | |
| a_sparse = _create_tensor( | |
| tensor_type, device, dtype=dtype, shape=shape0, sparsity=sparsity | |
| ) | |
| a = a_sparse.to_dense() | |
| mask = a != 0 | |
| a[~mask] = float("-inf") | |
| a_sparse.requires_grad_(True) | |
| a.requires_grad_(True) | |
| res_gt = torch.softmax(a, dim=-1) | |
| res_sparse = torch.softmax(a_sparse, dim=-1) | |
| res = res_sparse.to_dense() | |
| assert res.dtype == res_gt.dtype | |
| assert torch.allclose( | |
| res, res_gt, atol=atol | |
| ), f"{torch.max(torch.abs(res- res_gt))}" | |
| # WARNING: gradients are modified in-place! | |
| res_sparse.values().backward(torch.ones_like(res_sparse.values())) | |
| res_gt.backward(torch.ones_like(res_gt)) | |
| a_grad = a.grad.clone() | |
| a_grad[~mask] = 0 | |
| assert torch.allclose( | |
| a_grad, a_sparse.grad.to_dense(), atol=atol | |
| ), f"{torch.max(torch.abs(a_grad- a_sparse.grad.to_dense()))}" | |
| def test_deepcopy(tensor_type, device): | |
| import copy | |
| N, C, H, W = 8, 2, 64, 64 | |
| dtype = torch.float32 | |
| sparsity = 0.8 | |
| shape0 = (N, C, H, W) | |
| if tensor_type != BlockSparseTensor: | |
| shape0 = shape0[1:] | |
| a_sparse = _create_tensor( | |
| tensor_type, device, dtype=dtype, shape=shape0, sparsity=sparsity | |
| ) | |
| b_sparse = copy.deepcopy(a_sparse) | |
| assert torch.equal(a_sparse, b_sparse) | |
| def test_module_buffer(tensor_type, device): | |
| N, C, H, W = 8, 2, 64, 64 | |
| dtype = torch.float32 | |
| sparsity = 0.8 | |
| shape0 = (N, C, H, W) | |
| if tensor_type != BlockSparseTensor: | |
| shape0 = shape0[1:] | |
| a_sparse = _create_tensor( | |
| tensor_type, device, dtype=dtype, shape=shape0, sparsity=sparsity | |
| ) | |
| b_sparse = _create_tensor( | |
| tensor_type, device, dtype=dtype, shape=shape0, sparsity=sparsity | |
| ) | |
| module = torch.nn.Module() | |
| # test that register_buffer works | |
| module.register_buffer("a_sparse", a_sparse) | |
| assert module.a_sparse is a_sparse | |
| module.to(device) | |
| assert module.a_sparse.device == torch.device(device) | |
| state_dict = module.state_dict() | |
| assert "a_sparse" in state_dict | |
| assert torch.equal(a_sparse.to(device), state_dict["a_sparse"]) | |
| module.load_state_dict(state_dict) | |
| module.load_state_dict({"a_sparse": b_sparse}) | |
| assert torch.equal(module.a_sparse, b_sparse.to(device)) | |