|
|
from collections import OrderedDict |
|
|
from typing import Dict, Any |
|
|
from torch.ao.quantization.quantization_types import Pattern |
|
|
from ..fake_quantize import FixedQParamsFakeQuantize |
|
|
|
|
|
from ..observer import ObserverBase |
|
|
import copy |
|
|
|
|
|
|
|
|
QuantizeHandler = Any |
|
|
|
|
|
|
|
|
DEFAULT_FUSION_PATTERNS = OrderedDict() |
|
|
def register_fusion_pattern(pattern): |
|
|
def insert(fn): |
|
|
DEFAULT_FUSION_PATTERNS[pattern] = fn |
|
|
return fn |
|
|
return insert |
|
|
|
|
|
def get_default_fusion_patterns() -> Dict[Pattern, QuantizeHandler]: |
|
|
return copy.copy(DEFAULT_FUSION_PATTERNS) |
|
|
|
|
|
DEFAULT_QUANTIZATION_PATTERNS = OrderedDict() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP = {} |
|
|
DEFAULT_OUTPUT_OBSERVER_MAP = {} |
|
|
|
|
|
|
|
|
def register_quant_pattern(pattern, fixed_qparams_observer=None): |
|
|
def insert(fn): |
|
|
DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn |
|
|
if fixed_qparams_observer is not None: |
|
|
DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[pattern] = FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer) |
|
|
DEFAULT_OUTPUT_OBSERVER_MAP[pattern] = fixed_qparams_observer |
|
|
return fn |
|
|
return insert |
|
|
|
|
|
|
|
|
def get_default_quant_patterns() -> Dict[Pattern, QuantizeHandler]: |
|
|
return copy.copy(DEFAULT_QUANTIZATION_PATTERNS) |
|
|
|
|
|
|
|
|
|
|
|
def get_default_output_activation_post_process_map(is_training) -> Dict[Pattern, ObserverBase]: |
|
|
if is_training: |
|
|
return copy.copy(DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP) |
|
|
else: |
|
|
return copy.copy(DEFAULT_OUTPUT_OBSERVER_MAP) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sorted_patterns_dict(patterns_dict: Dict[Pattern, QuantizeHandler]) -> Dict[Pattern, QuantizeHandler]: |
|
|
""" |
|
|
Return a sorted version of the patterns dictionary such that longer patterns are matched first, |
|
|
e.g. match (F.relu, F.linear) before F.relu. |
|
|
This works for current use cases, but we may need to have a more clever way to sort |
|
|
things to address more complex patterns |
|
|
""" |
|
|
|
|
|
def get_len(pattern): |
|
|
""" this will calculate the length of the pattern by counting all the entries |
|
|
in the pattern. |
|
|
this will make sure (nn.ReLU, (nn.BatchNorm, nn.Conv2d)) comes before |
|
|
(nn.BatchNorm, nn.Conv2d) so that we can match the former first |
|
|
""" |
|
|
len = 0 |
|
|
if isinstance(pattern, tuple): |
|
|
for item in pattern: |
|
|
len += get_len(item) |
|
|
else: |
|
|
len += 1 |
|
|
return len |
|
|
|
|
|
return OrderedDict(sorted(patterns_dict.items(), key=lambda kv: -get_len(kv[0]) if isinstance(kv[0], tuple) else 1)) |
|
|
|