|
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union, Type |
|
|
from torch.ao.quantization.quant_type import QuantType |
|
|
import torch |
|
|
import copy |
|
|
import warnings |
|
|
from torch.fx import ( |
|
|
GraphModule, |
|
|
) |
|
|
from torch.fx.graph import ( |
|
|
Graph, |
|
|
Node, |
|
|
Argument, |
|
|
) |
|
|
from ..utils import ( |
|
|
activation_is_statically_quantized, |
|
|
weight_is_quantized, |
|
|
get_qparam_dict, |
|
|
_parent_name, |
|
|
get_swapped_custom_module_class, |
|
|
) |
|
|
from ..qconfig import ( |
|
|
QConfigAny, |
|
|
qconfig_equals |
|
|
) |
|
|
from ..qconfig_mapping import QConfigMapping |
|
|
from ..qconfig_mapping_utils import ( |
|
|
update_qconfig_for_qat, |
|
|
) |
|
|
from .qconfig_mapping_utils import ( |
|
|
generate_qconfig_map, |
|
|
compare_prepare_convert_qconfig_mappings, |
|
|
update_qconfig_for_fusion, |
|
|
is_qconfig_supported_by_dtype_configs, |
|
|
) |
|
|
from torch.ao.quantization.backend_config.utils import ( |
|
|
get_root_module_to_quantized_reference_module, |
|
|
get_pattern_to_dtype_configs, |
|
|
get_fused_module_classes, |
|
|
get_qat_module_classes, |
|
|
) |
|
|
from torch.ao.quantization.backend_config import ( |
|
|
BackendConfig, |
|
|
get_native_backend_config, |
|
|
) |
|
|
from .graph_module import ( |
|
|
QuantizedGraphModule, |
|
|
is_observed_module, |
|
|
is_observed_standalone_module, |
|
|
) |
|
|
from ._equalize import update_obs_for_equalization, convert_eq_obs |
|
|
from torch.nn.utils.parametrize import type_before_parametrizations |
|
|
from .utils import ( |
|
|
_get_module, |
|
|
_is_custom_module_lstm, |
|
|
get_custom_module_class_keys, |
|
|
get_quantize_node_info, |
|
|
create_getattr_from_value, |
|
|
collect_producer_nodes, |
|
|
graph_module_from_producer_nodes, |
|
|
node_arg_is_weight, |
|
|
) |
|
|
from torch.ao.quantization.quantize import ( |
|
|
_remove_qconfig, |
|
|
is_activation_post_process, |
|
|
) |
|
|
from torch.ao.quantization.stubs import DeQuantStub |
|
|
from .custom_config import ( |
|
|
ConvertCustomConfig, |
|
|
PrepareCustomConfig, |
|
|
) |
|
|
from .lower_to_fbgemm import lower_to_fbgemm |
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"convert", |
|
|
"convert_custom_module", |
|
|
"convert_standalone_module", |
|
|
"convert_weighted_module", |
|
|
"duplicate_dequantize_node", |
|
|
"duplicate_quantize_dynamic_node", |
|
|
"get_module_path_and_prefix", |
|
|
"has_none_qconfig", |
|
|
"insert_dequantize_node", |
|
|
"maybe_get_observer_for_node", |
|
|
"maybe_recursive_remove_dequantize", |
|
|
"remove_extra_dequantize", |
|
|
"restore_state", |
|
|
"run_weight_observers", |
|
|
] |
|
|
|
|
|
|
|
|
def restore_state( |
|
|
observed: torch.nn.Module |
|
|
) -> Tuple[Dict[str, Tuple[str, type]], |
|
|
PrepareCustomConfig, |
|
|
Set[str]]: |
|
|
assert is_observed_module(observed), \ |
|
|
'incoming model must be produced by prepare_fx' |
|
|
prepare_custom_config: PrepareCustomConfig = observed._prepare_custom_config |
|
|
node_name_to_scope: Dict[str, Tuple[str, type]] = observed._node_name_to_scope |
|
|
observed_node_names: Set[str] = observed._observed_node_names |
|
|
return node_name_to_scope, prepare_custom_config, observed_node_names |
|
|
|
|
|
def has_none_qconfig(node: Argument, qconfig_map: Dict[str, QConfigAny]) -> bool: |
|
|
""" Check if a node has a qconfig of None, i.e. user requested to not quantize |
|
|
the node |
|
|
""" |
|
|
return isinstance(node, Node) and node.name in qconfig_map and qconfig_map[node.name] is None |
|
|
|
|
|
def run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None: |
|
|
""" Extract the subgraph that produces the weight for dynamic quant |
|
|
or weight only quant node and run the subgraph to observe the weight. |
|
|
Note that the observers of dynamic quant or weight only quant ops are |
|
|
run during the convert step. |
|
|
""" |
|
|
for node in observed.graph.nodes: |
|
|
if node.op != "call_function": |
|
|
continue |
|
|
for node_arg in node.args: |
|
|
|
|
|
if node_arg and node_arg_is_weight(node, node_arg, backend_config): |
|
|
weight_observer_nodes = collect_producer_nodes(node_arg) |
|
|
if weight_observer_nodes is None: |
|
|
continue |
|
|
weight_observer_module = \ |
|
|
graph_module_from_producer_nodes( |
|
|
observed, weight_observer_nodes) |
|
|
|
|
|
weight_observer_module() |
|
|
|
|
|
|
|
|
def duplicate_quantize_dynamic_node(quantized: QuantizedGraphModule) -> QuantizedGraphModule: |
|
|
quantized_root = quantized |
|
|
for node in quantized.graph.nodes: |
|
|
if (node.op == "call_function" and node.target == torch.quantize_per_tensor_dynamic): |
|
|
users = list(node.users) |
|
|
if len(users) > 1: |
|
|
for user in users: |
|
|
with quantized.graph.inserting_before(node): |
|
|
new_node = quantized.graph.create_node( |
|
|
"call_function", |
|
|
torch.quantize_per_tensor_dynamic, |
|
|
node.args, |
|
|
node.kwargs) |
|
|
user.replace_input_with(node, new_node) |
|
|
quantized.graph.erase_node(node) |
|
|
|
|
|
quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names) |
|
|
return quantized |
|
|
|
|
|
def duplicate_dequantize_node(quantized: QuantizedGraphModule) -> QuantizedGraphModule: |
|
|
""" |
|
|
If a dequantize node has multiple uses, duplicate it and create one dequantize node for each use. |
|
|
This is to enable the pattern matching to map from individual quant - dequant - ref_module to |
|
|
final quantized module. |
|
|
""" |
|
|
quantized_root = quantized |
|
|
for node in quantized.graph.nodes: |
|
|
if (node.op == "call_method" and node.target == "dequantize" or |
|
|
(node.op == "call_function" and node.target == torch.dequantize)): |
|
|
users = list(node.users) |
|
|
if len(users) > 1: |
|
|
for user in users: |
|
|
with quantized.graph.inserting_before(node): |
|
|
new_node = quantized.graph.create_node("call_method", "dequantize", node.args, {}) |
|
|
user.replace_input_with(node, new_node) |
|
|
quantized.graph.erase_node(node) |
|
|
|
|
|
quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names) |
|
|
return quantized |
|
|
|
|
|
def remove_extra_dequantize(quantized: QuantizedGraphModule) -> QuantizedGraphModule: |
|
|
""" |
|
|
Removes duplicate dequant nodes in the graph, for an operator that has multiple dequant nodes as a user, |
|
|
replace them with a single dequant node that can be shared across all the uses. |
|
|
""" |
|
|
quantized_root = quantized |
|
|
for node in quantized.graph.nodes: |
|
|
users = list(node.users) |
|
|
dequant_users = [user for user in node.users if user.op == "call_method" and user.target == "dequantize" or |
|
|
(user.op == "call_function" and user.target == torch.dequantize)] |
|
|
|
|
|
if len(dequant_users) > 1: |
|
|
with quantized.graph.inserting_after(node): |
|
|
unique_dq = quantized.graph.create_node("call_method", "dequantize", users[0].args, {}) |
|
|
for dequant in dequant_users: |
|
|
dequant.replace_all_uses_with(unique_dq) |
|
|
quantized.graph.erase_node(dequant) |
|
|
|
|
|
quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names) |
|
|
return quantized |
|
|
|
|
|
def maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph): |
|
|
""" If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node, |
|
|
we'll recursively remove the dequantize Node |
|
|
""" |
|
|
if isinstance(arg, Node) and \ |
|
|
arg.op == "call_method" and \ |
|
|
arg.target == "dequantize": |
|
|
quantize_node = arg.args[0] |
|
|
|
|
|
|
|
|
node.replace_input_with(arg, quantize_node) |
|
|
elif isinstance(arg, (list, tuple)): |
|
|
for arg_element in arg: |
|
|
maybe_recursive_remove_dequantize(arg_element, node, graph) |
|
|
elif isinstance(arg, dict): |
|
|
for arg_element in arg.values(): |
|
|
maybe_recursive_remove_dequantize(arg_element, node, graph) |
|
|
else: |
|
|
warnings.warn(f"Unsupported node type in recursive remove dequantize: {type(arg)}") |
|
|
|
|
|
def get_module_path_and_prefix( |
|
|
obs_node: Node, |
|
|
node_name_to_scope: Dict[str, Tuple[str, type]], |
|
|
qconfig_map: Dict[str, QConfigAny]): |
|
|
""" Given and observer node, get the `Scope` or the fully qualified name for |
|
|
the submodule containing the observed node, also return a prefix of "_input" |
|
|
when the observed node is an input of a F.linear op, and not the output of another |
|
|
quantized op. |
|
|
TODO: this logic is hacky, we should think about how to remove it or make it more |
|
|
general |
|
|
""" |
|
|
observed_node = obs_node.args[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(observed_node, Node), \ |
|
|
f"Expecting observed node to be a Node, but got {observed_node}" |
|
|
is_input_observer_only = qconfig_map[observed_node.name] is None if observed_node.name in qconfig_map else None |
|
|
if is_input_observer_only: |
|
|
|
|
|
|
|
|
|
|
|
users = list(obs_node.users) |
|
|
first_linear_use_or_first_use = users[0] if users else None |
|
|
linear_node = None |
|
|
for n in users: |
|
|
if n.op == "call_function" and n.target == torch.nn.functional.linear: |
|
|
linear_node = n |
|
|
break |
|
|
if linear_node: |
|
|
first_linear_use_or_first_use = linear_node |
|
|
prefix = "_input" |
|
|
else: |
|
|
|
|
|
first_linear_use_or_first_use = observed_node |
|
|
prefix = "" |
|
|
|
|
|
if first_linear_use_or_first_use and first_linear_use_or_first_use.name in node_name_to_scope: |
|
|
module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name] |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
module_path = "" |
|
|
return module_path, prefix |
|
|
|
|
|
def insert_dequantize_node( |
|
|
node: Node, |
|
|
graph: Graph): |
|
|
""" Inserts dequantize node for `node` in `graph` |
|
|
""" |
|
|
with graph.inserting_after(node): |
|
|
dequantize_node = graph.call_method("dequantize", (node,)) |
|
|
for user_node in dict(node.users): |
|
|
if user_node is not dequantize_node: |
|
|
user_node.replace_input_with(node, dequantize_node) |
|
|
|
|
|
def maybe_get_observer_for_node( |
|
|
node: Node, |
|
|
modules: Dict[str, torch.nn.Module] |
|
|
) -> Optional[torch.nn.Module]: |
|
|
""" |
|
|
If the node is observed, return the observer |
|
|
instance. Otherwise, return None. |
|
|
""" |
|
|
for maybe_obs_node, _ in node.users.items(): |
|
|
if maybe_obs_node.op == 'call_module': |
|
|
maybe_obs = modules[str(maybe_obs_node.target)] |
|
|
if is_activation_post_process(maybe_obs): |
|
|
return maybe_obs |
|
|
return None |
|
|
|
|
|
def convert_standalone_module( |
|
|
node: Node, |
|
|
modules: Dict[str, torch.nn.Module], |
|
|
model: torch.fx.GraphModule, |
|
|
is_reference: bool, |
|
|
backend_config: Optional[BackendConfig]): |
|
|
""" Converts a observed standalone module to a quantized standalone module by calling |
|
|
the fx convert api, currently using the same `is_reference` flag as parent, but we may |
|
|
changing this behavior in the future (e.g. separating quantization and lowering for |
|
|
standalone module as well) |
|
|
|
|
|
Args: |
|
|
- node: The call_module node of the observed standalone module |
|
|
- modules: named_module of original model |
|
|
- model: original model |
|
|
- is_reference: a flag from parent provided by user to decide if we want to |
|
|
produce a reference model or a fbgemm/qnnpack model |
|
|
- backend_config: backend configuration of the target backend of quantization |
|
|
""" |
|
|
|
|
|
if is_reference: |
|
|
convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx |
|
|
else: |
|
|
convert_fn = torch.ao.quantization.quantize_fx.convert_fx |
|
|
|
|
|
|
|
|
observed_standalone_module : GraphModule = modules[str(node.target)] |
|
|
sm_input_quantized_idxs = \ |
|
|
observed_standalone_module \ |
|
|
._standalone_module_input_quantized_idxs\ |
|
|
.tolist() |
|
|
|
|
|
args = list(node.args) |
|
|
for idx in range(len(args)): |
|
|
if idx in sm_input_quantized_idxs: |
|
|
arg = args[idx] |
|
|
if arg.op == "call_method" and arg.target == "dequantize": |
|
|
quantize_node = arg.args[0] |
|
|
node.replace_input_with(arg, quantize_node) |
|
|
if len(arg.users) == 0: |
|
|
model.graph.erase_node(arg) |
|
|
|
|
|
sm_output_quantized_idxs = \ |
|
|
observed_standalone_module \ |
|
|
._standalone_module_output_quantized_idxs \ |
|
|
.tolist() |
|
|
if len(sm_output_quantized_idxs) > 0: |
|
|
assert sm_output_quantized_idxs[0] == 0, "Currently only quantized" |
|
|
"output idxs = [0] is supported" |
|
|
|
|
|
|
|
|
|
|
|
insert_dequantize_node(node, model.graph) |
|
|
|
|
|
|
|
|
|
|
|
quantized_standalone_module = convert_fn( |
|
|
observed_standalone_module, |
|
|
backend_config=backend_config) |
|
|
parent_name, name = _parent_name(node.target) |
|
|
|
|
|
setattr(modules[parent_name], name, quantized_standalone_module) |
|
|
modules[str(node.target)] = quantized_standalone_module |
|
|
|
|
|
def convert_weighted_module( |
|
|
node: Node, |
|
|
modules: Dict[str, torch.nn.Module], |
|
|
observed_node_names: Set[str], |
|
|
qconfig_map: Dict[str, QConfigAny], |
|
|
backend_config: BackendConfig): |
|
|
""" Convert a weighted module to reference quantized module in the model |
|
|
If the QConfig of a QAT module is not set, the module will still be converted to |
|
|
a float module. |
|
|
|
|
|
Args: |
|
|
- node: The call_module node of the observed standalone module |
|
|
- modules: named_module of original model |
|
|
- observed_node_names: names for the set of observed fx node, we can skip |
|
|
this conversion if the node is not observed |
|
|
""" |
|
|
original_module = modules[str(node.target)] |
|
|
qconfig: QConfigAny = original_module.qconfig |
|
|
weight_post_process = None |
|
|
qat_module_classes = get_qat_module_classes(backend_config) |
|
|
|
|
|
if isinstance( |
|
|
original_module, |
|
|
qat_module_classes): |
|
|
|
|
|
|
|
|
|
|
|
weight_post_process = original_module.weight_fake_quant |
|
|
original_module = original_module.to_float() |
|
|
|
|
|
parent_name, name = _parent_name(node.target) |
|
|
setattr(modules[parent_name], name, original_module) |
|
|
|
|
|
is_observed = node.name in observed_node_names |
|
|
|
|
|
if qconfig is None or has_none_qconfig(node, qconfig_map) or not is_observed: |
|
|
return |
|
|
|
|
|
|
|
|
pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) |
|
|
dtype_configs = pattern_to_dtype_configs.get(type(original_module), []) |
|
|
if not is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs): |
|
|
return |
|
|
|
|
|
|
|
|
is_weight_quantized = weight_is_quantized(qconfig) |
|
|
|
|
|
|
|
|
|
|
|
if not is_weight_quantized: |
|
|
return |
|
|
|
|
|
fused_module = None |
|
|
float_module = original_module |
|
|
|
|
|
if isinstance(original_module, torch.nn.intrinsic._FusedModule): |
|
|
fused_module = float_module |
|
|
float_module = fused_module[0] |
|
|
|
|
|
|
|
|
|
|
|
wq_or_wq_dict = {} |
|
|
if isinstance(float_module, torch.nn.RNNCellBase): |
|
|
weight_post_process_ih = qconfig.weight() |
|
|
weight_post_process_hh = qconfig.weight() |
|
|
weight_post_process_ih(float_module.weight_ih) |
|
|
weight_post_process_hh(float_module.weight_hh) |
|
|
weight_qparams_ih = get_qparam_dict(weight_post_process_ih) |
|
|
weight_qparams_hh = get_qparam_dict(weight_post_process_hh) |
|
|
wq_or_wq_dict = { |
|
|
"weight_ih": weight_qparams_ih, |
|
|
"weight_hh": weight_qparams_hh, |
|
|
} |
|
|
elif isinstance(float_module, torch.nn.LSTM): |
|
|
|
|
|
|
|
|
for wn in float_module._flat_weights_names: |
|
|
if hasattr(float_module, wn) and wn.startswith("weight"): |
|
|
weight = getattr(float_module, wn) |
|
|
weight_post_process = qconfig.weight() |
|
|
if weight_post_process.dtype == torch.qint8: |
|
|
weight_post_process(weight) |
|
|
wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process) |
|
|
else: |
|
|
|
|
|
|
|
|
if weight_post_process is None: |
|
|
weight_post_process = qconfig.weight() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
weight_post_process(float_module.weight) |
|
|
wq_or_wq_dict = get_qparam_dict(weight_post_process) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config) |
|
|
ref_qmodule_cls = root_module_to_quantized_reference_module.get(type_before_parametrizations(float_module), None) |
|
|
assert ( |
|
|
ref_qmodule_cls is not None |
|
|
), f"No reference quantized module class configured for {type_before_parametrizations(float_module)}" |
|
|
ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) |
|
|
if fused_module is not None: |
|
|
fused_module[0] = ref_qmodule |
|
|
else: |
|
|
parent_name, name = _parent_name(node.target) |
|
|
setattr(modules[parent_name], name, ref_qmodule) |
|
|
|
|
|
def _remove_previous_dequantize_in_custom_module(node: Node, prev_node: Node, graph: Graph): |
|
|
""" |
|
|
Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows: |
|
|
|
|
|
Before: quantize - dequantize - custom_module |
|
|
After: quantize - custom_module |
|
|
\\ - dequantize |
|
|
""" |
|
|
|
|
|
assert isinstance(prev_node, Node), \ |
|
|
f"Expecting the argument for custom module node to be a Node, but got {prev_node}" |
|
|
if prev_node.op == "call_method" and prev_node.target == "dequantize": |
|
|
node.replace_input_with(prev_node, prev_node.args[0]) |
|
|
|
|
|
if len(prev_node.users) == 0: |
|
|
graph.erase_node(prev_node) |
|
|
|
|
|
def convert_custom_module( |
|
|
node: Node, |
|
|
graph: Graph, |
|
|
modules: Dict[str, torch.nn.Module], |
|
|
custom_module_class_mapping: Dict[QuantType, Dict[Type, Type]], |
|
|
statically_quantized_custom_module_nodes: Set[Node]): |
|
|
""" Converts an observed custom module to a quantized custom module based on |
|
|
`custom_module_class_mapping` |
|
|
For static quantization, we'll also remove the previous `dequantize` node and |
|
|
attach the observer node for output to the module, the observer for the node |
|
|
will be converted to a dequantize node instead of quantize-dequantize pairs |
|
|
later in the graph. In the end we would have a quantized custom module that |
|
|
has the same interface as a default quantized module in nn.quantized namespace, |
|
|
i.e. quantized input and quantized output. |
|
|
|
|
|
Args: |
|
|
- node: The call_module node of the observed standalone module |
|
|
- graph: The graph containing the node |
|
|
- modules: named_module of original model |
|
|
- custom_module_class_mapping: mapping from observed custom module class to |
|
|
quantized custom module class, used to swap custom modules |
|
|
- statically_quantized_custom_module_nodes: we'll add the custom module node |
|
|
if we find it is statically quantized, this will be used later when converting |
|
|
observers to quant/dequant node pairs, if the observed node is a statically |
|
|
quantized custom module nodes, we'll convert the observer to a dequantize node, |
|
|
this is to keep the interface the same as the default quantized module. |
|
|
TODO: maybe we want to redesign this part to align with reference model design |
|
|
as well, but there has been some discussions around the interface, so we can do |
|
|
it later. |
|
|
""" |
|
|
observed_custom_module = modules[str(node.target)] |
|
|
maybe_obs = maybe_get_observer_for_node(node, modules) |
|
|
qconfig = observed_custom_module.qconfig |
|
|
if activation_is_statically_quantized(qconfig): |
|
|
statically_quantized_custom_module_nodes.add(node) |
|
|
if _is_custom_module_lstm(node, modules): |
|
|
|
|
|
|
|
|
assert ( |
|
|
len(node.args) == 2 and |
|
|
isinstance(node.args[1], tuple) and |
|
|
len(node.args[1]) == 2 |
|
|
) |
|
|
(inputs, (hidden0, hidden1)) = node.args |
|
|
assert isinstance(inputs, Node) |
|
|
assert isinstance(hidden0, Node) |
|
|
assert isinstance(hidden1, Node) |
|
|
_remove_previous_dequantize_in_custom_module(node, inputs, graph) |
|
|
_remove_previous_dequantize_in_custom_module(node, hidden0, graph) |
|
|
_remove_previous_dequantize_in_custom_module(node, hidden1, graph) |
|
|
else: |
|
|
|
|
|
arg = node.args[0] |
|
|
assert isinstance(arg, Node) |
|
|
_remove_previous_dequantize_in_custom_module(node, arg, graph) |
|
|
|
|
|
activation_post_process = maybe_get_observer_for_node(node, modules) |
|
|
assert activation_post_process is not None |
|
|
observed_custom_module.activation_post_process = activation_post_process |
|
|
|
|
|
|
|
|
quantized_custom_module_class = get_swapped_custom_module_class( |
|
|
observed_custom_module, custom_module_class_mapping, qconfig) |
|
|
quantized_custom_module = \ |
|
|
quantized_custom_module_class.from_observed(observed_custom_module) |
|
|
parent_name, name = _parent_name(node.target) |
|
|
setattr(modules[parent_name], name, quantized_custom_module) |
|
|
|
|
|
def convert( |
|
|
model: GraphModule, is_reference: bool = False, |
|
|
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, |
|
|
is_standalone_module: bool = False, |
|
|
_remove_qconfig_flag: bool = True, |
|
|
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, |
|
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None) -> torch.nn.Module: |
|
|
""" |
|
|
We will convert an observed model (a module with observer calls) to a reference |
|
|
quantized model, the rule is simple: |
|
|
1. for each observer module call in the graph, we'll convert it to calls to |
|
|
quantize and dequantize functions based on the observer instance |
|
|
2. for weighted operations like linear/conv, we need to convert them to reference |
|
|
quantized module, this requires us to know whether the dtype configured for the |
|
|
weight is supported in the backend, this is done in prepare step and the result |
|
|
is stored in observed_node_names, we can decide whether we need to swap the |
|
|
module based on this set |
|
|
|
|
|
standalone_module means it a submodule that is not inlined in |
|
|
parent module, and will be quantized separately as one unit. |
|
|
|
|
|
Returns a quantized standalone module, whether input/output is quantized is |
|
|
specified by prepare_custom_config, with |
|
|
input_quantized_idxs, output_quantized_idxs, please |
|
|
see docs for prepare_fx for details |
|
|
""" |
|
|
if convert_custom_config is None: |
|
|
convert_custom_config = ConvertCustomConfig() |
|
|
|
|
|
if isinstance(convert_custom_config, Dict): |
|
|
warnings.warn( |
|
|
"Passing a convert_custom_config_dict to convert is deprecated and will not be supported " |
|
|
"in a future version. Please pass in a ConvertCustomConfig instead.") |
|
|
convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) |
|
|
|
|
|
if isinstance(qconfig_mapping, Dict): |
|
|
warnings.warn( |
|
|
"Passing a QConfig dictionary to convert is deprecated and will not be supported " |
|
|
"in a future version. Please pass in a QConfigMapping instead.") |
|
|
qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None |
|
|
qconfig_mapping = copy.deepcopy(qconfig_mapping) |
|
|
assert(qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)) |
|
|
|
|
|
if isinstance(backend_config, Dict): |
|
|
warnings.warn( |
|
|
"Passing a backend_config_dict to prepare is deprecated and will not be supported " |
|
|
"in a future version. Please pass in a BackendConfig instead.") |
|
|
backend_config = BackendConfig.from_dict(backend_config) |
|
|
|
|
|
if backend_config is None: |
|
|
backend_config = get_native_backend_config() |
|
|
|
|
|
node_name_to_scope, prepare_custom_config, observed_node_names = restore_state(model) |
|
|
qconfig_map: Dict[str, QConfigAny] = model._qconfig_map |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules = dict(model.named_modules(remove_duplicate=False)) |
|
|
|
|
|
|
|
|
|
|
|
if qconfig_mapping: |
|
|
prepare_qconfig_mapping: QConfigMapping = model._qconfig_mapping |
|
|
modules_copy = copy.deepcopy(modules) |
|
|
|
|
|
if model._is_qat: |
|
|
update_qconfig_for_qat(qconfig_mapping, {}) |
|
|
update_qconfig_for_fusion(model, qconfig_mapping) |
|
|
|
|
|
compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping) |
|
|
convert_qconfig_map = generate_qconfig_map(model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope) |
|
|
|
|
|
|
|
|
for k, v in qconfig_map.items(): |
|
|
assert k in convert_qconfig_map, 'Expected key {} in convert qconfig_map'.format(k) |
|
|
if convert_qconfig_map[k] is not None: |
|
|
assert qconfig_equals(v, convert_qconfig_map[k]), \ |
|
|
"Expected k {} to have the same value in prepare and convert QConfigMappings, " \ |
|
|
"but {} was updated to {}".format(k, v, convert_qconfig_map[k]) |
|
|
qconfig_map = convert_qconfig_map |
|
|
|
|
|
custom_module_classes = get_custom_module_class_keys(convert_custom_config.observed_to_quantized_mapping) |
|
|
custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping |
|
|
|
|
|
if model._equalization_qconfig_map is not None: |
|
|
|
|
|
|
|
|
|
|
|
weight_eq_obs_dict = update_obs_for_equalization(model, modules) |
|
|
convert_eq_obs(model, modules, weight_eq_obs_dict) |
|
|
|
|
|
|
|
|
|
|
|
run_weight_observers(model, backend_config) |
|
|
|
|
|
graph_inputs: List[str] = [] |
|
|
for node in model.graph.nodes: |
|
|
if node.op == 'placeholder': |
|
|
graph_inputs.append(node.name) |
|
|
|
|
|
|
|
|
def replace_observer_with_quantize_dequantize_node( |
|
|
model: torch.nn.Module, |
|
|
graph: Graph, |
|
|
node: Node, |
|
|
modules: Dict[str, torch.nn.Module], |
|
|
node_name_to_scope: Dict[str, Tuple[str, type]], |
|
|
qconfig_map: Dict[str, QConfigAny]) -> None: |
|
|
""" Replace activation_post_process module call node with quantize and |
|
|
dequantize node |
|
|
|
|
|
Before: |
|
|
... -> observer_0(x) -> ... |
|
|
After: |
|
|
... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... |
|
|
""" |
|
|
assert modules is not None |
|
|
assert isinstance(node.target, str) |
|
|
module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, qconfig_map) |
|
|
observer_module = modules[node.target] |
|
|
maybe_quantize_node_info = get_quantize_node_info(observer_module) |
|
|
|
|
|
|
|
|
skip_replacement = all([ |
|
|
has_none_qconfig(n, qconfig_map) for n in |
|
|
list(node.args) + list(node.users.keys())]) |
|
|
if skip_replacement or maybe_quantize_node_info is None: |
|
|
|
|
|
|
|
|
with graph.inserting_before(node): |
|
|
node.replace_all_uses_with(node.args[0]) |
|
|
graph.erase_node(node) |
|
|
else: |
|
|
|
|
|
node_type, quantize_op, qparams = maybe_quantize_node_info |
|
|
|
|
|
with graph.inserting_before(node): |
|
|
input_node = node.args[0] |
|
|
inputs = [input_node] |
|
|
for key, value in qparams.items(): |
|
|
|
|
|
|
|
|
if key in ['_scale_', '_zero_point_']: |
|
|
|
|
|
|
|
|
qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value) |
|
|
inputs.append(qparam_node) |
|
|
else: |
|
|
|
|
|
inputs.append(value) |
|
|
|
|
|
quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {}) |
|
|
dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) |
|
|
node.replace_all_uses_with(dequantized_node) |
|
|
graph.erase_node(node) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph): |
|
|
call_custom_module_node = node.args[0] |
|
|
assert isinstance(call_custom_module_node, Node), \ |
|
|
f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" |
|
|
node.replace_all_uses_with(call_custom_module_node) |
|
|
graph.erase_node(node) |
|
|
insert_dequantize_node(call_custom_module_node, graph) |
|
|
|
|
|
|
|
|
|
|
|
placeholder_node_seen_cnt = 0 |
|
|
input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes |
|
|
output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes |
|
|
|
|
|
root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config) |
|
|
|
|
|
root_module_classes = tuple(root_module_to_quantized_reference_module.keys()) |
|
|
qat_module_classes = get_qat_module_classes(backend_config) |
|
|
fused_module_classes = get_fused_module_classes(backend_config) |
|
|
statically_quantized_custom_module_nodes: Set[Node] = set() |
|
|
|
|
|
for node in list(model.graph.nodes): |
|
|
if node.op == 'placeholder': |
|
|
cur_placeholder_node_idx = placeholder_node_seen_cnt |
|
|
placeholder_node_seen_cnt += 1 |
|
|
if cur_placeholder_node_idx in input_quantized_idxs: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
insert_dequantize_node(node, model.graph) |
|
|
elif node.op == "output": |
|
|
|
|
|
if len(output_quantized_idxs) == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
return_node = node |
|
|
output = node.args[0] |
|
|
|
|
|
if isinstance(output, (list, tuple)): |
|
|
for idx in output_quantized_idxs: |
|
|
maybe_recursive_remove_dequantize(output[idx], return_node, model.graph) |
|
|
elif isinstance(output, (Node, dict)): |
|
|
|
|
|
|
|
|
|
|
|
if 0 in output_quantized_idxs: |
|
|
maybe_recursive_remove_dequantize(output, return_node, model.graph) |
|
|
else: |
|
|
warnings.warn(f"Unsupported node type for output_quantized_idxs: {type(output)}") |
|
|
elif node.op == "call_module": |
|
|
mod = _get_module(node, modules) |
|
|
assert mod is not None |
|
|
if is_activation_post_process(mod): |
|
|
observed_node = node.args[0] |
|
|
if observed_node in statically_quantized_custom_module_nodes: |
|
|
replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) |
|
|
else: |
|
|
replace_observer_with_quantize_dequantize_node( |
|
|
model, model.graph, node, modules, node_name_to_scope, |
|
|
qconfig_map) |
|
|
elif isinstance(mod, DeQuantStub): |
|
|
replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) |
|
|
elif is_observed_standalone_module(mod): |
|
|
convert_standalone_module( |
|
|
node, modules, model, is_reference, backend_config) |
|
|
|
|
|
|
|
|
elif type_before_parametrizations(mod) in set( |
|
|
root_module_classes).union(qat_module_classes).union(fused_module_classes): |
|
|
|
|
|
|
|
|
if type_before_parametrizations(mod) in fused_module_classes and \ |
|
|
type_before_parametrizations(mod[0]) not in root_module_classes: |
|
|
continue |
|
|
convert_weighted_module( |
|
|
node, modules, observed_node_names, qconfig_map, backend_config) |
|
|
elif type_before_parametrizations(mod) in custom_module_classes: |
|
|
convert_custom_module( |
|
|
node, model.graph, modules, custom_module_class_mapping, |
|
|
statically_quantized_custom_module_nodes) |
|
|
|
|
|
preserved_attributes = set(convert_custom_config.preserved_attributes) |
|
|
model = QuantizedGraphModule(model, copy.deepcopy(model.graph), preserved_attributes) |
|
|
|
|
|
|
|
|
model.graph.eliminate_dead_code() |
|
|
model.recompile() |
|
|
|
|
|
|
|
|
if not is_reference: |
|
|
model = lower_to_fbgemm(model, qconfig_map, node_name_to_scope) |
|
|
|
|
|
|
|
|
|
|
|
if _remove_qconfig_flag: |
|
|
_remove_qconfig(model) |
|
|
return model |
|
|
|