|
|
import torch |
|
|
from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns, sorted_patterns_dict |
|
|
from torch.ao.quantization.backend_config import ( |
|
|
get_native_backend_config, |
|
|
ObservationType, |
|
|
) |
|
|
from torch.ao.quantization.quantization_types import ( |
|
|
Pattern, |
|
|
NodePattern, |
|
|
QuantizerCls, |
|
|
) |
|
|
from torch.ao.quantization.utils import ( |
|
|
activation_dtype, |
|
|
get_combined_dict, |
|
|
) |
|
|
|
|
|
from ..backend_config import BackendConfig |
|
|
from .quantization_patterns import QuantizeHandler |
|
|
from .fusion_patterns import DefaultFuseHandler |
|
|
|
|
|
from typing import Dict, Any, Callable, Optional |
|
|
|
|
|
def get_quantize_handler_cls( |
|
|
observation_type, |
|
|
dtype_configs, |
|
|
num_tensor_args_to_observation_type, |
|
|
overwrite_output_fake_quantizer, |
|
|
overwrite_output_observer, |
|
|
input_output_observed): |
|
|
|
|
|
class ConfigurableQuantizeHandler(QuantizeHandler): |
|
|
def __init__( |
|
|
self, |
|
|
node_pattern: NodePattern, |
|
|
modules: Dict[str, torch.nn.Module], |
|
|
root_node_getter: Callable = None): |
|
|
super().__init__(node_pattern, modules, root_node_getter) |
|
|
if num_tensor_args_to_observation_type: |
|
|
assert self.num_tensor_args in num_tensor_args_to_observation_type, \ |
|
|
f"Must provide observation_type config for tensor number {self.num_tensor_args}" \ |
|
|
f" in num_tensor_args_to_observation_type for {node_pattern}" |
|
|
self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args] |
|
|
else: |
|
|
self.observation_type = observation_type |
|
|
self.dtype_configs = dtype_configs |
|
|
self.overwrite_output_fake_quantizer = overwrite_output_fake_quantizer |
|
|
self.overwrite_output_observer = overwrite_output_observer |
|
|
self.input_output_observed_ = input_output_observed |
|
|
|
|
|
def is_general_tensor_value_op(self) -> bool: |
|
|
return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT |
|
|
|
|
|
|
|
|
def get_activation_ctr( |
|
|
self, |
|
|
qconfig: Any, |
|
|
pattern: Pattern, |
|
|
is_training: bool, |
|
|
) -> Optional[Callable]: |
|
|
""" |
|
|
Returns the constructor for the activation observer which should be |
|
|
used for the pattern matched to this handler. Some handlers override |
|
|
this to a different value than what is specified in the qconfig. |
|
|
""" |
|
|
act_dtype = activation_dtype(qconfig) |
|
|
|
|
|
if is_training: |
|
|
if act_dtype == torch.quint8 and self.overwrite_output_fake_quantizer is not None: |
|
|
return self.overwrite_output_fake_quantizer |
|
|
else: |
|
|
if act_dtype == torch.quint8 and self.overwrite_output_observer is not None: |
|
|
return self.overwrite_output_observer |
|
|
return qconfig.activation |
|
|
|
|
|
|
|
|
def input_output_observed(self): |
|
|
return self.input_output_observed_ |
|
|
|
|
|
|
|
|
return ConfigurableQuantizeHandler |
|
|
|
|
|
def get_pattern_to_quantize_handlers(backend_config: BackendConfig) -> Dict[Pattern, QuantizerCls]: |
|
|
""" |
|
|
Note: Quantize handler is just a holder for some check methods like |
|
|
(should_insert_observer_for_output), maybe this can be a enum as well, |
|
|
we can refactor this after we convert the path for fbgemm/qnnpack fully to the |
|
|
new path, this is not exposed to backend developers |
|
|
""" |
|
|
pattern_to_quantize_handlers = {} |
|
|
for pattern, config in backend_config.configs.items(): |
|
|
observation_type = config.observation_type |
|
|
dtype_configs = config.dtype_configs |
|
|
num_tensor_args_to_observation_type = config._num_tensor_args_to_observation_type |
|
|
overwrite_fake_quantizer = config._overwrite_output_fake_quantize |
|
|
overwrite_observer = config._overwrite_output_observer |
|
|
input_output_observed = config._input_output_observed |
|
|
if input_output_observed is None: |
|
|
input_output_observed = True |
|
|
pattern_to_quantize_handlers[pattern] = \ |
|
|
get_quantize_handler_cls( |
|
|
observation_type, |
|
|
dtype_configs, |
|
|
num_tensor_args_to_observation_type, |
|
|
overwrite_fake_quantizer, |
|
|
overwrite_observer, |
|
|
input_output_observed) |
|
|
|
|
|
return pattern_to_quantize_handlers |
|
|
|
|
|
|
|
|
def get_fusion_pattern_to_fuse_handler_cls( |
|
|
backend_config: BackendConfig) -> Dict[Pattern, Callable]: |
|
|
fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {} |
|
|
for pattern, config in backend_config.configs.items(): |
|
|
if config.fuser_method is not None: |
|
|
|
|
|
fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler |
|
|
|
|
|
return fusion_pattern_to_fuse_handlers |
|
|
|
|
|
|
|
|
def get_native_quant_patterns(additional_quant_patterns: Dict[Pattern, QuantizerCls] = None) -> Dict[Pattern, QuantizerCls]: |
|
|
""" |
|
|
Return a map from pattern to quantize handlers based on the default patterns and the native backend_config. |
|
|
The returned map is sorted such that longer patterns will be encountered first when iterating through it. |
|
|
""" |
|
|
patterns = get_default_quant_patterns() |
|
|
if additional_quant_patterns is not None: |
|
|
patterns = get_combined_dict(patterns, additional_quant_patterns) |
|
|
|
|
|
|
|
|
|
|
|
for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config()).items(): |
|
|
patterns[pattern] = quantize_handler |
|
|
return sorted_patterns_dict(patterns) |
|
|
|
|
|
get_fusion_pattern_to_fuse_handler_cls.__module__ = "torch.ao.quantization.fx.backend_config_utils" |
|
|
get_native_quant_patterns.__module__ = "torch.ao.quantization.fx.backend_config_utils" |
|
|
get_pattern_to_quantize_handlers.__module__ = "torch.ao.quantization.fx.backend_config_utils" |
|
|
|
|
|
__all__ = [ |
|
|
"get_fusion_pattern_to_fuse_handler_cls", |
|
|
"get_native_quant_patterns", |
|
|
"get_pattern_to_quantize_handlers", |
|
|
] |
|
|
|