| from torch import nn | |
| class QuantStub(nn.Module): | |
| r"""Quantize stub module, before calibration, this is same as an observer, | |
| it will be swapped as `nnq.Quantize` in `convert`. | |
| Args: | |
| qconfig: quantization configuration for the tensor, | |
| if qconfig is not provided, we will get qconfig from parent modules | |
| """ | |
| def __init__(self, qconfig=None): | |
| super(QuantStub, self).__init__() | |
| if qconfig: | |
| self.qconfig = qconfig | |
| def forward(self, x): | |
| return x | |
| class DeQuantStub(nn.Module): | |
| r"""Dequantize stub module, before calibration, this is same as identity, | |
| this will be swapped as `nnq.DeQuantize` in `convert`. | |
| Args: | |
| qconfig: quantization configuration for the tensor, | |
| if qconfig is not provided, we will get qconfig from parent modules | |
| """ | |
| def __init__(self, qconfig=None): | |
| super(DeQuantStub, self).__init__() | |
| if qconfig: | |
| self.qconfig = qconfig | |
| def forward(self, x): | |
| return x | |
| class QuantWrapper(nn.Module): | |
| r"""A wrapper class that wraps the input module, adds QuantStub and | |
| DeQuantStub and surround the call to module with call to quant and dequant | |
| modules. | |
| This is used by the `quantization` utility functions to add the quant and | |
| dequant modules, before `convert` function `QuantStub` will just be observer, | |
| it observes the input tensor, after `convert`, `QuantStub` | |
| will be swapped to `nnq.Quantize` which does actual quantization. Similarly | |
| for `DeQuantStub`. | |
| """ | |
| quant: QuantStub | |
| dequant: DeQuantStub | |
| module: nn.Module | |
| def __init__(self, module): | |
| super(QuantWrapper, self).__init__() | |
| qconfig = module.qconfig if hasattr(module, 'qconfig') else None | |
| self.add_module('quant', QuantStub(qconfig)) | |
| self.add_module('dequant', DeQuantStub(qconfig)) | |
| self.add_module('module', module) | |
| self.train(module.training) | |
| def forward(self, X): | |
| X = self.quant(X) | |
| X = self.module(X) | |
| return self.dequant(X) | |