File size: 3,095 Bytes
ad5f26a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from __future__ import annotations

from typing import TYPE_CHECKING

from .quantizer import QuantizationAnnotation, Quantizer


if TYPE_CHECKING:
    import torch
    from torch.fx import Node

__all__ = [
    "ComposableQuantizer",
]


class ComposableQuantizer(Quantizer):
    """

    ComposableQuantizer allows users to combine more than one quantizer into a single quantizer.

    This allows users to quantize a model with multiple quantizers. E.g., embedding quantization

    maybe supported by one quantizer while linear layers and other ops might be supported by another

    quantizer.



    ComposableQuantizer is initialized with a list of `Quantizer` instances.

    The order of the composition matters since that is the order in which the quantizers will be

    applies.

    Example:

    ```

    embedding_quantizer = EmbeddingQuantizer()

    linear_quantizer = MyLinearQuantizer()

    xnnpack_quantizer = (

        XNNPackQuantizer()

    )  # to handle ops not quantized by previous two quantizers

    composed_quantizer = ComposableQuantizer(

        [embedding_quantizer, linear_quantizer, xnnpack_quantizer]

    )

    prepared_m = prepare_pt2e(model, composed_quantizer)

    ```

    """

    def __init__(self, quantizers: list[Quantizer]):
        super().__init__()
        self.quantizers = quantizers
        self._graph_annotations: dict[Node, QuantizationAnnotation] = {}

    def _record_and_validate_annotations(

        self, gm: torch.fx.GraphModule, quantizer: Quantizer

    ) -> None:
        for n in gm.graph.nodes:
            if "quantization_annotation" in n.meta:
                # check if the annotation has been changed by
                # comparing QuantizationAnnotation object id
                if n in self._graph_annotations and (
                    id(self._graph_annotations[n])
                    != id(n.meta["quantization_annotation"])
                ):
                    raise RuntimeError(
                        f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}"
                    )
                else:
                    self._graph_annotations[n] = n.meta["quantization_annotation"]
            else:
                if n in self._graph_annotations:
                    raise RuntimeError(
                        f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}"
                    )

    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
        """just handling global spec for now"""
        for quantizer in self.quantizers:
            quantizer.annotate(model)
            self._record_and_validate_annotations(model, quantizer)
        return model

    def transform_for_annotation(

        self, model: torch.fx.GraphModule

    ) -> torch.fx.GraphModule:
        for quantizer in self.quantizers:
            model = quantizer.transform_for_annotation(model)
        return model

    def validate(self, model: torch.fx.GraphModule) -> None:
        pass