Phi2-Fine-Tuning
/
phivenv
/Lib
/site-packages
/torch
/ao
/quantization
/quantizer
/composable_quantizer.py
| from __future__ import annotations | |
| from typing import TYPE_CHECKING | |
| from .quantizer import QuantizationAnnotation, Quantizer | |
| if TYPE_CHECKING: | |
| import torch | |
| from torch.fx import Node | |
| __all__ = [ | |
| "ComposableQuantizer", | |
| ] | |
| class ComposableQuantizer(Quantizer): | |
| """ | |
| ComposableQuantizer allows users to combine more than one quantizer into a single quantizer. | |
| This allows users to quantize a model with multiple quantizers. E.g., embedding quantization | |
| maybe supported by one quantizer while linear layers and other ops might be supported by another | |
| quantizer. | |
| ComposableQuantizer is initialized with a list of `Quantizer` instances. | |
| The order of the composition matters since that is the order in which the quantizers will be | |
| applies. | |
| Example: | |
| ``` | |
| embedding_quantizer = EmbeddingQuantizer() | |
| linear_quantizer = MyLinearQuantizer() | |
| xnnpack_quantizer = ( | |
| XNNPackQuantizer() | |
| ) # to handle ops not quantized by previous two quantizers | |
| composed_quantizer = ComposableQuantizer( | |
| [embedding_quantizer, linear_quantizer, xnnpack_quantizer] | |
| ) | |
| prepared_m = prepare_pt2e(model, composed_quantizer) | |
| ``` | |
| """ | |
| def __init__(self, quantizers: list[Quantizer]): | |
| super().__init__() | |
| self.quantizers = quantizers | |
| self._graph_annotations: dict[Node, QuantizationAnnotation] = {} | |
| def _record_and_validate_annotations( | |
| self, gm: torch.fx.GraphModule, quantizer: Quantizer | |
| ) -> None: | |
| for n in gm.graph.nodes: | |
| if "quantization_annotation" in n.meta: | |
| # check if the annotation has been changed by | |
| # comparing QuantizationAnnotation object id | |
| if n in self._graph_annotations and ( | |
| id(self._graph_annotations[n]) | |
| != id(n.meta["quantization_annotation"]) | |
| ): | |
| raise RuntimeError( | |
| f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}" | |
| ) | |
| else: | |
| self._graph_annotations[n] = n.meta["quantization_annotation"] | |
| else: | |
| if n in self._graph_annotations: | |
| raise RuntimeError( | |
| f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}" | |
| ) | |
| def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: | |
| """just handling global spec for now""" | |
| for quantizer in self.quantizers: | |
| quantizer.annotate(model) | |
| self._record_and_validate_annotations(model, quantizer) | |
| return model | |
| def transform_for_annotation( | |
| self, model: torch.fx.GraphModule | |
| ) -> torch.fx.GraphModule: | |
| for quantizer in self.quantizers: | |
| model = quantizer.transform_for_annotation(model) | |
| return model | |
| def validate(self, model: torch.fx.GraphModule) -> None: | |
| pass | |