|
|
import sys |
|
|
import torch |
|
|
from torch.fx.graph import ( |
|
|
Graph, |
|
|
Node, |
|
|
) |
|
|
from torch.ao.quantization.quantization_types import Pattern |
|
|
from .quantization_patterns import ( |
|
|
QuantizeHandler, |
|
|
) |
|
|
from ..qconfig import ( |
|
|
QConfigAny, |
|
|
) |
|
|
from ..utils import ( |
|
|
MatchAllNode |
|
|
) |
|
|
from .graph_module import ( |
|
|
is_observed_standalone_module, |
|
|
) |
|
|
from torch.nn.utils.parametrize import type_before_parametrizations |
|
|
from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set |
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"is_match", |
|
|
"find_matches", |
|
|
] |
|
|
|
|
|
|
|
|
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler] |
|
|
|
|
|
_MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler, |
|
|
QConfigAny] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_match(modules, node, pattern, max_uses=sys.maxsize): |
|
|
""" Matches a node in fx against a pattern |
|
|
""" |
|
|
if isinstance(pattern, tuple): |
|
|
self_match, *arg_matches = pattern |
|
|
if self_match is getattr: |
|
|
assert len(pattern) == 2, 'Expecting getattr pattern to have two elements' |
|
|
arg_matches = [] |
|
|
else: |
|
|
self_match = pattern |
|
|
arg_matches = [] |
|
|
|
|
|
if isinstance(self_match, type) and issubclass(self_match, MatchAllNode): |
|
|
return True |
|
|
|
|
|
if not isinstance(node, Node) or len(node.users) > max_uses: |
|
|
return False |
|
|
|
|
|
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): |
|
|
if node.op != 'call_module': |
|
|
return False |
|
|
if not type_before_parametrizations(modules[node.target]) == self_match: |
|
|
return False |
|
|
elif callable(self_match): |
|
|
if node.op != 'call_function' or node.target is not self_match: |
|
|
return False |
|
|
elif node.target is getattr: |
|
|
if node.args[1] != pattern[1]: |
|
|
return False |
|
|
elif isinstance(self_match, str): |
|
|
if node.op != 'call_method' or node.target != self_match: |
|
|
return False |
|
|
elif node.target != self_match: |
|
|
return False |
|
|
|
|
|
if not arg_matches: |
|
|
return True |
|
|
|
|
|
if len(arg_matches) != len(node.args): |
|
|
return False |
|
|
|
|
|
return all(is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches)) |
|
|
|
|
|
def find_matches( |
|
|
graph: Graph, |
|
|
modules: Dict[str, torch.nn.Module], |
|
|
patterns: Dict[Pattern, QuantizeHandler], |
|
|
root_node_getter_mapping: Dict[Pattern, Callable], |
|
|
standalone_module_names: List[str] = None, |
|
|
standalone_module_classes: List[Type] = None, |
|
|
custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]: |
|
|
""" |
|
|
Matches the nodes in the input graph to quantization patterns, and |
|
|
outputs the information needed to quantize them in future steps. |
|
|
|
|
|
Inputs: |
|
|
- graph: an fx.Graph object |
|
|
- modules: a mapping of fully qualified module name to instance, |
|
|
for example, {'foo': ModuleFoo, ...} |
|
|
- patterns: a mapping from a tuple of nodes in reverse order to |
|
|
uninitialized QuantizeHandler subclass. |
|
|
|
|
|
Outputs a map of |
|
|
node_name -> |
|
|
(node, matched_values, matched_pattern, QuantizeHandler instance, |
|
|
qconfig) |
|
|
|
|
|
For example, { |
|
|
'relu_1': (relu_1, [relu_1], torch.nn.functional.relu, |
|
|
<CopyNodeQuantizeHandler instance>, QConfig(...)), |
|
|
... |
|
|
} |
|
|
""" |
|
|
if custom_module_classes is None: |
|
|
custom_module_classes = [] |
|
|
|
|
|
if standalone_module_classes is None: |
|
|
standalone_module_classes = [] |
|
|
|
|
|
if standalone_module_names is None: |
|
|
standalone_module_names = [] |
|
|
|
|
|
match_map: Dict[str, MatchResult] = {} |
|
|
all_matched : Set[str] = set() |
|
|
|
|
|
def _recursive_record_node_in_match_map( |
|
|
last_node, |
|
|
match_map, |
|
|
node_pattern, |
|
|
matched_node_pattern, |
|
|
pattern, |
|
|
match_value): |
|
|
if isinstance(node_pattern, Node): |
|
|
match_map[node_pattern.name] = ( |
|
|
last_node, matched_node_pattern, pattern, match_value) |
|
|
else: |
|
|
for n in node_pattern: |
|
|
_recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value) |
|
|
|
|
|
|
|
|
def record_match( |
|
|
pattern, |
|
|
node, |
|
|
last_node, |
|
|
matched_node_pattern, |
|
|
match_map): |
|
|
if isinstance(pattern, tuple): |
|
|
s, *args = pattern |
|
|
current_node_pattern: List[Node] = [] |
|
|
record_match( |
|
|
s, |
|
|
node, |
|
|
last_node, |
|
|
matched_node_pattern, |
|
|
match_map) |
|
|
if pattern[0] is not getattr: |
|
|
for subpattern, arg in zip(args, node.args): |
|
|
record_match( |
|
|
subpattern, |
|
|
arg, |
|
|
node, |
|
|
current_node_pattern, |
|
|
match_map) |
|
|
if len(current_node_pattern) > 1: |
|
|
matched_node_pattern.append(tuple(current_node_pattern)) |
|
|
else: |
|
|
matched_node_pattern.append(current_node_pattern[0]) |
|
|
else: |
|
|
matched_node_pattern.append(node) |
|
|
|
|
|
for node in reversed(graph.nodes): |
|
|
if node.name not in match_map and node.name not in all_matched: |
|
|
for pattern, quantize_handler_cls in patterns.items(): |
|
|
root_node_getter = root_node_getter_mapping.get(pattern, None) |
|
|
if is_match(modules, node, pattern) and node.name not in match_map: |
|
|
matched_node_pattern: List[Node] = [] |
|
|
record_match( |
|
|
pattern, |
|
|
node, |
|
|
node, |
|
|
matched_node_pattern, |
|
|
match_map) |
|
|
quantize_handler = quantize_handler_cls( |
|
|
matched_node_pattern, |
|
|
modules, |
|
|
root_node_getter) |
|
|
last_node = node |
|
|
|
|
|
_recursive_record_node_in_match_map( |
|
|
last_node, |
|
|
match_map, |
|
|
|
|
|
matched_node_pattern, |
|
|
|
|
|
matched_node_pattern, |
|
|
pattern, |
|
|
quantize_handler) |
|
|
break |
|
|
|
|
|
|
|
|
assert modules is not None |
|
|
for node in graph.nodes: |
|
|
if node.op == 'call_module' and \ |
|
|
type(modules[node.target]) in custom_module_classes: |
|
|
match_map[node.name] = ( |
|
|
node, node, None, QuantizeHandler(node, modules, is_custom_module=True)) |
|
|
|
|
|
def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]): |
|
|
assert modules is not None |
|
|
return ( |
|
|
node_target in standalone_module_names or |
|
|
type(modules[node_target]) in standalone_module_classes |
|
|
) |
|
|
|
|
|
|
|
|
for node in graph.nodes: |
|
|
if node.op == 'call_module' and \ |
|
|
(is_standalone_module(node.target, modules) or |
|
|
is_observed_standalone_module(modules[node.target])): |
|
|
|
|
|
match_map[node.name] = ( |
|
|
node, node, None, |
|
|
QuantizeHandler(node, modules, is_standalone_module=True)) |
|
|
|
|
|
return match_map |
|
|
|