|
|
import functools |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
|
|
|
|
|
|
from enum import Enum |
|
|
|
|
|
|
|
|
TORCH_HALF_MIN = torch.finfo(torch.float16).min |
|
|
TORCH_HALF_MAX = torch.finfo(torch.float16).max |
|
|
|
|
|
class DQuantType(Enum): |
|
|
""" |
|
|
Different quantization methods for auto_quantize API are identified here. |
|
|
auto_quantize API currently supports fp16 and bfp16 methods. |
|
|
""" |
|
|
FP16 = "fp16", |
|
|
BFP16 = "bfp16" |
|
|
|
|
|
def __str__(self) -> str: |
|
|
return self.value |
|
|
|
|
|
|
|
|
def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor: |
|
|
return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half() |
|
|
|
|
|
def _quantize_tensor(tensor, qtype): |
|
|
if not isinstance(tensor, torch.Tensor): |
|
|
raise RuntimeError( |
|
|
f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}" |
|
|
) |
|
|
if (qtype == DQuantType.FP16): |
|
|
return _fp32_to_fp16_with_clamp(tensor) |
|
|
elif (qtype == DQuantType.BFP16): |
|
|
return torch.ops.quantization._FloatToBfloat16Quantized(tensor) |
|
|
else: |
|
|
raise RuntimeError( |
|
|
f'Quantization type {qtype} is not supported' |
|
|
) |
|
|
|
|
|
def _quantize_tensor_list(tensor_list, qtype): |
|
|
if not isinstance(tensor_list, list) or not all( |
|
|
isinstance(p, torch.Tensor) for p in tensor_list |
|
|
): |
|
|
raise RuntimeError( |
|
|
f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}" |
|
|
) |
|
|
quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list] |
|
|
return quantized_tensor_list |
|
|
|
|
|
def _dequantize_tensor(tensor, qtype, quant_loss=None): |
|
|
if not isinstance(tensor, torch.Tensor): |
|
|
raise RuntimeError( |
|
|
f"_dequantize_tensor expecting torch.Tensor as input but found {type(tensor)}" |
|
|
) |
|
|
if (qtype == DQuantType.FP16): |
|
|
if tensor.dtype != torch.float16: |
|
|
raise RuntimeError( |
|
|
f"tensor dtype is {tensor.dtype} while expected to be FP16." |
|
|
) |
|
|
elif tensor.dtype == torch.float16 and quant_loss is None: |
|
|
return tensor.float() |
|
|
else: |
|
|
return tensor.float() / quant_loss |
|
|
elif (qtype == DQuantType.BFP16): |
|
|
if tensor.dtype != torch.float16: |
|
|
raise RuntimeError( |
|
|
f"tensor dtype is {tensor.dtype} while expected to be FP16." |
|
|
) |
|
|
else: |
|
|
return torch.ops.quantization._Bfloat16QuantizedToFloat(tensor) |
|
|
else: |
|
|
raise RuntimeError( |
|
|
f'Quantization type {qtype} is not supported' |
|
|
) |
|
|
|
|
|
|
|
|
def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None): |
|
|
if not isinstance(tensor_list, list) or not all( |
|
|
isinstance(p, torch.Tensor) for p in tensor_list |
|
|
): |
|
|
raise RuntimeError( |
|
|
f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}" |
|
|
) |
|
|
dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list] |
|
|
return dequantized_tensor_list |
|
|
|
|
|
|
|
|
def auto_quantize(func, qtype, quant_loss=None): |
|
|
""" |
|
|
This is a prototype API that automatically quantize the input tensors, choose the precision types, and |
|
|
pass other necessary arguments and then dequantizes the output. |
|
|
Currently it only supports: |
|
|
. FP16 and BFP16 quantization method supported for gloo and nccl backends |
|
|
. all_gather, all_to_all collective ops |
|
|
Note: BFP16 only supports 2D tensors. |
|
|
Args: |
|
|
func (Callable): A function representing collective operations. |
|
|
qtype (QuantType): Quantization method |
|
|
quant_loss (float, optional): This can be used to improve accuracy in the dequantization. |
|
|
Returns: |
|
|
(Callable): the same collective as func but enables automatic quantization/dequantization. |
|
|
""" |
|
|
@functools.wraps(func) |
|
|
def wrapper(*args, **kwargs): |
|
|
group = kwargs.get('group', None) |
|
|
async_op = kwargs.get('async_op', False) |
|
|
if (async_op is True): |
|
|
raise RuntimeError( |
|
|
'The async_op=True mode is not supported yet.' |
|
|
) |
|
|
if (func == dist.all_gather): |
|
|
tensors = args[0] |
|
|
input_tensors = _quantize_tensor(args[1], qtype) |
|
|
out_tensors = _quantize_tensor_list(tensors, qtype) |
|
|
dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op) |
|
|
for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)): |
|
|
tensors[i] = t |
|
|
|
|
|
elif (func == dist.all_to_all): |
|
|
tensors = args[0] |
|
|
input_tensors = _quantize_tensor_list(args[1], qtype) |
|
|
out_tensors = _quantize_tensor_list(tensors, qtype) |
|
|
dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op) |
|
|
for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)): |
|
|
tensors[i] = t |
|
|
|
|
|
elif (func == dist.all_to_all_single): |
|
|
tensors = args[0] |
|
|
out_splits = kwargs.get('out_splits', None) |
|
|
in_splits = kwargs.get('in_splits', None) |
|
|
|
|
|
input_tensors = _quantize_tensor(args[1], qtype) |
|
|
out_tensors = _quantize_tensor(tensors, qtype) |
|
|
dist.all_to_all_single(out_tensors, input_tensors, out_splits, in_splits, group=group) |
|
|
for i, t in enumerate(_dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss)): |
|
|
tensors[i] = t |
|
|
else: |
|
|
raise RuntimeError( |
|
|
f"The collective op {func} is not supported yet" |
|
|
) |
|
|
|
|
|
return wrapper |
|
|
|