| |
|
|
| import copy |
| import operator |
| import warnings |
| from typing import Any, Callable, Optional, Union |
|
|
| import torch |
| from torch.ao.quantization import CUSTOM_KEY, NUMERIC_DEBUG_HANDLE_KEY |
| from torch.ao.quantization.backend_config import ( |
| BackendConfig, |
| get_native_backend_config, |
| ) |
| from torch.ao.quantization.backend_config.utils import ( |
| get_fused_module_classes, |
| get_pattern_to_dtype_configs, |
| get_qat_module_classes, |
| get_root_module_to_quantized_reference_module, |
| ) |
| from torch.ao.quantization.observer import _is_activation_post_process |
| from torch.ao.quantization.qconfig import qconfig_equals, QConfigAny |
| from torch.ao.quantization.qconfig_mapping import QConfigMapping |
| from torch.ao.quantization.quant_type import QuantType |
| from torch.ao.quantization.quantize import _remove_qconfig |
| from torch.ao.quantization.stubs import DeQuantStub |
| from torch.ao.quantization.utils import ( |
| _parent_name, |
| activation_is_statically_quantized, |
| get_qparam_dict, |
| get_swapped_custom_module_class, |
| is_per_channel, |
| to_underlying_dtype, |
| weight_is_quantized, |
| ) |
| from torch.fx import GraphModule |
| from torch.fx.graph import Argument, Graph, Node |
| from torch.nn.utils.parametrize import type_before_parametrizations |
|
|
| |
| from ._decomposed import quantized_decomposed_lib |
| from ._equalize import convert_eq_obs, update_obs_for_equalization |
| from .custom_config import ConvertCustomConfig, PrepareCustomConfig |
| from .graph_module import _is_observed_module, _is_observed_standalone_module |
| from .lower_to_fbgemm import lower_to_fbgemm |
| from .qconfig_mapping_utils import ( |
| _compare_prepare_convert_qconfig_mappings, |
| _generate_node_name_to_qconfig, |
| _is_qconfig_supported_by_dtype_configs, |
| _update_qconfig_for_fusion, |
| _update_qconfig_for_qat, |
| ) |
| from .utils import ( |
| _get_module, |
| _is_custom_module_lstm, |
| _is_custom_module_mha, |
| assert_and_get_unique_device, |
| collect_producer_nodes, |
| create_getattr_from_value, |
| get_custom_module_class_keys, |
| graph_module_from_producer_nodes, |
| node_arg_is_weight, |
| ) |
|
|
|
|
| __all__ = [ |
| "convert", |
| "convert_custom_module", |
| "convert_standalone_module", |
| "convert_weighted_module", |
| ] |
|
|
| SUPPORTED_QDTYPES = [ |
| torch.quint8, |
| torch.qint8, |
| torch.qint32, |
| torch.uint8, |
| torch.int8, |
| torch.uint16, |
| torch.int16, |
| torch.int32, |
| torch.float8_e5m2, |
| torch.float8_e4m3fn, |
| ] |
|
|
| _QSCHEME_TO_CHOOSE_QPARAMS_OP = { |
| torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor, |
| torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, |
| } |
|
|
|
|
| def _replace_observer_with_quantize_dequantize_node_decomposed( |
| model: torch.fx.GraphModule, |
| node: Node, |
| modules: dict[str, torch.nn.Module], |
| node_name_to_scope: dict[str, tuple[str, type]], |
| node_name_to_qconfig: dict[str, QConfigAny], |
| model_device: Optional[torch.device] = None, |
| ) -> None: |
| """Replace activation_post_process module call node with quantize and |
| dequantize node working with decomposed Tensor |
| |
| Before: |
| ... -> observer_0(x) -> ... |
| After: |
| ... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) -> |
| torch.ops.quantized_decomposed.dequantize_per_tensor() -> ... |
| |
| or quantize_per_channel and dequantize_per_channel |
| """ |
| graph = model.graph |
| assert modules is not None |
| assert isinstance(node.target, str) |
| module_path, prefix = _get_module_path_and_prefix( |
| node, node_name_to_scope, node_name_to_qconfig |
| ) |
| activation_post_process = modules[node.target] |
| if hasattr(activation_post_process, "convert"): |
| activation_post_process.convert(model, node) |
| return |
| |
| |
| skip_replacement = all( |
| _has_none_qconfig(n, node_name_to_qconfig) |
| for n in list(node.args) + list(node.users.keys()) |
| ) |
| if skip_replacement or not _is_conversion_supported(activation_post_process): |
| |
| |
| with graph.inserting_before(node): |
| node.replace_all_uses_with(node.args[0]) |
| graph.erase_node(node) |
| return |
|
|
| |
|
|
| |
| |
| dtype = activation_post_process.dtype |
|
|
| is_dynamic = False |
| if hasattr(activation_post_process, "is_dynamic"): |
| is_dynamic = activation_post_process.is_dynamic |
|
|
| def add_dequantize_op_kwargs(dequantize_op, input_node): |
| dequantize_op_kwargs = {} |
| if "val" in input_node.meta: |
| dq_out_dtype = input_node.meta["val"].dtype |
| if dq_out_dtype != torch.float32: |
| dequantize_op_kwargs = {"out_dtype": dq_out_dtype} |
| return dequantize_op_kwargs |
|
|
| if dtype in SUPPORTED_QDTYPES and (not is_dynamic): |
| |
| |
|
|
| |
|
|
| |
| node_type = "call_function" |
| quantize_op: Optional[Callable] = None |
| scale, zero_point = activation_post_process.calculate_qparams() |
| if is_per_channel(activation_post_process.qscheme): |
| ch_axis = int(activation_post_process.ch_axis) |
| quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default |
| dequantize_op = ( |
| torch.ops.quantized_decomposed.dequantize_per_channel.default |
| ) |
| quant_min = activation_post_process.quant_min |
| quant_max = activation_post_process.quant_max |
| dtype_ = to_underlying_dtype(dtype) |
| qparams = { |
| "_scale_": scale, |
| "_zero_point_": zero_point, |
| "_axis_": ch_axis, |
| "_quant_min_": quant_min, |
| "_quant_max_": quant_max, |
| "_dtype_": dtype_, |
| } |
| else: |
| quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default |
| dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| scale = float(scale) |
| zero_point = int(zero_point) |
| quant_min = activation_post_process.quant_min |
| quant_max = activation_post_process.quant_max |
| dtype_ = to_underlying_dtype(dtype) |
| qparams = { |
| "_scale_": scale, |
| "_zero_point_": zero_point, |
| "_quant_min_": quant_min, |
| "_quant_max_": quant_max, |
| "_dtype_": dtype_, |
| } |
|
|
| |
| with graph.inserting_before(node): |
| input_node = node.args[0] |
| quantize_op_inputs = [input_node] |
| for key, value_or_node in qparams.items(): |
| |
| |
| if key in ["_scale_", "_zero_point_"] and ( |
| not isinstance(value_or_node, (float, int)) |
| ): |
| |
| |
| |
| |
| |
| |
| |
| |
| qparam_node = create_getattr_from_value( |
| model, |
| graph, |
| module_path + prefix + key, |
| value_or_node, |
| model_device, |
| ) |
| quantize_op_inputs.append(qparam_node) |
| else: |
| |
| quantize_op_inputs.append(value_or_node) |
|
|
| quantized_node = graph.create_node( |
| node_type, quantize_op, tuple(quantize_op_inputs), {} |
| ) |
| |
| dq_inputs = [quantized_node] + quantize_op_inputs[1:] |
| dequantized_node = graph.call_function( |
| dequantize_op, |
| tuple(dq_inputs), |
| add_dequantize_op_kwargs(dequantize_op, input_node), |
| ) |
|
|
| node.replace_all_uses_with(dequantized_node) |
| |
| if ( |
| CUSTOM_KEY in node.meta |
| and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] |
| ): |
| if CUSTOM_KEY not in dequantized_node.meta: |
| dequantized_node.meta[CUSTOM_KEY] = {} |
| dequantized_node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ |
| CUSTOM_KEY |
| ][NUMERIC_DEBUG_HANDLE_KEY] |
| graph.erase_node(node) |
| elif is_dynamic: |
| |
|
|
| |
| node_type = "call_function" |
| quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor |
| |
| |
| |
| |
| dtype_ = to_underlying_dtype(dtype) |
| assert dtype_ in [torch.uint8, torch.int8], ( |
| "only uint8 and int8 are supported in reference flow for " |
| "dynamic quantization right now" |
| ) |
| quant_min = activation_post_process.quant_min |
| quant_max = activation_post_process.quant_max |
| qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine) |
| eps = getattr(activation_post_process, "eps", torch.finfo(torch.float32).eps) |
| |
| |
| |
| qparams = { |
| "_quant_min_": quant_min, |
| "_quant_max_": quant_max, |
| "_eps_": eps, |
| "_dtype_": dtype_, |
| } |
|
|
| choose_qparams_op = _QSCHEME_TO_CHOOSE_QPARAMS_OP[qscheme] |
| |
| with graph.inserting_before(node): |
| input_node = node.args[0] |
| choose_qparams_op_inputs = [node.args[0]] |
| for key, value in qparams.items(): |
| |
| |
| choose_qparams_op_inputs.append(value) |
| choose_qparams_node = graph.create_node( |
| "call_function", choose_qparams_op, tuple(choose_qparams_op_inputs), {} |
| ) |
| |
| scale_node = graph.create_node( |
| "call_function", operator.getitem, (choose_qparams_node, 0), {} |
| ) |
| zero_point_node = graph.create_node( |
| "call_function", operator.getitem, (choose_qparams_node, 1), {} |
| ) |
| quant_min = qparams["_quant_min_"] |
| quant_max = qparams["_quant_max_"] |
| dtype = qparams["_dtype_"] |
| qparams = { |
| "_scale_": scale_node, |
| "_zero_point_": zero_point_node, |
| "_quant_min_": quant_min, |
| "_quant_max_": quant_max, |
| "_dtype_": dtype, |
| } |
|
|
| |
| with graph.inserting_before(node): |
| input_node = node.args[0] |
| quantize_op_inputs = [input_node] |
| for key, value_or_node in qparams.items(): |
| |
| |
| if key in ["_scale_", "_zero_point_"]: |
| |
| |
| qparam_node = value_or_node |
| quantize_op_inputs.append(qparam_node) |
| else: |
| |
| |
| quantize_op_inputs.append(value_or_node) |
|
|
| quantized_node = graph.create_node( |
| node_type, quantize_op, tuple(quantize_op_inputs), {} |
| ) |
| |
| dq_inputs = [quantized_node] + quantize_op_inputs[1:] |
| |
| |
| |
| dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor |
| dequantized_node = graph.call_function( |
| dequantize_op, |
| tuple(dq_inputs), |
| add_dequantize_op_kwargs(dequantize_op, input_node), |
| ) |
|
|
| node.replace_all_uses_with(dequantized_node) |
| |
| if NUMERIC_DEBUG_HANDLE_KEY in node.meta: |
| dequantized_node.meta[NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ |
| NUMERIC_DEBUG_HANDLE_KEY |
| ] |
| graph.erase_node(node) |
| elif dtype == torch.float16: |
| |
| dtype_convert_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse |
| with graph.inserting_before(node): |
| input_node = node.args[0] |
| convert_fp16_node = graph.create_node( |
| "call_function", dtype_convert_op, (input_node, torch.float16), {} |
| ) |
| convert_fp32_node = graph.create_node( |
| "call_function", dtype_convert_op, (convert_fp16_node, torch.float), {} |
| ) |
| node.replace_all_uses_with(convert_fp32_node) |
| graph.erase_node(node) |
|
|
| |
| |
|
|
|
|
| def _replace_observer_with_quantize_dequantize_node( |
| model: torch.fx.GraphModule, |
| node: Node, |
| modules: dict[str, torch.nn.Module], |
| node_name_to_scope: dict[str, tuple[str, type]], |
| node_name_to_qconfig: dict[str, QConfigAny], |
| model_device: Optional[torch.device] = None, |
| ) -> None: |
| """Replace activation_post_process module call node with quantize and |
| dequantize node |
| |
| Before: |
| ... -> observer_0(x) -> ... |
| After: |
| ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... |
| """ |
| assert modules is not None |
| assert isinstance(node.target, str) |
| graph = model.graph |
| module_path, prefix = _get_module_path_and_prefix( |
| node, node_name_to_scope, node_name_to_qconfig |
| ) |
| activation_post_process = modules[node.target] |
| |
| |
| skip_replacement = all( |
| _has_none_qconfig(n, node_name_to_qconfig) |
| for n in list(node.args) + list(node.users.keys()) |
| ) |
| if skip_replacement or not _is_conversion_supported(activation_post_process): |
| |
| |
| with graph.inserting_before(node): |
| node.replace_all_uses_with(node.args[0]) |
| graph.erase_node(node) |
| return |
|
|
| |
| dtype = activation_post_process.dtype |
|
|
| is_dynamic = False |
| if hasattr(activation_post_process, "is_dynamic"): |
| is_dynamic = activation_post_process.is_dynamic |
|
|
| if dtype in [ |
| torch.quint8, |
| torch.qint8, |
| torch.qint32, |
| torch.float8_e5m2, |
| torch.float8_e4m3fn, |
| ] and (not is_dynamic): |
| |
| |
|
|
| |
|
|
| |
| |
| node_type = "call_function" |
| quantize_op: Optional[Callable] = None |
| scale, zero_point = activation_post_process.calculate_qparams() |
| if is_per_channel(activation_post_process.qscheme): |
| ch_axis = int(activation_post_process.ch_axis) |
| qparams = { |
| "_scale_": scale, |
| "_zero_point_": zero_point, |
| "_axis_": ch_axis, |
| "_dtype_": dtype, |
| } |
| quantize_op = torch.quantize_per_channel |
| else: |
| scale = float(scale) |
| zero_point = int(zero_point) |
| qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype} |
| quantize_op = torch.quantize_per_tensor |
|
|
| |
| with graph.inserting_before(node): |
| input_node = node.args[0] |
| quantize_op_inputs = [input_node] |
| for key, value_or_node in qparams.items(): |
| |
| |
| if key in ["_scale_", "_zero_point_"]: |
| |
| |
| qparam_node = create_getattr_from_value( |
| model, |
| graph, |
| module_path + prefix + key, |
| value_or_node, |
| model_device, |
| ) |
| quantize_op_inputs.append(qparam_node) |
| else: |
| |
| quantize_op_inputs.append(value_or_node) |
|
|
| quantized_node = graph.create_node( |
| node_type, quantize_op, tuple(quantize_op_inputs), {} |
| ) |
| dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) |
| node.replace_all_uses_with(dequantized_node) |
| graph.erase_node(node) |
| elif is_dynamic: |
| |
|
|
| node_type = "call_function" |
| quantize_op = torch.quantize_per_tensor_dynamic |
| |
| |
| reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86") |
| qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range} |
|
|
| with graph.inserting_before(node): |
| input_node = node.args[0] |
| quantize_op_inputs = [input_node] |
| for key, value in qparams.items(): |
| quantize_op_inputs.append(value) |
|
|
| quantized_node = graph.create_node( |
| node_type, quantize_op, tuple(quantize_op_inputs), {} |
| ) |
| dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) |
| node.replace_all_uses_with(dequantized_node) |
| graph.erase_node(node) |
| elif dtype == torch.float16: |
| node_type = "call_method" |
| quantize_op = "to" |
| qparams = {"_dtype_": dtype} |
| with graph.inserting_before(node): |
| input_node = node.args[0] |
| quantize_op_inputs = [input_node] |
| for key, value in qparams.items(): |
| |
| |
| quantize_op_inputs.append(value) |
|
|
| quantized_node = graph.create_node( |
| node_type, quantize_op, tuple(quantize_op_inputs), {} |
| ) |
| dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) |
| node.replace_all_uses_with(dequantized_node) |
| graph.erase_node(node) |
|
|
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| def _replace_observer_or_dequant_stub_with_dequantize_node( |
| node: Node, graph: Graph |
| ) -> None: |
| call_custom_module_node = node.args[0] |
| assert isinstance(call_custom_module_node, Node), ( |
| f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" |
| ) |
| node.replace_all_uses_with(call_custom_module_node) |
| graph.erase_node(node) |
| _insert_dequantize_node(call_custom_module_node, graph) |
|
|
|
|
| def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool: |
| dtype = activation_post_process.dtype |
|
|
| is_dynamic = False |
| if hasattr(activation_post_process, "is_dynamic"): |
| is_dynamic = activation_post_process.is_dynamic |
|
|
| return ( |
| (dtype in SUPPORTED_QDTYPES and (not is_dynamic)) |
| or is_dynamic |
| or dtype == torch.float16 |
| ) |
|
|
|
|
| def _has_none_qconfig( |
| node: Argument, node_name_to_qconfig: dict[str, QConfigAny] |
| ) -> bool: |
| """Check if a node has a qconfig of None, i.e. user requested to not quantize |
| the node |
| """ |
| return ( |
| isinstance(node, Node) |
| and node.name in node_name_to_qconfig |
| and node_name_to_qconfig[node.name] is None |
| ) |
|
|
|
|
| def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None: |
| """Extract the subgraph that produces the weight for dynamic quant |
| or weight only quant node and run the subgraph to observe the weight. |
| Note that the observers of dynamic quant or weight only quant ops are |
| run during the convert step. |
| """ |
| for node in observed.graph.nodes: |
| if node.op != "call_function": |
| continue |
| for node_arg in node.args: |
| |
| if node_arg and node_arg_is_weight(node, node_arg): |
| weight_observer_nodes = collect_producer_nodes(node_arg) |
| if weight_observer_nodes is None: |
| continue |
| weight_observer_module = graph_module_from_producer_nodes( |
| observed, weight_observer_nodes |
| ) |
| |
| weight_observer_module() |
|
|
|
|
| def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None: |
| """If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node, |
| we'll recursively remove the dequantize Node |
| """ |
| if isinstance(arg, Node) and arg.op == "call_method" and arg.target == "dequantize": |
| quantize_node = arg.args[0] |
| |
| |
| node.replace_input_with(arg, quantize_node) |
| elif isinstance(arg, (list, tuple)): |
| for arg_element in arg: |
| _maybe_recursive_remove_dequantize(arg_element, node, graph) |
| elif isinstance(arg, dict): |
| for arg_element in arg.values(): |
| _maybe_recursive_remove_dequantize(arg_element, node, graph) |
| else: |
| warnings.warn( |
| f"Unsupported node type in recursive remove dequantize: {type(arg)}" |
| ) |
|
|
|
|
| def _get_module_path_and_prefix( |
| obs_node: Node, |
| node_name_to_scope: dict[str, tuple[str, type]], |
| node_name_to_qconfig: dict[str, QConfigAny], |
| ) -> tuple[str, str]: |
| """Given and observer node, get the `Scope` or the fully qualified name for |
| the submodule containing the observed node, also return a prefix of "_input" |
| when the observed node is an input of a F.linear op, and not the output of another |
| quantized op. |
| TODO: this logic is hacky, we should think about how to remove it or make it more |
| general |
| """ |
| observed_node = obs_node.args[0] |
| |
| |
| |
| |
| assert isinstance(observed_node, Node), ( |
| f"Expecting observed node to be a Node, but got {observed_node}" |
| ) |
| is_input_observer_only = ( |
| node_name_to_qconfig[observed_node.name] is None |
| if observed_node.name in node_name_to_qconfig |
| else None |
| ) |
| if is_input_observer_only: |
| |
| |
| |
| users = list(obs_node.users) |
| first_linear_use_or_first_use = users[0] if users else None |
| linear_node = None |
| for n in users: |
| if n.op == "call_function" and n.target == torch.nn.functional.linear: |
| linear_node = n |
| break |
| if linear_node: |
| first_linear_use_or_first_use = linear_node |
| prefix = "_input" |
| else: |
| |
| first_linear_use_or_first_use = observed_node |
| prefix = "" |
|
|
| if ( |
| first_linear_use_or_first_use |
| and first_linear_use_or_first_use.name in node_name_to_scope |
| ): |
| module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name] |
| else: |
| |
| |
| |
| module_path = "" |
| return module_path, prefix |
|
|
|
|
| def _insert_dequantize_node(node: Node, graph: Graph) -> None: |
| """Inserts dequantize node for `node` in `graph`""" |
| with graph.inserting_after(node): |
| dequantize_node = graph.call_method("dequantize", (node,)) |
| for user_node in dict(node.users): |
| if user_node is not dequantize_node: |
| user_node.replace_input_with(node, dequantize_node) |
|
|
|
|
| def _maybe_get_observer_for_node( |
| node: Node, modules: dict[str, torch.nn.Module] |
| ) -> Optional[torch.nn.Module]: |
| """ |
| If the node is observed, return the observer |
| instance. Otherwise, return None. |
| """ |
| for maybe_obs_node in node.users.keys(): |
| if maybe_obs_node.op == "call_module": |
| maybe_obs = modules[str(maybe_obs_node.target)] |
| if _is_activation_post_process(maybe_obs): |
| return maybe_obs |
| return None |
|
|
|
|
| def convert_standalone_module( |
| node: Node, |
| modules: dict[str, torch.nn.Module], |
| model: torch.fx.GraphModule, |
| is_reference: bool, |
| backend_config: Optional[BackendConfig], |
| ) -> None: |
| """Converts a observed standalone module to a quantized standalone module by calling |
| the fx convert api, currently using the same `is_reference` flag as parent, but we may |
| changing this behavior in the future (e.g. separating quantization and lowering for |
| standalone module as well) |
| |
| Args: |
| - node: The call_module node of the observed standalone module |
| - modules: named_module of original model |
| - model: original model |
| - is_reference: a flag from parent provided by user to decide if we want to |
| produce a reference model or a fbgemm/qnnpack model |
| - backend_config: backend configuration of the target backend of quantization |
| """ |
| |
| if is_reference: |
| convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx |
| else: |
| convert_fn = torch.ao.quantization.quantize_fx.convert_fx |
| |
| |
| observed_standalone_module: GraphModule = modules[str(node.target)] |
| sm_input_quantized_idxs = observed_standalone_module.meta[ |
| "_observed_graph_module_attrs" |
| ].standalone_module_input_quantized_idxs |
| |
| args = list(node.args) |
| for idx in range(len(args)): |
| if idx in sm_input_quantized_idxs: |
| arg = args[idx] |
| if arg.op == "call_method" and arg.target == "dequantize": |
| quantize_node = arg.args[0] |
| node.replace_input_with(arg, quantize_node) |
| if len(arg.users) == 0: |
| model.graph.erase_node(arg) |
| |
| sm_output_quantized_idxs = observed_standalone_module.meta[ |
| "_observed_graph_module_attrs" |
| ].standalone_module_output_quantized_idxs |
| if len(sm_output_quantized_idxs) > 0: |
| assert sm_output_quantized_idxs[0] == 0, "Currently only quantized" |
| "output idxs = [0] is supported" |
|
|
| |
| |
| _insert_dequantize_node(node, model.graph) |
|
|
| |
| |
| quantized_standalone_module = convert_fn( |
| observed_standalone_module, backend_config=backend_config |
| ) |
| parent_name, name = _parent_name(node.target) |
| |
| setattr(modules[parent_name], name, quantized_standalone_module) |
| modules[str(node.target)] = quantized_standalone_module |
|
|
|
|
| def convert_weighted_module( |
| node: Node, |
| modules: dict[str, torch.nn.Module], |
| observed_node_names: set[str], |
| node_name_to_qconfig: dict[str, QConfigAny], |
| backend_config: BackendConfig, |
| is_decomposed: bool = False, |
| is_reference: bool = False, |
| model_device: Optional[torch.device] = None, |
| ) -> None: |
| """Convert a weighted module to reference quantized module in the model |
| If the QConfig of a QAT module is not set, the module will still be converted to |
| a float module. |
| |
| Args: |
| - node: The call_module node of the observed standalone module |
| - modules: named_module of original model |
| - observed_node_names: names for the set of observed fx node, we can skip |
| this conversion if the node is not observed |
| """ |
| original_module = modules[str(node.target)] |
| qconfig: QConfigAny = original_module.qconfig |
| weight_post_process = None |
| qat_module_classes = get_qat_module_classes(backend_config) |
|
|
| if isinstance(original_module, qat_module_classes): |
| |
| |
| |
| weight_post_process = original_module.weight_fake_quant |
| original_module = original_module.to_float() |
| |
| parent_name, name = _parent_name(node.target) |
| setattr(modules[parent_name], name, original_module) |
|
|
| is_observed = node.name in observed_node_names |
| |
| if ( |
| qconfig is None |
| or _has_none_qconfig(node, node_name_to_qconfig) |
| or not is_observed |
| ): |
| return |
|
|
| |
| pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) |
| dtype_configs = pattern_to_dtype_configs.get(type(original_module), []) |
| if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs): |
| return |
|
|
| |
| is_weight_quantized = weight_is_quantized(qconfig) |
|
|
| |
| |
| if not is_weight_quantized: |
| return |
|
|
| fused_module = None |
| float_module = original_module |
| |
| if isinstance(original_module, torch.ao.nn.intrinsic._FusedModule): |
| fused_module = float_module |
| float_module = fused_module[0] |
|
|
| |
| |
| wq_or_wq_dict = {"is_decomposed": is_decomposed} |
| if isinstance(float_module, torch.nn.RNNCellBase): |
| weight_post_process_ih = qconfig.weight() |
| weight_post_process_hh = qconfig.weight() |
| weight_post_process_ih(float_module.weight_ih) |
| weight_post_process_hh(float_module.weight_hh) |
| weight_qparams_ih = get_qparam_dict(weight_post_process_ih) |
| weight_qparams_hh = get_qparam_dict(weight_post_process_hh) |
| wq_or_wq_dict.update( |
| { |
| "weight_ih": weight_qparams_ih, |
| "weight_hh": weight_qparams_hh, |
| } |
| ) |
| elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)): |
| |
| |
| for wn in float_module._flat_weights_names: |
| if hasattr(float_module, wn) and wn.startswith("weight"): |
| weight = getattr(float_module, wn) |
| weight_post_process = qconfig.weight() |
| if weight_post_process.dtype == torch.qint8: |
| weight_post_process(weight) |
| wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process) |
| else: |
| |
| |
| is_ptq = weight_post_process is None |
| if is_ptq: |
| weight_post_process = qconfig.weight() |
| if model_device is not None: |
| device = model_device |
| else: |
| device = assert_and_get_unique_device(float_module) |
| if device: |
| weight_post_process.to(device) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| is_qat = not is_ptq |
| if not (is_decomposed and is_reference and is_qat): |
| weight_post_process(float_module.weight) |
|
|
| wq_or_wq_dict.update(get_qparam_dict(weight_post_process)) |
|
|
| |
| |
| |
| root_module_to_quantized_reference_module = ( |
| get_root_module_to_quantized_reference_module(backend_config) |
| ) |
| ref_qmodule_cls = root_module_to_quantized_reference_module.get( |
| type_before_parametrizations(float_module), None |
| ) |
| assert ref_qmodule_cls is not None, ( |
| f"No reference quantized module class configured for {type_before_parametrizations(float_module)}" |
| ) |
| ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) |
| if fused_module is not None: |
| fused_module[0] = ref_qmodule |
| else: |
| parent_name, name = _parent_name(node.target) |
| setattr(modules[parent_name], name, ref_qmodule) |
|
|
|
|
| def _remove_previous_dequantize_in_custom_module( |
| node: Node, prev_node: Node, graph: Graph |
| ) -> None: |
| """ |
| Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows: |
| |
| Before: quantize - dequantize - custom_module |
| After: quantize - custom_module |
| \\ - dequantize |
| """ |
| |
| assert isinstance(prev_node, Node), ( |
| f"Expecting the argument for custom module node to be a Node, but got {prev_node}" |
| ) |
| if prev_node.op == "call_method" and prev_node.target == "dequantize": |
| node.replace_input_with(prev_node, prev_node.args[0]) |
| |
| if len(prev_node.users) == 0: |
| graph.erase_node(prev_node) |
|
|
|
|
| def convert_custom_module( |
| node: Node, |
| graph: Graph, |
| modules: dict[str, torch.nn.Module], |
| custom_module_class_mapping: dict[QuantType, dict[type, type]], |
| statically_quantized_custom_module_nodes: set[Node], |
| ) -> None: |
| """Converts an observed custom module to a quantized custom module based on |
| `custom_module_class_mapping` |
| For static quantization, we'll also remove the previous `dequantize` node and |
| attach the observer node for output to the module, the observer for the node |
| will be converted to a dequantize node instead of quantize-dequantize pairs |
| later in the graph. In the end we would have a quantized custom module that |
| has the same interface as a default quantized module in nn.quantized namespace, |
| i.e. quantized input and quantized output. |
| |
| Args: |
| - node: The call_module node of the observed standalone module |
| - graph: The graph containing the node |
| - modules: named_module of original model |
| - custom_module_class_mapping: mapping from observed custom module class to |
| quantized custom module class, used to swap custom modules |
| - statically_quantized_custom_module_nodes: we'll add the custom module node |
| if we find it is statically quantized, this will be used later when converting |
| observers to quant/dequant node pairs, if the observed node is a statically |
| quantized custom module nodes, we'll convert the observer to a dequantize node, |
| this is to keep the interface the same as the default quantized module. |
| TODO: maybe we want to redesign this part to align with reference model design |
| as well, but there has been some discussions around the interface, so we can do |
| it later. |
| """ |
| observed_custom_module = modules[str(node.target)] |
| qconfig = observed_custom_module.qconfig |
| if activation_is_statically_quantized(qconfig): |
| statically_quantized_custom_module_nodes.add(node) |
| if _is_custom_module_lstm(node, modules): |
| |
| |
| assert ( |
| len(node.args) == 2 |
| and isinstance(node.args[1], tuple) |
| and len(node.args[1]) == 2 |
| ) |
| (inputs, (hidden0, hidden1)) = node.args |
| assert isinstance(inputs, Node) |
| assert isinstance(hidden0, Node) |
| assert isinstance(hidden1, Node) |
| _remove_previous_dequantize_in_custom_module(node, inputs, graph) |
| _remove_previous_dequantize_in_custom_module(node, hidden0, graph) |
| _remove_previous_dequantize_in_custom_module(node, hidden1, graph) |
| elif _is_custom_module_mha(node, modules): |
| |
| |
| |
| |
| |
| |
| assert len(node.args) == 3 |
| query, key, value = node.args |
| assert isinstance(query, Node) |
| assert isinstance(key, Node) |
| assert isinstance(value, Node) |
| _remove_previous_dequantize_in_custom_module(node, query, graph) |
| _remove_previous_dequantize_in_custom_module(node, key, graph) |
| _remove_previous_dequantize_in_custom_module(node, value, graph) |
| else: |
| |
| arg = node.args[0] |
| assert isinstance(arg, Node) |
| _remove_previous_dequantize_in_custom_module(node, arg, graph) |
| |
| activation_post_process = _maybe_get_observer_for_node(node, modules) |
| assert activation_post_process is not None |
| observed_custom_module.activation_post_process = activation_post_process |
|
|
| |
| quantized_custom_module_class = get_swapped_custom_module_class( |
| observed_custom_module, custom_module_class_mapping, qconfig |
| ) |
| quantized_custom_module = quantized_custom_module_class.from_observed( |
| observed_custom_module |
| ) |
| parent_name, name = _parent_name(node.target) |
| setattr(modules[parent_name], name, quantized_custom_module) |
|
|
|
|
| def convert( |
| model: GraphModule, |
| is_reference: bool = False, |
| convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None, |
| is_standalone_module: bool = False, |
| _remove_qconfig_flag: bool = True, |
| qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None, |
| backend_config: Union[BackendConfig, dict[str, Any], None] = None, |
| is_decomposed: bool = False, |
| keep_original_weights: bool = False, |
| ) -> GraphModule: |
| """ |
| We will convert an observed model (a module with observer calls) to a reference |
| quantized model, the rule is simple: |
| 1. for each observer module call in the graph, we'll convert it to calls to |
| quantize and dequantize functions based on the observer instance |
| 2. for weighted operations like linear/conv, we need to convert them to reference |
| quantized module, this requires us to know whether the dtype configured for the |
| weight is supported in the backend, this is done in prepare step and the result |
| is stored in observed_node_names, we can decide whether we need to swap the |
| module based on this set |
| |
| Args: |
| * `is_standalone_module`: when this flag is True, it means we are quantizing |
| a submodule that is not inlined in parent module, and will be quantized |
| separately as one unit. |
| |
| * `is_decomposed`: a boolean flag to indicate whether we want to use the |
| quantize operator for decomposed quantized tensor |
| (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone |
| quantized tensor (torch.quantize_per_tensor) |
| |
| Returns: |
| a quantized standalone module, whether input/output is quantized is |
| specified by prepare_custom_config, with |
| input_quantized_idxs, output_quantized_idxs, please |
| see docs for :func:`~torch.ao.quantization.prepare_fx` for details |
| """ |
| if convert_custom_config is None: |
| convert_custom_config = ConvertCustomConfig() |
|
|
| if isinstance(convert_custom_config, dict): |
| warnings.warn( |
| "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " |
| "in a future version. Please pass in a ConvertCustomConfig instead.", |
| FutureWarning, |
| stacklevel=2, |
| ) |
| convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) |
|
|
| if isinstance(qconfig_mapping, dict): |
| warnings.warn( |
| "Passing a QConfig dictionary to convert is deprecated and will not be supported " |
| "in a future version. Please pass in a QConfigMapping instead.", |
| FutureWarning, |
| stacklevel=2, |
| ) |
| qconfig_mapping = ( |
| QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None |
| ) |
| qconfig_mapping = copy.deepcopy(qconfig_mapping) |
| assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping) |
|
|
| if isinstance(backend_config, dict): |
| warnings.warn( |
| "Passing a backend_config_dict to prepare is deprecated and will not be supported " |
| "in a future version. Please pass in a BackendConfig instead.", |
| FutureWarning, |
| stacklevel=2, |
| ) |
| backend_config = BackendConfig.from_dict(backend_config) |
|
|
| if backend_config is None: |
| backend_config = get_native_backend_config() |
|
|
| assert _is_observed_module(model), "incoming model must be produced by prepare_fx" |
| observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] |
| node_name_to_scope: dict[str, tuple[str, type]] = ( |
| observed_graph_module_attrs.node_name_to_scope |
| ) |
| prepare_custom_config: PrepareCustomConfig = ( |
| observed_graph_module_attrs.prepare_custom_config |
| ) |
| observed_node_names: set[str] = observed_graph_module_attrs.observed_node_names |
| node_name_to_qconfig: dict[str, QConfigAny] = ( |
| observed_graph_module_attrs.node_name_to_qconfig |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| modules = dict(model.named_modules(remove_duplicate=False)) |
|
|
| |
| |
| if qconfig_mapping: |
| prepare_qconfig_mapping: QConfigMapping = ( |
| observed_graph_module_attrs.qconfig_mapping |
| ) |
| modules_copy = copy.deepcopy(modules) |
|
|
| if observed_graph_module_attrs.is_qat: |
| _update_qconfig_for_qat(qconfig_mapping, backend_config) |
| _update_qconfig_for_fusion(model, qconfig_mapping) |
|
|
| _compare_prepare_convert_qconfig_mappings( |
| prepare_qconfig_mapping, qconfig_mapping |
| ) |
| convert_node_name_to_qconfig = _generate_node_name_to_qconfig( |
| model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope |
| ) |
| |
| |
| |
| for k, v in node_name_to_qconfig.items(): |
| assert k in convert_node_name_to_qconfig, ( |
| f"Expected key {k} in convert node_name_to_qconfig" |
| ) |
| if convert_node_name_to_qconfig[k] is not None: |
| assert qconfig_equals(v, convert_node_name_to_qconfig[k]), ( |
| f"Expected k {k} to have the same value in prepare and convert QConfigMappings, " |
| f"but {v} was updated to {convert_node_name_to_qconfig[k]}" |
| ) |
| node_name_to_qconfig = convert_node_name_to_qconfig |
|
|
| custom_module_classes = get_custom_module_class_keys( |
| convert_custom_config.observed_to_quantized_mapping |
| ) |
| custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping |
|
|
| if observed_graph_module_attrs.equalization_node_name_to_qconfig is not None: |
| |
| |
| |
| weight_eq_obs_dict = update_obs_for_equalization(model, modules) |
| convert_eq_obs(model, modules, weight_eq_obs_dict) |
|
|
| |
| |
| _run_weight_observers(model, backend_config) |
|
|
| |
| |
| placeholder_node_seen_cnt = 0 |
| input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes |
| output_quantized_idxs: list[int] = prepare_custom_config.output_quantized_indexes |
|
|
| root_module_to_quantized_reference_module = ( |
| get_root_module_to_quantized_reference_module(backend_config) |
| ) |
| |
| root_module_classes = tuple(root_module_to_quantized_reference_module.keys()) |
| qat_module_classes = get_qat_module_classes(backend_config) |
| fused_module_classes = get_fused_module_classes(backend_config) |
| statically_quantized_custom_module_nodes: set[Node] = set() |
| model_device = assert_and_get_unique_device(model) |
|
|
| for node in list(model.graph.nodes): |
| if node.op == "placeholder": |
| cur_placeholder_node_idx = placeholder_node_seen_cnt |
| placeholder_node_seen_cnt += 1 |
| if cur_placeholder_node_idx in input_quantized_idxs: |
| |
| |
| |
| |
| _insert_dequantize_node(node, model.graph) |
| elif node.op == "output": |
| |
| if len(output_quantized_idxs) == 0: |
| continue |
| |
| |
| |
| return_node = node |
| output = node.args[0] |
| |
| if isinstance(output, (list, tuple)): |
| for idx in output_quantized_idxs: |
| _maybe_recursive_remove_dequantize( |
| output[idx], return_node, model.graph |
| ) |
| elif isinstance(output, (Node, dict)): |
| |
| |
| |
| if 0 in output_quantized_idxs: |
| _maybe_recursive_remove_dequantize(output, return_node, model.graph) |
| else: |
| warnings.warn( |
| f"Unsupported node type for output_quantized_idxs: {type(output)}" |
| ) |
| elif node.op == "call_module": |
| mod = _get_module(node, modules) |
| assert mod is not None |
| if _is_activation_post_process(mod): |
| observed_node = node.args[0] |
| if observed_node in statically_quantized_custom_module_nodes: |
| _replace_observer_or_dequant_stub_with_dequantize_node( |
| node, model.graph |
| ) |
| else: |
| if is_decomposed: |
| _replace_observer_with_quantize_dequantize_node_decomposed( |
| model, |
| node, |
| modules, |
| node_name_to_scope, |
| node_name_to_qconfig, |
| model_device, |
| ) |
| else: |
| _replace_observer_with_quantize_dequantize_node( |
| model, |
| node, |
| modules, |
| node_name_to_scope, |
| node_name_to_qconfig, |
| model_device, |
| ) |
| elif isinstance(mod, DeQuantStub): |
| _replace_observer_or_dequant_stub_with_dequantize_node( |
| node, model.graph |
| ) |
| elif _is_observed_standalone_module(mod): |
| convert_standalone_module( |
| node, modules, model, is_reference, backend_config |
| ) |
| |
| |
| elif type_before_parametrizations(mod) in set(root_module_classes).union( |
| qat_module_classes |
| ).union(fused_module_classes): |
| |
| |
| if ( |
| type_before_parametrizations(mod) in fused_module_classes |
| and type_before_parametrizations(mod[0]) not in root_module_classes |
| ): |
| continue |
| convert_weighted_module( |
| node, |
| modules, |
| observed_node_names, |
| node_name_to_qconfig, |
| backend_config, |
| is_decomposed, |
| is_reference, |
| model_device, |
| ) |
| elif type_before_parametrizations(mod) in custom_module_classes: |
| convert_custom_module( |
| node, |
| model.graph, |
| modules, |
| custom_module_class_mapping, |
| statically_quantized_custom_module_nodes, |
| ) |
|
|
| |
| model.graph.eliminate_dead_code() |
| model = GraphModule(model, model.graph) |
|
|
| |
| if not is_reference: |
| model = lower_to_fbgemm( |
| model, node_name_to_qconfig, node_name_to_scope, keep_original_weights |
| ) |
|
|
| |
| |
| |
| if _remove_qconfig_flag: |
| _remove_qconfig(model) |
| model.delete_all_unused_submodules() |
| model.meta.pop("_observed_graph_module_attrs", None) |
| return model |
|
|