|
|
from typing import Any, Callable, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from torch.ao.quantization import FakeQuantizeBase, ObserverBase
|
|
|
from torch.ao.quantization.backend_config import get_native_backend_config
|
|
|
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
|
|
|
from torch.ao.quantization.utils import getattr_from_fqn
|
|
|
from torch.fx import GraphModule
|
|
|
from torch.fx.graph import Node
|
|
|
|
|
|
from .ns_types import NSNodeTargetType
|
|
|
|
|
|
|
|
|
toq = torch.ops.quantized
|
|
|
|
|
|
|
|
|
def get_type_a_related_to_b(
|
|
|
base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]],
|
|
|
) -> set[tuple[NSNodeTargetType, NSNodeTargetType]]:
|
|
|
|
|
|
|
|
|
|
|
|
type_a_related_to_b: set[tuple[NSNodeTargetType, NSNodeTargetType]] = set()
|
|
|
|
|
|
for s in base_name_to_sets_of_related_ops.values():
|
|
|
s_list = list(s)
|
|
|
|
|
|
for idx_0 in range(0, len(s_list)):
|
|
|
for idx_1 in range(idx_0, len(s_list)):
|
|
|
type_a_related_to_b.add((s_list[idx_0], s_list[idx_1]))
|
|
|
type_a_related_to_b.add((s_list[idx_1], s_list[idx_0]))
|
|
|
|
|
|
return type_a_related_to_b
|
|
|
|
|
|
|
|
|
NSFusionElType = Union[
|
|
|
Callable,
|
|
|
str,
|
|
|
tuple[
|
|
|
str, Any
|
|
|
],
|
|
|
]
|
|
|
NSFusionType = Union[
|
|
|
tuple[NSFusionElType, NSFusionElType],
|
|
|
tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType],
|
|
|
]
|
|
|
|
|
|
|
|
|
def get_reversed_fusions() -> list[tuple[NSFusionType, int]]:
|
|
|
"""
|
|
|
Set of potential fusions, in reverse order. The order is reversed
|
|
|
to match how fusion patterns are defined in quantization code.
|
|
|
|
|
|
Fusion format:
|
|
|
((fusion_op_0, fusion_op_1), base_op_idx)
|
|
|
|
|
|
Where base_op_idx is the idx of the op we should use to match other related
|
|
|
ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx
|
|
|
of 0 represents the first op in regular (non-reverse) order, 1 represents the
|
|
|
second op, etc.
|
|
|
"""
|
|
|
results: list[tuple[NSFusionType, int]] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config())
|
|
|
|
|
|
default_base_op_idx = 0
|
|
|
for quant_pattern in all_quant_patterns.keys():
|
|
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
isinstance(quant_pattern, tuple)
|
|
|
and len(quant_pattern) == 2
|
|
|
and isinstance(quant_pattern[1], tuple)
|
|
|
and len(quant_pattern[1]) == 2
|
|
|
):
|
|
|
|
|
|
quant_pattern = (quant_pattern[0], quant_pattern[1][0], quant_pattern[1][1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(quant_pattern, tuple):
|
|
|
results.append((quant_pattern, default_base_op_idx))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for cls in (ObserverBase, FakeQuantizeBase):
|
|
|
if isinstance(quant_pattern, tuple):
|
|
|
new_pattern = (cls, *quant_pattern)
|
|
|
else:
|
|
|
new_pattern = (cls, quant_pattern)
|
|
|
results.append((new_pattern, default_base_op_idx))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fp16_em_base_op_idx = 1
|
|
|
patterns_to_add = [
|
|
|
|
|
|
|
|
|
(
|
|
|
(("to", torch.float16), F.relu, F.linear, "dequantize"),
|
|
|
fp16_em_base_op_idx,
|
|
|
),
|
|
|
|
|
|
|
|
|
((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
|
|
|
((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
|
|
|
((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
|
|
|
((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
|
|
|
((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
|
|
|
((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
|
|
|
]
|
|
|
for p in patterns_to_add:
|
|
|
results.append(p)
|
|
|
results.append(((ObserverBase, *p[0]), p[1]))
|
|
|
results.append(((FakeQuantizeBase, *p[0]), p[1]))
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
def end_node_matches_reversed_fusion(
|
|
|
end_node: Node,
|
|
|
reversed_fusion: NSFusionType,
|
|
|
gm: GraphModule,
|
|
|
seen_nodes: set[Node],
|
|
|
) -> bool:
|
|
|
"""
|
|
|
Returns true if a pattern ending with `end_node` matches
|
|
|
the fusion pattern.
|
|
|
"""
|
|
|
cur_node = end_node
|
|
|
for fusion_idx in range(len(reversed_fusion)):
|
|
|
|
|
|
if cur_node in seen_nodes:
|
|
|
return False
|
|
|
|
|
|
cur_fusion_el = reversed_fusion[fusion_idx]
|
|
|
|
|
|
if cur_node.op == "call_function":
|
|
|
fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and (
|
|
|
not isinstance(cur_fusion_el, type)
|
|
|
)
|
|
|
if fusion_el_is_fun:
|
|
|
if cur_node.target != cur_fusion_el:
|
|
|
return False
|
|
|
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
|
|
|
cur_node = cur_node.args[0]
|
|
|
else:
|
|
|
return False
|
|
|
else:
|
|
|
return False
|
|
|
|
|
|
elif cur_node.op == "call_module":
|
|
|
fusion_el_is_mod = isinstance(cur_fusion_el, type)
|
|
|
if fusion_el_is_mod:
|
|
|
assert isinstance(cur_node.target, str)
|
|
|
target_mod = getattr_from_fqn(gm, cur_node.target)
|
|
|
if not isinstance(cur_fusion_el, type):
|
|
|
return False
|
|
|
if not isinstance(target_mod, cur_fusion_el):
|
|
|
return False
|
|
|
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
|
|
|
cur_node = cur_node.args[0]
|
|
|
else:
|
|
|
return False
|
|
|
else:
|
|
|
return False
|
|
|
|
|
|
elif cur_node.op == "call_method":
|
|
|
fusion_el_is_meth_with_second_arg = (
|
|
|
isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2
|
|
|
)
|
|
|
fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str)
|
|
|
if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg:
|
|
|
if fusion_el_is_meth_without_args:
|
|
|
if cur_node.target != cur_fusion_el:
|
|
|
return False
|
|
|
else:
|
|
|
assert isinstance(cur_fusion_el, tuple)
|
|
|
if cur_node.target != cur_fusion_el[0]:
|
|
|
return False
|
|
|
elif len(cur_node.args) < 2:
|
|
|
return False
|
|
|
elif cur_node.args[1] != cur_fusion_el[1]:
|
|
|
return False
|
|
|
|
|
|
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
|
|
|
cur_node = cur_node.args[0]
|
|
|
else:
|
|
|
return False
|
|
|
else:
|
|
|
return False
|
|
|
else:
|
|
|
return False
|
|
|
|
|
|
return True
|
|
|
|