| import math |
| import pytest |
| import torch |
|
|
| import sage_attention as sa |
|
|
|
|
| cuda_available = torch.cuda.is_available() |
|
|
|
|
| @pytest.mark.skipif(not cuda_available, reason="CUDA is required") |
| @pytest.mark.parametrize("tensor_layout", ["HND", "NHD"]) |
| def test_per_block_int8_shapes_and_types(tensor_layout): |
| device = "cuda" |
| dtype = torch.float16 |
|
|
| if tensor_layout == "HND": |
| q = torch.randn(2, 4, 129, 128, dtype=dtype, device=device) |
| k = torch.randn(2, 4, 257, 128, dtype=dtype, device=device) |
| expected_q_scale_shape = (2, 4, math.ceil(129 / 128)) |
| expected_k_scale_shape = (2, 4, math.ceil(257 / 64)) |
| else: |
| q = torch.randn(2, 129, 4, 128, dtype=dtype, device=device) |
| k = torch.randn(2, 257, 4, 128, dtype=dtype, device=device) |
| expected_q_scale_shape = (2, 4, math.ceil(129 / 128)) |
| expected_k_scale_shape = (2, 4, math.ceil(257 / 64)) |
|
|
| km = ( |
| torch.randn(2, 4, 128, dtype=dtype, device=device) |
| if tensor_layout == "HND" |
| else torch.randn(2, 4, 128, dtype=dtype, device=device) |
| ) |
|
|
| q_int8, q_scale, k_int8, k_scale = sa.per_block_int8( |
| q, k, km, tensor_layout=tensor_layout |
| ) |
|
|
| assert q_int8.shape == q.shape and q_int8.dtype == torch.int8 |
| assert k_int8.shape == k.shape and k_int8.dtype == torch.int8 |
| assert q_scale.shape == expected_q_scale_shape and q_scale.dtype == torch.float32 |
| assert k_scale.shape == expected_k_scale_shape and k_scale.dtype == torch.float32 |
| assert q_int8.device == q.device == k.device == q_scale.device == k_scale.device |
| assert torch.isfinite(q_scale).all() |
| assert torch.isfinite(k_scale).all() |
|
|
|
|
| @pytest.mark.skipif(not cuda_available, reason="CUDA is required") |
| @pytest.mark.parametrize("tensor_layout", ["HND", "NHD"]) |
| @pytest.mark.parametrize("head_dim", [64, 128]) |
| def test_per_warp_int8_shapes_and_types(tensor_layout, head_dim): |
| device = "cuda" |
| dtype = torch.float16 |
|
|
| if tensor_layout == "HND": |
| q = torch.randn(1, 2, 130, head_dim, dtype=dtype, device=device) |
| k = torch.randn(1, 2, 70, head_dim, dtype=dtype, device=device) |
| expected_q_scale_shape = ( |
| 1, |
| 2, |
| math.ceil(130 / 128) * (128 // (16 if head_dim == 128 else 32)), |
| ) |
| expected_k_scale_shape = (1, 2, math.ceil(70 / 64)) |
| else: |
| q = torch.randn(1, 130, 2, head_dim, dtype=dtype, device=device) |
| k = torch.randn(1, 70, 2, head_dim, dtype=dtype, device=device) |
| expected_q_scale_shape = ( |
| 1, |
| 2, |
| math.ceil(130 / 128) * (128 // (16 if head_dim == 128 else 32)), |
| ) |
| expected_k_scale_shape = (1, 2, math.ceil(70 / 64)) |
|
|
| q_int8, q_scale, k_int8, k_scale = sa.per_warp_int8( |
| q, |
| k, |
| tensor_layout=tensor_layout, |
| BLKQ=128, |
| WARPQ=(16 if head_dim == 128 else 32), |
| BLKK=64, |
| ) |
|
|
| assert q_int8.shape == q.shape and q_int8.dtype == torch.int8 |
| assert k_int8.shape == k.shape and k_int8.dtype == torch.int8 |
| assert q_scale.shape == expected_q_scale_shape and q_scale.dtype == torch.float32 |
| assert k_scale.shape == expected_k_scale_shape and k_scale.dtype == torch.float32 |
| assert torch.isfinite(q_scale).all() |
| assert torch.isfinite(k_scale).all() |
|
|
|
|
| @pytest.mark.skipif(not cuda_available, reason="CUDA is required") |
| @pytest.mark.parametrize("tensor_layout", ["HND", "NHD"]) |
| def test_sub_mean_properties(tensor_layout): |
| device = "cuda" |
| dtype = torch.float16 |
|
|
| if tensor_layout == "HND": |
| v = torch.randn(2, 3, 65, 128, dtype=dtype, device=device) |
| seq_dim = 2 |
| nh_dim = 1 |
| else: |
| v = torch.randn(2, 65, 3, 128, dtype=dtype, device=device) |
| seq_dim = 1 |
| nh_dim = 2 |
|
|
| v_smoothed, vm = sa.sub_mean(v, tensor_layout=tensor_layout) |
|
|
| assert v_smoothed.shape == v.shape and v_smoothed.dtype == torch.float16 |
| assert vm.shape == (v.size(0), v.size(nh_dim), v.size(-1)) and vm.dtype == v.dtype |
| |
| mean_after = v_smoothed.mean(dim=seq_dim) |
| assert torch.isfinite(mean_after).all() |
| assert (mean_after.abs() < 1e-1).all() |
|
|
|
|
| @pytest.mark.skipif(not cuda_available, reason="CUDA is required") |
| @pytest.mark.parametrize("tensor_layout", ["HND", "NHD"]) |
| @pytest.mark.parametrize("smooth_v", [True, False]) |
| def test_per_channel_fp8_shapes_and_outputs(tensor_layout, smooth_v): |
| device = "cuda" |
| dtype = torch.float16 |
|
|
| if tensor_layout == "HND": |
| v = torch.randn(2, 3, 77, 128, dtype=dtype, device=device) |
| kv_len = v.size(2) |
| else: |
| v = torch.randn(2, 77, 3, 128, dtype=dtype, device=device) |
| kv_len = v.size(1) |
|
|
| v_fp8, v_scale, vm = sa.per_channel_fp8( |
| v, tensor_layout=tensor_layout, smooth_v=smooth_v |
| ) |
|
|
| assert v_fp8.dtype == torch.float8_e4m3fn |
| assert v_scale.shape == (2, 3, 128) |
| if smooth_v: |
| assert vm is not None and vm.shape == (2, 3, 128) and vm.dtype == torch.float32 |
| else: |
| assert vm is None |
|
|
| |
| padded_len = ((kv_len + 63) // 64) * 64 |
| if tensor_layout == "HND": |
| assert v_fp8.shape == (2, 3, 128, padded_len) |
| else: |
| assert v_fp8.shape == (2, 128, 3, padded_len) |
| assert torch.isfinite(v_scale).all() |
|
|