| from typing import Optional, Tuple |
|
|
| import torch |
|
|
| try: |
| from ._ops import ops |
| except ImportError as e: |
| |
| try: |
| import _quantization |
|
|
| ops = torch.ops._quantization |
| except ImportError: |
| raise e |
|
|
|
|
| |
| def scaled_fp8_quant( |
| input: torch.Tensor, |
| scale: Optional[torch.Tensor] = None, |
| num_token_padding: Optional[int] = None, |
| scale_ub: Optional[torch.Tensor] = None, |
| use_per_token_if_dynamic: bool = False, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Quantize input tensor to FP8 and return quantized tensor and scale. |
| |
| This function supports both static and dynamic quantization: If you |
| provide the scale, it will use static scaling and if you omit it, |
| the scale will be determined dynamically. The function also allows |
| optional padding of the output tensors for downstream kernels that |
| will benefit from padding. |
| |
| Args: |
| input: The input tensor to be quantized to FP8 |
| scale: Optional scaling factor for the FP8 quantization |
| scale_ub: Optional upper bound for scaling factor in dynamic |
| per token case |
| num_token_padding: If specified, pad the first dimension |
| of the output to at least this value. |
| use_per_token_if_dynamic: Whether to do per_tensor or per_token |
| in the dynamic quantization case. |
| |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and |
| scaling factor. |
| """ |
| |
| assert input.ndim == 2 |
| shape: Union[Tuple[int, int], torch.Size] = input.shape |
| |
| |
| |
| out_dtype = torch.float8_e4m3fn |
| if num_token_padding: |
| shape = (max(num_token_padding, input.shape[0]), shape[1]) |
| output = torch.empty(shape, device=input.device, dtype=out_dtype) |
|
|
| if scale is None: |
| if use_per_token_if_dynamic: |
| scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) |
| ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) |
| else: |
| scale = torch.zeros(1, device=input.device, dtype=torch.float32) |
| ops.dynamic_scaled_fp8_quant(output, input, scale) |
| else: |
| |
| assert scale.numel() == 1 or num_token_padding is None |
| ops.static_scaled_fp8_quant(output, input, scale) |
|
|
| return output, scale |
|
|
|
|
| |
| def scaled_int8_quant( |
| input: torch.Tensor, |
| scale: Optional[torch.Tensor] = None, |
| azp: Optional[torch.Tensor] = None, |
| symmetric: bool = True, |
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| """ |
| Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. |
| |
| Args: |
| input: The input tensor to be quantized to int8. |
| scale: Optional scaling factor for the int8 quantization. |
| When not provided, we invoke dynamic-per-token quantization. |
| azp: Optional zero-point for the int8 quantization. |
| Must be provided for asymmetric quantization if `scale` is provided. |
| symmetric: Whether to use symmetric quantization (scale only, azp ignored). |
| |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. |
| """ |
| output = torch.empty_like(input, dtype=torch.int8) |
| if scale is not None: |
| |
| assert symmetric == ( |
| azp is None |
| ), "azp must only be provided for asymmetric quantization." |
| ops.static_scaled_int8_quant(output, input, scale, azp) |
| return output, scale, azp |
|
|
| |
| input_scales = torch.empty( |
| (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 |
| ) |
| input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) |
| ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) |
| return output, input_scales, input_azp |
|
|