| | import math |
| | import pytest |
| | import torch |
| |
|
| | import sage_attention as sa |
| |
|
| |
|
| | cuda_available = torch.cuda.is_available() |
| |
|
| |
|
| | @pytest.mark.skipif(not cuda_available, reason="CUDA is required") |
| | @pytest.mark.parametrize("tensor_layout", ["HND", "NHD"]) |
| | def test_per_block_int8_shapes_and_types(tensor_layout): |
| | device = "cuda" |
| | dtype = torch.float16 |
| |
|
| | if tensor_layout == "HND": |
| | q = torch.randn(2, 4, 129, 128, dtype=dtype, device=device) |
| | k = torch.randn(2, 4, 257, 128, dtype=dtype, device=device) |
| | expected_q_scale_shape = (2, 4, math.ceil(129 / 128)) |
| | expected_k_scale_shape = (2, 4, math.ceil(257 / 64)) |
| | else: |
| | q = torch.randn(2, 129, 4, 128, dtype=dtype, device=device) |
| | k = torch.randn(2, 257, 4, 128, dtype=dtype, device=device) |
| | expected_q_scale_shape = (2, 4, math.ceil(129 / 128)) |
| | expected_k_scale_shape = (2, 4, math.ceil(257 / 64)) |
| |
|
| | km = ( |
| | torch.randn(2, 4, 128, dtype=dtype, device=device) |
| | if tensor_layout == "HND" |
| | else torch.randn(2, 4, 128, dtype=dtype, device=device) |
| | ) |
| |
|
| | q_int8, q_scale, k_int8, k_scale = sa.per_block_int8( |
| | q, k, km, tensor_layout=tensor_layout |
| | ) |
| |
|
| | assert q_int8.shape == q.shape and q_int8.dtype == torch.int8 |
| | assert k_int8.shape == k.shape and k_int8.dtype == torch.int8 |
| | assert q_scale.shape == expected_q_scale_shape and q_scale.dtype == torch.float32 |
| | assert k_scale.shape == expected_k_scale_shape and k_scale.dtype == torch.float32 |
| | assert q_int8.device == q.device == k.device == q_scale.device == k_scale.device |
| | assert torch.isfinite(q_scale).all() |
| | assert torch.isfinite(k_scale).all() |
| |
|
| |
|
| | @pytest.mark.skipif(not cuda_available, reason="CUDA is required") |
| | @pytest.mark.parametrize("tensor_layout", ["HND", "NHD"]) |
| | @pytest.mark.parametrize("head_dim", [64, 128]) |
| | def test_per_warp_int8_shapes_and_types(tensor_layout, head_dim): |
| | device = "cuda" |
| | dtype = torch.float16 |
| |
|
| | if tensor_layout == "HND": |
| | q = torch.randn(1, 2, 130, head_dim, dtype=dtype, device=device) |
| | k = torch.randn(1, 2, 70, head_dim, dtype=dtype, device=device) |
| | expected_q_scale_shape = ( |
| | 1, |
| | 2, |
| | math.ceil(130 / 128) * (128 // (16 if head_dim == 128 else 32)), |
| | ) |
| | expected_k_scale_shape = (1, 2, math.ceil(70 / 64)) |
| | else: |
| | q = torch.randn(1, 130, 2, head_dim, dtype=dtype, device=device) |
| | k = torch.randn(1, 70, 2, head_dim, dtype=dtype, device=device) |
| | expected_q_scale_shape = ( |
| | 1, |
| | 2, |
| | math.ceil(130 / 128) * (128 // (16 if head_dim == 128 else 32)), |
| | ) |
| | expected_k_scale_shape = (1, 2, math.ceil(70 / 64)) |
| |
|
| | q_int8, q_scale, k_int8, k_scale = sa.per_warp_int8( |
| | q, |
| | k, |
| | tensor_layout=tensor_layout, |
| | BLKQ=128, |
| | WARPQ=(16 if head_dim == 128 else 32), |
| | BLKK=64, |
| | ) |
| |
|
| | assert q_int8.shape == q.shape and q_int8.dtype == torch.int8 |
| | assert k_int8.shape == k.shape and k_int8.dtype == torch.int8 |
| | assert q_scale.shape == expected_q_scale_shape and q_scale.dtype == torch.float32 |
| | assert k_scale.shape == expected_k_scale_shape and k_scale.dtype == torch.float32 |
| | assert torch.isfinite(q_scale).all() |
| | assert torch.isfinite(k_scale).all() |
| |
|
| |
|
| | @pytest.mark.skipif(not cuda_available, reason="CUDA is required") |
| | @pytest.mark.parametrize("tensor_layout", ["HND", "NHD"]) |
| | def test_sub_mean_properties(tensor_layout): |
| | device = "cuda" |
| | dtype = torch.float16 |
| |
|
| | if tensor_layout == "HND": |
| | v = torch.randn(2, 3, 65, 128, dtype=dtype, device=device) |
| | seq_dim = 2 |
| | nh_dim = 1 |
| | else: |
| | v = torch.randn(2, 65, 3, 128, dtype=dtype, device=device) |
| | seq_dim = 1 |
| | nh_dim = 2 |
| |
|
| | v_smoothed, vm = sa.sub_mean(v, tensor_layout=tensor_layout) |
| |
|
| | assert v_smoothed.shape == v.shape and v_smoothed.dtype == torch.float16 |
| | assert vm.shape == (v.size(0), v.size(nh_dim), v.size(-1)) and vm.dtype == v.dtype |
| | |
| | mean_after = v_smoothed.mean(dim=seq_dim) |
| | assert torch.isfinite(mean_after).all() |
| | assert (mean_after.abs() < 1e-1).all() |
| |
|
| |
|
| | @pytest.mark.skipif(not cuda_available, reason="CUDA is required") |
| | @pytest.mark.parametrize("tensor_layout", ["HND", "NHD"]) |
| | @pytest.mark.parametrize("smooth_v", [True, False]) |
| | def test_per_channel_fp8_shapes_and_outputs(tensor_layout, smooth_v): |
| | device = "cuda" |
| | dtype = torch.float16 |
| |
|
| | if tensor_layout == "HND": |
| | v = torch.randn(2, 3, 77, 128, dtype=dtype, device=device) |
| | kv_len = v.size(2) |
| | else: |
| | v = torch.randn(2, 77, 3, 128, dtype=dtype, device=device) |
| | kv_len = v.size(1) |
| |
|
| | v_fp8, v_scale, vm = sa.per_channel_fp8( |
| | v, tensor_layout=tensor_layout, smooth_v=smooth_v |
| | ) |
| |
|
| | assert v_fp8.dtype == torch.float8_e4m3fn |
| | assert v_scale.shape == (2, 3, 128) |
| | if smooth_v: |
| | assert vm is not None and vm.shape == (2, 3, 128) and vm.dtype == torch.float32 |
| | else: |
| | assert vm is None |
| |
|
| | |
| | padded_len = ((kv_len + 63) // 64) * 64 |
| | if tensor_layout == "HND": |
| | assert v_fp8.shape == (2, 3, 128, padded_len) |
| | else: |
| | assert v_fp8.shape == (2, 128, 3, padded_len) |
| | assert torch.isfinite(v_scale).all() |
| |
|