| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | toq = torch.ops.quantized |
| |
|
| | from torch.fx import GraphModule |
| | from torch.fx.graph import Node |
| |
|
| | from torch.ao.quantization.utils import getattr_from_fqn |
| | from .ns_types import NSNodeTargetType |
| | from torch.ao.quantization.fx.backend_config_utils import get_native_quant_patterns |
| | from torch.ao.quantization import ( |
| | ObserverBase, |
| | FakeQuantizeBase, |
| | ) |
| |
|
| | from typing import Dict, Tuple, Set, Callable, Any, Union, List |
| |
|
| |
|
| | def get_type_a_related_to_b( |
| | base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]], |
| | ) -> Set[Tuple[NSNodeTargetType, NSNodeTargetType]]: |
| | |
| | |
| | |
| | type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]] = set() |
| |
|
| | for base_name, s in base_name_to_sets_of_related_ops.items(): |
| | s_list = list(s) |
| | |
| | for idx_0 in range(0, len(s_list)): |
| | for idx_1 in range(idx_0, len(s_list)): |
| | type_a_related_to_b.add((s_list[idx_0], s_list[idx_1])) |
| | type_a_related_to_b.add((s_list[idx_1], s_list[idx_0])) |
| |
|
| | return type_a_related_to_b |
| |
|
| |
|
| | NSFusionElType = Union[ |
| | Callable, |
| | str, |
| | Tuple[str, Any], |
| | ] |
| | NSFusionType = Union[ |
| | Tuple[NSFusionElType, NSFusionElType], |
| | Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType], |
| | ] |
| |
|
| | def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]: |
| | """ |
| | Set of potential fusions, in reverse order. The order is reversed |
| | to match how fusion patterns are defined in quantization code. |
| | |
| | Fusion format: |
| | ((fusion_op_0, fusion_op_1), base_op_idx) |
| | |
| | Where base_op_idx is the idx of the op we should use to match other related |
| | ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx |
| | of 0 represents the first op in regular (non-reverse) order, 1 represents the |
| | second op, etc. |
| | """ |
| | results: List[Tuple[NSFusionType, int]] = [] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | all_quant_patterns = get_native_quant_patterns() |
| |
|
| | default_base_op_idx = 0 |
| | for quant_pattern, _quant_handler in all_quant_patterns.items(): |
| | |
| | |
| | |
| | if isinstance(quant_pattern, tuple) and len(quant_pattern) == 2 and \ |
| | isinstance(quant_pattern[1], tuple) and len(quant_pattern[1]) == 2: |
| | |
| | quant_pattern = (quant_pattern[0], quant_pattern[1][0], quant_pattern[1][1]) |
| |
|
| | |
| | |
| | |
| | if isinstance(quant_pattern, tuple): |
| | results.append((quant_pattern, default_base_op_idx)) |
| |
|
| | |
| | |
| | |
| | |
| | for cls in (ObserverBase, FakeQuantizeBase): |
| | if isinstance(quant_pattern, tuple): |
| | new_pattern = (cls, *quant_pattern) |
| | else: |
| | new_pattern = (cls, quant_pattern) |
| | results.append((new_pattern, default_base_op_idx)) |
| |
|
| |
|
| | |
| | |
| |
|
| | |
| | |
| | fp16_em_base_op_idx = 1 |
| | patterns_to_add = [ |
| | |
| | |
| | ((("to", torch.float16), F.relu, F.linear, "dequantize"), fp16_em_base_op_idx,), |
| | |
| | |
| | ((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx), |
| | ((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx), |
| | ((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx), |
| | ((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx), |
| | ((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx), |
| | ((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx), |
| | ] |
| | for p in patterns_to_add: |
| | results.append(p) |
| | results.append(((ObserverBase, *p[0]), p[1])) |
| | results.append(((FakeQuantizeBase, *p[0]), p[1])) |
| |
|
| | return results |
| |
|
| |
|
| | def end_node_matches_reversed_fusion( |
| | end_node: Node, |
| | reversed_fusion: NSFusionType, |
| | gm: GraphModule, |
| | seen_nodes: Set[Node], |
| | ) -> bool: |
| | """ |
| | Returns true if a pattern ending with `end_node` matches |
| | the fusion pattern. |
| | """ |
| | cur_node = end_node |
| | for fusion_idx in range(len(reversed_fusion)): |
| | |
| | if cur_node in seen_nodes: |
| | return False |
| |
|
| | cur_fusion_el = reversed_fusion[fusion_idx] |
| |
|
| | if cur_node.op == 'call_function': |
| | fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and \ |
| | (not isinstance(cur_fusion_el, type)) |
| | if fusion_el_is_fun: |
| | if cur_node.target != cur_fusion_el: |
| | return False |
| | if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): |
| | cur_node = cur_node.args[0] |
| | else: |
| | return False |
| | else: |
| | return False |
| |
|
| | elif cur_node.op == 'call_module': |
| | fusion_el_is_mod = isinstance(cur_fusion_el, type) |
| | if fusion_el_is_mod: |
| | assert isinstance(cur_node.target, str) |
| | target_mod = getattr_from_fqn(gm, cur_node.target) |
| | if not isinstance(cur_fusion_el, type): |
| | return False |
| | if not isinstance(target_mod, cur_fusion_el): |
| | return False |
| | if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): |
| | cur_node = cur_node.args[0] |
| | else: |
| | return False |
| | else: |
| | return False |
| |
|
| | elif cur_node.op == 'call_method': |
| | fusion_el_is_meth_with_second_arg = \ |
| | isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2 |
| | fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str) |
| | if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg: |
| | if fusion_el_is_meth_without_args: |
| | if cur_node.target != cur_fusion_el: |
| | return False |
| | else: |
| | assert isinstance(cur_fusion_el, tuple) |
| | if cur_node.target != cur_fusion_el[0]: |
| | return False |
| | elif len(cur_node.args) < 2: |
| | return False |
| | elif cur_node.args[1] != cur_fusion_el[1]: |
| | return False |
| |
|
| | if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): |
| | cur_node = cur_node.args[0] |
| | else: |
| | return False |
| | else: |
| | return False |
| | else: |
| | return False |
| |
|
| | return True |
| |
|