|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional, Dict, Any |
|
|
from .utils import ReferenceQuantizedModule |
|
|
|
|
|
__all__ = ['Linear'] |
|
|
|
|
|
class Linear(nn.Linear, ReferenceQuantizedModule): |
|
|
""" A reference quantized linear module that fits into the FX |
|
|
Graph Mode Quantization workflow |
|
|
activation will be floating point Tensor, we will store floating |
|
|
point weight as well in the module, but in forward we'll quantize |
|
|
and dequantize the weight before running the floating point functional |
|
|
linear operator. |
|
|
""" |
|
|
_IS_REFERENCE = True |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_features: int, |
|
|
out_features: int, |
|
|
bias_: bool = True, |
|
|
device: Optional[torch.device] = None, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
weight_qparams: Optional[Dict[str, Any]] = None): |
|
|
super().__init__(in_features, out_features, bias_, device, dtype) |
|
|
self._init_weight_qparams(weight_qparams, device) |
|
|
|
|
|
def _get_name(self): |
|
|
return "QuantizedLinear(Reference)" |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
we have: |
|
|
w(float) -- quant - dequant \ |
|
|
x(float) ------------- F.linear --- |
|
|
|
|
|
In the full model, we will see |
|
|
w(float) -- quant - *dequant \ |
|
|
x -- quant --- *dequant -- *F.linear --- *quant - dequant |
|
|
and the backend should be able to fuse the ops with `*` into a quantized linear |
|
|
""" |
|
|
weight_quant_dequant = self.get_weight() |
|
|
result = F.linear(x, weight_quant_dequant, self.bias) |
|
|
return result |
|
|
|
|
|
@classmethod |
|
|
def from_float(cls, float_linear, weight_qparams): |
|
|
qref_linear = Linear( |
|
|
float_linear.in_features, float_linear.out_features, |
|
|
float_linear.bias is not None, device=float_linear.weight.device, |
|
|
dtype=float_linear.weight.dtype, weight_qparams=weight_qparams) |
|
|
qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach()) |
|
|
if float_linear.bias is not None: |
|
|
qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach()) |
|
|
return qref_linear |
|
|
|