|
|
import copy |
|
|
import re |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.ao.quantization import ( |
|
|
QConfigAny, |
|
|
QuantType, |
|
|
) |
|
|
from torch.ao.quantization.backend_config import ( |
|
|
BackendConfig, |
|
|
DTypeWithConstraints, |
|
|
) |
|
|
from torch.ao.quantization.fake_quantize import FakeQuantizeBase |
|
|
from torch.ao.quantization.observer import ObserverBase |
|
|
from torch.ao.quantization.stubs import DeQuantStub |
|
|
from torch.ao.quantization.utils import ( |
|
|
activation_is_statically_quantized, |
|
|
is_per_tensor, |
|
|
is_per_channel, |
|
|
) |
|
|
from torch.ao.quantization.quantize import is_activation_post_process |
|
|
|
|
|
from torch.fx import GraphModule, map_arg |
|
|
|
|
|
from torch.fx.graph import ( |
|
|
Graph, |
|
|
Node, |
|
|
) |
|
|
from .custom_config import PrepareCustomConfig |
|
|
|
|
|
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type |
|
|
from collections import namedtuple |
|
|
import operator |
|
|
import warnings |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"all_node_args_except_first", |
|
|
"all_node_args_have_no_tensors", |
|
|
"assert_and_get_unique_device", |
|
|
"collect_producer_nodes", |
|
|
"create_getattr_from_value", |
|
|
"create_node_from_old_node_preserve_meta", |
|
|
"create_qparam_nodes", |
|
|
"EMPTY_ARG_DICT", |
|
|
"get_custom_module_class_keys", |
|
|
"get_linear_prepack_op_for_dtype", |
|
|
"get_new_attr_name_with_prefix", |
|
|
"get_non_observable_arg_indexes_and_types", |
|
|
"get_per_tensor_qparams", |
|
|
"get_qconv_op", |
|
|
"get_qconv_prepack_op", |
|
|
"get_quantize_node_info", |
|
|
"get_skipped_module_name_and_classes", |
|
|
"graph_module_from_producer_nodes", |
|
|
"graph_pretty_str", |
|
|
"is_get_tensor_info_node", |
|
|
"maybe_get_next_module", |
|
|
"NodeInfo", |
|
|
"node_return_type_is_int", |
|
|
"node_arg_is_bias", |
|
|
"node_arg_is_weight", |
|
|
"NON_OBSERVABLE_ARG_DICT", |
|
|
"NON_QUANTIZABLE_WEIGHT_OPS", |
|
|
"quantize_node", |
|
|
"return_arg_list", |
|
|
] |
|
|
|
|
|
NON_QUANTIZABLE_WEIGHT_OPS = {torch.nn.functional.layer_norm, torch.nn.functional.group_norm, torch.nn.functional.instance_norm} |
|
|
|
|
|
def node_arg_is_weight(node: Node, arg: Any, backend_config: BackendConfig) -> bool: |
|
|
"""Returns if node arg is weight""" |
|
|
if isinstance(node, Node) and node.op == "call_function" and node.target in backend_config.configs: |
|
|
weight_index = backend_config.configs[node.target]._input_type_to_index.get("weight") |
|
|
if weight_index is not None and weight_index < len(node.args) and node.args[weight_index] is arg: |
|
|
return True |
|
|
return node.kwargs.get("weight") is arg |
|
|
return False |
|
|
|
|
|
def node_arg_is_bias(node: Node, arg: Any, backend_config: BackendConfig) -> bool: |
|
|
"""Returns if node arg is bias""" |
|
|
if isinstance(node, Node) and node.op == "call_function" and node.target in backend_config.configs: |
|
|
bias_index = backend_config.configs[node.target]._input_type_to_index.get("bias") |
|
|
if bias_index is not None and bias_index < len(node.args) and node.args[bias_index] is arg: |
|
|
return True |
|
|
return node.kwargs.get("bias") is arg |
|
|
return False |
|
|
|
|
|
def graph_pretty_str(g, shorten=True) -> str: |
|
|
"""Returns a printable representation of the ops in the graph of g. |
|
|
If shorten is True, tries to abbreviate fields. |
|
|
""" |
|
|
built_in_func_re = re.compile('<built-in function (.*)>') |
|
|
built_in_meth_re = re.compile('<built-in method (.*) of type.*>') |
|
|
op_dict = { |
|
|
'placeholder': 'plchdr', |
|
|
'get_attr': 'gt_prm', |
|
|
'call_function': 'cl_fun', |
|
|
'call_module': 'cl_mod', |
|
|
'call_method': 'cl_meth', |
|
|
} |
|
|
|
|
|
max_lens = {} |
|
|
col_names = ("name", "op", "target", "args", "kwargs") |
|
|
for s in col_names: |
|
|
max_lens[s] = len(s) |
|
|
|
|
|
results = [] |
|
|
for n in g.nodes: |
|
|
|
|
|
|
|
|
name = str(n.name) |
|
|
if shorten: |
|
|
name = name.replace("activation_post_process", "obs") |
|
|
|
|
|
op = str(n.op) |
|
|
|
|
|
if shorten and op in op_dict: |
|
|
op = op_dict[op] |
|
|
|
|
|
target = str(n.target) |
|
|
|
|
|
if shorten: |
|
|
built_in_func = built_in_func_re.search(target) |
|
|
if built_in_func: |
|
|
target = f"<bi_fun {built_in_func.group(1)}>" |
|
|
built_in_meth = built_in_meth_re.search(target) |
|
|
if built_in_meth: |
|
|
target = f"<bi_meth {built_in_meth.group(1)}>" |
|
|
target = target.replace("activation_post_process", "obs") |
|
|
|
|
|
args = str(n.args) |
|
|
if shorten: |
|
|
args = args.replace("activation_post_process", "obs") |
|
|
|
|
|
kwargs = str(n.kwargs) |
|
|
|
|
|
|
|
|
for k, v in zip(col_names, (name, op, target, args, kwargs)): |
|
|
max_lens[k] = max(max_lens[k], len(v)) |
|
|
results.append([name, op, target, args, kwargs]) |
|
|
|
|
|
res_str = "" |
|
|
format_str = "{:<{name}} {:<{op}} {:<{target}} {:<{args}} {:<{kwargs}}\n" |
|
|
res_str += format_str.format(*col_names, **max_lens) |
|
|
for result in results: |
|
|
res_str += format_str.format(*result, **max_lens) |
|
|
|
|
|
|
|
|
|
|
|
if shorten: |
|
|
res_str += "*obs_{n} = activation_post_process_{n}\n" |
|
|
return res_str |
|
|
|
|
|
def get_per_tensor_qparams(activation_post_process): |
|
|
assert is_per_tensor(activation_post_process.qscheme), 'Only per tensor quantization is supported' |
|
|
scale, zero_point = activation_post_process.calculate_qparams() |
|
|
scale = float(scale) |
|
|
zero_point = int(zero_point) |
|
|
dtype = activation_post_process.dtype |
|
|
return scale, zero_point, dtype |
|
|
|
|
|
def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[str, Union[Callable, str], Dict[str, Any]]]: |
|
|
''' Given an activation_post_process module, |
|
|
return node_type(e.g. call_function), quantize op(e.g. quantize_per_tensor) and a dictionary |
|
|
of extracted qparams from the module |
|
|
''' |
|
|
dtype = activation_post_process.dtype |
|
|
compute_dtype = None |
|
|
if hasattr(activation_post_process, "compute_dtype"): |
|
|
compute_dtype = activation_post_process.compute_dtype |
|
|
quantize_op : Optional[Union[Callable, str]] = None |
|
|
if dtype in [torch.quint8, torch.qint8] and \ |
|
|
not hasattr(activation_post_process, 'compute_dtype'): |
|
|
node_type = "call_function" |
|
|
scale, zero_point = activation_post_process.calculate_qparams() |
|
|
if is_per_channel(activation_post_process.qscheme): |
|
|
ch_axis = int(activation_post_process.ch_axis) |
|
|
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype} |
|
|
quantize_op = torch.quantize_per_channel |
|
|
else: |
|
|
scale = float(scale) |
|
|
zero_point = int(zero_point) |
|
|
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype} |
|
|
quantize_op = torch.quantize_per_tensor |
|
|
elif compute_dtype in [torch.quint8, torch.qint8, torch.float16]: |
|
|
|
|
|
|
|
|
node_type = "call_function" |
|
|
quantize_op = torch.quantize_per_tensor_dynamic |
|
|
|
|
|
|
|
|
reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86") |
|
|
qparams = {"_dtype_": compute_dtype, "_reduce_range_": reduce_range} |
|
|
elif dtype == torch.float16: |
|
|
node_type = "call_method" |
|
|
quantize_op = "to" |
|
|
qparams = {"_dtype_": dtype} |
|
|
else: |
|
|
warnings.warn(f"Unsupported activation_post_process in get_quantize_node_info: {activation_post_process}") |
|
|
return None |
|
|
return node_type, quantize_op, qparams |
|
|
|
|
|
def quantize_node( |
|
|
in_node: Node, |
|
|
obs_module: torch.nn.Module, |
|
|
obs_node: Node, |
|
|
modules: Dict[str, torch.nn.Module], |
|
|
quantized_graph: Graph, |
|
|
node_name_to_scope: Dict[str, Tuple[str, type]], |
|
|
is_input: bool, |
|
|
output_prefix: str = "_output") -> Node: |
|
|
''' Add quantization nodes (eg. quantize_per_tensor/per_channel) for given node to graph |
|
|
with the qparams calculated from activation_post_process (obs_module). |
|
|
The observer node (obs_node) is used to find the FQN of the user of act_post_process. |
|
|
e.g. Given input `node` in `node = self.conv(x)`, insert node: |
|
|
`quantized_node = torch.quantize_per_tensor(x, self._scale_0, self._zer_point_0, self._dtype_0)` |
|
|
where self._scale_0, self._zero_point_0 and self._dtype_0 are |
|
|
calculated from `obs_module` |
|
|
''' |
|
|
|
|
|
if is_input: |
|
|
|
|
|
|
|
|
|
|
|
users = list(obs_node.users) |
|
|
first_linear_use_or_first_use = users[0] if users else None |
|
|
linear_node = None |
|
|
for n in users: |
|
|
if n.op == "call_function" and n.target == torch.nn.functional.linear: |
|
|
linear_node = n |
|
|
break |
|
|
if linear_node: |
|
|
first_linear_use_or_first_use = linear_node |
|
|
prefix = "_input" |
|
|
else: |
|
|
|
|
|
first_linear_use_or_first_use = in_node |
|
|
prefix = output_prefix |
|
|
|
|
|
if first_linear_use_or_first_use and first_linear_use_or_first_use.name in node_name_to_scope: |
|
|
module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name] |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
module_path = "" |
|
|
root_module = modules[''] |
|
|
graph = quantized_graph |
|
|
maybe_quantize_node_info = get_quantize_node_info(obs_module) |
|
|
assert maybe_quantize_node_info is not None, \ |
|
|
f"Expecting quantize node info not to be None, observer: {obs_module}" |
|
|
node_type, quantize_op, qparams = maybe_quantize_node_info |
|
|
inputs = [in_node] |
|
|
|
|
|
for key, value in qparams.items(): |
|
|
if key in ['_scale_', '_zero_point_']: |
|
|
|
|
|
qparam_node = create_getattr_from_value(root_module, graph, module_path + prefix + key, value) |
|
|
inputs.append(qparam_node) |
|
|
else: |
|
|
|
|
|
inputs.append(value) |
|
|
return graph.create_node(node_type, quantize_op, tuple(inputs), {}) |
|
|
|
|
|
def get_custom_module_class_keys(custom_module_mapping: Dict[QuantType, Dict[Type, Type]]) -> List[Any]: |
|
|
r""" Get all the unique custom module keys in the custom config dict |
|
|
e.g. |
|
|
Input: |
|
|
{ |
|
|
QuantType.STATIC: { |
|
|
CustomModule1: ObservedCustomModule |
|
|
}, |
|
|
QuantType.DYNAMIC: { |
|
|
CustomModule2: DynamicObservedCustomModule |
|
|
}, |
|
|
QuantType.WEIGHT_ONLY: { |
|
|
CustomModule3: WeightOnlyObservedCustomModule |
|
|
}, |
|
|
} |
|
|
|
|
|
Output: |
|
|
# extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts |
|
|
[CustomModule1, CustomModule2, CustomModule3] |
|
|
""" |
|
|
|
|
|
float_custom_module_classes : Set[Any] = set() |
|
|
for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]: |
|
|
quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {}) |
|
|
quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys()) |
|
|
float_custom_module_classes |= quant_mode_custom_module_classes |
|
|
return list(float_custom_module_classes) |
|
|
|
|
|
def get_linear_prepack_op_for_dtype(dtype): |
|
|
if dtype == torch.float16: |
|
|
return torch.ops.quantized.linear_prepack_fp16 |
|
|
elif dtype == torch.qint8: |
|
|
return torch.ops.quantized.linear_prepack |
|
|
else: |
|
|
raise Exception("can't get linear prepack op for dtype:", dtype) |
|
|
|
|
|
def get_qconv_prepack_op(conv_op: Callable) -> Callable: |
|
|
prepack_ops = { |
|
|
torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack, |
|
|
torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack, |
|
|
torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack |
|
|
} |
|
|
prepack_op = prepack_ops.get(conv_op, None) |
|
|
assert prepack_op, "Didn't find prepack op for {}".format(conv_op) |
|
|
return prepack_op |
|
|
|
|
|
def get_qconv_op(conv_op: Callable, has_relu: bool) -> Callable: |
|
|
qconv_op = { |
|
|
|
|
|
True: { |
|
|
torch.nn.functional.conv1d: torch.ops.quantized.conv1d_relu, |
|
|
torch.nn.functional.conv2d: torch.ops.quantized.conv2d_relu, |
|
|
torch.nn.functional.conv3d: torch.ops.quantized.conv3d_relu |
|
|
}, |
|
|
False: { |
|
|
torch.nn.functional.conv1d: torch.ops.quantized.conv1d, |
|
|
torch.nn.functional.conv2d: torch.ops.quantized.conv2d, |
|
|
torch.nn.functional.conv3d: torch.ops.quantized.conv3d |
|
|
} |
|
|
} |
|
|
qconv = qconv_op[has_relu].get(conv_op) |
|
|
assert qconv, "Can't find corresponding quantized conv op for {} {}".format(conv_op, has_relu) |
|
|
return qconv |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_new_attr_name_with_prefix(prefix: str) -> Callable: |
|
|
prefix = prefix.replace(".", "_") |
|
|
|
|
|
def get_new_attr_name(module: torch.nn.Module): |
|
|
def get_attr_name(i: int): |
|
|
return prefix + str(i) |
|
|
i = 0 |
|
|
attr_name = get_attr_name(i) |
|
|
while hasattr(module, attr_name): |
|
|
i += 1 |
|
|
attr_name = get_attr_name(i) |
|
|
return attr_name |
|
|
return get_new_attr_name |
|
|
|
|
|
def collect_producer_nodes(node: Node) -> Optional[List[Node]]: |
|
|
r''' Starting from a target node, trace back until we hit inpu or |
|
|
getattr node. This is used to extract the chain of operators |
|
|
starting from getattr to the target node, for example |
|
|
def forward(self, x): |
|
|
observed = self.observer(self.weight) |
|
|
return F.linear(x, observed) |
|
|
collect_producer_nodes(observed) will either return a list of nodes that |
|
|
produces the observed node or None if we can't extract a self contained |
|
|
graph without free variables(inputs of the forward function). |
|
|
''' |
|
|
nodes = [node] |
|
|
frontier = [node] |
|
|
while frontier: |
|
|
node = frontier.pop() |
|
|
all_args = list(node.args) + list(node.kwargs.values()) |
|
|
for arg in all_args: |
|
|
if not isinstance(arg, Node): |
|
|
continue |
|
|
if arg.op == 'placeholder': |
|
|
|
|
|
return None |
|
|
nodes.append(arg) |
|
|
if not (arg.op == 'call_function' and arg.target == getattr): |
|
|
frontier.append(arg) |
|
|
return nodes |
|
|
|
|
|
def graph_module_from_producer_nodes( |
|
|
root: GraphModule, producer_nodes: List[Node]) -> GraphModule: |
|
|
r''' Construct a graph module from extracted producer nodes |
|
|
from `collect_producer_nodes` function |
|
|
Args: |
|
|
root: the root module for the original graph |
|
|
producer_nodes: a list of nodes we use to construct the graph |
|
|
Return: |
|
|
A graph module constructed from the producer nodes |
|
|
''' |
|
|
assert len(producer_nodes) > 0, 'list of producer nodes can not be empty' |
|
|
|
|
|
producer_nodes.reverse() |
|
|
graph = Graph() |
|
|
env: Dict[Any, Any] = {} |
|
|
|
|
|
def load_arg(a): |
|
|
return map_arg(a, lambda node: env[node]) |
|
|
for producer_node in producer_nodes: |
|
|
env[producer_node] = graph.node_copy(producer_node, load_arg) |
|
|
graph.output(load_arg(producer_nodes[-1])) |
|
|
graph_module = GraphModule(root, graph) |
|
|
return graph_module |
|
|
|
|
|
def assert_and_get_unique_device(module: torch.nn.Module) -> Any: |
|
|
""" |
|
|
Returns the unique device for a module, or None if no device is found. |
|
|
Throws an error if multiple devices are detected. |
|
|
""" |
|
|
devices = {p.device for p in module.parameters()} | \ |
|
|
{p.device for p in module.buffers()} |
|
|
assert len(devices) <= 1, ( |
|
|
"prepare only works with cpu or single-device CUDA modules, " |
|
|
"but got devices {}".format(devices) |
|
|
) |
|
|
device = next(iter(devices)) if len(devices) > 0 else None |
|
|
return device |
|
|
|
|
|
def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str, value: Any) -> Node: |
|
|
""" |
|
|
Given a value of any type, creates a getattr node corresponding to the value and |
|
|
registers the value as a buffer to the module. |
|
|
""" |
|
|
get_new_attr_name = get_new_attr_name_with_prefix(prefix) |
|
|
attr_name = get_new_attr_name(module) |
|
|
device = assert_and_get_unique_device(module) |
|
|
new_value = value.clone().detach() if isinstance(value, torch.Tensor) \ |
|
|
else torch.tensor(value, device=device) |
|
|
module.register_buffer(attr_name, new_value) |
|
|
|
|
|
attr_node = graph.create_node("get_attr", attr_name) |
|
|
return attr_node |
|
|
|
|
|
def create_qparam_nodes( |
|
|
node_name: str, |
|
|
scale: Any, |
|
|
zero_point: Any, |
|
|
modules: Dict[str, torch.nn.Module], |
|
|
quantized_graph: Graph, |
|
|
node_name_to_scope: Dict[str, Tuple[str, type]] |
|
|
) -> Tuple[Node, Node]: |
|
|
""" |
|
|
Create getattr nodes in the quantized graph for scale and zero point values. |
|
|
The nodes are registered with the root_module of the model. |
|
|
""" |
|
|
root_module = modules[''] |
|
|
module_path, _ = node_name_to_scope[node_name] |
|
|
scale_node = create_getattr_from_value(root_module, quantized_graph, (module_path + "_scale_"), scale) |
|
|
zero_point_node = create_getattr_from_value(root_module, quantized_graph, (module_path + "_zero_point_"), zero_point) |
|
|
return (scale_node, zero_point_node) |
|
|
|
|
|
|
|
|
def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module], cache: Dict[Node, bool]) -> bool: |
|
|
""" |
|
|
If we know for sure that all of this node's args have no |
|
|
tensors (are primitives), return True. If we either |
|
|
find a tensor or are not sure, return False. Note: this |
|
|
function is not exact. |
|
|
""" |
|
|
if cache and node in cache: |
|
|
return cache[node] |
|
|
|
|
|
result = False |
|
|
if not isinstance(node, Node): |
|
|
result = True |
|
|
elif node.op == 'placeholder': |
|
|
result = False |
|
|
elif node.op == 'call_module': |
|
|
assert isinstance(node.target, str) |
|
|
if is_activation_post_process(modules[node.target]): |
|
|
result = all_node_args_have_no_tensors(node.args[0], modules, cache) |
|
|
elif node.op == 'call_module': |
|
|
result = False |
|
|
elif node.op == 'call_function' and node.target is operator.getitem: |
|
|
result = all_node_args_have_no_tensors(node.args[0], modules, cache) |
|
|
elif node.op == 'get_attr': |
|
|
result = False |
|
|
elif node.target is getattr and node.args[1] in ['ndim', 'shape']: |
|
|
|
|
|
result = True |
|
|
elif node.op == 'call_method' and node.target == 'size': |
|
|
|
|
|
result = True |
|
|
else: |
|
|
found_one_tensor = False |
|
|
for arg in node.args: |
|
|
if isinstance(arg, list): |
|
|
for list_el in arg: |
|
|
if isinstance(list_el, Node): |
|
|
this_list_el_args_have_no_tensors = \ |
|
|
all_node_args_have_no_tensors(list_el, modules, cache) |
|
|
found_one_tensor = found_one_tensor or \ |
|
|
(not this_list_el_args_have_no_tensors) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if found_one_tensor: |
|
|
result = not found_one_tensor |
|
|
if cache: |
|
|
cache[node] = result |
|
|
return result |
|
|
elif isinstance(arg, int): |
|
|
pass |
|
|
else: |
|
|
if isinstance(arg, Node): |
|
|
this_arg_args_have_no_tensors = all_node_args_have_no_tensors(arg, modules, cache) |
|
|
found_one_tensor = found_one_tensor or \ |
|
|
(not this_arg_args_have_no_tensors) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if found_one_tensor: |
|
|
result = not found_one_tensor |
|
|
if cache: |
|
|
cache[node] = result |
|
|
return result |
|
|
else: |
|
|
found_one_tensor = True |
|
|
result = not found_one_tensor |
|
|
if cache: |
|
|
cache[node] = result |
|
|
return result |
|
|
|
|
|
def all_node_args_except_first(node: Node) -> List[int]: |
|
|
""" |
|
|
Returns all node arg indices after first |
|
|
""" |
|
|
return list(range(1, len(node.args))) |
|
|
|
|
|
def return_arg_list(arg_indices: List[int]) -> Callable[[Node], List[int]]: |
|
|
""" |
|
|
Constructs a function that takes a node as arg and returns the arg_indices |
|
|
that are valid for node.args |
|
|
""" |
|
|
def arg_indices_func(node: Node) -> List[int]: |
|
|
return [i for i in arg_indices if i < len(node.args)] |
|
|
return arg_indices_func |
|
|
|
|
|
NodeInfo = namedtuple("NodeInfo", "op target") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NON_OBSERVABLE_ARG_DICT: Dict[NodeInfo, Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]] = { |
|
|
NodeInfo("call_method", "masked_fill") : { |
|
|
torch.bool: return_arg_list([1]), |
|
|
float: return_arg_list([2]) |
|
|
}, |
|
|
NodeInfo("call_method", "permute") : { |
|
|
int: all_node_args_except_first |
|
|
}, |
|
|
NodeInfo("call_method", "repeat") : { |
|
|
int: all_node_args_except_first |
|
|
}, |
|
|
NodeInfo("call_method", "reshape") : { |
|
|
int: all_node_args_except_first |
|
|
}, |
|
|
NodeInfo("call_method", "size") : { |
|
|
int: return_arg_list([1]) |
|
|
}, |
|
|
NodeInfo("call_method", "transpose") : { |
|
|
int: all_node_args_except_first |
|
|
}, |
|
|
NodeInfo("call_method", torch.transpose) : { |
|
|
int: all_node_args_except_first |
|
|
}, |
|
|
NodeInfo("call_method", "unsqueeze") : { |
|
|
int: return_arg_list([1]) |
|
|
}, |
|
|
NodeInfo("call_method", "unsqueeze_") : { |
|
|
int: return_arg_list([1]) |
|
|
}, |
|
|
NodeInfo("call_method", torch.unsqueeze) : { |
|
|
int: return_arg_list([1]) |
|
|
}, |
|
|
NodeInfo("call_method", "view") : { |
|
|
int: all_node_args_except_first |
|
|
}, |
|
|
} |
|
|
|
|
|
EMPTY_ARG_DICT: Dict[Union[type, torch.dtype], Callable[[Node], List[int]]] = {} |
|
|
|
|
|
def get_non_observable_arg_indexes_and_types(node: Node) -> Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]: |
|
|
""" |
|
|
Returns a dict with of non float tensor types as keys and values which correspond to a |
|
|
function to retrieve the list (which takes the node as an argument) |
|
|
""" |
|
|
info = NodeInfo(node.op, node.target) |
|
|
|
|
|
return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT) |
|
|
|
|
|
def node_return_type_is_int(node: Node) -> bool: |
|
|
""" |
|
|
Returns true if this node results in an integer, even if some of the args |
|
|
are Tensors. |
|
|
""" |
|
|
return node.op == 'call_method' and node.target == 'size' |
|
|
|
|
|
|
|
|
def is_get_tensor_info_node(node: Node) -> bool: |
|
|
""" Returns True if this node is a node that takes a Tensor as input and output some |
|
|
meta information about the Tensor, e.g. shape, size etc. |
|
|
""" |
|
|
result: bool = \ |
|
|
node.op == "call_function" and node.target == getattr and node.args[1] == "shape" |
|
|
return result |
|
|
|
|
|
def maybe_get_next_module( |
|
|
node: Node, |
|
|
modules: Dict[str, nn.Module], |
|
|
target_module_type: Optional[Type[nn.Module]] = None, |
|
|
target_functional_type: Any = None, |
|
|
) -> Optional[Node]: |
|
|
""" Gets the next module that matches what is needed in |
|
|
is_target_module_type if it exists |
|
|
|
|
|
Args: |
|
|
node: The node whose users we want to look at |
|
|
target_module_type: Module type that we want to check |
|
|
target_functional_type: Functional type that we want to check |
|
|
""" |
|
|
|
|
|
for user, _ in node.users.items(): |
|
|
if user.op == 'call_module' and target_module_type is not None and \ |
|
|
isinstance(modules[str(user.target)], target_module_type): |
|
|
return user |
|
|
elif (user.op == 'call_function' and target_functional_type is not None and |
|
|
user.target == target_functional_type): |
|
|
return user |
|
|
|
|
|
return None |
|
|
|
|
|
def create_node_from_old_node_preserve_meta( |
|
|
quantized_graph: Graph, |
|
|
create_node_args: Tuple[Any, ...], |
|
|
old_node: Node, |
|
|
) -> Node: |
|
|
""" |
|
|
Creates `new_node` and copies the necessary metadata to it from `old_node`. |
|
|
""" |
|
|
new_node = quantized_graph.create_node(*create_node_args) |
|
|
new_node.stack_trace = old_node.stack_trace |
|
|
return new_node |
|
|
|
|
|
def get_skipped_module_name_and_classes( |
|
|
prepare_custom_config: PrepareCustomConfig, |
|
|
is_standalone_module: bool) -> Tuple[List[str], List[Type[Any]]]: |
|
|
skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names) |
|
|
skipped_module_classes = copy.copy(prepare_custom_config.non_traceable_module_classes) |
|
|
if not is_standalone_module: |
|
|
|
|
|
skipped_module_names += list(prepare_custom_config.standalone_module_names.keys()) |
|
|
skipped_module_classes += list(prepare_custom_config.standalone_module_classes.keys()) |
|
|
skipped_module_classes += get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping) |
|
|
|
|
|
return skipped_module_names, skipped_module_classes |
|
|
|
|
|
def _is_custom_module_lstm( |
|
|
node: Node, |
|
|
named_modules: Dict[str, torch.nn.Module], |
|
|
qconfig: QConfigAny = None, |
|
|
|
|
|
qhandler: Optional[Any] = None, |
|
|
) -> bool: |
|
|
""" |
|
|
Return whether this refers to the custom module LSTM flow. |
|
|
""" |
|
|
mod = _get_module(node, named_modules) |
|
|
if qconfig is not None and qhandler is not None: |
|
|
assert isinstance(qhandler, torch.ao.quantization.fx.quantization_patterns.QuantizeHandler) |
|
|
return isinstance(mod, torch.nn.LSTM) and \ |
|
|
activation_is_statically_quantized(qconfig) and \ |
|
|
qhandler.is_custom_module() |
|
|
else: |
|
|
return isinstance(mod, torch.ao.nn.quantizable.LSTM) |
|
|
|
|
|
def _get_module(node: Node, named_modules: Dict[str, torch.nn.Module]) -> Optional[torch.nn.Module]: |
|
|
""" |
|
|
If `node` refers to a call_module node, return the module, else None. |
|
|
""" |
|
|
if node.op == "call_module" and str(node.target) in named_modules: |
|
|
return named_modules[str(node.target)] |
|
|
else: |
|
|
return None |
|
|
|
|
|
def _insert_dequant_stub( |
|
|
node: Node, |
|
|
model: torch.nn.Module, |
|
|
named_modules: Dict[str, torch.nn.Module], |
|
|
graph: Graph, |
|
|
) -> Node: |
|
|
""" |
|
|
Attach a `DeQuantStub` to the model and create a node that calls this |
|
|
`DeQuantStub` on the output of `node`, similar to how observers are inserted. |
|
|
""" |
|
|
prefix = "dequant_stub_" |
|
|
get_new_dequant_stub_name = get_new_attr_name_with_prefix(prefix) |
|
|
dequant_stub_name = get_new_dequant_stub_name(model) |
|
|
dequant_stub = DeQuantStub() |
|
|
setattr(model, dequant_stub_name, dequant_stub) |
|
|
named_modules[dequant_stub_name] = dequant_stub |
|
|
with graph.inserting_after(node): |
|
|
return graph.call_module(dequant_stub_name, (node,)) |
|
|
|
|
|
def _insert_dequant_stubs_for_custom_module_lstm_output( |
|
|
node: Node, |
|
|
model: torch.nn.Module, |
|
|
named_modules: Dict[str, torch.nn.Module], |
|
|
graph: Graph, |
|
|
) -> Node: |
|
|
""" |
|
|
Insert DeQuantStubs after each internal output node of custom module LSTM. |
|
|
|
|
|
Custom module LSTM outputs are nested tuples of the sturcture (output, (hidden0, hidden1)), |
|
|
Since we cannot dequantize a tuple as a whole, we must first break down the tuple into its |
|
|
components through `getitem`. This function transforms the graph as follows: |
|
|
|
|
|
(1) Split the LSTM node into (output, (hidden0, hidden1)) |
|
|
(2) Insert a DeQuantStub after each internal node |
|
|
(3) Recombine the DeQuantStubs into the same structure as before |
|
|
(4) Reroute all consumers of the original LSTM node and its sub-nodes |
|
|
(e.g. lstm[0]) |
|
|
|
|
|
Before: |
|
|
lstm_output |
|
|
| |
|
|
v |
|
|
original_user(s) |
|
|
After: |
|
|
lstm_output |
|
|
/ \\ |
|
|
/ (getitem) \\ |
|
|
/ \\ |
|
|
v v |
|
|
output hidden |
|
|
| / \\ |
|
|
(DeQuantStub) (getitem) |
|
|
| / \\ |
|
|
v v v |
|
|
output_dq hidden0 hidden1 |
|
|
| | | |
|
|
| (DeQuantStub) (DeQuantStub) |
|
|
| | | |
|
|
| v v |
|
|
| hidden0_dq hidden1_dq |
|
|
| \\ / |
|
|
| (tuple) |
|
|
| \\ / |
|
|
| v v |
|
|
| hidden_dq |
|
|
\\ / |
|
|
\\ (tuple) / |
|
|
v v |
|
|
lstm_output_dq |
|
|
| |
|
|
v |
|
|
original_user(s) |
|
|
|
|
|
For step (4), reroute all users of the original LSTM node(s) as follows: |
|
|
lstm_output -> lstm_output_dq |
|
|
lstm_output[0] -> output_dq |
|
|
lstm_output[1] -> hidden_dq |
|
|
lstm_output[1][0] -> hidden0_dq |
|
|
lstm_output[1][1] -> hidden1_dq |
|
|
|
|
|
Return the node `lstm_output_dq`. |
|
|
""" |
|
|
|
|
|
|
|
|
with graph.inserting_after(node): |
|
|
output = graph.call_function(operator.getitem, (node, 0)) |
|
|
output_dq = _insert_dequant_stub(output, model, named_modules, graph) |
|
|
with graph.inserting_after(output_dq): |
|
|
hidden = graph.call_function(operator.getitem, (node, 1)) |
|
|
with graph.inserting_after(hidden): |
|
|
hidden0 = graph.call_function(operator.getitem, (hidden, 0)) |
|
|
hidden0_dq = _insert_dequant_stub(hidden0, model, named_modules, graph) |
|
|
with graph.inserting_after(hidden0_dq): |
|
|
hidden1 = graph.call_function(operator.getitem, (hidden, 1)) |
|
|
hidden1_dq = _insert_dequant_stub(hidden1, model, named_modules, graph) |
|
|
|
|
|
|
|
|
with graph.inserting_after(hidden1_dq): |
|
|
hidden_dq = graph.call_function(tuple, ([hidden0_dq, hidden1_dq],)) |
|
|
with graph.inserting_after(hidden_dq): |
|
|
lstm_output_dq = graph.call_function(tuple, ([output_dq, hidden_dq],)) |
|
|
|
|
|
|
|
|
for user in list(node.users.keys()): |
|
|
if user != output and user != hidden: |
|
|
user.replace_input_with(node, lstm_output_dq) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_reroute_tuple_getitem_pattern(graph) |
|
|
return lstm_output_dq |
|
|
|
|
|
def _maybe_get_custom_module_lstm_from_node_arg( |
|
|
arg: Node, |
|
|
named_modules: Dict[str, torch.nn.Module], |
|
|
) -> Optional[Node]: |
|
|
""" |
|
|
Given an argument of a node, if the argument refers to the path through which the node |
|
|
is a consumer of custom module LSTM, return the custom module LSTM node, or None otherwise. |
|
|
|
|
|
This is used to determine whether a node is a consumer of custom module LSTM, and, if so, |
|
|
skip inserting input observers for this node. This is because custom module LSTM produces |
|
|
quantized outputs, so inserting an input observer for the consumer of custom module LSTM |
|
|
would unnecessarily quantize the outputs again. |
|
|
|
|
|
lstm -> consumer |
|
|
|
|
|
In practice, however, custom module LSTM outputs a tuple (output, (hidden0, hidden1)) with |
|
|
DeQuantStubs attached to each internal node (see `_insert_dequant_stubs_for_custom_module_lstm_output`). |
|
|
This tuple can be consumed in one of four ways: |
|
|
|
|
|
lstm -> getitem -> DeQuantStub -> consumer # consume lstm[0] |
|
|
lstm -> getitem -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm[1] |
|
|
lstm -> getitem -> getitem -> DeQuantStub -> consumer # consume lstm[1][0] or lstm[1][1] |
|
|
lstm -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm |
|
|
|
|
|
Thus, we must match against the above patterns instead of simply checking the parent node |
|
|
to determine whether this node is a consumer of a custom module LSTM. |
|
|
""" |
|
|
def match_dq(a): |
|
|
return isinstance(_get_module(a, named_modules), DeQuantStub) |
|
|
|
|
|
def match_lstm(a): |
|
|
return _is_custom_module_lstm(a, named_modules) |
|
|
|
|
|
def match_getitem(a): |
|
|
return a.op == "call_function" and a.target == operator.getitem |
|
|
|
|
|
def match_tuple(a): |
|
|
return a.op == "call_function" and a.target == tuple |
|
|
|
|
|
def _match_pattern(match_pattern: List[Callable]) -> Optional[Node]: |
|
|
""" |
|
|
Traverse up the graph and match the args one by one. |
|
|
If there is a match, return the last matched node, or None otherwise. |
|
|
""" |
|
|
a = arg |
|
|
for i, match in enumerate(match_pattern): |
|
|
if not match(a): |
|
|
return None |
|
|
|
|
|
if i < len(match_pattern) - 1: |
|
|
if match == match_tuple: |
|
|
a = a.args[0][0] |
|
|
else: |
|
|
a = a.args[0] |
|
|
return a |
|
|
|
|
|
all_match_patterns = [ |
|
|
[match_dq, match_getitem, match_lstm], |
|
|
[match_tuple, match_dq, match_getitem, match_getitem, match_lstm], |
|
|
[match_dq, match_getitem, match_getitem, match_lstm], |
|
|
[match_tuple, match_dq, match_getitem, match_lstm], |
|
|
] |
|
|
|
|
|
for p in all_match_patterns: |
|
|
matched_node = _match_pattern(p) |
|
|
if matched_node is not None: |
|
|
return matched_node |
|
|
return None |
|
|
|
|
|
def _reroute_tuple_getitem_pattern(graph: Graph): |
|
|
""" |
|
|
Search for patterns where N consecutive `tuple` call_function nodes are followed by |
|
|
N consecutive `getitem` call_function nodes that are "reverses" of the `tuple` nodes. |
|
|
If we find this pattern, reroute the consumers of the last `getitem` to skip these |
|
|
N `tuple` and `getitem` nodes. |
|
|
|
|
|
Before: |
|
|
|
|
|
a b c |
|
|
| \\ / |
|
|
\\ tuple |
|
|
\\ / |
|
|
tuple |
|
|
| |
|
|
getitem(1) |
|
|
| |
|
|
getitem(0) |
|
|
| |
|
|
d |
|
|
|
|
|
After: |
|
|
|
|
|
b |
|
|
| |
|
|
d |
|
|
""" |
|
|
def find_patterns( |
|
|
node: Node, |
|
|
index_stack: List[int], |
|
|
current_pattern: List[Node], |
|
|
matched_patterns: List[List[Node]], |
|
|
seen: Set[Tuple[Node, Tuple[int, ...]]]): |
|
|
""" |
|
|
Traverse the graph recursively to match for the N-tuple - N-getitem patterns, |
|
|
starting at the given node. |
|
|
|
|
|
We use a stack to keep track of the expected `getitem` indices, since these are |
|
|
reversed from the `tuple` indices. In the above example, the stack after |
|
|
(b -> tuple -> tuple) will be [0, 1], which will be popped by getitem(1) first |
|
|
and then by getitem(0). |
|
|
|
|
|
TODO: traverse upwards from the output and handle the case when tuple is not a |
|
|
separate node, e.g. graph.call_function(operator.getitem, args=(a, (b, c))) |
|
|
""" |
|
|
if len(index_stack) == 0 and len(current_pattern) > 0: |
|
|
matched_patterns.append(copy.copy(current_pattern)) |
|
|
current_pattern.clear() |
|
|
|
|
|
|
|
|
state = (node, tuple(index_stack)) |
|
|
if state in seen: |
|
|
return |
|
|
seen.add(state) |
|
|
|
|
|
|
|
|
for user in node.users: |
|
|
if user.op == "call_function" and user.target == tuple: |
|
|
for i, user_arg in enumerate(user.args[0]): |
|
|
if user_arg == node: |
|
|
index_stack.append(i) |
|
|
current_pattern.append(user) |
|
|
find_patterns(user, index_stack, current_pattern, matched_patterns, seen) |
|
|
elif user.op == "call_function" and user.target == operator.getitem: |
|
|
if len(index_stack) > 0: |
|
|
if user.args[1] == index_stack[-1]: |
|
|
index_stack.pop() |
|
|
current_pattern.append(user) |
|
|
find_patterns(user, index_stack, current_pattern, matched_patterns, seen) |
|
|
return matched_patterns |
|
|
|
|
|
|
|
|
matched_patterns: List[List[Node]] = [] |
|
|
seen: Set[Tuple[Node, Tuple[int, ...]]] = set() |
|
|
for node in graph.nodes: |
|
|
find_patterns(node, [], [], matched_patterns, seen) |
|
|
|
|
|
|
|
|
|
|
|
for pattern in matched_patterns: |
|
|
first_tuple = pattern[0] |
|
|
last_getitem = pattern[-1] |
|
|
assert first_tuple.op == "call_function" and first_tuple.target == tuple |
|
|
assert last_getitem.op == "call_function" and last_getitem.target == operator.getitem |
|
|
last_getitem_index = last_getitem.args[1] |
|
|
new_input = first_tuple.args[0][last_getitem_index] |
|
|
for user in list(last_getitem.users.keys()): |
|
|
user.replace_input_with(last_getitem, new_input) |
|
|
|
|
|
def _get_observer_from_activation_post_process( |
|
|
activation_post_process: Union[ObserverBase, FakeQuantizeBase], |
|
|
) -> ObserverBase: |
|
|
""" |
|
|
If `activation_post_process` is an observer, return the observer. |
|
|
If `activation_post_process` is a fake quantize, return the internal observer. |
|
|
""" |
|
|
if isinstance(activation_post_process, ObserverBase): |
|
|
return activation_post_process |
|
|
else: |
|
|
assert isinstance(activation_post_process, FakeQuantizeBase) |
|
|
return activation_post_process.activation_post_process |
|
|
|
|
|
def _qconfig_satisfies_dtype_config_constraints( |
|
|
qconfig: QConfigAny, |
|
|
dtype_with_constraints: DTypeWithConstraints, |
|
|
is_activation: bool = True) -> bool: |
|
|
""" |
|
|
Return whether `qconfig` satisfies the following constraints from the backend, |
|
|
specified through the activation and weight DTypeWithConstraints. |
|
|
|
|
|
1. QConfig specified a quantization range that falls within the backend's, if any |
|
|
2. QConfig specified a min scale value that is >= the backend's, if any |
|
|
|
|
|
If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`. |
|
|
If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True. |
|
|
""" |
|
|
def _activation_post_process_satisfies_dtype_config_constraints( |
|
|
activation_post_process: Union[ObserverBase, FakeQuantizeBase], |
|
|
dtype_with_constraints: DTypeWithConstraints, |
|
|
debug_string: str) -> bool: |
|
|
observer = _get_observer_from_activation_post_process(activation_post_process) |
|
|
app_quant_min = getattr(observer, "quant_min", None) |
|
|
app_quant_max = getattr(observer, "quant_max", None) |
|
|
|
|
|
|
|
|
app_scale_min = getattr(observer, "eps", None) |
|
|
backend_quant_min = dtype_with_constraints.quant_min_lower_bound |
|
|
backend_quant_max = dtype_with_constraints.quant_max_upper_bound |
|
|
backend_scale_min = dtype_with_constraints.scale_min_lower_bound |
|
|
|
|
|
if backend_quant_min is not None and backend_quant_max is not None: |
|
|
if app_quant_min is None or app_quant_max is None: |
|
|
warnings.warn("QConfig %s must specify 'quant_min' and 'quant_max', ignoring %s" % |
|
|
(debug_string, qconfig)) |
|
|
return False |
|
|
elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max: |
|
|
warnings.warn(("QConfig %s quantization range must fall within the backend's:\n" |
|
|
"QConfig range = (%s, %s), BackendConfig range = (%s, %s), ignoring %s") % |
|
|
(debug_string, app_quant_min, app_quant_max, |
|
|
backend_quant_min, backend_quant_max, qconfig)) |
|
|
return False |
|
|
|
|
|
if backend_scale_min is not None: |
|
|
if app_scale_min is None: |
|
|
warnings.warn("QConfig %s must specify 'eps', ignoring %s" % (debug_string, qconfig)) |
|
|
return False |
|
|
elif app_scale_min < backend_scale_min: |
|
|
warnings.warn(("QConfig %s eps (%s) must be greater than or equal to " |
|
|
"the backend's min scale value (%s), ignoring %s") % |
|
|
(debug_string, app_scale_min, backend_scale_min, qconfig)) |
|
|
return False |
|
|
return True |
|
|
|
|
|
if qconfig is None or dtype_with_constraints.dtype is None: |
|
|
return True |
|
|
|
|
|
activation_post_process_ctr = qconfig.activation if is_activation else qconfig.weight |
|
|
debug_string = "activation" if is_activation else "weight" |
|
|
satisfies_constraints = True |
|
|
if activation_post_process_ctr is not None: |
|
|
activation_post_process = activation_post_process_ctr() |
|
|
assert is_activation_post_process(activation_post_process) |
|
|
|
|
|
if activation_post_process.dtype != dtype_with_constraints.dtype: |
|
|
return True |
|
|
satisfies_constraints = _activation_post_process_satisfies_dtype_config_constraints( |
|
|
activation_post_process, dtype_with_constraints, debug_string) |
|
|
return satisfies_constraints |
|
|
|