| # Adapted from https://github.com/vllm-project/vllm/blob/8ca7a71df787ad711ad3ac70a5bd2eb2bb398938/tests/quantization/test_fp8.py | |
| import pytest | |
| import torch | |
| from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant | |
| from sglang.srt.utils import is_cuda, is_hip | |
| _is_cuda = is_cuda() | |
| _is_hip = is_hip() | |
| _is_fp8_fnuz = is_fp8_fnuz() | |
| fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn | |
| def test_scaled_fp8_quant_per_tensor(dtype) -> None: | |
| def quantize_ref_per_tensor(tensor, inv_scale): | |
| # The reference implementation that fully aligns to | |
| # the kernel being tested. | |
| finfo = torch.finfo(fp8_dtype) | |
| scale = inv_scale.reciprocal() | |
| qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max) | |
| qweight = qweight.to(fp8_dtype) | |
| return qweight | |
| def dequantize_per_tensor(tensor, inv_scale, dtype): | |
| fake_qweight = tensor.to(dtype) | |
| dq_weight = fake_qweight * inv_scale | |
| return dq_weight | |
| # Note that we use a shape % 8 != 0 to cover edge cases, | |
| # because scaled_fp8_quant is vectorized by 8. | |
| x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype) | |
| # Test Per Tensor Dynamic quantization | |
| # scale = max(abs(x)) / FP8_E4M3_MAX | |
| y, scale = scaled_fp8_quant(x, None) | |
| ref_y = quantize_ref_per_tensor(x, scale) | |
| torch.testing.assert_close(y, ref_y) | |
| torch.testing.assert_close( | |
| dequantize_per_tensor(y, scale, dtype), | |
| dequantize_per_tensor(ref_y, scale, dtype), | |
| ) | |
| # Test Per Tensor Static quantization | |
| y, _ = scaled_fp8_quant(x, scale) | |
| ref_y = quantize_ref_per_tensor(x, scale) | |
| torch.testing.assert_close(y, ref_y) | |
| torch.testing.assert_close( | |
| dequantize_per_tensor(y, scale, dtype), | |
| dequantize_per_tensor(ref_y, scale, dtype), | |
| ) | |
| if _is_cuda or _is_hip: | |
| def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None: | |
| def quantize_ref_per_token(tensor, inv_scale): | |
| # The reference implementation that fully aligns to | |
| # the kernel being tested. | |
| finfo = torch.finfo(fp8_dtype) | |
| scale = inv_scale.reciprocal() | |
| qweight = (tensor.to(torch.float32) * scale).clamp( | |
| min=finfo.min, max=finfo.max | |
| ) | |
| qweight = qweight.to(fp8_dtype) | |
| return qweight | |
| def dequantize_per_token(tensor, inv_scale, dtype): | |
| fake_qweight = tensor.to(dtype) | |
| dq_weight = fake_qweight * inv_scale | |
| return dq_weight | |
| # Note that we use a shape % 8 = 0, | |
| # because per_token_quant_fp8 is vectorized by 8 elements. | |
| x = (torch.randn(size=(11, 16), device="cuda") * 13).to(dtype) | |
| # Test Per Tensor Dynamic quantization | |
| # scale = max(abs(x)) / FP8_E4M3_MAX | |
| y, scale = scaled_fp8_quant(x, None, use_per_token_if_dynamic=True) | |
| ref_y = quantize_ref_per_token(x, scale) | |
| torch.testing.assert_close(y, ref_y) | |
| torch.testing.assert_close( | |
| dequantize_per_token(y, scale, dtype), | |
| dequantize_per_token(ref_y, scale, dtype), | |
| ) | |
| def test_scaled_fp8_quant_with_padding(dtype) -> None: | |
| original_rows = 5 | |
| x = (torch.randn(size=(original_rows, 16), device="cuda") * 13).to(dtype) | |
| padding_size = 10 | |
| # Test with dynamic quantization | |
| y_dynamic, scale_dynamic = scaled_fp8_quant( | |
| x, None, num_token_padding=padding_size | |
| ) | |
| # Verify output shape has the padded size | |
| assert y_dynamic.shape[0] == padding_size | |
| assert y_dynamic.shape[1] == x.shape[1] | |
| # Verify that the actual data in the non-padded region is correctly quantized | |
| y_without_padding, scale_without_padding = scaled_fp8_quant(x, None) | |
| torch.testing.assert_close(y_dynamic[:original_rows], y_without_padding) | |
| # Test with static quantization | |
| # First get a scale | |
| _, scale = scaled_fp8_quant(x, None) | |
| # Then use it for static quantization with padding | |
| y_static, _ = scaled_fp8_quant(x, scale, num_token_padding=padding_size) | |
| # Verify output shape has the padded size | |
| assert y_static.shape[0] == padding_size | |
| assert y_static.shape[1] == x.shape[1] | |
| # Verify that the actual data in the non-padded region is correctly quantized | |
| y_static_without_padding, _ = scaled_fp8_quant(x, scale) | |
| torch.testing.assert_close(y_static[:original_rows], y_static_without_padding) | |
| # Test with per-token dynamic quantization | |
| y_per_token, scale_per_token = scaled_fp8_quant( | |
| x, None, num_token_padding=padding_size, use_per_token_if_dynamic=True | |
| ) | |
| # Verify output shape has the padded size | |
| assert y_per_token.shape[0] == padding_size | |
| assert y_per_token.shape[1] == x.shape[1] | |
| # Verify that the actual data in the non-padded region is correctly quantized | |
| y_per_token_without_padding, scale_per_token_without_padding = scaled_fp8_quant( | |
| x, None, use_per_token_if_dynamic=True | |
| ) | |
| torch.testing.assert_close( | |
| y_per_token[:original_rows], y_per_token_without_padding | |
| ) | |
| torch.testing.assert_close( | |
| scale_per_token[:original_rows], scale_per_token_without_padding | |
| ) | |
| if __name__ == "__main__": | |
| # Run the specific test function directly | |
| pytest.main([__file__]) | |
Xet Storage Details
- Size:
- 5.71 kB
- Xet hash:
- 5b68040eef55c337275aaccda8f2115ad04883e37adb2e72491aecf367b5887e
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.