Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import sys | |
| import pytest | |
| import torch | |
| import xformers | |
| try: | |
| import triton | |
| import triton.language as tl | |
| from xformers.triton.vararg_kernel import unroll_varargs | |
| _triton_available = xformers._is_triton_available() | |
| except ImportError as e: | |
| logging.warning( | |
| f"Triton is not available, some optimizations will not be tested.\n{e}" | |
| ) | |
| _triton_available = False | |
| enable_tests = ( | |
| (sys.version_info.major, sys.version_info.minor) >= (3, 9) | |
| and _triton_available | |
| and torch.cuda.is_available() | |
| ) | |
| def test_triton_varargs_kernel(): | |
| def sumN(output_ptr, scaling_ptr, *inputs, BLOCK_SIZE: tl.constexpr): | |
| offset = tl.arange(0, BLOCK_SIZE) | |
| output = tl.zeros([BLOCK_SIZE], tl.float32) | |
| scaling: "VAR_ARGS_ARRAY" # type: ignore # noqa: F821 | |
| for i in range(len(scaling)): | |
| scaling[i] = tl.load(scaling_ptr + i) | |
| for i in range(2): | |
| for j in range(len(inputs)): | |
| output = output + tl.load(inputs[j] + offset) * scaling[j] | |
| tl.store(output_ptr + offset, output) | |
| BLOCK_SIZE = 32 | |
| NUM_INPUTS = 2 | |
| torch.manual_seed(0) | |
| inputs = [ | |
| torch.randn([BLOCK_SIZE], dtype=torch.float32, device="cuda") | |
| for _ in range(NUM_INPUTS) | |
| ] | |
| output = torch.randn([BLOCK_SIZE], dtype=torch.float32, device="cuda") | |
| scaling = torch.randn([NUM_INPUTS, 1], dtype=torch.float32, device="cuda") | |
| sumN_unrolled = unroll_varargs(sumN, N=NUM_INPUTS) | |
| sumN_unrolled[(1,)](output, scaling, *inputs, BLOCK_SIZE=32) | |
| assert torch.allclose((2 * torch.stack(inputs) * scaling).sum(0), output) | |
| def test_triton_multiple_varargs_kernel(): | |
| def weighted_sumN( | |
| output_ptr, | |
| a_ptr: "VAR_ARGS_ARRAY", # type: ignore # noqa: F821 | |
| b: "VAR_ARGS_ARRAY", # type: ignore # noqa: F821 | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| # Weighted sum, where the weights are on CPU | |
| offset = tl.arange(0, BLOCK_SIZE) | |
| output = tl.zeros([BLOCK_SIZE], tl.float32) | |
| for i in range(len(a_ptr)): | |
| output = output + tl.load(a_ptr[i] + offset) * b[i] | |
| tl.store(output_ptr + offset, output) | |
| BLOCK_SIZE = 32 | |
| NUM_INPUTS = 2 | |
| torch.manual_seed(0) | |
| a = [ | |
| torch.randn([BLOCK_SIZE], dtype=torch.float32, device="cuda") | |
| for _ in range(NUM_INPUTS) | |
| ] | |
| b = [torch.randn([], dtype=torch.float32, device="cuda") for _ in range(NUM_INPUTS)] | |
| b_list = [x.item() for x in b] | |
| output = torch.randn([BLOCK_SIZE], dtype=torch.float32, device="cuda") | |
| kernel = unroll_varargs(weighted_sumN, N=NUM_INPUTS) | |
| kernel[(1,)](output, *a, *b_list, BLOCK_SIZE=32) | |
| expected_output = (torch.stack(a) * torch.stack(b).unsqueeze(1)).sum(0) | |
| assert torch.allclose(expected_output, output) | |