| import torch |
| from torch._inductor.constant_folding import constant_fold |
| from torch._inductor.fx_passes.freezing_patterns import freezing_passes |
|
|
|
|
| __all__ = [ |
| "lower_pt2e_quantized_to_x86", |
| ] |
|
|
|
|
| def lower_pt2e_quantized_to_x86( |
| model: torch.fx.GraphModule, |
| example_inputs: tuple[torch.Tensor, ...], |
| ) -> torch.fx.GraphModule: |
| """Lower a PT2E-qantized model to x86 backend. |
| |
| Args: |
| * `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow. |
| * `example_inputs` (tuple[torch.Tensor, ...]): example inputs for the model. |
| |
| Return: |
| A GraphModule lowered to x86 backend. |
| """ |
|
|
| def _post_autograd_decomp_table(): |
| decomp_table = torch.export.default_decompositions() |
|
|
| |
| |
| for k in list(decomp_table.keys()): |
| if not torch._export.utils._is_cia_op(k): |
| del decomp_table[k] |
|
|
| return decomp_table |
|
|
| def _node_replace(m): |
| |
| aten = torch.ops.aten |
| g = m.graph |
| for node in g.nodes: |
| if node.target == aten.t.default: |
| with g.inserting_before(node): |
| x = node.args[0] |
| dims = [1, 0] |
| perm_node = g.call_function(aten.permute.default, args=(x, dims)) |
| node.replace_all_uses_with(perm_node) |
| g.erase_node(node) |
|
|
| g.lint() |
| m.recompile() |
|
|
| lowered_model = ( |
| torch.export.export_for_training(model, example_inputs, strict=True) |
| .run_decompositions(_post_autograd_decomp_table()) |
| .module() |
| ) |
| _node_replace(lowered_model) |
| freezing_passes(lowered_model, example_inputs) |
| constant_fold(lowered_model) |
| return lowered_model |
|
|