| | |
| | import copy |
| | import operator |
| | import warnings |
| | from collections import namedtuple |
| | from dataclasses import dataclass |
| | from typing import Any, Callable, Optional, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.ao.quantization import QConfigAny, QuantType |
| | from torch.ao.quantization.backend_config import DTypeWithConstraints |
| | from torch.ao.quantization.fake_quantize import ( |
| | FakeQuantizeBase, |
| | FixedQParamsFakeQuantize, |
| | ) |
| | from torch.ao.quantization.observer import ( |
| | _is_activation_post_process, |
| | FixedQParamsObserver, |
| | ObserverBase, |
| | ) |
| | from torch.ao.quantization.qconfig import ( |
| | float16_dynamic_qconfig, |
| | float16_static_qconfig, |
| | qconfig_equals, |
| | ) |
| | from torch.ao.quantization.qconfig_mapping import QConfigMapping |
| | from torch.ao.quantization.stubs import DeQuantStub |
| | from torch.ao.quantization.utils import ( |
| | _assert_and_get_unique_device, |
| | activation_is_statically_quantized, |
| | ) |
| | from torch.fx import GraphModule, map_arg |
| | from torch.fx.graph import Graph, Node |
| |
|
| | |
| | from ._decomposed import quantized_decomposed_lib |
| | from .custom_config import PrepareCustomConfig |
| |
|
| |
|
| | |
| | __all__ = [ |
| | "all_node_args_except_first", |
| | "all_node_args_have_no_tensors", |
| | "assert_and_get_unique_device", |
| | "collect_producer_nodes", |
| | "create_getattr_from_value", |
| | "create_node_from_old_node_preserve_meta", |
| | "EMPTY_ARG_DICT", |
| | "get_custom_module_class_keys", |
| | "get_linear_prepack_op_for_dtype", |
| | "get_new_attr_name_with_prefix", |
| | "get_non_observable_arg_indexes_and_types", |
| | "get_qconv_prepack_op", |
| | "get_skipped_module_name_and_classes", |
| | "graph_module_from_producer_nodes", |
| | "maybe_get_next_module", |
| | "NodeInfo", |
| | "node_arg_is_bias", |
| | "node_arg_is_weight", |
| | "NON_OBSERVABLE_ARG_DICT", |
| | "NON_QUANTIZABLE_WEIGHT_OPS", |
| | "return_arg_list", |
| | "ObservedGraphModuleAttrs", |
| | ] |
| |
|
| | NON_QUANTIZABLE_WEIGHT_OPS = { |
| | torch.nn.functional.layer_norm, |
| | torch.nn.functional.group_norm, |
| | torch.nn.functional.instance_norm, |
| | } |
| |
|
| |
|
| | @dataclass |
| | class ObservedGraphModuleAttrs: |
| | node_name_to_qconfig: dict[str, QConfigAny] |
| | node_name_to_scope: dict[str, tuple[str, type]] |
| | prepare_custom_config: PrepareCustomConfig |
| | equalization_node_name_to_qconfig: dict[str, Any] |
| | qconfig_mapping: QConfigMapping |
| | is_qat: bool |
| | observed_node_names: set[str] |
| | is_observed_standalone_module: bool = False |
| | standalone_module_input_quantized_idxs: Optional[list[int]] = None |
| | standalone_module_output_quantized_idxs: Optional[list[int]] = None |
| |
|
| |
|
| | def node_arg_is_weight(node: Node, arg: Any) -> bool: |
| | """Returns if node arg is weight""" |
| | weight_index = None |
| | if "target_dtype_info" in node.meta: |
| | weight_index = node.meta["target_dtype_info"].get("weight_index", None) |
| | if ( |
| | weight_index is not None |
| | and weight_index < len(node.args) |
| | and node.args[weight_index] is arg |
| | ): |
| | return True |
| | return node.kwargs.get("weight") is arg |
| |
|
| |
|
| | def node_arg_is_bias(node: Node, arg: Any) -> bool: |
| | """Returns if node arg is bias""" |
| | bias_index = None |
| | if "target_dtype_info" in node.meta: |
| | bias_index = node.meta["target_dtype_info"].get("bias_index", None) |
| | if ( |
| | bias_index is not None |
| | and bias_index < len(node.args) |
| | and node.args[bias_index] is arg |
| | ): |
| | return True |
| | return node.kwargs.get("bias") is arg |
| |
|
| |
|
| | def get_custom_module_class_keys( |
| | custom_module_mapping: dict[QuantType, dict[type, type]], |
| | ) -> list[Any]: |
| | r"""Get all the unique custom module keys in the custom config dict |
| | e.g. |
| | Input: |
| | { |
| | QuantType.STATIC: { |
| | CustomModule1: ObservedCustomModule |
| | }, |
| | QuantType.DYNAMIC: { |
| | CustomModule2: DynamicObservedCustomModule |
| | }, |
| | QuantType.WEIGHT_ONLY: { |
| | CustomModule3: WeightOnlyObservedCustomModule |
| | }, |
| | } |
| | |
| | Output: |
| | # extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts |
| | [CustomModule1, CustomModule2, CustomModule3] |
| | """ |
| | |
| | float_custom_module_classes: set[Any] = set() |
| | for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]: |
| | quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {}) |
| | quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys()) |
| | float_custom_module_classes |= quant_mode_custom_module_classes |
| | return list(float_custom_module_classes) |
| |
|
| |
|
| | def get_linear_prepack_op_for_dtype(dtype): |
| | if dtype == torch.float16: |
| | return torch.ops.quantized.linear_prepack_fp16 |
| | elif dtype == torch.qint8: |
| | return torch.ops.quantized.linear_prepack |
| | else: |
| | raise Exception("can't get linear prepack op for dtype:", dtype) |
| |
|
| |
|
| | def get_qconv_prepack_op(conv_op: Callable) -> Callable: |
| | prepack_ops = { |
| | torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack, |
| | torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack, |
| | torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack, |
| | torch.nn.functional.conv_transpose1d: torch.ops.quantized.conv_transpose1d_prepack, |
| | torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack, |
| | torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack, |
| | } |
| | prepack_op = prepack_ops.get(conv_op, None) |
| | assert prepack_op, f"Didn't find prepack op for {conv_op}" |
| | return prepack_op |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | def get_new_attr_name_with_prefix(prefix: str) -> Callable: |
| | prefix = prefix.replace(".", "_") |
| |
|
| | def get_new_attr_name(module: torch.nn.Module): |
| | def get_attr_name(i: int): |
| | return prefix + str(i) |
| |
|
| | i = 0 |
| | attr_name = get_attr_name(i) |
| | while hasattr(module, attr_name): |
| | i += 1 |
| | attr_name = get_attr_name(i) |
| | return attr_name |
| |
|
| | return get_new_attr_name |
| |
|
| |
|
| | def collect_producer_nodes(node: Node) -> Optional[list[Node]]: |
| | r"""Starting from a target node, trace back until we hit input or |
| | getattr node. This is used to extract the chain of operators |
| | starting from getattr to the target node, for example |
| | def forward(self, x): |
| | observed = self.observer(self.weight) |
| | return F.linear(x, observed) |
| | collect_producer_nodes(observed) will either return a list of nodes that |
| | produces the observed node or None if we can't extract a self contained |
| | graph without free variables(inputs of the forward function). |
| | """ |
| | nodes = [node] |
| | frontier = [node] |
| | while frontier: |
| | node = frontier.pop() |
| | all_args = list(node.args) + list(node.kwargs.values()) |
| | for arg in all_args: |
| | if not isinstance(arg, Node): |
| | continue |
| | if arg.op == "placeholder": |
| | |
| | return None |
| | nodes.append(arg) |
| | if not (arg.op == "call_function" and arg.target == getattr): |
| | frontier.append(arg) |
| | return nodes |
| |
|
| |
|
| | def graph_module_from_producer_nodes( |
| | root: GraphModule, producer_nodes: list[Node] |
| | ) -> GraphModule: |
| | r"""Construct a graph module from extracted producer nodes |
| | from `collect_producer_nodes` function |
| | Args: |
| | root: the root module for the original graph |
| | producer_nodes: a list of nodes we use to construct the graph |
| | Return: |
| | A graph module constructed from the producer nodes |
| | """ |
| | assert len(producer_nodes) > 0, "list of producer nodes can not be empty" |
| | |
| | producer_nodes.reverse() |
| | graph = Graph() |
| | env: dict[Any, Any] = {} |
| |
|
| | def load_arg(a): |
| | return map_arg(a, lambda node: env[node]) |
| |
|
| | for producer_node in producer_nodes: |
| | env[producer_node] = graph.node_copy(producer_node, load_arg) |
| | graph.output(load_arg(producer_nodes[-1])) |
| | graph_module = GraphModule(root, graph) |
| | return graph_module |
| |
|
| |
|
| | |
| | def assert_and_get_unique_device(module: torch.nn.Module) -> Any: |
| | """ |
| | Returns the unique device for a module, or None if no device is found. |
| | Throws an error if multiple devices are detected. |
| | """ |
| | return _assert_and_get_unique_device(module) |
| |
|
| |
|
| | def create_getattr_from_value( |
| | module: torch.nn.Module, |
| | graph: Graph, |
| | prefix: str, |
| | value: Any, |
| | device: Optional[torch.device] = None, |
| | ) -> Node: |
| | """ |
| | Given a value of any type, creates a getattr node corresponding to the value and |
| | registers the value as a buffer to the module. |
| | """ |
| | get_new_attr_name = get_new_attr_name_with_prefix(prefix) |
| | attr_name = get_new_attr_name(module) |
| | if device is None: |
| | device = assert_and_get_unique_device(module) |
| | new_value = ( |
| | value.detach().clone() |
| | if isinstance(value, torch.Tensor) |
| | else torch.tensor(value, device=device) |
| | ) |
| | module.register_buffer(attr_name, new_value) |
| | |
| | attr_node = graph.create_node("get_attr", attr_name) |
| | return attr_node |
| |
|
| |
|
| | def all_node_args_have_no_tensors( |
| | node: Node, modules: dict[str, torch.nn.Module], cache: dict[Node, bool] |
| | ) -> bool: |
| | """ |
| | If we know for sure that all of this node's args have no |
| | tensors (are primitives), return True. If we either |
| | find a tensor or are not sure, return False. Note: this |
| | function is not exact. |
| | """ |
| | if cache and node in cache: |
| | return cache[node] |
| |
|
| | result = False |
| | if not isinstance(node, Node): |
| | result = True |
| | elif node.op == "placeholder": |
| | result = False |
| | elif node.op == "call_module": |
| | assert isinstance(node.target, str) |
| | if _is_activation_post_process(modules[node.target]): |
| | result = all_node_args_have_no_tensors(node.args[0], modules, cache) |
| | elif node.op == "call_module": |
| | result = False |
| | elif node.op == "call_function" and node.target is operator.getitem: |
| | result = all_node_args_have_no_tensors(node.args[0], modules, cache) |
| | elif node.op == "get_attr": |
| | result = False |
| | elif node.target is getattr and node.args[1] in ["ndim", "shape"]: |
| | |
| | result = True |
| | elif node.op == "call_method" and node.target == "size": |
| | |
| | result = True |
| | else: |
| | found_one_tensor = False |
| | for arg in node.args: |
| | if isinstance(arg, list): |
| | for list_el in arg: |
| | if isinstance(list_el, Node): |
| | this_list_el_args_have_no_tensors = ( |
| | all_node_args_have_no_tensors(list_el, modules, cache) |
| | ) |
| | found_one_tensor = found_one_tensor or ( |
| | not this_list_el_args_have_no_tensors |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | if found_one_tensor: |
| | result = not found_one_tensor |
| | if cache: |
| | cache[node] = result |
| | return result |
| | elif isinstance(arg, int): |
| | pass |
| | else: |
| | if isinstance(arg, Node): |
| | this_arg_args_have_no_tensors = all_node_args_have_no_tensors( |
| | arg, modules, cache |
| | ) |
| | found_one_tensor = found_one_tensor or ( |
| | not this_arg_args_have_no_tensors |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | if found_one_tensor: |
| | result = not found_one_tensor |
| | if cache: |
| | cache[node] = result |
| | return result |
| | else: |
| | found_one_tensor = True |
| | result = not found_one_tensor |
| | if cache: |
| | cache[node] = result |
| | return result |
| |
|
| |
|
| | def all_node_args_except_first(node: Node) -> list[int]: |
| | """ |
| | Returns all node arg indices after first |
| | """ |
| | return list(range(1, len(node.args))) |
| |
|
| |
|
| | def return_arg_list(arg_indices: list[int]) -> Callable[[Node], list[int]]: |
| | """ |
| | Constructs a function that takes a node as arg and returns the arg_indices |
| | that are valid for node.args |
| | """ |
| |
|
| | def arg_indices_func(node: Node) -> list[int]: |
| | return [i for i in arg_indices if i < len(node.args)] |
| |
|
| | return arg_indices_func |
| |
|
| |
|
| | NodeInfo = namedtuple("NodeInfo", "op target") |
| |
|
| | |
| | |
| | |
| |
|
| | NON_OBSERVABLE_ARG_DICT: dict[ |
| | NodeInfo, dict[Union[type, torch.dtype], Callable[[Node], list[int]]] |
| | ] = { |
| | NodeInfo("call_method", "masked_fill"): { |
| | torch.bool: return_arg_list([1]), |
| | float: return_arg_list([2]), |
| | }, |
| | NodeInfo("call_method", "permute"): {int: all_node_args_except_first}, |
| | NodeInfo("call_method", "repeat"): {int: all_node_args_except_first}, |
| | NodeInfo("call_method", "reshape"): {int: all_node_args_except_first}, |
| | NodeInfo("call_method", "size"): {int: return_arg_list([1])}, |
| | NodeInfo("call_method", "transpose"): {int: all_node_args_except_first}, |
| | NodeInfo("call_method", torch.transpose): {int: all_node_args_except_first}, |
| | NodeInfo("call_method", "unsqueeze"): {int: return_arg_list([1])}, |
| | NodeInfo("call_method", "unsqueeze_"): {int: return_arg_list([1])}, |
| | NodeInfo("call_method", torch.unsqueeze): {int: return_arg_list([1])}, |
| | NodeInfo("call_method", "view"): {int: all_node_args_except_first}, |
| | } |
| |
|
| | EMPTY_ARG_DICT: dict[Union[type, torch.dtype], Callable[[Node], list[int]]] = {} |
| |
|
| |
|
| | def get_non_observable_arg_indexes_and_types( |
| | node: Node, |
| | ) -> dict[Union[type, torch.dtype], Callable[[Node], list[int]]]: |
| | """ |
| | Returns a dict with of non float tensor types as keys and values which correspond to a |
| | function to retrieve the list (which takes the node as an argument) |
| | """ |
| | info = NodeInfo(node.op, node.target) |
| |
|
| | return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT) |
| |
|
| |
|
| | def maybe_get_next_module( |
| | node: Node, |
| | modules: dict[str, nn.Module], |
| | target_module_type: Optional[type[nn.Module]] = None, |
| | target_functional_type: Any = None, |
| | ) -> Optional[Node]: |
| | """Gets the next module that matches what is needed in |
| | is_target_module_type if it exists |
| | |
| | Args: |
| | node: The node whose users we want to look at |
| | target_module_type: Module type that we want to check |
| | target_functional_type: Functional type that we want to check |
| | """ |
| |
|
| | for user in node.users.keys(): |
| | if ( |
| | user.op == "call_module" |
| | and target_module_type is not None |
| | and isinstance(modules[str(user.target)], target_module_type) |
| | ): |
| | return user |
| | elif ( |
| | user.op == "call_function" |
| | and target_functional_type is not None |
| | and user.target == target_functional_type |
| | ): |
| | return user |
| |
|
| | return None |
| |
|
| |
|
| | def create_node_from_old_node_preserve_meta( |
| | quantized_graph: Graph, |
| | create_node_args: tuple[Any, ...], |
| | old_node: Node, |
| | ) -> Node: |
| | """ |
| | Creates `new_node` and copies the necessary metadata to it from `old_node`. |
| | """ |
| | new_node = quantized_graph.create_node(*create_node_args) |
| | new_node.stack_trace = old_node.stack_trace |
| | return new_node |
| |
|
| |
|
| | def get_skipped_module_name_and_classes( |
| | prepare_custom_config: PrepareCustomConfig, is_standalone_module: bool |
| | ) -> tuple[list[str], list[type[Any]]]: |
| | skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names) |
| | skipped_module_classes = copy.copy( |
| | prepare_custom_config.non_traceable_module_classes |
| | ) |
| | if not is_standalone_module: |
| | |
| | skipped_module_names += list( |
| | prepare_custom_config.standalone_module_names.keys() |
| | ) |
| | skipped_module_classes += list( |
| | prepare_custom_config.standalone_module_classes.keys() |
| | ) |
| | skipped_module_classes += get_custom_module_class_keys( |
| | prepare_custom_config.float_to_observed_mapping |
| | ) |
| |
|
| | return skipped_module_names, skipped_module_classes |
| |
|
| |
|
| | def _is_custom_module_lstm( |
| | node: Node, |
| | named_modules: dict[str, torch.nn.Module], |
| | qconfig: QConfigAny = None, |
| | |
| | qhandler: Optional[Any] = None, |
| | ) -> bool: |
| | """ |
| | Return whether this refers to the custom module LSTM flow. |
| | """ |
| | mod = _get_module(node, named_modules) |
| | if qconfig is not None and qhandler is not None: |
| | assert isinstance( |
| | qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler |
| | ) |
| | return ( |
| | isinstance(mod, torch.nn.LSTM) |
| | and activation_is_statically_quantized(qconfig) |
| | and qhandler.is_custom_module() |
| | ) |
| | else: |
| | return isinstance(mod, torch.ao.nn.quantizable.LSTM) |
| |
|
| |
|
| | def _is_custom_module_mha( |
| | node: Node, |
| | named_modules: dict[str, torch.nn.Module], |
| | qconfig: QConfigAny = None, |
| | |
| | qhandler: Optional[Any] = None, |
| | ) -> bool: |
| | """ |
| | Return whether this refers to the custom module MultiheadAttention flow. |
| | """ |
| | mod = _get_module(node, named_modules) |
| | if qconfig is not None and qhandler is not None: |
| | assert isinstance( |
| | qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler |
| | ) |
| | return ( |
| | isinstance(mod, torch.nn.MultiheadAttention) |
| | and activation_is_statically_quantized(qconfig) |
| | and qhandler.is_custom_module() |
| | ) |
| | else: |
| | return isinstance(mod, torch.ao.nn.quantizable.MultiheadAttention) |
| |
|
| |
|
| | def _get_module( |
| | node: Node, named_modules: dict[str, torch.nn.Module] |
| | ) -> Optional[torch.nn.Module]: |
| | """ |
| | If `node` refers to a call_module node, return the module, else None. |
| | """ |
| | if node.op == "call_module" and str(node.target) in named_modules: |
| | return named_modules[str(node.target)] |
| | else: |
| | return None |
| |
|
| |
|
| | def _insert_dequant_stub( |
| | node: Node, |
| | model: torch.nn.Module, |
| | named_modules: dict[str, torch.nn.Module], |
| | graph: Graph, |
| | ) -> Node: |
| | """ |
| | Attach a `DeQuantStub` to the model and create a node that calls this |
| | `DeQuantStub` on the output of `node`, similar to how observers are inserted. |
| | """ |
| | prefix = "dequant_stub_" |
| | get_new_dequant_stub_name = get_new_attr_name_with_prefix(prefix) |
| | dequant_stub_name = get_new_dequant_stub_name(model) |
| | dequant_stub = DeQuantStub() |
| | setattr(model, dequant_stub_name, dequant_stub) |
| | named_modules[dequant_stub_name] = dequant_stub |
| | with graph.inserting_after(node): |
| | return graph.call_module(dequant_stub_name, (node,)) |
| |
|
| |
|
| | def _insert_dequant_stubs_for_custom_module_lstm_output( |
| | node: Node, |
| | model: torch.nn.Module, |
| | named_modules: dict[str, torch.nn.Module], |
| | graph: Graph, |
| | ) -> Node: |
| | """ |
| | Insert DeQuantStubs after each internal output node of custom module LSTM. |
| | |
| | Custom module LSTM outputs are nested tuples of the structure (output, (hidden0, hidden1)), |
| | Since we cannot dequantize a tuple as a whole, we must first break down the tuple into its |
| | components through `getitem`. This function transforms the graph as follows: |
| | |
| | (1) Split the LSTM node into (output, (hidden0, hidden1)) |
| | (2) Insert a DeQuantStub after each internal node |
| | (3) Recombine the DeQuantStubs into the same structure as before |
| | (4) Reroute all consumers of the original LSTM node and its sub-nodes |
| | (e.g. lstm[0]) |
| | |
| | Before: |
| | lstm_output |
| | | |
| | v |
| | original_user(s) |
| | After: |
| | lstm_output |
| | / \\ |
| | / (getitem) \\ |
| | / \\ |
| | v v |
| | output hidden |
| | | / \\ |
| | (DeQuantStub) (getitem) |
| | | / \\ |
| | v v v |
| | output_dq hidden0 hidden1 |
| | | | | |
| | | (DeQuantStub) (DeQuantStub) |
| | | | | |
| | | v v |
| | | hidden0_dq hidden1_dq |
| | | \\ / |
| | | (tuple) |
| | | \\ / |
| | | v v |
| | | hidden_dq |
| | \\ / |
| | \\ (tuple) / |
| | v v |
| | lstm_output_dq |
| | | |
| | v |
| | original_user(s) |
| | |
| | For step (4), reroute all users of the original LSTM node(s) as follows: |
| | lstm_output -> lstm_output_dq |
| | lstm_output[0] -> output_dq |
| | lstm_output[1] -> hidden_dq |
| | lstm_output[1][0] -> hidden0_dq |
| | lstm_output[1][1] -> hidden1_dq |
| | |
| | Return the node `lstm_output_dq`. |
| | """ |
| | |
| | |
| | with graph.inserting_after(node): |
| | output = graph.call_function(operator.getitem, (node, 0)) |
| | output_dq = _insert_dequant_stub(output, model, named_modules, graph) |
| | with graph.inserting_after(output_dq): |
| | hidden = graph.call_function(operator.getitem, (node, 1)) |
| | with graph.inserting_after(hidden): |
| | hidden0 = graph.call_function(operator.getitem, (hidden, 0)) |
| | hidden0_dq = _insert_dequant_stub(hidden0, model, named_modules, graph) |
| | with graph.inserting_after(hidden0_dq): |
| | hidden1 = graph.call_function(operator.getitem, (hidden, 1)) |
| | hidden1_dq = _insert_dequant_stub(hidden1, model, named_modules, graph) |
| |
|
| | |
| | with graph.inserting_after(hidden1_dq): |
| | hidden_dq = graph.call_function(tuple, ([hidden0_dq, hidden1_dq],)) |
| | with graph.inserting_after(hidden_dq): |
| | lstm_output_dq = graph.call_function(tuple, ([output_dq, hidden_dq],)) |
| |
|
| | |
| | for user in list(node.users.keys()): |
| | if user != output and user != hidden: |
| | user.replace_input_with(node, lstm_output_dq) |
| | |
| | |
| | |
| | |
| | _reroute_tuple_getitem_pattern(graph) |
| | return lstm_output_dq |
| |
|
| |
|
| | def _maybe_get_custom_module_lstm_from_node_arg( |
| | arg: Node, |
| | named_modules: dict[str, torch.nn.Module], |
| | ) -> Optional[Node]: |
| | """ |
| | Given an argument of a node, if the argument refers to the path through which the node |
| | is a consumer of custom module LSTM, return the custom module LSTM node, or None otherwise. |
| | |
| | This is used to determine whether a node is a consumer of custom module LSTM, and, if so, |
| | skip inserting input observers for this node. This is because custom module LSTM produces |
| | quantized outputs, so inserting an input observer for the consumer of custom module LSTM |
| | would unnecessarily quantize the outputs again. |
| | |
| | lstm -> consumer |
| | |
| | In practice, however, custom module LSTM outputs a tuple (output, (hidden0, hidden1)) with |
| | DeQuantStubs attached to each internal node (see `_insert_dequant_stubs_for_custom_module_lstm_output`). |
| | This tuple can be consumed in one of four ways: |
| | |
| | lstm -> getitem -> DeQuantStub -> consumer # consume lstm[0] |
| | lstm -> getitem -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm[1] |
| | lstm -> getitem -> getitem -> DeQuantStub -> consumer # consume lstm[1][0] or lstm[1][1] |
| | lstm -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm |
| | |
| | Thus, we must match against the above patterns instead of simply checking the parent node |
| | to determine whether this node is a consumer of a custom module LSTM. |
| | """ |
| |
|
| | def match_dq(a): |
| | return isinstance(_get_module(a, named_modules), DeQuantStub) |
| |
|
| | def match_lstm(a): |
| | return _is_custom_module_lstm(a, named_modules) |
| |
|
| | def match_getitem(a): |
| | return a.op == "call_function" and a.target == operator.getitem |
| |
|
| | def match_tuple(a): |
| | return a.op == "call_function" and a.target == tuple |
| |
|
| | def _match_pattern(match_pattern: list[Callable]) -> Optional[Node]: |
| | """ |
| | Traverse up the graph and match the args one by one. |
| | If there is a match, return the last matched node, or None otherwise. |
| | """ |
| | a = arg |
| | for i, match in enumerate(match_pattern): |
| | if not match(a): |
| | return None |
| | |
| | if i < len(match_pattern) - 1: |
| | if match == match_tuple: |
| | a = a.args[0][0] |
| | else: |
| | a = a.args[0] |
| | return a |
| |
|
| | all_match_patterns = [ |
| | [match_dq, match_getitem, match_lstm], |
| | [match_tuple, match_dq, match_getitem, match_getitem, match_lstm], |
| | [match_dq, match_getitem, match_getitem, match_lstm], |
| | [match_tuple, match_dq, match_getitem, match_lstm], |
| | ] |
| |
|
| | for p in all_match_patterns: |
| | matched_node = _match_pattern(p) |
| | if matched_node is not None: |
| | return matched_node |
| | return None |
| |
|
| |
|
| | def _reroute_tuple_getitem_pattern(graph: Graph): |
| | """ |
| | Search for patterns where N consecutive `tuple` call_function nodes are followed by |
| | N consecutive `getitem` call_function nodes that are "reverses" of the `tuple` nodes. |
| | If we find this pattern, reroute the consumers of the last `getitem` to skip these |
| | N `tuple` and `getitem` nodes. |
| | |
| | Before: |
| | |
| | a b c |
| | | \\ / |
| | \\ tuple |
| | \\ / |
| | tuple |
| | | |
| | getitem(1) |
| | | |
| | getitem(0) |
| | | |
| | d |
| | |
| | After: |
| | |
| | b |
| | | |
| | d |
| | """ |
| |
|
| | def find_patterns( |
| | node: Node, |
| | index_stack: list[int], |
| | current_pattern: list[Node], |
| | matched_patterns: list[list[Node]], |
| | seen: set[tuple[Node, tuple[int, ...]]], |
| | ): |
| | """ |
| | Traverse the graph recursively to match for the N-tuple - N-getitem patterns, |
| | starting at the given node. |
| | |
| | We use a stack to keep track of the expected `getitem` indices, since these are |
| | reversed from the `tuple` indices. In the above example, the stack after |
| | (b -> tuple -> tuple) will be [0, 1], which will be popped by getitem(1) first |
| | and then by getitem(0). |
| | |
| | TODO: traverse upwards from the output and handle the case when tuple is not a |
| | separate node, e.g. graph.call_function(operator.getitem, args=(a, (b, c))) |
| | """ |
| | if len(index_stack) == 0 and len(current_pattern) > 0: |
| | matched_patterns.append(copy.copy(current_pattern)) |
| | current_pattern.clear() |
| |
|
| | |
| | state = (node, tuple(index_stack)) |
| | if state in seen: |
| | return |
| | seen.add(state) |
| |
|
| | |
| | for user in node.users: |
| | if user.op == "call_function" and user.target == tuple: |
| | for i, user_arg in enumerate(user.args[0]): |
| | if user_arg == node: |
| | index_stack.append(i) |
| | current_pattern.append(user) |
| | find_patterns( |
| | user, index_stack, current_pattern, matched_patterns, seen |
| | ) |
| | elif user.op == "call_function" and user.target == operator.getitem: |
| | if len(index_stack) > 0: |
| | if user.args[1] == index_stack[-1]: |
| | index_stack.pop() |
| | current_pattern.append(user) |
| | find_patterns( |
| | user, index_stack, current_pattern, matched_patterns, seen |
| | ) |
| | return matched_patterns |
| |
|
| | |
| | matched_patterns: list[list[Node]] = [] |
| | seen: set[tuple[Node, tuple[int, ...]]] = set() |
| | for node in graph.nodes: |
| | find_patterns(node, [], [], matched_patterns, seen) |
| |
|
| | |
| | |
| | for pattern in matched_patterns: |
| | first_tuple = pattern[0] |
| | last_getitem = pattern[-1] |
| | assert first_tuple.op == "call_function" and first_tuple.target == tuple |
| | assert ( |
| | last_getitem.op == "call_function" |
| | and last_getitem.target == operator.getitem |
| | ) |
| | last_getitem_index = last_getitem.args[1] |
| | new_input = first_tuple.args[0][last_getitem_index] |
| | for user in list(last_getitem.users.keys()): |
| | user.replace_input_with(last_getitem, new_input) |
| |
|
| |
|
| | def _get_observer_from_activation_post_process( |
| | activation_post_process: Union[ObserverBase, FakeQuantizeBase], |
| | ) -> ObserverBase: |
| | """ |
| | If `activation_post_process` is an observer, return the observer. |
| | If `activation_post_process` is a fake quantize, return the internal observer. |
| | """ |
| | if isinstance(activation_post_process, ObserverBase): |
| | return activation_post_process |
| | else: |
| | assert isinstance(activation_post_process, FakeQuantizeBase) |
| | return activation_post_process.activation_post_process |
| |
|
| |
|
| | def _qconfig_satisfies_dtype_config_constraints( |
| | qconfig: QConfigAny, |
| | dtype_with_constraints: DTypeWithConstraints, |
| | is_activation: bool = True, |
| | ) -> bool: |
| | """ |
| | Return whether `qconfig` satisfies the following constraints from the backend, |
| | specified through the activation and weight DTypeWithConstraints. |
| | |
| | 1. QConfig specified a quantization range that falls within the backend's, if any |
| | 2. QConfig specified a min scale value that is >= the backend's, if any |
| | 3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has |
| | scale and zero point that match the backend's, if any |
| | |
| | If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`. |
| | If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True. |
| | """ |
| |
|
| | |
| | def _activation_post_process_satisfies_dtype_config_constraints( |
| | activation_post_process: Union[ObserverBase, FakeQuantizeBase], |
| | dtype_with_constraints: DTypeWithConstraints, |
| | debug_string: str, |
| | ) -> bool: |
| | observer = _get_observer_from_activation_post_process(activation_post_process) |
| | app_quant_min = getattr(observer, "quant_min", None) |
| | app_quant_max = getattr(observer, "quant_max", None) |
| | |
| | |
| | app_scale_min = getattr(observer, "eps", None) |
| | backend_quant_min = dtype_with_constraints.quant_min_lower_bound |
| | backend_quant_max = dtype_with_constraints.quant_max_upper_bound |
| | backend_scale_min = dtype_with_constraints.scale_min_lower_bound |
| | backend_scale_exact_match = dtype_with_constraints.scale_exact_match |
| | backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match |
| | |
| | if backend_quant_min is not None and backend_quant_max is not None: |
| | if app_quant_min is None or app_quant_max is None: |
| | warnings.warn( |
| | f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}" |
| | ) |
| | return False |
| | elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max: |
| | warnings.warn( |
| | f"QConfig {debug_string} quantization range must fall within the backend's:\n" |
| | f"QConfig range = ({app_quant_min}, {app_quant_max}), " |
| | f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), " |
| | f"ignoring {qconfig}" |
| | ) |
| | return False |
| | |
| | if backend_scale_min is not None: |
| | if app_scale_min is None: |
| | warnings.warn( |
| | f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}" |
| | ) |
| | return False |
| | if app_scale_min < backend_scale_min: |
| | warnings.warn( |
| | f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to " |
| | f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}" |
| | ) |
| | return False |
| | |
| | if ( |
| | backend_scale_exact_match is not None |
| | and backend_zero_point_exact_match is not None |
| | ): |
| | |
| | |
| | for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]: |
| | if qconfig_equals(qconfig, accepted_qconfig): |
| | return True |
| | suggestion_str = ( |
| | "Please use torch.ao.quantization.get_default_qconfig_mapping or " |
| | "torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n" |
| | ' qconfig_mapping = get_default_qconfig_mapping("fbgemm")\n' |
| | " model = prepare_fx(model, qconfig_mapping, example_inputs)" |
| | ) |
| | if not isinstance( |
| | activation_post_process, FixedQParamsObserver |
| | ) and not isinstance(activation_post_process, FixedQParamsFakeQuantize): |
| | warnings.warn( |
| | f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize " |
| | f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}" |
| | ) |
| | return False |
| | if ( |
| | observer.scale != backend_scale_exact_match |
| | or observer.zero_point != backend_zero_point_exact_match |
| | ): |
| | warnings.warn( |
| | f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) " |
| | f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), " |
| | f"ignoring {qconfig}.\n{suggestion_str}" |
| | ) |
| | return False |
| | return True |
| |
|
| | if qconfig is None or dtype_with_constraints.dtype is None: |
| | return True |
| |
|
| | activation_post_process_ctr = ( |
| | qconfig.activation if is_activation else qconfig.weight |
| | ) |
| | debug_string = "activation" if is_activation else "weight" |
| | satisfies_constraints = True |
| | if activation_post_process_ctr is not None: |
| | activation_post_process = activation_post_process_ctr() |
| | assert _is_activation_post_process(activation_post_process) |
| | |
| | if activation_post_process.dtype != dtype_with_constraints.dtype: |
| | return True |
| | satisfies_constraints = ( |
| | _activation_post_process_satisfies_dtype_config_constraints( |
| | activation_post_process, dtype_with_constraints, debug_string |
| | ) |
| | ) |
| | return satisfies_constraints |
| |
|