Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import random | |
| import pytest | |
| import torch | |
| import xformers.ops | |
| from xformers.ops.common import _get_storage_base | |
| def test_unbind(dim: int, contiguous: bool): | |
| x = torch.randn([10, 20, 4, 10, 3]) | |
| x2 = x.clone() | |
| if not contiguous: | |
| perm = list(range(x.ndim)) | |
| random.Random(dim).shuffle(perm) | |
| # Let's hope we didn't pick identity | |
| x = x.permute(perm) | |
| x2 = x2.permute(perm) | |
| assert contiguous == x.is_contiguous() | |
| x.requires_grad_(True) | |
| x2.requires_grad_(True) | |
| # FW | |
| tensors = xformers.ops.unbind(x, dim) | |
| tensors2 = torch.unbind(x2, dim) | |
| assert len(tensors) == len(tensors2) | |
| for t1, t2 in zip(tensors, tensors2): | |
| assert torch.allclose(t1, t2) | |
| # BW | |
| grads = torch.unbind(torch.randn(x.shape), dim) | |
| zero = torch.zeros_like(tensors[0]) | |
| loss1 = sum(((g * t) for (g, t) in zip(grads, tensors)), zero) | |
| loss2 = sum(((g * t) for (g, t) in zip(grads, tensors2)), zero) | |
| assert torch.allclose(loss1, loss2) | |
| g = torch.randn_like(loss1) | |
| loss1.backward(g) | |
| loss2.backward(g) | |
| assert x.grad is not None | |
| assert x2.grad is not None | |
| assert torch.allclose(x.grad, x2.grad) | |
| def test_unbind_get_stack_strides(dim: int, contiguous: bool): | |
| def not_stacked(t, d): | |
| return xformers.ops.get_stack_strides(t, d) is None | |
| x = torch.randn([10, 20, 4, 4, 3]) | |
| ndim = x.ndim | |
| # Non-contiguous tensors | |
| if not contiguous: | |
| x = x.transpose(dim, (dim + 1) % ndim) | |
| assert contiguous == x.is_contiguous() | |
| tensors = xformers.ops.unbind(x, dim) | |
| tensors2 = torch.unbind(x.clone(), dim) | |
| for cat_dim in range(ndim): | |
| permute = list(range(ndim)) | |
| permute.pop(dim) | |
| permute.insert(cat_dim, dim) | |
| x_permuted = x.permute(permute) | |
| assert not_stacked([tensors2[0], tensors[1]], cat_dim), "different storage" | |
| assert not_stacked( | |
| [tensors[0], tensors[1].clone()], cat_dim | |
| ), "different storage" | |
| def test_slice(s): | |
| slices = [slice(None) for _ in range(ndim)] | |
| slices[cat_dim] = s | |
| reference = x_permuted[tuple(slices)] | |
| stacked = xformers.ops.stack_or_none(tensors[s], cat_dim) | |
| assert stacked is not None | |
| assert ( | |
| xformers.ops.get_stack_strides(tensors[s], cat_dim) | |
| == reference.stride() | |
| ) | |
| assert torch.allclose(stacked, torch.stack(tensors2[s], cat_dim)) | |
| assert _get_storage_base(stacked) == _get_storage_base(tensors[0]) | |
| # tensors | |
| test_slice(slice(None)) | |
| # tensors[1:] | |
| test_slice(slice(1, None)) | |
| # tensors[:2] | |
| test_slice(slice(None, 2)) | |
| # tensors[::2] | |
| test_slice(slice(None, None, 2)) | |