|
|
import sys |
|
|
import torch |
|
|
import types |
|
|
from typing import List |
|
|
|
|
|
|
|
|
def _get_qengine_id(qengine: str) -> int: |
|
|
if qengine == 'none' or qengine == '' or qengine is None: |
|
|
ret = 0 |
|
|
elif qengine == 'fbgemm': |
|
|
ret = 1 |
|
|
elif qengine == 'qnnpack': |
|
|
ret = 2 |
|
|
elif qengine == 'onednn': |
|
|
ret = 3 |
|
|
elif qengine == 'x86': |
|
|
ret = 4 |
|
|
else: |
|
|
ret = -1 |
|
|
raise RuntimeError("{} is not a valid value for quantized engine".format(qengine)) |
|
|
return ret |
|
|
|
|
|
|
|
|
def _get_qengine_str(qengine: int) -> str: |
|
|
all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack', 3 : 'onednn', 4 : 'x86'} |
|
|
return all_engines.get(qengine, '*undefined') |
|
|
|
|
|
class _QEngineProp(object): |
|
|
def __get__(self, obj, objtype) -> str: |
|
|
return _get_qengine_str(torch._C._get_qengine()) |
|
|
|
|
|
def __set__(self, obj, val: str) -> None: |
|
|
torch._C._set_qengine(_get_qengine_id(val)) |
|
|
|
|
|
class _SupportedQEnginesProp(object): |
|
|
def __get__(self, obj, objtype) -> List[str]: |
|
|
qengines = torch._C._supported_qengines() |
|
|
return [_get_qengine_str(qe) for qe in qengines] |
|
|
|
|
|
def __set__(self, obj, val) -> None: |
|
|
raise RuntimeError("Assignment not supported") |
|
|
|
|
|
class QuantizedEngine(types.ModuleType): |
|
|
def __init__(self, m, name): |
|
|
super(QuantizedEngine, self).__init__(name) |
|
|
self.m = m |
|
|
|
|
|
def __getattr__(self, attr): |
|
|
return self.m.__getattribute__(attr) |
|
|
|
|
|
engine = _QEngineProp() |
|
|
supported_engines = _SupportedQEnginesProp() |
|
|
|
|
|
|
|
|
|
|
|
sys.modules[__name__] = QuantizedEngine(sys.modules[__name__], __name__) |
|
|
engine: str |
|
|
supported_engines: List[str] |
|
|
|