| | from __future__ import annotations |
| | from dataclasses import dataclass |
| | from typing import Any, Callable, Dict, List, Optional, Type, Union |
| |
|
| | import torch |
| | from torch.ao.quantization.observer import _PartialWrapper |
| | from torch.ao.quantization.utils import Pattern |
| | from enum import Enum |
| |
|
| |
|
| | __all__ = [ |
| | "BackendConfig", |
| | "BackendPatternConfig", |
| | "DTypeConfig", |
| | "DTypeWithConstraints", |
| | "ObservationType", |
| | ] |
| |
|
| |
|
| | |
| | INPUT_DTYPE_DICT_KEY = "input_dtype" |
| | OUTPUT_DTYPE_DICT_KEY = "output_dtype" |
| | WEIGHT_DTYPE_DICT_KEY = "weight_dtype" |
| | BIAS_DTYPE_DICT_KEY = "bias_dtype" |
| | IS_DYNAMIC_DICT_KEY = "is_dynamic" |
| |
|
| | |
| | NAME_DICT_KEY = "name" |
| | CONFIGS_DICT_KEY = "configs" |
| |
|
| | |
| | PATTERN_DICT_KEY = "pattern" |
| | OBSERVATION_TYPE_DICT_KEY = "observation_type" |
| | DTYPE_CONFIGS_DICT_KEY = "dtype_configs" |
| | ROOT_MODULE_DICT_KEY = "root_module" |
| | QAT_MODULE_DICT_KEY = "qat_module" |
| | REFERENCE_QUANTIZED_MODULE_DICT_KEY = "reference_quantized_module_for_root" |
| | FUSED_MODULE_DICT_KEY = "fused_module" |
| | FUSER_METHOD_DICT_KEY = "fuser_method" |
| | ROOT_NODE_GETTER_DICT_KEY = "root_node_getter" |
| | EXTRA_INPUTS_GETTER_DICT_KEY = "extra_inputs_getter" |
| | NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type" |
| | INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index" |
| | INPUT_OUTPUT_OBSERVED_DICT_KEY = "input_output_observed" |
| | OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY = "overwrite_output_fake_quantize" |
| | OVERWRITE_OUTPUT_OBSERVER_DICT_KEY = "overwrite_output_observer" |
| |
|
| |
|
| | |
| | |
| | class ObservationType(Enum): |
| | """ An enum that represents different ways of how an operator/operator pattern |
| | should be observed |
| | """ |
| |
|
| | OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT = 0 |
| | """this means input and output are observed with different observers, based |
| | on qconfig.activation |
| | example: conv, linear, softmax |
| | """ |
| |
|
| | OUTPUT_SHARE_OBSERVER_WITH_INPUT = 1 |
| | """this means the output will use the same observer instance as input, based |
| | on qconfig.activation |
| | example: torch.cat, maxpool |
| | """ |
| |
|
| |
|
| | @dataclass |
| | class DTypeWithConstraints: |
| | """ |
| | Config for specifying additional constraints for a given dtype, such as quantization value |
| | ranges and scale value ranges, to be used in :class:`~torch.ao.quantization.backend_config.DTypeConfig`. |
| | """ |
| | dtype: Optional[torch.dtype] = None |
| | quant_min_lower_bound: Union[int, float, None] = None |
| | quant_max_upper_bound: Union[int, float, None] = None |
| | scale_min_lower_bound: Union[int, float, None] = None |
| | scale_max_upper_bound: Union[int, float, None] = None |
| |
|
| |
|
| | @dataclass |
| | class DTypeConfig: |
| | """ |
| | Config for the set of supported input/output activation, weight, and bias data types for the |
| | patterns defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`. |
| | |
| | Example usage:: |
| | |
| | >>> dtype_config1 = DTypeConfig( |
| | ... input_dtype=torch.quint8, |
| | ... output_dtype=torch.quint8, |
| | ... weight_dtype=torch.qint8, |
| | ... bias_dtype=torch.float) |
| | |
| | >>> dtype_config2 = DTypeConfig( |
| | ... input_dtype=DTypeWithConstraints( |
| | ... dtype=torch.quint8, |
| | ... quant_min_lower_bound=0, |
| | ... quant_max_upper_bound=255, |
| | ... ), |
| | ... output_dtype=DTypeWithConstraints( |
| | ... dtype=torch.quint8, |
| | ... quant_min_lower_bound=0, |
| | ... quant_max_upper_bound=255, |
| | ... ), |
| | ... weight_dtype=DTypeWithConstraints( |
| | ... dtype=torch.qint8, |
| | ... quant_min_lower_bound=-128, |
| | ... quant_max_upper_bound=127, |
| | ... ), |
| | ... bias_dtype=torch.float) |
| | |
| | >>> dtype_config1.input_dtype |
| | torch.quint8 |
| | |
| | >>> dtype_config2.input_dtype |
| | torch.quint8 |
| | |
| | >>> dtype_config2.input_dtype_with_constraints |
| | DTypeWithConstraints(dtype=torch.quint8, quant_min_lower_bound=0, quant_max_upper_bound=255, \ |
| | scale_min_lower_bound=None, scale_max_upper_bound=None) |
| | """ |
| | input_dtype_with_constraints: DTypeWithConstraints |
| | output_dtype_with_constraints: DTypeWithConstraints |
| | weight_dtype_with_constraints: DTypeWithConstraints |
| | bias_dtype: Optional[torch.dtype] |
| | is_dynamic: Optional[bool] |
| |
|
| | def __init__( |
| | self, |
| | input_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None, |
| | output_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None, |
| | weight_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None, |
| | bias_dtype: Optional[torch.dtype] = None, |
| | is_dynamic: Optional[bool] = None, |
| | ): |
| | if isinstance(input_dtype, DTypeWithConstraints): |
| | self.input_dtype_with_constraints = input_dtype |
| | else: |
| | self.input_dtype_with_constraints = DTypeWithConstraints(dtype=input_dtype) |
| |
|
| | if isinstance(output_dtype, DTypeWithConstraints): |
| | self.output_dtype_with_constraints = output_dtype |
| | else: |
| | self.output_dtype_with_constraints = DTypeWithConstraints(dtype=output_dtype) |
| |
|
| | if isinstance(weight_dtype, DTypeWithConstraints): |
| | self.weight_dtype_with_constraints = weight_dtype |
| | else: |
| | self.weight_dtype_with_constraints = DTypeWithConstraints(dtype=weight_dtype) |
| |
|
| | self.bias_dtype = bias_dtype |
| | self.is_dynamic = is_dynamic |
| |
|
| | @property |
| | def input_dtype(self) -> Optional[torch.dtype]: |
| | return self.input_dtype_with_constraints.dtype |
| |
|
| | @property |
| | def output_dtype(self) -> Optional[torch.dtype]: |
| | return self.output_dtype_with_constraints.dtype |
| |
|
| | @property |
| | def weight_dtype(self) -> Optional[torch.dtype]: |
| | return self.weight_dtype_with_constraints.dtype |
| |
|
| | @classmethod |
| | def from_dict(cls, dtype_config_dict: Dict[str, Any]) -> DTypeConfig: |
| | """ |
| | Create a ``DTypeConfig`` from a dictionary with the following items (all optional): |
| | "input_dtype": torch.dtype or ``DTypeWithConstraints`` |
| | "output_dtype": torch.dtype or ``DTypeWithConstraints`` |
| | "weight_dtype": torch.dtype or ``DTypeWithConstraints`` |
| | "bias_type": torch.dtype |
| | "is_dynamic": bool |
| | """ |
| | input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None) |
| | if input_dtype is not None and not isinstance(input_dtype, (torch.dtype, DTypeWithConstraints)): |
| | raise ValueError("Expected input_dtype to be a torch.dtype or DTypeWithConstraints") |
| | output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None) |
| | if output_dtype is not None and not isinstance(output_dtype, (torch.dtype, DTypeWithConstraints)): |
| | raise ValueError("Expected output_dtype to be a torch.dtype or DTypeWithConstraints") |
| | weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None) |
| | if weight_dtype is not None and not isinstance(weight_dtype, (torch.dtype, DTypeWithConstraints)): |
| | raise ValueError("Expected weight_dtype to be a torch.dtype or DTypeWithConstraints") |
| | bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None) |
| | is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None) |
| | return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic) |
| |
|
| | def to_dict(self) -> Dict[str, Any]: |
| | """ |
| | Convert this ``DTypeConfig`` to a dictionary with the items described in |
| | :func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`. |
| | """ |
| | dtype_config_dict: Dict[str, Any] = {} |
| | if self.input_dtype is not None: |
| | dtype_config_dict[INPUT_DTYPE_DICT_KEY] = self.input_dtype_with_constraints |
| | if self.output_dtype is not None: |
| | dtype_config_dict[OUTPUT_DTYPE_DICT_KEY] = self.output_dtype_with_constraints |
| | if self.weight_dtype is not None: |
| | dtype_config_dict[WEIGHT_DTYPE_DICT_KEY] = self.weight_dtype_with_constraints |
| | if self.bias_dtype is not None: |
| | dtype_config_dict[BIAS_DTYPE_DICT_KEY] = self.bias_dtype |
| | if self.is_dynamic is not None: |
| | dtype_config_dict[IS_DYNAMIC_DICT_KEY] = self.is_dynamic |
| | return dtype_config_dict |
| |
|
| |
|
| | class BackendConfig: |
| | |
| | """Config that defines the set of patterns that can be quantized on a given backend, and how reference |
| | quantized models can be produced from these patterns. |
| | |
| | A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph |
| | of the above. Each pattern supported on the target backend can be individually configured through |
| | :class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of: |
| | |
| | (1) The supported input/output activation, weight, and bias data types |
| | |
| | (2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and |
| | |
| | (3) (Optionally) Fusion, QAT, and reference module mappings. |
| | |
| | The format of the patterns is described in: |
| | https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md |
| | |
| | Example usage:: |
| | |
| | import torch |
| | from torch.ao.quantization.backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType |
| | from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2 |
| | |
| | weighted_int8_dtype_config = DTypeConfig( |
| | input_dtype=torch.quint8, |
| | output_dtype=torch.quint8, |
| | weight_dtype=torch.qint8, |
| | bias_type=torch.float) |
| | |
| | linear_config = BackendPatternConfig(torch.nn.Linear) \ |
| | .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ |
| | .add_dtype_config(weighted_int8_dtype_config) \ |
| | .set_root_module(torch.nn.Linear) \ |
| | .set_qat_module(torch.nn.qat.Linear) \ |
| | .set_reference_quantized_module(torch.nn.quantized._reference.Linear) |
| | |
| | conv_relu_config = BackendPatternConfig((torch.nn.ReLU, torch.nn.Conv2d)) \ |
| | .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ |
| | .add_dtype_config(weighted_int8_dtype_config) \ |
| | .set_fused_module(torch.nn.intrinsic.ConvReLU2d) \ |
| | .set_fuser_method(reverse_sequential_wrapper2(torch.nn.intrinsic.ConvReLU2d)) |
| | |
| | backend_config = BackendConfig("my_backend") \ |
| | .set_backend_pattern_config(linear_config) \ |
| | .set_backend_pattern_config(conv_relu_config) |
| | |
| | """ |
| | def __init__(self, name: str = ""): |
| | self.name = name |
| | self.configs: Dict[Pattern, BackendPatternConfig] = {} |
| |
|
| | def set_name(self, name: str) -> BackendConfig: |
| | """ |
| | Set the name of the target backend. |
| | """ |
| | self.name = name |
| | return self |
| |
|
| | def set_backend_pattern_config(self, config: BackendPatternConfig) -> BackendConfig: |
| | """ |
| | Set the config for an pattern that can be run on the target backend. |
| | This overrides any existing config for the given pattern. |
| | """ |
| | self.configs[config.pattern] = config |
| | return self |
| |
|
| | def set_backend_pattern_configs(self, configs: List[BackendPatternConfig]) -> BackendConfig: |
| | """ |
| | Set the configs for patterns that can be run on the target backend. |
| | This overrides any existing config for a given pattern if it was previously registered already. |
| | """ |
| | for conf in configs: |
| | self.set_backend_pattern_config(conf) |
| | return self |
| |
|
| | @classmethod |
| | def from_dict(cls, backend_config_dict: Dict[str, Any]) -> BackendConfig: |
| | """ |
| | Create a ``BackendConfig`` from a dictionary with the following items: |
| | |
| | "name": the name of the target backend |
| | |
| | "configs": a list of dictionaries that each represents a `BackendPatternConfig` |
| | |
| | """ |
| | conf = cls(backend_config_dict.get(NAME_DICT_KEY, "")) |
| | for d in backend_config_dict.get(CONFIGS_DICT_KEY, []): |
| | if isinstance(d, BackendPatternConfig): |
| | conf.set_backend_pattern_config(d) |
| | elif isinstance(d, Dict): |
| | conf.set_backend_pattern_config(BackendPatternConfig.from_dict(d)) |
| | else: |
| | raise ValueError("Expected backend_config_dict['%s'] to be a dictionary" % CONFIGS_DICT_KEY) |
| | return conf |
| |
|
| | def to_dict(self) -> Dict[str, Any]: |
| | """ |
| | Convert this ``BackendConfig`` to a dictionary with the items described in |
| | :func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`. |
| | """ |
| | return { |
| | NAME_DICT_KEY: self.name, |
| | CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs.values()], |
| | } |
| |
|
| |
|
| | class BackendPatternConfig: |
| | """ |
| | Config for ops defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`. |
| | For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`. |
| | |
| | """ |
| | def __init__(self, pattern: Pattern): |
| | self.pattern = pattern |
| | self.observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT |
| | self.dtype_configs: List[DTypeConfig] = [] |
| | self.root_module: Optional[Type[torch.nn.Module]] = None |
| | self.qat_module: Optional[Type[torch.nn.Module]] = None |
| | self.reference_quantized_module: Optional[Type[torch.nn.Module]] = None |
| | self.fused_module: Optional[Type[torch.nn.Module]] = None |
| | self.fuser_method: Optional[Callable] = None |
| |
|
| | |
| | self._root_node_getter: Optional[Callable] = None |
| | self._extra_inputs_getter: Optional[Callable] = None |
| | self._num_tensor_args_to_observation_type: Dict[int, ObservationType] = {} |
| | self._input_type_to_index: Dict[str, int] = {} |
| | self._input_output_observed: Optional[bool] = None |
| | self._overwrite_output_fake_quantize: Optional[_PartialWrapper] = None |
| | self._overwrite_output_observer: Optional[_PartialWrapper] = None |
| |
|
| | def set_observation_type(self, observation_type: ObservationType) -> BackendPatternConfig: |
| | """ |
| | Set how observers should be inserted for this pattern. |
| | See :class:`~torch.ao.quantization.backend_config.ObservationType` for details |
| | |
| | """ |
| | self.observation_type = observation_type |
| | return self |
| |
|
| | def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig: |
| | """ |
| | Add a set of supported input/output activation, weight, and bias data types for this pattern. |
| | """ |
| | self.dtype_configs.append(dtype_config) |
| | return self |
| |
|
| | def set_dtype_configs(self, dtype_configs: List[DTypeConfig]) -> BackendPatternConfig: |
| | """ |
| | Set the supported input/output activation, weight, and bias data types for this pattern, |
| | overriding all previously registered data types. |
| | """ |
| | self.dtype_configs = dtype_configs |
| | return self |
| |
|
| | def set_root_module(self, root_module: Type[torch.nn.Module]) -> BackendPatternConfig: |
| | """ |
| | Set the module that represents the root for this pattern. |
| | For example, the root module for :class:`torch.nn.intrinsic.LinearReLU` should be :class:`torch.nn.Linear`. |
| | """ |
| | self.root_module = root_module |
| | return self |
| |
|
| | def set_qat_module(self, qat_module: Type[torch.nn.Module]) -> BackendPatternConfig: |
| | """ |
| | Set the module that represents the QAT implementation for this pattern. |
| | """ |
| | self.qat_module = qat_module |
| | return self |
| |
|
| | def set_reference_quantized_module(self, reference_quantized_module: Type[torch.nn.Module]) -> BackendPatternConfig: |
| | """ |
| | Set the module that represents the reference quantized implementation for this pattern's root module. |
| | """ |
| | self.reference_quantized_module = reference_quantized_module |
| | return self |
| |
|
| | def set_fused_module(self, fused_module: Type[torch.nn.Module]) -> BackendPatternConfig: |
| | """ |
| | Set the module that represents the fused implementation for this pattern. |
| | """ |
| | self.fused_module = fused_module |
| | return self |
| |
|
| | def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig: |
| | """ |
| | Set the function that specifies how to fuse the pattern for this pattern. |
| | """ |
| | self.fuser_method = fuser_method |
| | return self |
| |
|
| | def _set_root_node_getter(self, root_node_getter: Callable) -> BackendPatternConfig: |
| | self._root_node_getter = root_node_getter |
| | return self |
| |
|
| | def _set_extra_inputs_getter(self, extra_inputs_getter: Callable) -> BackendPatternConfig: |
| | self._extra_inputs_getter = extra_inputs_getter |
| | return self |
| |
|
| | def _set_num_tensor_args_to_observation_type( |
| | self, num_tensor_args_to_observation_type: Dict[int, ObservationType]) -> BackendPatternConfig: |
| | self._num_tensor_args_to_observation_type = num_tensor_args_to_observation_type |
| | return self |
| |
|
| | def _set_input_type_to_index(self, input_type_to_index: Dict[str, int]) -> BackendPatternConfig: |
| | self._input_type_to_index = input_type_to_index |
| | return self |
| |
|
| | def _set_input_output_observed(self, input_output_observed: bool) -> BackendPatternConfig: |
| | self._input_output_observed = input_output_observed |
| | return self |
| |
|
| | def _set_overwrite_output_fake_quantize(self, overwrite_output_fake_quantize: _PartialWrapper) -> BackendPatternConfig: |
| | self._overwrite_output_fake_quantize = overwrite_output_fake_quantize |
| | return self |
| |
|
| | def _set_overwrite_output_observer(self, overwrite_output_observer: _PartialWrapper) -> BackendPatternConfig: |
| | self._overwrite_output_observer = overwrite_output_observer |
| | return self |
| |
|
| | @classmethod |
| | def from_dict(cls, backend_pattern_config_dict: Dict[str, Any]) -> BackendPatternConfig: |
| | """ |
| | Create a ``BackendPatternConfig`` from a dictionary with the following items: |
| | |
| | "pattern": the pattern being configured |
| | "observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how |
| | observers should be inserted for this pattern |
| | "dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig` s |
| | "root_module": a :class:`torch.nn.Module` that represents the root for this pattern |
| | "qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern |
| | "reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized |
| | implementation for this pattern's root module. |
| | "fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern |
| | "fuser_method": a function that specifies how to fuse the pattern for this pattern |
| | |
| | """ |
| | def _get_dtype_config(obj: Any) -> DTypeConfig: |
| | """ |
| | Convert the given object into a ``DTypeConfig`` if possible, else throw an exception. |
| | """ |
| | if isinstance(obj, DTypeConfig): |
| | return obj |
| | if isinstance(obj, Dict): |
| | return DTypeConfig.from_dict(obj) |
| | raise ValueError("Expected a list of DTypeConfigs in backend_pattern_config_dict[\"%s\"], got '%s'" % |
| | (DTYPE_CONFIGS_DICT_KEY, type(obj))) |
| |
|
| | if PATTERN_DICT_KEY not in backend_pattern_config_dict: |
| | raise ValueError("backend_pattern_config_dict must contain '%s'" % PATTERN_DICT_KEY) |
| | conf = cls(backend_pattern_config_dict[PATTERN_DICT_KEY]) |
| | if OBSERVATION_TYPE_DICT_KEY in backend_pattern_config_dict: |
| | conf.set_observation_type(backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY]) |
| | for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []): |
| | conf.add_dtype_config(_get_dtype_config(d)) |
| | conf.set_root_module(backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None)) |
| | conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None)) |
| | conf.set_reference_quantized_module(backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None)) |
| | conf.set_fused_module(backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None)) |
| | conf.set_fuser_method(backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None)) |
| | conf._set_root_node_getter(backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None)) |
| | conf._set_extra_inputs_getter(backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None)) |
| | conf._set_num_tensor_args_to_observation_type( |
| | backend_pattern_config_dict.get(NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {})) |
| | conf._set_input_type_to_index(backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {})) |
| | conf._set_input_output_observed(backend_pattern_config_dict.get(INPUT_OUTPUT_OBSERVED_DICT_KEY, None)) |
| | conf._set_overwrite_output_fake_quantize(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY, None)) |
| | conf._set_overwrite_output_observer(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_OBSERVER_DICT_KEY, None)) |
| | return conf |
| |
|
| | def to_dict(self) -> Dict[str, Any]: |
| | """ |
| | Convert this ``BackendPatternConfig`` to a dictionary with the items described in |
| | :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`. |
| | """ |
| | backend_pattern_config_dict: Dict[str, Any] = { |
| | PATTERN_DICT_KEY: self.pattern, |
| | OBSERVATION_TYPE_DICT_KEY: self.observation_type, |
| | DTYPE_CONFIGS_DICT_KEY: [c.to_dict() for c in self.dtype_configs], |
| | } |
| | if self.root_module is not None: |
| | backend_pattern_config_dict[ROOT_MODULE_DICT_KEY] = self.root_module |
| | if self.qat_module is not None: |
| | backend_pattern_config_dict[QAT_MODULE_DICT_KEY] = self.qat_module |
| | if self.reference_quantized_module is not None: |
| | backend_pattern_config_dict[REFERENCE_QUANTIZED_MODULE_DICT_KEY] = self.reference_quantized_module |
| | if self.fused_module is not None: |
| | backend_pattern_config_dict[FUSED_MODULE_DICT_KEY] = self.fused_module |
| | if self.fuser_method is not None: |
| | backend_pattern_config_dict[FUSER_METHOD_DICT_KEY] = self.fuser_method |
| | if self._root_node_getter is not None: |
| | backend_pattern_config_dict[ROOT_NODE_GETTER_DICT_KEY] = self._root_node_getter |
| | if self._extra_inputs_getter is not None: |
| | backend_pattern_config_dict[EXTRA_INPUTS_GETTER_DICT_KEY] = self._extra_inputs_getter |
| | if len(self._num_tensor_args_to_observation_type) > 0: |
| | backend_pattern_config_dict[NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY] = self._num_tensor_args_to_observation_type |
| | if len(self._input_type_to_index) > 0: |
| | backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = self._input_type_to_index |
| | if self._input_output_observed is not None: |
| | backend_pattern_config_dict[INPUT_OUTPUT_OBSERVED_DICT_KEY] = self._input_output_observed |
| | if self._overwrite_output_fake_quantize is not None: |
| | backend_pattern_config_dict[OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY] = self._overwrite_output_fake_quantize |
| | if self._overwrite_output_observer is not None: |
| | backend_pattern_config_dict[OVERWRITE_OUTPUT_OBSERVER_DICT_KEY] = self._overwrite_output_observer |
| | return backend_pattern_config_dict |
| |
|