|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.ao.nn.quantizable |
|
|
|
|
|
from torch.nn.modules.pooling import MaxPool2d |
|
|
|
|
|
from .activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid, Softmax, MultiheadAttention, PReLU |
|
|
from .dropout import Dropout |
|
|
from .batchnorm import BatchNorm2d, BatchNorm3d |
|
|
from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \ |
|
|
InstanceNorm2d, InstanceNorm3d |
|
|
from .conv import Conv1d, Conv2d, Conv3d |
|
|
from .conv import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d |
|
|
from .linear import Linear |
|
|
from .embedding_ops import Embedding, EmbeddingBag |
|
|
from .rnn import LSTM |
|
|
|
|
|
from .functional_modules import FloatFunctional, FXFloatFunctional, QFunctional |
|
|
|
|
|
|
|
|
class Quantize(torch.nn.Module): |
|
|
r"""Quantizes an incoming tensor |
|
|
|
|
|
Args: |
|
|
`scale`: scale of the output Quantized Tensor |
|
|
`zero_point`: zero_point of output Quantized Tensor |
|
|
`dtype`: data type of output Quantized Tensor |
|
|
`factory_kwargs`: Dictionary of kwargs used for configuring initialization |
|
|
of internal buffers. Currently, `device` and `dtype` are supported. |
|
|
Example: `factory_kwargs={'device': 'cuda', 'dtype': torch.float64}` |
|
|
will initialize internal buffers as type `torch.float64` on the current CUDA device. |
|
|
Note that `dtype` only applies to floating-point buffers. |
|
|
|
|
|
Examples:: |
|
|
>>> t = torch.tensor([[1., -1.], [1., -1.]]) |
|
|
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8 |
|
|
>>> qm = Quantize(scale, zero_point, dtype) |
|
|
>>> # xdoctest: +SKIP |
|
|
>>> qt = qm(t) |
|
|
>>> print(qt) |
|
|
tensor([[ 1., -1.], |
|
|
[ 1., -1.]], size=(2, 2), dtype=torch.qint8, scale=1.0, zero_point=2) |
|
|
""" |
|
|
|
|
|
scale: torch.Tensor |
|
|
zero_point: torch.Tensor |
|
|
|
|
|
def __init__(self, scale, zero_point, dtype, factory_kwargs=None): |
|
|
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) |
|
|
super(Quantize, self).__init__() |
|
|
self.register_buffer('scale', torch.tensor([scale], **factory_kwargs)) |
|
|
self.register_buffer('zero_point', |
|
|
torch.tensor([zero_point], dtype=torch.long, |
|
|
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'})) |
|
|
self.dtype = dtype |
|
|
|
|
|
def forward(self, X): |
|
|
return torch.quantize_per_tensor(X, float(self.scale), |
|
|
int(self.zero_point), self.dtype) |
|
|
|
|
|
@staticmethod |
|
|
def from_float(mod): |
|
|
assert hasattr(mod, 'activation_post_process') |
|
|
scale, zero_point = mod.activation_post_process.calculate_qparams() |
|
|
return Quantize(scale.float().item(), zero_point.long().item(), mod.activation_post_process.dtype) |
|
|
|
|
|
def extra_repr(self): |
|
|
return 'scale={}, zero_point={}, dtype={}'.format(self.scale, self.zero_point, self.dtype) |
|
|
|
|
|
|
|
|
class DeQuantize(torch.nn.Module): |
|
|
r"""Dequantizes an incoming tensor |
|
|
|
|
|
Examples:: |
|
|
>>> input = torch.tensor([[1., -1.], [1., -1.]]) |
|
|
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8 |
|
|
>>> qm = Quantize(scale, zero_point, dtype) |
|
|
>>> # xdoctest: +SKIP |
|
|
>>> quantized_input = qm(input) |
|
|
>>> dqm = DeQuantize() |
|
|
>>> dequantized = dqm(quantized_input) |
|
|
>>> print(dequantized) |
|
|
tensor([[ 1., -1.], |
|
|
[ 1., -1.]], dtype=torch.float32) |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
super(DeQuantize, self).__init__() |
|
|
|
|
|
def forward(self, Xq): |
|
|
return Xq.dequantize() |
|
|
|
|
|
@staticmethod |
|
|
def from_float(mod): |
|
|
return DeQuantize() |
|
|
|
|
|
__all__ = [ |
|
|
'BatchNorm2d', |
|
|
'BatchNorm3d', |
|
|
'Conv1d', |
|
|
'Conv2d', |
|
|
'Conv3d', |
|
|
'ConvTranspose1d', |
|
|
'ConvTranspose2d', |
|
|
'ConvTranspose3d', |
|
|
'DeQuantize', |
|
|
'ELU', |
|
|
'Embedding', |
|
|
'EmbeddingBag', |
|
|
'GroupNorm', |
|
|
'Hardswish', |
|
|
'InstanceNorm1d', |
|
|
'InstanceNorm2d', |
|
|
'InstanceNorm3d', |
|
|
'LayerNorm', |
|
|
'LeakyReLU', |
|
|
'Linear', |
|
|
'LSTM', |
|
|
'MaxPool2d', |
|
|
'MultiheadAttention', |
|
|
'Quantize', |
|
|
'ReLU6', |
|
|
'Sigmoid', |
|
|
'Softmax', |
|
|
'Dropout', |
|
|
'PReLU', |
|
|
|
|
|
'FloatFunctional', |
|
|
'FXFloatFunctional', |
|
|
'QFunctional', |
|
|
] |
|
|
|