File size: 39,682 Bytes
9dd3461 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 |
from typing import Any, Dict, List, Optional, Set, Tuple, Union, Type
from torch.ao.quantization.quant_type import QuantType
import torch
import copy
import warnings
from torch.fx import (
GraphModule,
)
from torch.fx.graph import (
Graph,
Node,
Argument,
)
from ..utils import (
activation_is_statically_quantized,
weight_is_quantized,
get_qparam_dict,
_parent_name,
get_swapped_custom_module_class,
)
from ..qconfig import (
QConfigAny,
qconfig_equals
)
from ..qconfig_mapping import QConfigMapping
from ..qconfig_mapping_utils import (
update_qconfig_for_qat,
)
from .qconfig_mapping_utils import (
generate_qconfig_map,
compare_prepare_convert_qconfig_mappings,
update_qconfig_for_fusion,
is_qconfig_supported_by_dtype_configs,
)
from torch.ao.quantization.backend_config.utils import (
get_root_module_to_quantized_reference_module,
get_pattern_to_dtype_configs,
get_fused_module_classes,
get_qat_module_classes,
)
from torch.ao.quantization.backend_config import (
BackendConfig,
get_native_backend_config,
)
from .graph_module import (
QuantizedGraphModule,
is_observed_module,
is_observed_standalone_module,
)
from ._equalize import update_obs_for_equalization, convert_eq_obs
from torch.nn.utils.parametrize import type_before_parametrizations
from .utils import (
_get_module,
_is_custom_module_lstm,
get_custom_module_class_keys,
get_quantize_node_info,
create_getattr_from_value,
collect_producer_nodes,
graph_module_from_producer_nodes,
node_arg_is_weight,
)
from torch.ao.quantization.quantize import (
_remove_qconfig,
is_activation_post_process,
)
from torch.ao.quantization.stubs import DeQuantStub
from .custom_config import (
ConvertCustomConfig,
PrepareCustomConfig,
)
from .lower_to_fbgemm import lower_to_fbgemm
# TODO: revisit this list. Many helper methods shouldn't be public
__all__ = [
"convert",
"convert_custom_module",
"convert_standalone_module",
"convert_weighted_module",
"duplicate_dequantize_node",
"duplicate_quantize_dynamic_node",
"get_module_path_and_prefix",
"has_none_qconfig",
"insert_dequantize_node",
"maybe_get_observer_for_node",
"maybe_recursive_remove_dequantize",
"remove_extra_dequantize",
"restore_state",
"run_weight_observers",
]
def restore_state(
observed: torch.nn.Module
) -> Tuple[Dict[str, Tuple[str, type]],
PrepareCustomConfig,
Set[str]]:
assert is_observed_module(observed), \
'incoming model must be produced by prepare_fx'
prepare_custom_config: PrepareCustomConfig = observed._prepare_custom_config # type: ignore[assignment]
node_name_to_scope: Dict[str, Tuple[str, type]] = observed._node_name_to_scope # type: ignore[assignment]
observed_node_names: Set[str] = observed._observed_node_names # type: ignore[assignment]
return node_name_to_scope, prepare_custom_config, observed_node_names
def has_none_qconfig(node: Argument, qconfig_map: Dict[str, QConfigAny]) -> bool:
""" Check if a node has a qconfig of None, i.e. user requested to not quantize
the node
"""
return isinstance(node, Node) and node.name in qconfig_map and qconfig_map[node.name] is None
def run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None:
""" Extract the subgraph that produces the weight for dynamic quant
or weight only quant node and run the subgraph to observe the weight.
Note that the observers of dynamic quant or weight only quant ops are
run during the convert step.
"""
for node in observed.graph.nodes:
if node.op != "call_function":
continue
for node_arg in node.args:
# node_arg is weight
if node_arg and node_arg_is_weight(node, node_arg, backend_config):
weight_observer_nodes = collect_producer_nodes(node_arg)
if weight_observer_nodes is None:
continue
weight_observer_module = \
graph_module_from_producer_nodes(
observed, weight_observer_nodes)
# run the weight observer
weight_observer_module()
# this method is temporary will be removed soon
def duplicate_quantize_dynamic_node(quantized: QuantizedGraphModule) -> QuantizedGraphModule:
quantized_root = quantized
for node in quantized.graph.nodes:
if (node.op == "call_function" and node.target == torch.quantize_per_tensor_dynamic):
users = list(node.users)
if len(users) > 1:
for user in users:
with quantized.graph.inserting_before(node):
new_node = quantized.graph.create_node(
"call_function",
torch.quantize_per_tensor_dynamic,
node.args,
node.kwargs)
user.replace_input_with(node, new_node)
quantized.graph.erase_node(node)
quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names)
return quantized
def duplicate_dequantize_node(quantized: QuantizedGraphModule) -> QuantizedGraphModule:
"""
If a dequantize node has multiple uses, duplicate it and create one dequantize node for each use.
This is to enable the pattern matching to map from individual quant - dequant - ref_module to
final quantized module.
"""
quantized_root = quantized
for node in quantized.graph.nodes:
if (node.op == "call_method" and node.target == "dequantize" or
(node.op == "call_function" and node.target == torch.dequantize)):
users = list(node.users)
if len(users) > 1:
for user in users:
with quantized.graph.inserting_before(node):
new_node = quantized.graph.create_node("call_method", "dequantize", node.args, {})
user.replace_input_with(node, new_node)
quantized.graph.erase_node(node)
quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names)
return quantized
def remove_extra_dequantize(quantized: QuantizedGraphModule) -> QuantizedGraphModule:
"""
Removes duplicate dequant nodes in the graph, for an operator that has multiple dequant nodes as a user,
replace them with a single dequant node that can be shared across all the uses.
"""
quantized_root = quantized
for node in quantized.graph.nodes:
users = list(node.users)
dequant_users = [user for user in node.users if user.op == "call_method" and user.target == "dequantize" or
(user.op == "call_function" and user.target == torch.dequantize)]
if len(dequant_users) > 1:
with quantized.graph.inserting_after(node):
unique_dq = quantized.graph.create_node("call_method", "dequantize", users[0].args, {})
for dequant in dequant_users:
dequant.replace_all_uses_with(unique_dq)
quantized.graph.erase_node(dequant)
quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names)
return quantized
def maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph):
""" If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node,
we'll recursively remove the dequantize Node
"""
if isinstance(arg, Node) and \
arg.op == "call_method" and \
arg.target == "dequantize":
quantize_node = arg.args[0]
# we only replace the specific use since dequantize could be used by other nodes
# as well
node.replace_input_with(arg, quantize_node)
elif isinstance(arg, (list, tuple)):
for arg_element in arg:
maybe_recursive_remove_dequantize(arg_element, node, graph)
elif isinstance(arg, dict):
for arg_element in arg.values():
maybe_recursive_remove_dequantize(arg_element, node, graph)
else:
warnings.warn(f"Unsupported node type in recursive remove dequantize: {type(arg)}")
def get_module_path_and_prefix(
obs_node: Node,
node_name_to_scope: Dict[str, Tuple[str, type]],
qconfig_map: Dict[str, QConfigAny]):
""" Given and observer node, get the `Scope` or the fully qualified name for
the submodule containing the observed node, also return a prefix of "_input"
when the observed node is an input of a F.linear op, and not the output of another
quantized op.
TODO: this logic is hacky, we should think about how to remove it or make it more
general
"""
observed_node = obs_node.args[0]
# an observer can be inserted for both input of the next operator or output of the previous
# operator (they can be the same)
# this flag identifies if the observer is inserted only because the observed node is
# the input of the next operator
assert isinstance(observed_node, Node), \
f"Expecting observed node to be a Node, but got {observed_node}"
is_input_observer_only = qconfig_map[observed_node.name] is None if observed_node.name in qconfig_map else None
if is_input_observer_only:
# if the quantize function is at the input of op, then we find the first user of the observer_node
# to get the path. If a linear call_function is in the user list, we return the first instance
# of linear node to get the FQN.
users = list(obs_node.users)
first_linear_use_or_first_use = users[0] if users else None
linear_node = None
for n in users:
if n.op == "call_function" and n.target == torch.nn.functional.linear:
linear_node = n
break
if linear_node:
first_linear_use_or_first_use = linear_node
prefix = "_input"
else:
# if the quantize function is at the output of the op, we use the observer input node to get the path
first_linear_use_or_first_use = observed_node
prefix = ""
if first_linear_use_or_first_use and first_linear_use_or_first_use.name in node_name_to_scope:
module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name]
else:
# TODO: it's not used, so actually we can skip quantization
# but this requires changing return type of quantize_node
# we can fix it later if needed
module_path = ""
return module_path, prefix
def insert_dequantize_node(
node: Node,
graph: Graph):
""" Inserts dequantize node for `node` in `graph`
"""
with graph.inserting_after(node):
dequantize_node = graph.call_method("dequantize", (node,))
for user_node in dict(node.users):
if user_node is not dequantize_node:
user_node.replace_input_with(node, dequantize_node)
def maybe_get_observer_for_node(
node: Node,
modules: Dict[str, torch.nn.Module]
) -> Optional[torch.nn.Module]:
"""
If the node is observed, return the observer
instance. Otherwise, return None.
"""
for maybe_obs_node, _ in node.users.items():
if maybe_obs_node.op == 'call_module':
maybe_obs = modules[str(maybe_obs_node.target)]
if is_activation_post_process(maybe_obs):
return maybe_obs
return None
def convert_standalone_module(
node: Node,
modules: Dict[str, torch.nn.Module],
model: torch.fx.GraphModule,
is_reference: bool,
backend_config: Optional[BackendConfig]):
""" Converts a observed standalone module to a quantized standalone module by calling
the fx convert api, currently using the same `is_reference` flag as parent, but we may
changing this behavior in the future (e.g. separating quantization and lowering for
standalone module as well)
Args:
- node: The call_module node of the observed standalone module
- modules: named_module of original model
- model: original model
- is_reference: a flag from parent provided by user to decide if we want to
produce a reference model or a fbgemm/qnnpack model
- backend_config: backend configuration of the target backend of quantization
"""
# TODO: remove is_reference flag
if is_reference:
convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx
else:
convert_fn = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined]
# We know that observed standalone module is a GraphModule since
# it's produced by us
observed_standalone_module : GraphModule = modules[str(node.target)] # type: ignore[assignment]
sm_input_quantized_idxs = \
observed_standalone_module \
._standalone_module_input_quantized_idxs\
.tolist() # type: ignore[operator]
# remove the dequantize nodes for inputs
args = list(node.args)
for idx in range(len(args)):
if idx in sm_input_quantized_idxs:
arg = args[idx]
if arg.op == "call_method" and arg.target == "dequantize": # type: ignore[union-attr]
quantize_node = arg.args[0] # type: ignore[union-attr]
node.replace_input_with(arg, quantize_node)
if len(arg.users) == 0: # type: ignore[union-attr]
model.graph.erase_node(arg)
# add dequantize node for output
sm_output_quantized_idxs = \
observed_standalone_module \
._standalone_module_output_quantized_idxs \
.tolist() # type: ignore[operator]
if len(sm_output_quantized_idxs) > 0:
assert sm_output_quantized_idxs[0] == 0, "Currently only quantized"
"output idxs = [0] is supported"
# if it's non-empty, then it means the output is kept in quantized form
# we'll just add a dequantize node after this node
insert_dequantize_node(node, model.graph)
# TODO: allow convert_custom_config to override backend_config
# for standalone module
quantized_standalone_module = convert_fn(
observed_standalone_module,
backend_config=backend_config)
parent_name, name = _parent_name(node.target)
# update the modules dict
setattr(modules[parent_name], name, quantized_standalone_module)
modules[str(node.target)] = quantized_standalone_module
def convert_weighted_module(
node: Node,
modules: Dict[str, torch.nn.Module],
observed_node_names: Set[str],
qconfig_map: Dict[str, QConfigAny],
backend_config: BackendConfig):
""" Convert a weighted module to reference quantized module in the model
If the QConfig of a QAT module is not set, the module will still be converted to
a float module.
Args:
- node: The call_module node of the observed standalone module
- modules: named_module of original model
- observed_node_names: names for the set of observed fx node, we can skip
this conversion if the node is not observed
"""
original_module = modules[str(node.target)]
qconfig: QConfigAny = original_module.qconfig # type: ignore[assignment]
weight_post_process = None
qat_module_classes = get_qat_module_classes(backend_config)
if isinstance(
original_module,
qat_module_classes):
# Converting qat module to a float module, we need to attch
# weight fake_quant to the module, weight fake_quant is assumed to be run during
# QAT so we don't need to run it again here
weight_post_process = original_module.weight_fake_quant
original_module = original_module.to_float() # type: ignore[operator]
# change qat module to float module
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name, original_module)
is_observed = node.name in observed_node_names
# If a qconfig is not defined for this node, then skip converting to a reference module
if qconfig is None or has_none_qconfig(node, qconfig_map) or not is_observed:
return
# skip converting to reference quantized module if the qconfig is not supported
pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
dtype_configs = pattern_to_dtype_configs.get(type(original_module), [])
if not is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs):
return
# TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
is_weight_quantized = weight_is_quantized(qconfig)
# the condition for swapping the module to reference quantized module is:
# weights need to be quantized
if not is_weight_quantized:
return
fused_module = None
float_module = original_module
# extract the inidividual float_module and fused module
if isinstance(original_module, torch.nn.intrinsic._FusedModule):
fused_module = float_module
float_module = fused_module[0] # type: ignore[index]
# TODO: move this to the reference quantized module
# weight_qparams or weight_qparams dict
wq_or_wq_dict = {}
if isinstance(float_module, torch.nn.RNNCellBase):
weight_post_process_ih = qconfig.weight() # type: ignore[union-attr, operator]
weight_post_process_hh = qconfig.weight() # type: ignore[union-attr, operator]
weight_post_process_ih(float_module.weight_ih)
weight_post_process_hh(float_module.weight_hh)
weight_qparams_ih = get_qparam_dict(weight_post_process_ih)
weight_qparams_hh = get_qparam_dict(weight_post_process_hh)
wq_or_wq_dict = {
"weight_ih": weight_qparams_ih,
"weight_hh": weight_qparams_hh,
}
elif isinstance(float_module, torch.nn.LSTM):
# format for wq_or_wq_dict (flattened attributes):
# {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...}
for wn in float_module._flat_weights_names:
if hasattr(float_module, wn) and wn.startswith("weight"):
weight = getattr(float_module, wn)
weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
if weight_post_process.dtype == torch.qint8: # type: ignore[union-attr]
weight_post_process(weight) # type: ignore[operator, misc]
wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process)
else:
# weight_post_process is None means the original module is not a QAT module
# we need to get weight_post_process from qconfig in this case
if weight_post_process is None:
weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
# run weight observer
# TODO: This is currently a hack for QAT to get the right shapes for scale and zero point.
# In the future, we should require the user to calibrate the model after calling prepare
# Issue: https://github.com/pytorch/pytorch/issues/73941
weight_post_process(float_module.weight) # type: ignore[operator]
wq_or_wq_dict = get_qparam_dict(weight_post_process)
# We use the same reference module for all modes of quantization: static, dynamic, weight_only
# root_module_to_quantized_reference_module: module mapping from root (floating point) module class
# to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d
root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config)
ref_qmodule_cls = root_module_to_quantized_reference_module.get(type_before_parametrizations(float_module), None)
assert (
ref_qmodule_cls is not None
), f"No reference quantized module class configured for {type_before_parametrizations(float_module)}"
ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined]
if fused_module is not None:
fused_module[0] = ref_qmodule # type: ignore[operator]
else:
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name, ref_qmodule)
def _remove_previous_dequantize_in_custom_module(node: Node, prev_node: Node, graph: Graph):
"""
Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows:
Before: quantize - dequantize - custom_module
After: quantize - custom_module
\\ - dequantize
"""
# expecting the input node for a custom module node to be a Node
assert isinstance(prev_node, Node), \
f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
if prev_node.op == "call_method" and prev_node.target == "dequantize":
node.replace_input_with(prev_node, prev_node.args[0])
# Remove the dequantize node if it doesn't have other users
if len(prev_node.users) == 0:
graph.erase_node(prev_node)
def convert_custom_module(
node: Node,
graph: Graph,
modules: Dict[str, torch.nn.Module],
custom_module_class_mapping: Dict[QuantType, Dict[Type, Type]],
statically_quantized_custom_module_nodes: Set[Node]):
""" Converts an observed custom module to a quantized custom module based on
`custom_module_class_mapping`
For static quantization, we'll also remove the previous `dequantize` node and
attach the observer node for output to the module, the observer for the node
will be converted to a dequantize node instead of quantize-dequantize pairs
later in the graph. In the end we would have a quantized custom module that
has the same interface as a default quantized module in nn.quantized namespace,
i.e. quantized input and quantized output.
Args:
- node: The call_module node of the observed standalone module
- graph: The graph containing the node
- modules: named_module of original model
- custom_module_class_mapping: mapping from observed custom module class to
quantized custom module class, used to swap custom modules
- statically_quantized_custom_module_nodes: we'll add the custom module node
if we find it is statically quantized, this will be used later when converting
observers to quant/dequant node pairs, if the observed node is a statically
quantized custom module nodes, we'll convert the observer to a dequantize node,
this is to keep the interface the same as the default quantized module.
TODO: maybe we want to redesign this part to align with reference model design
as well, but there has been some discussions around the interface, so we can do
it later.
"""
observed_custom_module = modules[str(node.target)]
maybe_obs = maybe_get_observer_for_node(node, modules)
qconfig = observed_custom_module.qconfig
if activation_is_statically_quantized(qconfig):
statically_quantized_custom_module_nodes.add(node)
if _is_custom_module_lstm(node, modules):
# The inputs are tuples in the form (input, (hidden0, hidden1))
# Ensure all three input nodes are quantized
assert (
len(node.args) == 2 and
isinstance(node.args[1], tuple) and
len(node.args[1]) == 2
)
(inputs, (hidden0, hidden1)) = node.args # type: ignore[misc]
assert isinstance(inputs, Node)
assert isinstance(hidden0, Node)
assert isinstance(hidden1, Node)
_remove_previous_dequantize_in_custom_module(node, inputs, graph)
_remove_previous_dequantize_in_custom_module(node, hidden0, graph)
_remove_previous_dequantize_in_custom_module(node, hidden1, graph)
else:
# remove the previous dequant node to ensure the inputs are quantized
arg = node.args[0]
assert isinstance(arg, Node)
_remove_previous_dequantize_in_custom_module(node, arg, graph)
# absorb the following observer into the module conversion
activation_post_process = maybe_get_observer_for_node(node, modules)
assert activation_post_process is not None
observed_custom_module.activation_post_process = activation_post_process
# swap the observed custom module to quantized custom module
quantized_custom_module_class = get_swapped_custom_module_class(
observed_custom_module, custom_module_class_mapping, qconfig)
quantized_custom_module = \
quantized_custom_module_class.from_observed(observed_custom_module)
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name, quantized_custom_module)
def convert(
model: GraphModule, is_reference: bool = False,
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
is_standalone_module: bool = False,
_remove_qconfig_flag: bool = True,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None) -> torch.nn.Module:
"""
We will convert an observed model (a module with observer calls) to a reference
quantized model, the rule is simple:
1. for each observer module call in the graph, we'll convert it to calls to
quantize and dequantize functions based on the observer instance
2. for weighted operations like linear/conv, we need to convert them to reference
quantized module, this requires us to know whether the dtype configured for the
weight is supported in the backend, this is done in prepare step and the result
is stored in observed_node_names, we can decide whether we need to swap the
module based on this set
standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
Returns a quantized standalone module, whether input/output is quantized is
specified by prepare_custom_config, with
input_quantized_idxs, output_quantized_idxs, please
see docs for prepare_fx for details
"""
if convert_custom_config is None:
convert_custom_config = ConvertCustomConfig()
if isinstance(convert_custom_config, Dict):
warnings.warn(
"Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
"in a future version. Please pass in a ConvertCustomConfig instead.")
convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
if isinstance(qconfig_mapping, Dict):
warnings.warn(
"Passing a QConfig dictionary to convert is deprecated and will not be supported "
"in a future version. Please pass in a QConfigMapping instead.")
qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
qconfig_mapping = copy.deepcopy(qconfig_mapping)
assert(qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping))
if isinstance(backend_config, Dict):
warnings.warn(
"Passing a backend_config_dict to prepare is deprecated and will not be supported "
"in a future version. Please pass in a BackendConfig instead.")
backend_config = BackendConfig.from_dict(backend_config)
if backend_config is None:
backend_config = get_native_backend_config()
node_name_to_scope, prepare_custom_config, observed_node_names = restore_state(model)
qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment]
# mapping from fully qualified module name to module instance
# for example,
# {
# '': Model(...),
# 'linear': Linear(...),
# 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
# }
# We use remove_duplicate=False here because torch.cat uses
# the same activation_post_process module instance but different names
modules = dict(model.named_modules(remove_duplicate=False))
# TODO refactor this code once we update the prepare logic to have additional information on
# which graph nodes have been observed and share that with convert to decide which observers to ignore.
if qconfig_mapping:
prepare_qconfig_mapping: QConfigMapping = model._qconfig_mapping # type: ignore[assignment]
modules_copy = copy.deepcopy(modules)
if model._is_qat:
update_qconfig_for_qat(qconfig_mapping, {})
update_qconfig_for_fusion(model, qconfig_mapping)
compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping) # type: ignore[arg-type]
convert_qconfig_map = generate_qconfig_map(model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope)
# check the convert_qconfig_map generated and ensure that all the values either match what was set in prepare qconfig_map
# or are set to None in the convert_qconfig_map.
for k, v in qconfig_map.items():
assert k in convert_qconfig_map, 'Expected key {} in convert qconfig_map'.format(k)
if convert_qconfig_map[k] is not None:
assert qconfig_equals(v, convert_qconfig_map[k]), \
"Expected k {} to have the same value in prepare and convert QConfigMappings, " \
"but {} was updated to {}".format(k, v, convert_qconfig_map[k])
qconfig_map = convert_qconfig_map
custom_module_classes = get_custom_module_class_keys(convert_custom_config.observed_to_quantized_mapping)
custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping
if model._equalization_qconfig_map is not None:
# If we want to do equalization then do the following:
# Calculate the equalization scale, update the observers with the scaled
# inputs, and scale the weight
weight_eq_obs_dict = update_obs_for_equalization(model, modules)
convert_eq_obs(model, modules, weight_eq_obs_dict)
# always run weight observers in the top level forward method
# for dynamic quant ops or weight only quant ops
run_weight_observers(model, backend_config)
graph_inputs: List[str] = []
for node in model.graph.nodes:
if node.op == 'placeholder':
graph_inputs.append(node.name)
# TODO: move this outside of this function
def replace_observer_with_quantize_dequantize_node(
model: torch.nn.Module,
graph: Graph,
node: Node,
modules: Dict[str, torch.nn.Module],
node_name_to_scope: Dict[str, Tuple[str, type]],
qconfig_map: Dict[str, QConfigAny]) -> None:
""" Replace activation_post_process module call node with quantize and
dequantize node
Before:
... -> observer_0(x) -> ...
After:
... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
"""
assert modules is not None
assert isinstance(node.target, str)
module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, qconfig_map)
observer_module = modules[node.target]
maybe_quantize_node_info = get_quantize_node_info(observer_module)
# Skip replacing observers to quant/dequant nodes if the qconfigs of all
# consumers and producers of this observer are None
skip_replacement = all([
has_none_qconfig(n, qconfig_map) for n in
list(node.args) + list(node.users.keys())])
if skip_replacement or maybe_quantize_node_info is None:
# didn't find correponding quantize op and info for the observer_module
# so we just remove the observer
with graph.inserting_before(node):
node.replace_all_uses_with(node.args[0])
graph.erase_node(node)
else:
# otherwise, we can convert the observer moduel call to quantize/dequantize node
node_type, quantize_op, qparams = maybe_quantize_node_info
# replace observer node with quant - dequant node
with graph.inserting_before(node):
input_node = node.args[0]
inputs = [input_node]
for key, value in qparams.items():
# TODO: we can add the information of whether a value needs to
# be registered as an attribute in qparams dict itself
if key in ['_scale_', '_zero_point_']:
# For scale and zero_point values we register them as buffers in the root module.
# TODO: maybe need more complex attr name here
qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value)
inputs.append(qparam_node)
else:
# for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
inputs.append(value)
quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {})
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
node.replace_all_uses_with(dequantized_node)
graph.erase_node(node)
# this is a temporary hack for custom module, we may want to implement
# this properly after the custom module class design is finalized
# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted
# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs
# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively.
def replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph):
call_custom_module_node = node.args[0]
assert isinstance(call_custom_module_node, Node), \
f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
node.replace_all_uses_with(call_custom_module_node)
graph.erase_node(node)
insert_dequantize_node(call_custom_module_node, graph)
# additional state to override inputs to be quantized, if specified
# by the user
placeholder_node_seen_cnt = 0
input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config)
# convert tuples so that it can work with isinstance(module, tuple_of_classes)
root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
qat_module_classes = get_qat_module_classes(backend_config)
fused_module_classes = get_fused_module_classes(backend_config)
statically_quantized_custom_module_nodes: Set[Node] = set()
for node in list(model.graph.nodes):
if node.op == 'placeholder':
cur_placeholder_node_idx = placeholder_node_seen_cnt
placeholder_node_seen_cnt += 1
if cur_placeholder_node_idx in input_quantized_idxs:
# Inputs are assumed to be quantized if the user specifid the
# input_quantized_idxs override.
# we need to dequantize the inputs since all operators took
# floating point inputs in reference quantized models
insert_dequantize_node(node, model.graph)
elif node.op == "output":
# If the argument is empty we don't need to do anything
if len(output_quantized_idxs) == 0:
continue
# Result are kept quantized if the user specified the
# output_quantized_idxs override.
# Remove the dequantize operator for the node in the end if any
return_node = node
output = node.args[0]
# outputs can be Node, list, tuple, dict, other cases are not supported yet
if isinstance(output, (list, tuple)):
for idx in output_quantized_idxs:
maybe_recursive_remove_dequantize(output[idx], return_node, model.graph)
elif isinstance(output, (Node, dict)):
# we treat dict as a single argument currently, but it can be extended
# to support {"key": dtype} after we change output_quantized_idxs to
# dict
if 0 in output_quantized_idxs:
maybe_recursive_remove_dequantize(output, return_node, model.graph)
else:
warnings.warn(f"Unsupported node type for output_quantized_idxs: {type(output)}")
elif node.op == "call_module":
mod = _get_module(node, modules)
assert mod is not None
if is_activation_post_process(mod):
observed_node = node.args[0]
if observed_node in statically_quantized_custom_module_nodes:
replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
else:
replace_observer_with_quantize_dequantize_node(
model, model.graph, node, modules, node_name_to_scope,
qconfig_map)
elif isinstance(mod, DeQuantStub):
replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
elif is_observed_standalone_module(mod):
convert_standalone_module(
node, modules, model, is_reference, backend_config)
# below this point `type_before_parametrizations` is used
# instead of `type` to handle situations with fx quant + sparsity
elif type_before_parametrizations(mod) in set(
root_module_classes).union(qat_module_classes).union(fused_module_classes):
# extra check for fused module classes to make sure they are fused module classes
# of target modules
if type_before_parametrizations(mod) in fused_module_classes and \
type_before_parametrizations(mod[0]) not in root_module_classes: # type: ignore[index]
continue
convert_weighted_module(
node, modules, observed_node_names, qconfig_map, backend_config)
elif type_before_parametrizations(mod) in custom_module_classes:
convert_custom_module(
node, model.graph, modules, custom_module_class_mapping,
statically_quantized_custom_module_nodes)
preserved_attributes = set(convert_custom_config.preserved_attributes)
model = QuantizedGraphModule(model, copy.deepcopy(model.graph), preserved_attributes)
# remove deadcode after converting observers to quant/dequant ops
model.graph.eliminate_dead_code()
model.recompile()
# TODO: maybe move this to quantize_fx.py
if not is_reference:
model = lower_to_fbgemm(model, qconfig_map, node_name_to_scope)
# TODO: this looks hacky, we want to check why we need this and see if we can
# remove this
# removes qconfig and activation_post_process modules
if _remove_qconfig_flag:
_remove_qconfig(model)
return model
|