JSX_TTS / torch /ao /quantization /fx /backend_config_utils.py
UMMJ's picture
Upload 5875 files
9dd3461
import torch
from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns, sorted_patterns_dict
from torch.ao.quantization.backend_config import (
get_native_backend_config,
ObservationType,
)
from torch.ao.quantization.quantization_types import (
Pattern,
NodePattern,
QuantizerCls,
)
from torch.ao.quantization.utils import (
activation_dtype,
get_combined_dict,
)
from ..backend_config import BackendConfig
from .quantization_patterns import QuantizeHandler
from .fusion_patterns import DefaultFuseHandler
from typing import Dict, Any, Callable, Optional
def get_quantize_handler_cls(
observation_type,
dtype_configs,
num_tensor_args_to_observation_type,
overwrite_output_fake_quantizer,
overwrite_output_observer,
input_output_observed):
class ConfigurableQuantizeHandler(QuantizeHandler):
def __init__(
self,
node_pattern: NodePattern,
modules: Dict[str, torch.nn.Module],
root_node_getter: Callable = None):
super().__init__(node_pattern, modules, root_node_getter)
if num_tensor_args_to_observation_type:
assert self.num_tensor_args in num_tensor_args_to_observation_type, \
f"Must provide observation_type config for tensor number {self.num_tensor_args}" \
f" in num_tensor_args_to_observation_type for {node_pattern}"
self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args]
else:
self.observation_type = observation_type
self.dtype_configs = dtype_configs
self.overwrite_output_fake_quantizer = overwrite_output_fake_quantizer
self.overwrite_output_observer = overwrite_output_observer
self.input_output_observed_ = input_output_observed
def is_general_tensor_value_op(self) -> bool:
return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
# TODO: change this to output activation
def get_activation_ctr(
self,
qconfig: Any,
pattern: Pattern,
is_training: bool,
) -> Optional[Callable]:
"""
Returns the constructor for the activation observer which should be
used for the pattern matched to this handler. Some handlers override
this to a different value than what is specified in the qconfig.
"""
act_dtype = activation_dtype(qconfig)
# TODO: change to is_qat
if is_training:
if act_dtype == torch.quint8 and self.overwrite_output_fake_quantizer is not None:
return self.overwrite_output_fake_quantizer
else:
if act_dtype == torch.quint8 and self.overwrite_output_observer is not None:
return self.overwrite_output_observer
return qconfig.activation
# This is temporary, and will be removed soon
def input_output_observed(self):
return self.input_output_observed_
return ConfigurableQuantizeHandler
def get_pattern_to_quantize_handlers(backend_config: BackendConfig) -> Dict[Pattern, QuantizerCls]:
"""
Note: Quantize handler is just a holder for some check methods like
(should_insert_observer_for_output), maybe this can be a enum as well,
we can refactor this after we convert the path for fbgemm/qnnpack fully to the
new path, this is not exposed to backend developers
"""
pattern_to_quantize_handlers = {}
for pattern, config in backend_config.configs.items():
observation_type = config.observation_type
dtype_configs = config.dtype_configs
num_tensor_args_to_observation_type = config._num_tensor_args_to_observation_type
overwrite_fake_quantizer = config._overwrite_output_fake_quantize
overwrite_observer = config._overwrite_output_observer
input_output_observed = config._input_output_observed
if input_output_observed is None:
input_output_observed = True
pattern_to_quantize_handlers[pattern] = \
get_quantize_handler_cls(
observation_type,
dtype_configs,
num_tensor_args_to_observation_type,
overwrite_fake_quantizer,
overwrite_observer,
input_output_observed)
return pattern_to_quantize_handlers
# TODO: move this to torch/ao/quantization/backend_config/utils.py
def get_fusion_pattern_to_fuse_handler_cls(
backend_config: BackendConfig) -> Dict[Pattern, Callable]:
fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {}
for pattern, config in backend_config.configs.items():
if config.fuser_method is not None:
# TODO: is this logic right?
fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
return fusion_pattern_to_fuse_handlers
# TODO: remove when all uses are changed to backend_config
def get_native_quant_patterns(additional_quant_patterns: Dict[Pattern, QuantizerCls] = None) -> Dict[Pattern, QuantizerCls]:
"""
Return a map from pattern to quantize handlers based on the default patterns and the native backend_config.
The returned map is sorted such that longer patterns will be encountered first when iterating through it.
"""
patterns = get_default_quant_patterns()
if additional_quant_patterns is not None:
patterns = get_combined_dict(patterns, additional_quant_patterns)
# TODO: currently we just extend the quantize handlers generated from
# `get_native_backend_config`
# in the future we can just assign backend_config when everything is defined
for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config()).items():
patterns[pattern] = quantize_handler
return sorted_patterns_dict(patterns)
get_fusion_pattern_to_fuse_handler_cls.__module__ = "torch.ao.quantization.fx.backend_config_utils"
get_native_quant_patterns.__module__ = "torch.ao.quantization.fx.backend_config_utils"
get_pattern_to_quantize_handlers.__module__ = "torch.ao.quantization.fx.backend_config_utils"
__all__ = [
"get_fusion_pattern_to_fuse_handler_cls",
"get_native_quant_patterns",
"get_pattern_to_quantize_handlers",
]