UMMJ's picture
Upload 5875 files
9dd3461
import torch
import typing
class ReferenceQuantizedModule(torch.nn.Module):
def _init_weight_qparams(self, weight_qparams, device):
if weight_qparams is None:
weight_qparams = {
"qscheme": torch.per_tensor_affine,
"dtype": torch.quint8,
"scale": 1.0,
"zero_point": 0
}
self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
self.weight_dtype = weight_qparams["dtype"]
assert self.weight_qscheme in [
None, torch.per_tensor_affine, torch.per_channel_affine,
torch.per_channel_affine_float_qparams], \
Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}")
if self.weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
zero_point_dtype = weight_qparams["zero_point"].dtype if \
isinstance(weight_qparams["zero_point"], torch.Tensor) else \
torch.int
w_scale = weight_qparams["scale"]
w_scale_tensor = w_scale.clone().detach() \
if isinstance(w_scale, torch.Tensor) \
else torch.tensor(w_scale, dtype=torch.float, device=device)
self.register_buffer("weight_scale", w_scale_tensor)
w_zp = weight_qparams["zero_point"]
w_zp_tensor = w_zp.clone().detach() \
if isinstance(w_zp, torch.Tensor) \
else torch.tensor(w_zp, dtype=zero_point_dtype, device=device)
self.register_buffer("weight_zero_point", w_zp_tensor)
if self.weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
w_axis = weight_qparams["axis"]
w_axis_tensor = w_axis.clone().detach() \
if isinstance(w_axis, torch.Tensor) \
else torch.tensor(w_axis, dtype=torch.int, device=device)
self.register_buffer("weight_axis", w_axis_tensor)
else:
# added for TorchScriptability, not used
self.register_buffer(
"weight_axis", torch.tensor(0, dtype=torch.int, device=device))
else:
# added for TorchScriptability, and for torch.float
self.register_buffer("weight_scale", torch.tensor(1.0, dtype=torch.float, device=device))
self.register_buffer("weight_zero_point", torch.tensor(0, dtype=torch.int, device=device))
self.register_buffer(
"weight_axis", torch.tensor(0, dtype=torch.int, device=device))
def get_weight(self):
"""
Fake quantize (quantize and dequantize) the weight with
the quantization parameters for weight, this is used to
simulate the numerics for the quantized weight in a quantized
model
"""
# suppress mypy warning
assert isinstance(self.weight_scale, torch.Tensor)
assert isinstance(self.weight_zero_point, torch.Tensor)
assert isinstance(self.weight_axis, torch.Tensor)
return _quantize_and_dequantize_weight(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
self.weight_dtype,
self.weight_scale,
self.weight_zero_point, self.weight_axis)
def get_quantized_weight(self):
# suppress mypy warning
assert isinstance(self.weight_scale, torch.Tensor)
assert isinstance(self.weight_zero_point, torch.Tensor)
assert isinstance(self.weight_axis, torch.Tensor)
return _quantize_weight(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
self.weight_dtype,
self.weight_scale,
self.weight_zero_point,
self.weight_axis)
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
_save_weight_qparams(
destination, prefix, self.weight_qscheme, self.weight_dtype,
self.weight_scale, self.weight_zero_point, self.weight_axis)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
for key in _get_weight_qparam_keys(state_dict, prefix):
setattr(self, key, state_dict[prefix + key])
state_dict.pop(prefix + key)
super()._load_from_state_dict(
state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
def _quantize_weight(
weight: torch.Tensor,
weight_qscheme: torch.qscheme,
weight_dtype: torch.dtype,
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
weight_axis: torch.Tensor):
if weight_dtype == torch.float16:
weight = weight.to(weight_dtype)
return weight
if weight_qscheme == torch.per_tensor_affine:
if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype)
return weight
elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
weight = torch.quantize_per_channel(
weight, weight_scale,
weight_zero_point, weight_axis.item(), weight_dtype) # type: ignore[arg-type]
return weight
raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
def _quantize_and_dequantize_weight(
weight: torch.Tensor,
weight_qscheme: torch.qscheme,
weight_dtype: torch.dtype,
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
weight_axis: torch.Tensor):
""" Quantize and then dequantize the weight based on
the quantization parameters
"""
if weight_qscheme in [
torch.per_tensor_affine,
torch.per_channel_affine,
torch.per_channel_affine_float_qparams]:
weight_quant = _quantize_weight(
weight, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis)
weight_dequant = weight_quant.dequantize()
else:
weight_dequant = weight
return weight_dequant
def _save_weight_qparams(destination, prefix, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis):
destination[prefix + "weight_qscheme"] = weight_qscheme
destination[prefix + "weight_dtype"] = weight_dtype
if weight_qscheme is not None:
destination[prefix + "weight_scale"] = weight_scale
destination[prefix + "weight_zero_point"] = weight_zero_point
if weight_qscheme == torch.per_channel_affine:
destination[prefix + "weight_axis"] = weight_axis
def _get_weight_qparam_keys(
state_dict: typing.Dict[str, typing.Any],
prefix: str):
keys = ["weight_qscheme", "weight_dtype"]
weight_qscheme = state_dict[prefix + "weight_qscheme"]
if weight_qscheme is not None:
keys.append("weight_scale")
keys.append("weight_zero_point")
if weight_qscheme == torch.quantize_per_channel:
keys.append("weight_axis")
return keys