| import operator |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| toq = torch.ops.quantized |
|
|
| import torch.ao.nn.quantized as nnq |
| import torch.ao.nn.quantized.dynamic as nnqd |
| import torch.nn.intrinsic.quantized as nniq |
| import torch.nn.intrinsic.quantized.dynamic as nniqd |
| import torch.nn.intrinsic.qat as nniqat |
| import torch.nn.intrinsic as nni |
| import torch.ao.nn.qat as nnqat |
| import torch.ao.nn.qat.dynamic as nnqatd |
| from torch.ao.quantization.backend_config import get_native_backend_config_dict |
| import torch.ao.quantization.fx._lower_to_native_backend as \ |
| _lower_to_native_backend |
| import torch.ao.quantization.quantization_mappings as quantization_mappings |
|
|
| from .ns_types import NSNodeTargetType |
|
|
| from typing import Set, Dict, List, Optional |
|
|
|
|
| def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: |
| |
| sets_of_related_ops: List[Set[NSNodeTargetType]] = [ |
| |
| set([ |
| nn.Conv1d, |
| ]), |
| set([ |
| nn.Conv2d, |
| ]), |
| set([ |
| nn.Conv3d, |
| ]), |
| |
| set([ |
| F.conv1d, |
| ]), |
| set([ |
| F.conv2d, |
| ]), |
| set([ |
| F.conv3d, |
| ]), |
| |
| set([ |
| nn.Linear, |
| ]), |
| |
| set([ |
| F.linear, |
| ]), |
| |
| set([ |
| nn.AvgPool1d, |
| torch.avg_pool1d, |
| ]), |
| set([ |
| nn.AvgPool2d, |
| torch._C._nn.avg_pool2d, |
| ]), |
| set([ |
| nn.AvgPool3d, |
| torch._C._nn.avg_pool3d, |
| ]), |
| |
| set([ |
| nn.AdaptiveAvgPool1d, |
| F.adaptive_avg_pool1d, |
| ]), |
| set([ |
| nn.AdaptiveAvgPool2d, |
| F.adaptive_avg_pool2d, |
| ]), |
| set([ |
| nn.AdaptiveAvgPool3d, |
| F.adaptive_avg_pool3d, |
| ]), |
| |
| set([ |
| nn.LSTM, |
| ]), |
| |
| set([ |
| torch.add, |
| operator.add, |
| ]), |
| |
| set([ |
| torch.cat, |
| ]), |
| |
| set([ |
| torch.mul, |
| operator.mul, |
| ]), |
| |
| set([ |
| F.relu, |
| nn.ReLU, |
| 'relu', |
| 'relu_', |
| torch.relu, |
| ]), |
| |
| set([ |
| nn.MaxPool1d, |
| F.max_pool1d, |
| ]), |
| set([ |
| nn.MaxPool2d, |
| F.max_pool2d, |
| ]), |
| set([ |
| nn.MaxPool3d, |
| F.max_pool3d, |
| ]), |
| |
| set([ |
| torch.sigmoid, |
| 'sigmoid', |
| 'sigmoid_', |
| nn.Sigmoid, |
| F.sigmoid, |
| ]), |
| |
| set([ |
| nn.BatchNorm2d, |
| ]), |
| set([ |
| nn.BatchNorm3d, |
| ]), |
| |
| set([ |
| nn.ConvTranspose1d, |
| ]), |
| set([ |
| nn.ConvTranspose2d, |
| ]), |
| set([ |
| nn.ConvTranspose3d, |
| ]), |
| |
| set([ |
| nn.ELU, |
| ]), |
| |
| set([ |
| nn.Embedding, |
| ]), |
| |
| set([ |
| nn.EmbeddingBag, |
| ]), |
| |
| set([ |
| nn.GroupNorm, |
| ]), |
| |
| set([ |
| nn.Hardswish, |
| ]), |
| |
| set([ |
| nn.InstanceNorm1d, |
| ]), |
| set([ |
| nn.InstanceNorm2d, |
| ]), |
| set([ |
| nn.InstanceNorm3d, |
| ]), |
| |
| set([ |
| nn.LayerNorm, |
| ]), |
| |
| set([ |
| nn.LeakyReLU, |
| ]), |
| |
| set([ |
| nn.ReLU6, |
| F.relu6, |
| ]), |
| |
| set([ |
| F.elu, |
| ]), |
| |
| set([ |
| F.hardswish, |
| ]), |
| |
| set([ |
| F.group_norm, |
| ]), |
| |
| set([ |
| F.instance_norm, |
| ]), |
| |
| set([ |
| F.layer_norm, |
| ]), |
| |
| set([ |
| F.leaky_relu, |
| ]), |
| |
| set([ |
| nn.SiLU, |
| F.silu, |
| ]), |
| |
| set([ |
| nn.Mish, |
| F.mish, |
| ]), |
| |
| set([ |
| nn.Tanh, |
| F.tanh, |
| torch.tanh, |
| 'tanh_', |
| 'tanh', |
| ]), |
| |
| set([ |
| 'hardsigmoid_', |
| 'hardsigmoid', |
| F.hardsigmoid, |
| nn.Hardsigmoid, |
| ]), |
| |
| set([ |
| nn.Hardtanh, |
| F.hardtanh, |
| F.hardtanh_, |
| ]), |
| |
| set([ |
| operator.floordiv, |
| ]), |
| |
| set([ |
| torch.unsqueeze, |
| ]), |
| |
| set([ |
| torch.stack, |
| ]), |
| |
| set([ |
| torch.squeeze, |
| ]), |
| |
| set([ |
| torch.sort, |
| ]), |
| |
| set([ |
| torch.repeat_interleave, |
| ]), |
| |
| set([ |
| torch.min, |
| ]), |
| |
| set([ |
| torch.mean, |
| ]), |
| |
| set([ |
| torch.max, |
| ]), |
| |
| set([ |
| torch.transpose, |
| ]), |
| |
| set([ |
| torch.flatten, |
| ]), |
| |
| set([ |
| torch.clamp, |
| ]), |
| |
| set([ |
| torch.chunk, |
| ]), |
| |
| set([ |
| torch.nn.functional.interpolate, |
| ]), |
| |
| set([ |
| nn.Dropout, |
| ]), |
| |
| set([ |
| F.dropout, |
| ]), |
| |
| set([ |
| torch.matmul, |
| ]), |
| |
| set([ |
| nn.Softmax, |
| ]), |
| |
| set([ |
| nn.PReLU, |
| nnq.PReLU, |
| ]), |
| |
| set([ |
| F.prelu, |
| toq.prelu, |
| ]), |
| ] |
|
|
| |
| |
| backend_config_dict = get_native_backend_config_dict() |
|
|
| new_connections = [ |
| |
| (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear), |
| ] |
|
|
| for config in backend_config_dict['configs']: |
|
|
| if 'pattern' not in config: |
| continue |
|
|
| |
| pattern = config['pattern'] |
| first_element = pattern |
| |
| while isinstance(first_element, (list, tuple)): |
| first_element = first_element[-1] |
|
|
| if 'fused_module' in config: |
| |
| |
| new_connections.append((first_element, config['fused_module'])) |
|
|
| if 'qat_module' in config: |
| |
| |
| new_connections.append((first_element, config['qat_module'])) |
|
|
| if 'reference_quantized_module_for_root' in config: |
| |
| |
| new_connections.append( |
| (first_element, config['reference_quantized_module_for_root']) |
| ) |
|
|
| |
| |
| |
|
|
| for source_to_target in ( |
| _lower_to_native_backend.STATIC_LOWER_MODULE_MAP, |
| _lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP, |
| _lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP, |
| _lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP, |
| ): |
| for source, target in source_to_target.items(): |
| new_connections.append((source, target)) |
|
|
| for source_to_double_target in ( |
| _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP, |
| _lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP, |
| ): |
| for source, (target1, target2) in source_to_double_target.items(): |
| new_connections.append((source, target1)) |
| new_connections.append((source, target2)) |
|
|
| |
| |
| |
|
|
| for source, (target1, target2) in \ |
| _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items(): |
| new_connections.append((source, target1)) |
| new_connections.append((source, target2)) |
|
|
| for source_to_target in ( |
| _lower_to_native_backend.QBIN_OP_MAPPING, |
| _lower_to_native_backend.QBIN_RELU_OP_MAPPING, |
| quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, |
| ): |
| for source, target in source_to_target.items(): |
| new_connections.append((source, target)) |
|
|
| |
| |
| |
| |
| for source_to_target in ( |
| quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, |
| ): |
| for source, target in source_to_target.items(): |
| new_connections.append((source, target)) |
|
|
|
|
| |
| for item1, item2 in new_connections: |
| for set_of_related_ops in sets_of_related_ops: |
| if item1 in set_of_related_ops or item2 in set_of_related_ops: |
| set_of_related_ops.add(item1) |
| set_of_related_ops.add(item2) |
| break |
|
|
| base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]] = {} |
|
|
| counter = 0 |
| for set_of_related_ops in sets_of_related_ops: |
| base_name = str(counter) |
| counter += 1 |
| base_name_to_sets_of_related_ops[base_name] = set_of_related_ops |
|
|
| return base_name_to_sets_of_related_ops |
|
|
|
|
| def get_base_name_for_op( |
| base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]], |
| op: NSNodeTargetType, |
| ) -> Optional[str]: |
| for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items(): |
| if op in set_of_related_ops: |
| return base_name |
| return None |
|
|
|
|
| def add_op_to_sets_of_related_ops( |
| base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]], |
| op: NSNodeTargetType, |
| related_op: Optional[NSNodeTargetType], |
| ) -> None: |
| if related_op is not None: |
| for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items(): |
| if related_op in set_of_related_ops: |
| set_of_related_ops.add(op) |
| return |
| |
| raise AssertionError(f"{related_op} was not found") |
| else: |
| counter = 0 |
| while str(counter) in base_name_to_sets_of_related_ops: |
| counter += 1 |
| base_name_to_sets_of_related_ops[str(counter)] = set([op]) |
|
|
|
|
| |
| def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: |
| FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = set([ |
| F.linear, |
| F.conv1d, |
| F.conv2d, |
| F.conv3d, |
| torch.cat, |
| F.elu, |
| F.hardswish, |
| F.instance_norm, |
| F.layer_norm, |
| F.leaky_relu, |
| F.dropout, |
| F.silu, |
| F.mish, |
| operator.add, |
| torch.add, |
| operator.mul, |
| torch.mul, |
| torch.sum, |
| F.prelu, |
| ]) |
|
|
| FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set() |
|
|
| FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([ |
| toq.linear, |
| toq.linear_relu, |
| toq.conv1d, |
| toq.conv1d_relu, |
| toq.conv2d, |
| toq.conv2d_relu, |
| toq.conv3d, |
| toq.conv3d_relu, |
| toq.cat, |
| toq.elu, |
| toq.hardswish, |
| toq.instance_norm, |
| toq.layer_norm, |
| toq.leaky_relu, |
| toq.dropout, |
| toq.prelu, |
| |
| |
| |
| |
| ]) |
|
|
| FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([ |
| F.relu, |
| F.tanh, |
| torch.tanh, |
| F.sigmoid, |
| torch.sigmoid, |
| F.hardsigmoid, |
| operator.floordiv, |
| torch.adaptive_avg_pool1d, |
| F.adaptive_avg_pool2d, |
| F.adaptive_avg_pool3d, |
| F.dropout, |
| F.hardtanh, |
| F.hardtanh_, |
| F.interpolate, |
| F.max_pool1d, |
| F.max_pool2d, |
| F.max_pool3d, |
| F.relu6, |
| torch.avg_pool1d, |
| torch._C._nn.avg_pool2d, |
| torch._C._nn.avg_pool3d, |
| torch.cat, |
| torch.chunk, |
| torch.clamp, |
| torch.flatten, |
| torch.transpose, |
| torch.max, |
| torch.mean, |
| torch.min, |
| torch.repeat_interleave, |
| torch.sort, |
| torch.squeeze, |
| torch.stack, |
| torch.unsqueeze, |
| operator.add, |
| ]) |
|
|
| MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = set([ |
| nn.Linear, |
| nnqat.Linear, |
| nnqatd.Linear, |
| nnqd.Linear, |
| torch.nn.modules.linear.NonDynamicallyQuantizableLinear, |
| nn.Conv1d, |
| nn.Conv2d, |
| nn.Conv3d, |
| nnqat.Conv1d, |
| nnqat.Conv2d, |
| nnqat.Conv3d, |
| nnqat.Embedding, |
| nnqat.EmbeddingBag, |
| nn.LSTM, |
| |
| |
| nnqd.LSTM, |
| nn.BatchNorm2d, |
| nn.BatchNorm3d, |
| nn.Dropout, |
| nn.ConvTranspose1d, |
| nn.ConvTranspose2d, |
| nn.ConvTranspose3d, |
| nn.ELU, |
| nn.GroupNorm, |
| nn.InstanceNorm1d, |
| nn.InstanceNorm2d, |
| nn.InstanceNorm3d, |
| nn.LayerNorm, |
| nn.Hardswish, |
| nn.LeakyReLU, |
| nn.ReLU6, |
| nn.SiLU, |
| nn.Mish, |
| nn.Softmax, |
| nn.PReLU, |
| nni.BNReLU2d, |
| nni.BNReLU3d, |
| nni.ConvReLU1d, |
| nni.ConvReLU2d, |
| nni.ConvReLU3d, |
| nni.LinearReLU, |
| nni.LinearBn1d, |
| nni.ConvBn1d, |
| nni.ConvBn2d, |
| nni.ConvBn3d, |
| nniqat.ConvBn1d, |
| nniqat.ConvBn2d, |
| nniqat.ConvBn3d, |
| nniqat.ConvBnReLU1d, |
| nniqat.ConvBnReLU2d, |
| nniqat.ConvBnReLU3d, |
| nniqat.ConvReLU1d, |
| nniqat.ConvReLU2d, |
| nniqat.ConvReLU3d, |
| nniqat.LinearReLU, |
| nniqat.LinearBn1d, |
| nniqd.LinearReLU, |
| ]) |
|
|
| MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([ |
| nnq.Linear, |
| nnq.Conv1d, |
| nnq.Conv2d, |
| nnq.Conv3d, |
| nnq.BatchNorm2d, |
| nnq.BatchNorm3d, |
| nnq.Dropout, |
| nnq.ConvTranspose1d, |
| nnq.ConvTranspose2d, |
| nnq.ELU, |
| nnq.InstanceNorm1d, |
| nnq.InstanceNorm2d, |
| nnq.InstanceNorm3d, |
| nnq.LayerNorm, |
| nnq.Hardswish, |
| nnq.LeakyReLU, |
| nnq.Embedding, |
| nnq.EmbeddingBag, |
| nnq.Dropout, |
| nnq.Softmax, |
| nnq.PReLU, |
| nniq.BNReLU2d, |
| nniq.BNReLU3d, |
| nniq.ConvReLU1d, |
| nniq.ConvReLU2d, |
| nniq.ConvReLU3d, |
| nniq.LinearReLU, |
| ]) |
|
|
| MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([ |
| nn.ReLU, |
| nn.Tanh, |
| nn.Sigmoid, |
| nn.Hardsigmoid, |
| nn.AdaptiveAvgPool1d, |
| nn.AdaptiveAvgPool2d, |
| nn.AdaptiveAvgPool3d, |
| nn.AvgPool1d, |
| nn.AvgPool2d, |
| nn.AvgPool3d, |
| nn.Dropout, |
| nn.Hardtanh, |
| nn.Identity, |
| nn.MaxPool1d, |
| nn.MaxPool2d, |
| nn.MaxPool3d, |
| nn.ReLU6, |
| ]) |
|
|
| METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([ |
| 'sigmoid_', |
| 'sigmoid', |
| 'tanh_', |
| 'tanh', |
| 'hardsigmoid_', |
| 'hardsigmoid', |
| 'relu_', |
| 'relu', |
| ]) |
|
|
| return { |
| 'funs_io_type_fp32': FUNS_IO_TYPE_FP32, |
| 'funs_io_type_fp16': FUNS_IO_TYPE_FP16, |
| 'funs_io_type_int8': FUNS_IO_TYPE_INT8, |
| 'funs_io_type_fp32_or_int8': FUNS_IO_TYPE_FP32_OR_INT8, |
| 'mods_io_type_fp32': MODS_IO_TYPE_FP32, |
| 'mods_io_type_int8': MODS_IO_TYPE_INT8, |
| 'mods_io_type_fp32_or_int8': MODS_IO_TYPE_FP32_OR_INT8, |
| 'meths_io_type_fp32_or_int8': METHS_IO_TYPE_FP32_OR_INT8, |
| } |
|
|
|
|
| def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]: |
|
|
| FUNS_UNMATCHABLE: Set[NSNodeTargetType] = set([ |
| torch.quantize_per_tensor, |
| operator.getitem, |
| ]) |
|
|
| MODS_UNMATCHABLE: Set[NSNodeTargetType] = set([ |
| nn.Identity, |
| ]) |
|
|
| METHS_UNMATCHABLE: Set[NSNodeTargetType] = set([ |
| 'to', |
| 'dequantize', |
| 'reshape', |
| 'view', |
| 'unsqueeze_', |
| 'unsqueeze', |
| 'transpose', |
| 'squeeze_', |
| 'squeeze', |
| 'size', |
| 'shape', |
| 'resize_', |
| 'repeat_interleave', |
| 'repeat', |
| 'permute', |
| 'numel', |
| 'mean', |
| 'detach_', |
| 'detach', |
| 'contiguous', |
| 'clamp', |
| 'chunk', |
| ]) |
|
|
| return { |
| 'funs_unmatchable': FUNS_UNMATCHABLE, |
| 'mods_unmatchable': MODS_UNMATCHABLE, |
| 'meths_unmatchable': METHS_UNMATCHABLE, |
| } |
|
|