Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import random | |
| from typing import Tuple | |
| import pytest | |
| import torch | |
| from xformers import _is_triton_available | |
| from xformers.ops import ( | |
| sequence_parallel_leading_matmul, | |
| sequence_parallel_trailing_matmul, | |
| ) | |
| from .multiprocessing_utils import launch_subprocesses | |
| compute_capability = (0, 0) | |
| if torch.cuda.is_available(): | |
| compute_capability = torch.cuda.get_device_capability("cuda") | |
| cuda_sm70_only = pytest.mark.skipif( | |
| compute_capability < (7, 0), reason="requires sm70+" | |
| ) | |
| at_least_2_gpus = pytest.mark.skipif( | |
| torch.cuda.device_count() < 2, reason="needs at least 2 GPUs" | |
| ) | |
| # We care about correctness, not performance, hence let's "disable" the | |
| # expensive autotuning by removing all configs except one (the first one). | |
| if _is_triton_available(): | |
| from xformers.ops._triton.sequence_parallel_fused_kernels import ( | |
| _xformers_seqpar_matmul_kernel, | |
| ) | |
| while len(_xformers_seqpar_matmul_kernel.configs) > 1: | |
| _xformers_seqpar_matmul_kernel.configs.pop() | |
| from xformers.ops._triton.tiled_matmul_kernels import _xformers_tiled_matmul_kernel | |
| while len(_xformers_tiled_matmul_kernel.configs) > 1: | |
| _xformers_tiled_matmul_kernel.configs.pop() | |
| def reference_leading(input_, w1, w2): | |
| hidden1 = torch.matmul(input_, w1.t()) | |
| hidden2 = torch.matmul(input_, w2.t()) | |
| return [hidden1, hidden2] | |
| def reference_trailing(hidden, w): | |
| output = torch.matmul(hidden, w.t()) | |
| return output | |
| def xformers_leading(input_, w1, w2, *, fuse, group): | |
| return sequence_parallel_leading_matmul( | |
| input_, [w1.t(), w2.t()], fuse=fuse, process_group=group | |
| ) | |
| def xformers_trailing(hidden, w, *, fuse, group): | |
| return sequence_parallel_trailing_matmul( | |
| hidden, w.t(), fuse=fuse, process_group=group | |
| ) | |
| def inner_seqpar( | |
| kind: str, | |
| step: str, | |
| dims: Tuple[int, ...], | |
| dtype: torch.dtype, | |
| seed: int, | |
| ): | |
| my_rank = torch.distributed.get_rank() | |
| world_size = torch.distributed.get_world_size() | |
| subgroup = torch.distributed.new_group() | |
| fused = True | |
| if kind == "unfused": | |
| fused = False | |
| elif kind == "fallback": | |
| os.environ["DISABLE_FUSED_SEQUENCE_PARALLEL"] = "1" | |
| torch.random.manual_seed(seed) | |
| batch_dims = dims[:-2] | |
| outer_dim = dims[-2] | |
| inner_dim = dims[-1] | |
| # To check for correctness we want to compare the outputs but the accuracy | |
| # of matmuls, apparently, is not that great. We thus try to produce inputs | |
| # for which no rounding at all will occur. We do this by using zero or one | |
| # inputs, so their product will also be zero or one, and keep the reduction | |
| # dimension small enough so that they fit in the mantissa without overflow. | |
| max_exact_value = 2 * (1 / torch.finfo(dtype).eps) | |
| # 0.25 is the ratio of expected ones and we aim at 2/3 of the safe range | |
| assert outer_dim * 0.25 <= max_exact_value * 0.66 | |
| assert inner_dim * world_size * 0.25 <= max_exact_value * 0.66 | |
| def my_chunk(t, *, dim): | |
| return t.tensor_split(world_size, dim=dim)[my_rank] | |
| if step == "leading": | |
| input_ = torch.testing.make_tensor( | |
| batch_dims + (outer_dim,), | |
| dtype=dtype, | |
| device="cuda", | |
| low=0, | |
| high=1, | |
| ).round() | |
| weight1, weight2 = [ | |
| torch.testing.make_tensor( | |
| (inner_dim * (idx + 1), outer_dim), | |
| dtype=dtype, | |
| device="cuda", | |
| low=0, | |
| high=1, | |
| ).round() | |
| for idx in range(2) | |
| ] | |
| gradient1, gradient2 = [ | |
| torch.testing.make_tensor( | |
| batch_dims + (inner_dim * (idx + 1),), | |
| dtype=dtype, | |
| device="cuda", | |
| low=0, | |
| high=1, | |
| ).round() | |
| for idx in range(2) | |
| ] | |
| # Non-fused reference code | |
| input_ref = input_.detach().requires_grad_() | |
| weight1_ref = weight1.detach().requires_grad_() | |
| weight2_ref = weight2.detach().requires_grad_() | |
| output1_ref, output2_ref = reference_leading( | |
| input_ref, weight1_ref, weight2_ref | |
| ) | |
| torch.autograd.backward([output1_ref, output2_ref], [gradient1, gradient2]) | |
| my_output1_ref = my_chunk(output1_ref, dim=-1) | |
| my_output2_ref = my_chunk(output2_ref, dim=-1) | |
| my_weight1_grad_ref = my_chunk(weight1_ref.grad, dim=0) | |
| my_weight2_grad_ref = my_chunk(weight2_ref.grad, dim=0) | |
| my_input_grad_ref = my_chunk(input_ref.grad, dim=0) | |
| # Faster fused mode | |
| my_input_xf = my_chunk(input_, dim=0).detach().requires_grad_() | |
| my_weight1_xf = my_chunk(weight1, dim=0).detach().requires_grad_() | |
| my_weight2_xf = my_chunk(weight2, dim=0).detach().requires_grad_() | |
| my_gradient1 = my_chunk(gradient1, dim=-1) | |
| my_gradient2 = my_chunk(gradient2, dim=-1) | |
| my_output1_xf, my_output2_xf = xformers_leading( | |
| my_input_xf, my_weight1_xf, my_weight2_xf, fuse=fused, group=subgroup | |
| ) | |
| torch.autograd.backward( | |
| [my_output1_xf, my_output2_xf], [my_gradient1, my_gradient2] | |
| ) | |
| my_weight1_grad_xf = my_weight1_xf.grad | |
| my_weight2_grad_xf = my_weight2_xf.grad | |
| my_input_grad_xf = my_input_xf.grad | |
| # Checks | |
| torch.testing.assert_close(my_output1_ref, my_output1_xf) | |
| torch.testing.assert_close(my_output2_ref, my_output2_xf) | |
| torch.testing.assert_close(my_input_grad_ref, my_input_grad_xf) | |
| torch.testing.assert_close(my_weight1_grad_ref, my_weight1_grad_xf) | |
| torch.testing.assert_close(my_weight2_grad_ref, my_weight2_grad_xf) | |
| elif step == "trailing": | |
| input_ = torch.testing.make_tensor( | |
| batch_dims + (inner_dim,), | |
| dtype=dtype, | |
| device="cuda", | |
| low=0, | |
| high=1, | |
| ).round() | |
| weight = torch.testing.make_tensor( | |
| (outer_dim, inner_dim), | |
| dtype=dtype, | |
| device="cuda", | |
| low=0, | |
| high=1, | |
| ).round() | |
| gradient = torch.testing.make_tensor( | |
| batch_dims + (outer_dim,), | |
| dtype=dtype, | |
| device="cuda", | |
| low=0, | |
| high=1, | |
| ).round() | |
| # Non-fused reference code | |
| input_ref = input_.detach().requires_grad_() | |
| weight_ref = weight.detach().requires_grad_() | |
| output_ref = reference_trailing(input_ref, weight_ref) | |
| torch.autograd.backward([output_ref], [gradient]) | |
| my_output_ref = my_chunk(output_ref, dim=0) | |
| my_weight_grad_ref = my_chunk(weight_ref.grad, dim=1) | |
| my_input_grad_ref = my_chunk(input_ref.grad, dim=-1) | |
| # Faster fused mode | |
| my_input_xf = my_chunk(input_, dim=-1).detach().clone().requires_grad_() | |
| my_weight_xf = my_chunk(weight, dim=1).detach().requires_grad_() | |
| my_gradient = my_chunk(gradient, dim=0) | |
| my_output_xf = xformers_trailing( | |
| my_input_xf, my_weight_xf, fuse=fused, group=subgroup | |
| ) | |
| torch.autograd.backward([my_output_xf], [my_gradient]) | |
| my_weight_grad_xf = my_weight_xf.grad | |
| my_input_grad_xf = my_input_xf.grad | |
| # Checks | |
| torch.testing.assert_close(my_output_ref, my_output_xf) | |
| torch.testing.assert_close(my_input_grad_ref, my_input_grad_xf) | |
| torch.testing.assert_close(my_weight_grad_ref, my_weight_grad_xf) | |
| def test_seqpar( | |
| kind: str, | |
| step: str, | |
| dims: Tuple[int, ...], | |
| dtype: torch.dtype, | |
| ): | |
| world_size = 1 if kind == "singleton" else 2 | |
| seed = random.getrandbits(32) | |
| launch_subprocesses( | |
| world_size=world_size, | |
| fn=inner_seqpar, | |
| kind=kind, | |
| step=step, | |
| dims=dims, | |
| dtype=dtype, | |
| seed=seed, | |
| ) | |