|
|
|
|
|
|
|
|
from torch.ao.quantization.pt2e.utils import _is_sym_size_node |
|
|
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation |
|
|
from torch.fx import Node |
|
|
|
|
|
|
|
|
def _annotate_input_qspec_map(node: Node, input_node: Node, qspec): |
|
|
quantization_annotation = node.meta.get( |
|
|
"quantization_annotation", QuantizationAnnotation() |
|
|
) |
|
|
if quantization_annotation.input_qspec_map is None: |
|
|
quantization_annotation.input_qspec_map = {} |
|
|
quantization_annotation.input_qspec_map[input_node] = qspec |
|
|
node.meta["quantization_annotation"] = quantization_annotation |
|
|
|
|
|
|
|
|
def _annotate_output_qspec(node: Node, qspec): |
|
|
quantization_annotation = node.meta.get( |
|
|
"quantization_annotation", QuantizationAnnotation() |
|
|
) |
|
|
quantization_annotation.output_qspec = qspec |
|
|
node.meta["quantization_annotation"] = quantization_annotation |
|
|
|
|
|
|
|
|
def _node_only_used_for_sym_size(node: Node, partition_nodes: list[Node]): |
|
|
""" |
|
|
This utility is used to handle cases when dynami_shape=True tracing leads |
|
|
to symint nodes in the pattern of linear module. In those cases, we need to |
|
|
distinguish between the nodes that are in input for just extracting value of |
|
|
some dimensions (and symint nodes) vs. the one that is activation. |
|
|
For example: |
|
|
graph(x, y, weight): |
|
|
size_0 = torch.ops.aten.sym_size([x], [0]) |
|
|
size_1 = torch.ops.aten.sym_size([y], [1]) |
|
|
view_size = size_0 * size_1 |
|
|
size_3 = torch.ops.aten.sym_size([x], [2]) |
|
|
vie_out = torch.ops.aten.view(x, [view_size, size_3]) |
|
|
return mm(view_out, weight) |
|
|
In the example above y node is not actual input. It exist only to extract size_1 |
|
|
""" |
|
|
if _is_sym_size_node(node): |
|
|
return True |
|
|
|
|
|
return all( |
|
|
((user not in partition_nodes) or _is_sym_size_node(user)) |
|
|
for user in node.users |
|
|
) |
|
|
|
|
|
|
|
|
def _get_module_name_filter(module_name: str): |
|
|
"""Get the module_name_filter function for a given module name, the filter accepts |
|
|
a node and checks if the node comes from a module that has certain module name |
|
|
|
|
|
For example: |
|
|
node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 |
|
|
|
|
|
|
|
|
>> module_name_filter = _get_module_name_filter("blocks.sub") |
|
|
>> print(module_name_filter(node)) |
|
|
True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" |
|
|
""" |
|
|
|
|
|
def module_name_filter(n: Node) -> bool: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn_module_stack = n.meta.get("nn_module_stack", {}) |
|
|
|
|
|
def _normalize_path(n): |
|
|
prefix = 0 |
|
|
|
|
|
if n.startswith("L['self']."): |
|
|
prefix = len("L['self'].") |
|
|
return n[prefix:] |
|
|
|
|
|
names = [_normalize_path(n) for n, _ in nn_module_stack.values()] |
|
|
return module_name in names |
|
|
|
|
|
return module_name_filter |
|
|
|