Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Visitor.cpython-311-x86_64-linux-gnu.so +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/impl.py +976 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/fft.py +590 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/__init__.py +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/__init__.py +18 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/utils.py +533 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py +160 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/fbgemm.py +116 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/native.py +204 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fuser_method_mappings.py +259 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__init__.py +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_equalize.py +820 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/detector.py +1539 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py +666 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/quantize_handler.py +197 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/utils.py +885 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -69,3 +69,6 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distl
|
|
| 69 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__pycache__/console.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 70 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Debugger/__pycache__/libpython.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 71 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FlowControl.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__pycache__/console.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 70 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Debugger/__pycache__/libpython.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 71 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FlowControl.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 72 |
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so filter=lfs diff=lfs merge=lfs -text
|
| 73 |
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Visitor.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 74 |
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Visitor.cpython-311-x86_64-linux-gnu.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8597f1985804f6c0c55b84d29a8744f0e2bc6600aaa695402499fbbbcba1decc
|
| 3 |
+
size 374848
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:38073c63ab8f022926f58f7cb39c565005f382bdfacd85822e7502a5256b6671
|
| 3 |
+
size 1509528
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f1e41fea31e2f114e2b8bb3065092e62588a33b909a8fa70bc578e734128e529
|
| 3 |
+
size 176864
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/impl.py
ADDED
|
@@ -0,0 +1,976 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import functools
|
| 3 |
+
import inspect
|
| 4 |
+
import sys
|
| 5 |
+
import typing
|
| 6 |
+
import weakref
|
| 7 |
+
|
| 8 |
+
from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch._C as _C
|
| 12 |
+
import torch.library as library
|
| 13 |
+
from torch._library.abstract_impl import AbstractImplCtx
|
| 14 |
+
from torch.library import get_ctx
|
| 15 |
+
|
| 16 |
+
from .autograd import autograd_kernel_indirection, construct_autograd_kernel
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
For a detailed guide on custom ops, please see
|
| 20 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
| 21 |
+
|
| 22 |
+
This file includes pieces of the implementation of our custom operator API.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
__all__ = ["custom_op", "CustomOp", "get_ctx", "AbstractImplCtx"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
SUPPORTED_DEVICE_TYPE_TO_KEY = {
|
| 29 |
+
"cpu": "CPU",
|
| 30 |
+
"cuda": "CUDA",
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# We will not let users register CustomOps with anything that could look like
|
| 34 |
+
# PyTorch internals to avoid confusion.
|
| 35 |
+
RESERVED_NS = {
|
| 36 |
+
"prim",
|
| 37 |
+
"prims",
|
| 38 |
+
"aten",
|
| 39 |
+
"at",
|
| 40 |
+
"torch",
|
| 41 |
+
"pytorch",
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def custom_op(
|
| 46 |
+
qualname: str, manual_schema: typing.Optional[str] = None
|
| 47 |
+
) -> typing.Callable:
|
| 48 |
+
r"""Creates a new CustomOp object.
|
| 49 |
+
|
| 50 |
+
WARNING: if you're a user, please do not use this directly
|
| 51 |
+
(instead use the torch._custom_ops APIs).
|
| 52 |
+
Also please see the following for a detailed guide on custom ops.
|
| 53 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
| 54 |
+
|
| 55 |
+
In PyTorch, defining an op (short for "operator") is a two step-process:
|
| 56 |
+
- we need to define (create) the op
|
| 57 |
+
- we need to implement behavior for how the operator interacts with
|
| 58 |
+
various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
|
| 59 |
+
|
| 60 |
+
This entrypoint defines the CustomOp object (the first step);
|
| 61 |
+
you must then perform the second step by calling various methods on
|
| 62 |
+
the CustomOp object.
|
| 63 |
+
|
| 64 |
+
This API is used as a decorator (see examples).
|
| 65 |
+
|
| 66 |
+
Arguments:
|
| 67 |
+
qualname (str): Should be a string that looks like
|
| 68 |
+
"namespace::operator_name". Operators in PyTorch need a namespace to
|
| 69 |
+
avoid name collisions; a given operator may only be created once.
|
| 70 |
+
If you are writing a Python library, we recommend the namespace to
|
| 71 |
+
be the name of your top-level module. The operator_name must be
|
| 72 |
+
the same as the name of the function you pass to custom_op
|
| 73 |
+
(see examples).
|
| 74 |
+
manual_schema (Optional[str]): Each PyTorch operator needs a schema that
|
| 75 |
+
tells PyTorch the types of the inputs/outputs. If None (default),
|
| 76 |
+
we will infer the schema from the type annotations on the function
|
| 77 |
+
(see examples). Otherwise, if you don't want to use type annotations,
|
| 78 |
+
you may provide us the schema string.
|
| 79 |
+
|
| 80 |
+
Example::
|
| 81 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
| 82 |
+
>>> import numpy as np
|
| 83 |
+
>>> from torch import Tensor
|
| 84 |
+
>>>
|
| 85 |
+
>>> # Step 1: define the CustomOp.
|
| 86 |
+
>>> # We need to provide the decorator a "prototype function"
|
| 87 |
+
>>> # (a function with Python ellipses as the body).
|
| 88 |
+
>>> @custom_op("my_library::numpy_sin")
|
| 89 |
+
>>> def numpy_sin(x: Tensor) -> Tensor:
|
| 90 |
+
>>> ...
|
| 91 |
+
>>>
|
| 92 |
+
>>> # numpy_sin is now an instance of class CustomOp
|
| 93 |
+
>>> print(type(numpy_sin))
|
| 94 |
+
>>>
|
| 95 |
+
>>> # Step 2: Register an implementation for various PyTorch subsystems
|
| 96 |
+
>>>
|
| 97 |
+
>>> # Register an implementation for CPU tensors
|
| 98 |
+
>>> @numpy_sin.impl('cpu')
|
| 99 |
+
>>> def numpy_sin_impl_cpu(x):
|
| 100 |
+
>>> return torch.from_numpy(np.sin(x.numpy()))
|
| 101 |
+
>>>
|
| 102 |
+
>>> # Register an implementation for CUDA tensors
|
| 103 |
+
>>> @numpy_sin.impl('cuda')
|
| 104 |
+
>>> def numpy_sin_impl_cuda(x):
|
| 105 |
+
>>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
|
| 106 |
+
>>>
|
| 107 |
+
>>> x = torch.randn(3)
|
| 108 |
+
>>> numpy_sin(x) # calls numpy_sin_impl_cpu
|
| 109 |
+
>>>
|
| 110 |
+
>>> x_cuda = x.cuda()
|
| 111 |
+
>>> numpy_sin(x) # calls numpy_sin_impl_cuda
|
| 112 |
+
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def inner(func):
|
| 116 |
+
if not inspect.isfunction(func):
|
| 117 |
+
raise ValueError(
|
| 118 |
+
f"custom_op(...)(func): Expected `func` to be a Python "
|
| 119 |
+
f"function, got: {type(func)}"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
ns, name = parse_qualname(qualname)
|
| 123 |
+
validate_namespace(ns)
|
| 124 |
+
if func.__name__ != name:
|
| 125 |
+
raise ValueError(
|
| 126 |
+
f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
|
| 127 |
+
f"to have name '{name}' but got '{func.__name__}'. "
|
| 128 |
+
f"Please either change the name of `func` or the qualname that "
|
| 129 |
+
f"is passed to `custom_op`"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
schema = infer_schema(func) if manual_schema is None else manual_schema
|
| 133 |
+
schema_str = f"{name}{schema}"
|
| 134 |
+
function_schema = FunctionSchema.parse(schema_str)
|
| 135 |
+
validate_schema(function_schema)
|
| 136 |
+
if manual_schema is not None:
|
| 137 |
+
validate_function_matches_schema(function_schema, func)
|
| 138 |
+
|
| 139 |
+
lib = library.Library(ns, "FRAGMENT")
|
| 140 |
+
lib.define(schema_str)
|
| 141 |
+
ophandle = find_ophandle_or_throw(ns, function_schema.name)
|
| 142 |
+
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
|
| 143 |
+
|
| 144 |
+
result.__name__ = func.__name__
|
| 145 |
+
result.__module__ = func.__module__
|
| 146 |
+
result.__doc__ = func.__doc__
|
| 147 |
+
|
| 148 |
+
library.impl(lib, result._opname, "Autograd")(
|
| 149 |
+
autograd_kernel_indirection(weakref.proxy(result))
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
torch._C._dispatch_set_report_error_callback(
|
| 153 |
+
ophandle, functools.partial(report_error_callback, weakref.proxy(result))
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return result
|
| 157 |
+
|
| 158 |
+
return inner
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# Global dictionary holding references to all CustomOp objects
|
| 162 |
+
# Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
|
| 163 |
+
# Used to query the CustomOp associated with a specific C++ dispatcher operator.
|
| 164 |
+
# An example usage is FakeTensor: FakeTensor checks if a specific operator
|
| 165 |
+
# has an implementation registered via the CustomOp API.
|
| 166 |
+
# Indexed by qualname (e.g. aten::foo)
|
| 167 |
+
global_registry: typing.Dict[str, "CustomOp"] = {}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class CustomOp:
|
| 171 |
+
r"""Class for custom operators in PyTorch.
|
| 172 |
+
|
| 173 |
+
Use the CustomOp API to create user-defined custom operators that behave
|
| 174 |
+
just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it
|
| 175 |
+
comes to various PyTorch subsystems (like torch.compile).
|
| 176 |
+
|
| 177 |
+
To construct a `CustomOp`, use `custom_op`.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False):
|
| 181 |
+
super().__init__()
|
| 182 |
+
if not _private_access:
|
| 183 |
+
raise RuntimeError(
|
| 184 |
+
"The CustomOp constructor is private and we do not guarantee "
|
| 185 |
+
"BC for it. Please use custom_op(...) to create a CustomOp object"
|
| 186 |
+
)
|
| 187 |
+
name = f"{cpp_ns}::{operator_name}"
|
| 188 |
+
self._schema = schema
|
| 189 |
+
self._cpp_ns = cpp_ns
|
| 190 |
+
self._lib: library.Library = lib
|
| 191 |
+
self._ophandle: _C._DispatchOperatorHandle = ophandle
|
| 192 |
+
# Has the name of the op, e.g. "foo". We cache here for convenience.
|
| 193 |
+
self._opname: str = operator_name
|
| 194 |
+
# this is _opname but with namespace. e.g. "custom::foo"
|
| 195 |
+
self._qualname: str = name
|
| 196 |
+
self.__name__ = None # mypy requires this
|
| 197 |
+
# NB: Some of these impls are registered as kernels to DispatchKeys.
|
| 198 |
+
# Modifying the _impls dict directly won't do anything in that case.
|
| 199 |
+
self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
|
| 200 |
+
# See NOTE [CustomOp autograd kernel indirection]
|
| 201 |
+
self._registered_autograd_kernel_indirection = False
|
| 202 |
+
|
| 203 |
+
global_registry[self._qualname] = self
|
| 204 |
+
|
| 205 |
+
def _register_autograd_kernel_indirection(self):
|
| 206 |
+
assert not self._registered_autograd_kernel_indirection
|
| 207 |
+
self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd")
|
| 208 |
+
self._registered_autograd_kernel_indirection = True
|
| 209 |
+
|
| 210 |
+
# Records the impl and the source location in self._impls
|
| 211 |
+
# Note that this doesn't cause torch.library to use the impl, that
|
| 212 |
+
# needs to be done in a separate self._lib.impl call.
|
| 213 |
+
def _register_impl(self, kind, func, stacklevel=2):
|
| 214 |
+
if self._has_impl(kind):
|
| 215 |
+
func_and_location = self._impls[kind]
|
| 216 |
+
assert func_and_location is not None # Pacify mypy
|
| 217 |
+
location = func_and_location.location
|
| 218 |
+
raise RuntimeError(
|
| 219 |
+
f"Attempting to register a {kind} impl for operator {self._qualname} "
|
| 220 |
+
f"that already has a {kind} impl registered from Python at "
|
| 221 |
+
f"{location}. This is not supported."
|
| 222 |
+
)
|
| 223 |
+
frame = inspect.getframeinfo(sys._getframe(stacklevel))
|
| 224 |
+
location = f"{frame.filename}:{frame.lineno}"
|
| 225 |
+
self._impls[kind] = FuncAndLocation(func, location)
|
| 226 |
+
|
| 227 |
+
def _get_impl(self, kind):
|
| 228 |
+
return self._impls[kind]
|
| 229 |
+
|
| 230 |
+
def _has_impl(self, kind):
|
| 231 |
+
return kind in self._impls
|
| 232 |
+
|
| 233 |
+
def _destroy(self):
|
| 234 |
+
# NOTE: [CustomOp lifetime]
|
| 235 |
+
# A CustomOp, once created, lives forever. The mechanism is that the
|
| 236 |
+
# global registry holds a reference to it. However, to make testing
|
| 237 |
+
# easier, we want to be able to destroy CustomOp objects.
|
| 238 |
+
# CustomOp._destroy does the job, though it leaves the CustomOp
|
| 239 |
+
# in a garbage state.
|
| 240 |
+
del self._lib
|
| 241 |
+
|
| 242 |
+
opnamespace = getattr(torch.ops, self._cpp_ns)
|
| 243 |
+
if hasattr(opnamespace, self._opname):
|
| 244 |
+
delattr(opnamespace, self._opname)
|
| 245 |
+
|
| 246 |
+
del global_registry[self._qualname]
|
| 247 |
+
|
| 248 |
+
def __repr__(self):
|
| 249 |
+
return f'<CustomOp(op="{self._qualname}")>'
|
| 250 |
+
|
| 251 |
+
def __call__(self, *args, **kwargs):
|
| 252 |
+
# Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
|
| 253 |
+
# Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
|
| 254 |
+
# issues from caching operators that make testing CustomOp difficult).
|
| 255 |
+
result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
|
| 256 |
+
return result
|
| 257 |
+
|
| 258 |
+
def impl(
|
| 259 |
+
self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2,
|
| 260 |
+
) -> typing.Callable:
|
| 261 |
+
r"""Register an implementation for a device type for this CustomOp object.
|
| 262 |
+
|
| 263 |
+
WARNING: if you're a user, please do not use this directly
|
| 264 |
+
(instead use the torch._custom_ops APIs).
|
| 265 |
+
Also please see the following for a detailed guide on custom ops.
|
| 266 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
| 267 |
+
|
| 268 |
+
If the CustomOp is passed multiple Tensor inputs with different device
|
| 269 |
+
types, it will dispatch to the registered implementation for the highest
|
| 270 |
+
priority device type among those present.
|
| 271 |
+
The supported device types, in order of priority, are {'cuda', 'cpu'}.
|
| 272 |
+
|
| 273 |
+
This API is used as a decorator (see examples).
|
| 274 |
+
|
| 275 |
+
Arguments:
|
| 276 |
+
device_types (str or Iterable[str]): the device type(s) to register the function for.
|
| 277 |
+
|
| 278 |
+
Examples::
|
| 279 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
| 280 |
+
>>> import numpy as np
|
| 281 |
+
>>> from torch import Tensor
|
| 282 |
+
>>>
|
| 283 |
+
>>> @custom_op("my_library::numpy_cos")
|
| 284 |
+
>>> def numpy_cos(x: Tensor) -> Tensor:
|
| 285 |
+
>>> ...
|
| 286 |
+
>>>
|
| 287 |
+
>>> # Register an implementation for CPU Tensors
|
| 288 |
+
>>> @numpy_cos.impl('cpu')
|
| 289 |
+
>>> def numpy_cos_impl_cpu(x):
|
| 290 |
+
>>> return torch.from_numpy(np.cos(x.numpy()))
|
| 291 |
+
>>>
|
| 292 |
+
>>> # Register an implementation for CUDA Tensors
|
| 293 |
+
>>> @numpy_cos.impl('cuda')
|
| 294 |
+
>>> def numpy_cos_impl_cuda(x):
|
| 295 |
+
>>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
|
| 296 |
+
>>>
|
| 297 |
+
>>> x = torch.randn(3)
|
| 298 |
+
>>> numpy_cos(x) # calls numpy_cos_impl_cpu
|
| 299 |
+
>>>
|
| 300 |
+
>>> x_cuda = x.cuda()
|
| 301 |
+
>>> numpy_cos(x) # calls numpy_cos_impl_cuda
|
| 302 |
+
|
| 303 |
+
"""
|
| 304 |
+
if isinstance(device_types, str):
|
| 305 |
+
device_types = [device_types]
|
| 306 |
+
for device_type in device_types:
|
| 307 |
+
validate_device_type(device_type)
|
| 308 |
+
|
| 309 |
+
def inner(f):
|
| 310 |
+
for device_type in set(device_types):
|
| 311 |
+
self._check_doesnt_have_library_impl(device_type)
|
| 312 |
+
self._register_impl(device_type, f, stacklevel=_stacklevel)
|
| 313 |
+
dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
|
| 314 |
+
library.impl(self._lib, self._opname, dispatch_key)(f)
|
| 315 |
+
return f
|
| 316 |
+
|
| 317 |
+
return inner
|
| 318 |
+
|
| 319 |
+
def _check_doesnt_have_library_impl(self, device_type):
|
| 320 |
+
if self._has_impl(device_type):
|
| 321 |
+
return
|
| 322 |
+
key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
|
| 323 |
+
if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
|
| 324 |
+
raise RuntimeError(
|
| 325 |
+
f"impl(..., device_types={device_type}): the operator {self._qualname} "
|
| 326 |
+
f"already has an implementation for this device type via a "
|
| 327 |
+
f"pre-existing torch.library or TORCH_LIBRARY registration.")
|
| 328 |
+
|
| 329 |
+
def impl_factory(self) -> typing.Callable:
|
| 330 |
+
r"""Register an implementation for a factory function."""
|
| 331 |
+
|
| 332 |
+
def inner(f):
|
| 333 |
+
self._register_impl("factory", f)
|
| 334 |
+
library.impl(self._lib, self._opname, "BackendSelect")(f)
|
| 335 |
+
return f
|
| 336 |
+
|
| 337 |
+
return inner
|
| 338 |
+
|
| 339 |
+
def impl_abstract(self, _stacklevel=2) -> typing.Callable:
|
| 340 |
+
r"""Register an abstract implementation for this operator.
|
| 341 |
+
|
| 342 |
+
WARNING: please do not use this directly (and instead use the torch._custom_ops
|
| 343 |
+
APIs). Also please see the following for a detailed guide on custom ops.
|
| 344 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
| 345 |
+
|
| 346 |
+
An "abstract implementation" specifies the behavior of this operator on
|
| 347 |
+
Tensors that carry no data. Given some input Tensors with certain properties
|
| 348 |
+
(sizes/strides/storage_offset/device), it specifies what the properties of
|
| 349 |
+
the output Tensors are.
|
| 350 |
+
|
| 351 |
+
The abstract implementation has the same signature as the operator.
|
| 352 |
+
It is run for both FakeTensors and meta tensors. To write an abstract
|
| 353 |
+
implementation, assume that all Tensor inputs to the operator are
|
| 354 |
+
regular CPU/CUDA/Meta tensors, but they do not have storage, and
|
| 355 |
+
you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
|
| 356 |
+
The abstract implementation must consist of only PyTorch operations
|
| 357 |
+
(and may not directly access the storage or data of any input or
|
| 358 |
+
intermediate Tensors).
|
| 359 |
+
|
| 360 |
+
This API is used as a decorator (see examples).
|
| 361 |
+
|
| 362 |
+
Examples::
|
| 363 |
+
>>> import numpy as np
|
| 364 |
+
>>> from torch import Tensor
|
| 365 |
+
>>>
|
| 366 |
+
>>> # Example 1: an operator without data-dependent output shape
|
| 367 |
+
>>> @custom_op('my_library::custom_linear')
|
| 368 |
+
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
|
| 369 |
+
>>> ...
|
| 370 |
+
>>>
|
| 371 |
+
>>> @custom_linear.impl_abstract()
|
| 372 |
+
>>> def custom_linear_abstract(x, weight):
|
| 373 |
+
>>> assert x.dim() == 2
|
| 374 |
+
>>> assert weight.dim() == 2
|
| 375 |
+
>>> assert bias.dim() == 1
|
| 376 |
+
>>> assert x.shape[1] == weight.shape[1]
|
| 377 |
+
>>> assert weight.shape[0] == bias.shape[0]
|
| 378 |
+
>>> assert x.device == weight.device
|
| 379 |
+
>>>
|
| 380 |
+
>>> return (x @ weight.t()) + bias
|
| 381 |
+
>>>
|
| 382 |
+
>>> # Example 2: an operator with data-dependent output shape
|
| 383 |
+
>>> @custom_op('my_library::custom_nonzero')
|
| 384 |
+
>>> def custom_nonzero(x: Tensor) -> Tensor:
|
| 385 |
+
>>> ...
|
| 386 |
+
>>>
|
| 387 |
+
>>> @custom_nonzero.impl_abstract()
|
| 388 |
+
>>> def custom_nonzero_abstract(x):
|
| 389 |
+
>>> # Number of nonzero-elements is data-dependent.
|
| 390 |
+
>>> # Since we cannot peek at the data in an abstract impl,
|
| 391 |
+
>>> # we use the ctx object to construct a new symint that
|
| 392 |
+
>>> # represents the data-dependent size.
|
| 393 |
+
>>> ctx = torch._custom_op.get_ctx()
|
| 394 |
+
>>> nnz = ctx.create_unbacked_symint()
|
| 395 |
+
>>> shape = [x.dim(), nnz]
|
| 396 |
+
>>> result = x.new_empty(shape, dtype=torch.long)
|
| 397 |
+
>>> return result
|
| 398 |
+
>>>
|
| 399 |
+
>>> @custom_nonzero.impl(['cpu', 'cuda'])
|
| 400 |
+
>>> def custom_nonzero_impl(x):
|
| 401 |
+
>>> x_np = to_numpy(x)
|
| 402 |
+
>>> res = np.stack(np.nonzero(x_np), axis=1)
|
| 403 |
+
>>> # unbacked symbolic ints in PyTorch must be >= 2, so we
|
| 404 |
+
>>> # constrain the range to at least 2
|
| 405 |
+
>>> if res.shape[0] <= 1:
|
| 406 |
+
>>> raise RuntimeError("not supported")
|
| 407 |
+
>>> return torch.tensor(res, device=x.device)
|
| 408 |
+
|
| 409 |
+
"""
|
| 410 |
+
|
| 411 |
+
def inner(f):
|
| 412 |
+
self._check_doesnt_have_library_meta_impl()
|
| 413 |
+
self._register_impl("abstract", f, stacklevel=_stacklevel)
|
| 414 |
+
location = self._get_impl("abstract").location
|
| 415 |
+
|
| 416 |
+
qualname = self._qualname
|
| 417 |
+
|
| 418 |
+
# Handle DispatchKey.Meta registration
|
| 419 |
+
@functools.wraps(f)
|
| 420 |
+
def f_with_ctx(*args, **kwargs):
|
| 421 |
+
def error_on_ctx():
|
| 422 |
+
raise RuntimeError(
|
| 423 |
+
f"Attempted to call get_ctx() for the meta implementation "
|
| 424 |
+
f"for {qualname}."
|
| 425 |
+
f"You have presumably called get_ctx() because the operator "
|
| 426 |
+
f"has a data-dependent output shape; if so, there is no "
|
| 427 |
+
f"such meta implementation and this error is the correct "
|
| 428 |
+
f"behavior. Otherwise, please remove the call to get_ctx() "
|
| 429 |
+
f"in the implementation registered with impl_abstract "
|
| 430 |
+
f"at {location}"
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
with torch._library.abstract_impl.set_ctx_getter(error_on_ctx):
|
| 434 |
+
return f(*args, **kwargs)
|
| 435 |
+
|
| 436 |
+
self._lib.impl(self._opname, f_with_ctx, "Meta")
|
| 437 |
+
return f
|
| 438 |
+
|
| 439 |
+
return inner
|
| 440 |
+
|
| 441 |
+
def _check_can_register_backward(self):
|
| 442 |
+
def error(detail):
|
| 443 |
+
raise RuntimeError(
|
| 444 |
+
f"Cannot use torch._custom_ops APIs to register backward "
|
| 445 |
+
f"formula for {detail}. Got operator "
|
| 446 |
+
f"{self._qualname} with schema: {schema}"
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
schema = self._schema
|
| 450 |
+
if schema.kind() != SchemaKind.functional:
|
| 451 |
+
error("non-functional operator")
|
| 452 |
+
|
| 453 |
+
rets = schema.returns
|
| 454 |
+
if not schema.returns:
|
| 455 |
+
error("operator with no returns")
|
| 456 |
+
|
| 457 |
+
assert len(rets) > 0
|
| 458 |
+
is_non_mutating_view = any(
|
| 459 |
+
r.annotation is not None and not r.annotation.is_write for r in rets
|
| 460 |
+
)
|
| 461 |
+
if is_non_mutating_view:
|
| 462 |
+
error("operator that returns views")
|
| 463 |
+
|
| 464 |
+
# We make assumptions about the schema's return types.
|
| 465 |
+
allowed_return_types = {
|
| 466 |
+
BaseType(BaseTy.int): "int",
|
| 467 |
+
BaseType(BaseTy.SymInt): "SymInt",
|
| 468 |
+
BaseType(BaseTy.bool): "bool",
|
| 469 |
+
BaseType(BaseTy.float): "float",
|
| 470 |
+
BaseType(BaseTy.Tensor): "Tensor",
|
| 471 |
+
ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
|
| 472 |
+
}
|
| 473 |
+
for ret in schema.returns:
|
| 474 |
+
if ret.type in allowed_return_types:
|
| 475 |
+
continue
|
| 476 |
+
error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})")
|
| 477 |
+
|
| 478 |
+
def _check_doesnt_have_library_autograd_impl(self):
|
| 479 |
+
if self._registered_autograd_kernel_indirection:
|
| 480 |
+
return
|
| 481 |
+
|
| 482 |
+
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
|
| 483 |
+
raise RuntimeError(
|
| 484 |
+
f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
|
| 485 |
+
f"already has an implementation for this device type via a "
|
| 486 |
+
f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
|
| 487 |
+
f"CompositeImplicitAutograd operators do not need an autograd formula; "
|
| 488 |
+
f"instead, the operator will decompose into its constituents and those "
|
| 489 |
+
f"can have autograd formulas defined on them.")
|
| 490 |
+
|
| 491 |
+
# We can improve this by adding "all Autograd<BACKEND> keys", but
|
| 492 |
+
# realistically people will just be using this API for CPU/CUDA for now.
|
| 493 |
+
for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
|
| 494 |
+
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
|
| 495 |
+
raise RuntimeError(
|
| 496 |
+
f"impl_backward/impl_save_for_backward: "
|
| 497 |
+
f"the operator {self._qualname} already has an Autograd kernel "
|
| 498 |
+
f"registered to DispatchKey::{key} vi a pre-existing "
|
| 499 |
+
f"torch.library or TORCH_LIBRARY registration. Please either "
|
| 500 |
+
f"remove those registrations or don't use the torch._custom_ops APIs")
|
| 501 |
+
|
| 502 |
+
def _check_doesnt_have_library_meta_impl(self):
|
| 503 |
+
if self._has_impl("abstract"):
|
| 504 |
+
return
|
| 505 |
+
|
| 506 |
+
# If the user's operator is CompositeExplicitAutograd,
|
| 507 |
+
# allow them to impl_abstract. This is being pragmatic
|
| 508 |
+
# (existing custom ops may have CompositeExplicitAutograd
|
| 509 |
+
# registration that don't work with Meta kernels, so this
|
| 510 |
+
# gives them an escape hatch).
|
| 511 |
+
if (
|
| 512 |
+
_C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd")
|
| 513 |
+
and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta")
|
| 514 |
+
):
|
| 515 |
+
return
|
| 516 |
+
|
| 517 |
+
# Otherwise, if the user's already has a Meta kernel or their
|
| 518 |
+
# op is CompositeImplicitAutograd or some other alias dispatch key,
|
| 519 |
+
# raise.
|
| 520 |
+
|
| 521 |
+
# Special case for CompositeImplicitAutograd
|
| 522 |
+
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
|
| 523 |
+
raise RuntimeError(
|
| 524 |
+
f"impl_abstract(...): the operator {self._qualname} "
|
| 525 |
+
f"already has an implementation for this device type via a "
|
| 526 |
+
f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
|
| 527 |
+
f"CompositeImplicitAutograd operators do not need an abstract impl; "
|
| 528 |
+
f"instead, the operator will decompose into its constituents and those "
|
| 529 |
+
f"can have abstract impls defined on them.")
|
| 530 |
+
|
| 531 |
+
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
|
| 532 |
+
raise RuntimeError(
|
| 533 |
+
f"impl_abstract(...): the operator {self._qualname} "
|
| 534 |
+
f"already has an DispatchKey::Meta implementation via a "
|
| 535 |
+
f"pre-existing torch.library or TORCH_LIBRARY registration. "
|
| 536 |
+
f"Please either remove that registration or don't call impl_abstract.")
|
| 537 |
+
|
| 538 |
+
# NOTE ["backward", "save_for_backward", and "autograd"]
|
| 539 |
+
# As a part of the explicit autograd API, a user must provide us
|
| 540 |
+
# a "save_for_backward" function and a "backward" function.
|
| 541 |
+
# When both of these have been provided, then we automatically
|
| 542 |
+
# construct the "autograd" kernel.
|
| 543 |
+
def _register_autograd_kernel(self):
|
| 544 |
+
assert self._has_impl("backward")
|
| 545 |
+
assert self._has_impl("save_for_backward")
|
| 546 |
+
kernel = construct_autograd_kernel(
|
| 547 |
+
self._schema,
|
| 548 |
+
self._output_differentiability,
|
| 549 |
+
self,
|
| 550 |
+
get_op(self._qualname),
|
| 551 |
+
self._get_impl("save_for_backward").func,
|
| 552 |
+
self._get_impl("backward").func)
|
| 553 |
+
self._register_impl("autograd", kernel)
|
| 554 |
+
|
| 555 |
+
def impl_save_for_backward(self, _stacklevel=2):
|
| 556 |
+
r"""Register a function that tells us what to save for backward.
|
| 557 |
+
|
| 558 |
+
Please see impl_backward for more details.
|
| 559 |
+
"""
|
| 560 |
+
def inner(f):
|
| 561 |
+
self._check_can_register_backward()
|
| 562 |
+
self._check_doesnt_have_library_autograd_impl()
|
| 563 |
+
if not self._registered_autograd_kernel_indirection:
|
| 564 |
+
self._register_autograd_kernel_indirection()
|
| 565 |
+
self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
|
| 566 |
+
if self._has_impl("backward"):
|
| 567 |
+
self._register_autograd_kernel()
|
| 568 |
+
return inner
|
| 569 |
+
|
| 570 |
+
def impl_backward(self, output_differentiability=None, _stacklevel=2):
|
| 571 |
+
r"""Registers a backward formula.
|
| 572 |
+
|
| 573 |
+
WARNING: if you're a user, please do not use this directly
|
| 574 |
+
(instead use the torch._custom_ops APIs).
|
| 575 |
+
Also please see the following for a detailed guide on custom ops.
|
| 576 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
| 577 |
+
|
| 578 |
+
In order for the CustomOp to work with autograd, you need to register
|
| 579 |
+
a backward formula. There are two pieces to this:
|
| 580 |
+
1. You must give us a function to specify what to save for backward.
|
| 581 |
+
Call this the "save for backward" function.
|
| 582 |
+
2. You must give us a function that computes gradients. Call this the
|
| 583 |
+
"backward" function.
|
| 584 |
+
|
| 585 |
+
Use `impl_save_for_backward` to define a "save for backward" function
|
| 586 |
+
that specifies what gets saved for backward. The function should accept
|
| 587 |
+
two arguments ``(inputs, output)`` and return the quantities to be saved
|
| 588 |
+
for backward.
|
| 589 |
+
|
| 590 |
+
During runtime, when you call the CustomOp, PyTorch will invoke the
|
| 591 |
+
"save for backward" function with the inputs and output of the CustomOp.
|
| 592 |
+
|
| 593 |
+
Use `impl_backward` to define the "backward" function. The backward
|
| 594 |
+
function must accept ``(ctx, saved, *grads)``:
|
| 595 |
+
- ``ctx`` is a context object where we may provide information
|
| 596 |
+
- ``saved`` is exactly what gets returned from the "save for backward"
|
| 597 |
+
function
|
| 598 |
+
- ``grads`` is one or more gradients. The number of gradients matches
|
| 599 |
+
the number of outputs of the CustomOp.
|
| 600 |
+
|
| 601 |
+
The backward function must return a dict that maps the name of
|
| 602 |
+
an input to the CustomOp to its corresponding gradient. All inputs that
|
| 603 |
+
were declared to be Tensors in the CustomOp definition must be accounted
|
| 604 |
+
for in the dict. The gradient may be a Tensor or None.
|
| 605 |
+
|
| 606 |
+
"""
|
| 607 |
+
if output_differentiability is not None:
|
| 608 |
+
def yell():
|
| 609 |
+
raise RuntimeError(
|
| 610 |
+
f"impl_backward(output_differentiability): expected "
|
| 611 |
+
f"output_differentiability to be a list of bools with "
|
| 612 |
+
f"length equal to the number of outputs of this CustomOp "
|
| 613 |
+
f"got: {output_differentiability}")
|
| 614 |
+
|
| 615 |
+
if not isinstance(output_differentiability, list):
|
| 616 |
+
yell()
|
| 617 |
+
for diff in output_differentiability:
|
| 618 |
+
if not isinstance(diff, bool):
|
| 619 |
+
yell()
|
| 620 |
+
if len(self._schema.returns) != len(output_differentiability):
|
| 621 |
+
yell()
|
| 622 |
+
|
| 623 |
+
def inner(f):
|
| 624 |
+
self._check_can_register_backward()
|
| 625 |
+
self._check_doesnt_have_library_autograd_impl()
|
| 626 |
+
if not self._registered_autograd_kernel_indirection:
|
| 627 |
+
self._register_autograd_kernel_indirection()
|
| 628 |
+
self._register_impl("backward", f, stacklevel=_stacklevel)
|
| 629 |
+
self._output_differentiability = output_differentiability
|
| 630 |
+
if self._has_impl("save_for_backward"):
|
| 631 |
+
self._register_autograd_kernel()
|
| 632 |
+
return inner
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
@dataclasses.dataclass
|
| 636 |
+
class FuncAndLocation:
|
| 637 |
+
func: typing.Callable
|
| 638 |
+
location: str
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
|
| 642 |
+
overload_name = (
|
| 643 |
+
"" if operator_name.overload_name is None else operator_name.overload_name
|
| 644 |
+
)
|
| 645 |
+
return _C._dispatch_find_schema_or_throw(
|
| 646 |
+
f"{cpp_ns}::{str(operator_name.name)}", overload_name
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def validate_namespace(ns: str) -> None:
|
| 651 |
+
if "." in ns:
|
| 652 |
+
raise ValueError(
|
| 653 |
+
f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
|
| 654 |
+
f"valid variable name)"
|
| 655 |
+
)
|
| 656 |
+
if ns in RESERVED_NS:
|
| 657 |
+
raise ValueError(
|
| 658 |
+
f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
|
| 659 |
+
f"please choose something else. "
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
def validate_schema(schema: FunctionSchema) -> None:
|
| 663 |
+
if not torch._library.utils.is_functional_schema(schema):
|
| 664 |
+
raise ValueError(
|
| 665 |
+
f"custom_op only supports functional operators "
|
| 666 |
+
f"(ops that do not mutate any inputs, do not return "
|
| 667 |
+
f"views of the inputs, and has at least one return). "
|
| 668 |
+
f"Got the following non-functional schema: {schema}"
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
# For simplicity: don't allow self arguments
|
| 672 |
+
if schema.arguments.self_arg is not None:
|
| 673 |
+
raise ValueError(
|
| 674 |
+
f"custom_op does not support arguments named 'self'. Please "
|
| 675 |
+
f"rename your argument. Got: {schema}"
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
def parse_qualname(qualname: str) -> typing.Tuple[str, str]:
|
| 680 |
+
names = qualname.split("::", 1)
|
| 681 |
+
if len(names) != 2:
|
| 682 |
+
raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The "
|
| 683 |
+
f"operator name should look something like ns::foo")
|
| 684 |
+
if '.' in names[1]:
|
| 685 |
+
raise ValueError(f"The torch.custom_ops APIs do not handle overloads, "
|
| 686 |
+
f"i.e. operator names with '.' in them. "
|
| 687 |
+
f"Please name your operator something like ns::foo. "
|
| 688 |
+
f"Got: {qualname}")
|
| 689 |
+
return names[0], names[1]
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
def validate_device_type(device_type: str) -> None:
|
| 693 |
+
if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
|
| 694 |
+
raise ValueError(
|
| 695 |
+
f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
|
| 696 |
+
f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
def supported_param(param: inspect.Parameter) -> bool:
|
| 701 |
+
return param.kind in (
|
| 702 |
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
| 703 |
+
inspect.Parameter.KEYWORD_ONLY,
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
def validate_function_matches_schema(
|
| 708 |
+
schema: FunctionSchema, func: typing.Callable
|
| 709 |
+
) -> None:
|
| 710 |
+
sig = inspect.signature(func)
|
| 711 |
+
|
| 712 |
+
if not all(supported_param(p) for _, p in sig.parameters.items()):
|
| 713 |
+
raise ValueError(
|
| 714 |
+
f"custom_op(..., manual_schema)(func): positional-only args, "
|
| 715 |
+
f"varargs, and kwargs are not supported. Please rewrite `func` "
|
| 716 |
+
f"to not have them. Got `func` with signature: {sig}"
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
if (
|
| 720 |
+
any(
|
| 721 |
+
p.annotation is not inspect.Parameter.empty
|
| 722 |
+
for _, p in sig.parameters.items()
|
| 723 |
+
)
|
| 724 |
+
or sig.return_annotation is not inspect.Signature.empty
|
| 725 |
+
):
|
| 726 |
+
raise ValueError(
|
| 727 |
+
f"custom_op(..., manual_schema)(func): When passing in a manual "
|
| 728 |
+
f"schema, we expect `func` to have no type annotations to avoid "
|
| 729 |
+
f"ambiguity. Got `func` with signature: {sig}"
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
positional = [
|
| 733 |
+
(name, param)
|
| 734 |
+
for name, param in sig.parameters.items()
|
| 735 |
+
if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
| 736 |
+
]
|
| 737 |
+
kwargonly = [
|
| 738 |
+
(name, param)
|
| 739 |
+
for name, param in sig.parameters.items()
|
| 740 |
+
if param.kind == inspect.Parameter.KEYWORD_ONLY
|
| 741 |
+
]
|
| 742 |
+
|
| 743 |
+
def error():
|
| 744 |
+
raise ValueError(
|
| 745 |
+
f"custom_op(..., manual_schema)(func): When passing in a manual "
|
| 746 |
+
f"schema, we expect `func`'s signature to match `manual_schema` "
|
| 747 |
+
f"(aside from type annotations). "
|
| 748 |
+
f"func's signature: {sig}, manual_schema: {schema}"
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
def error_default_args():
|
| 752 |
+
raise ValueError(
|
| 753 |
+
f"custom_op(..., manual_schema)(func): "
|
| 754 |
+
f"neither func nor manual_schema should have default "
|
| 755 |
+
f"arguments. Got "
|
| 756 |
+
f"func's signature: {sig}, manual_schema: {schema}"
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
def compare(sig_args, schema_args):
|
| 760 |
+
if len(sig_args) != len(schema_args):
|
| 761 |
+
error()
|
| 762 |
+
for (name, param), arg in zip(sig_args, schema_args):
|
| 763 |
+
if name != arg.name:
|
| 764 |
+
error()
|
| 765 |
+
if param.default is not inspect.Parameter.empty or arg.default is not None:
|
| 766 |
+
error_default_args()
|
| 767 |
+
|
| 768 |
+
compare(positional, schema.arguments.flat_positional)
|
| 769 |
+
compare(kwargonly, schema.arguments.flat_kwarg_only)
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
def infer_schema(prototype_function: typing.Callable) -> str:
|
| 773 |
+
sig = inspect.signature(prototype_function)
|
| 774 |
+
|
| 775 |
+
def error_fn(what):
|
| 776 |
+
raise ValueError(
|
| 777 |
+
f"custom_op(...)(func): {what} " f"Got func with signature {sig})"
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
params = [
|
| 781 |
+
parse_param(name, param, error_fn) for name, param in sig.parameters.items()
|
| 782 |
+
]
|
| 783 |
+
ret = parse_return(sig.return_annotation, error_fn)
|
| 784 |
+
return f"({', '.join(params)}) -> {ret}"
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
def parse_param(name, param, error_fn):
|
| 788 |
+
if not supported_param(param):
|
| 789 |
+
error_fn("We do not support positional-only args, varargs, or varkwargs.")
|
| 790 |
+
|
| 791 |
+
if param.annotation is inspect.Parameter.empty:
|
| 792 |
+
error_fn(f"Parameter {name} must have a type annotation.")
|
| 793 |
+
|
| 794 |
+
if param.annotation not in SUPPORTED_PARAM_TYPES.keys():
|
| 795 |
+
error_fn(
|
| 796 |
+
f"Parameter {name} has unsupported type {param.annotation}. "
|
| 797 |
+
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
|
| 798 |
+
)
|
| 799 |
+
|
| 800 |
+
if param.default is not inspect.Parameter.empty:
|
| 801 |
+
error_fn(
|
| 802 |
+
f"Parameter {name} has a default value; this is not supported. "
|
| 803 |
+
f"If you want to use default values then create a function with "
|
| 804 |
+
f"default values that calls the CustomOp"
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
return f"{SUPPORTED_PARAM_TYPES[param.annotation]} {name}"
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
def derived_types(
|
| 811 |
+
base_type, cpp_type, list_base, optional_base_list, optional_list_base
|
| 812 |
+
):
|
| 813 |
+
result = [
|
| 814 |
+
(base_type, cpp_type),
|
| 815 |
+
(typing.Optional[base_type], f"{cpp_type}?"),
|
| 816 |
+
]
|
| 817 |
+
if list_base:
|
| 818 |
+
result.append((typing.Sequence[base_type], f"{cpp_type}[]")) # type: ignore[valid-type]
|
| 819 |
+
if optional_base_list:
|
| 820 |
+
result.append((typing.Sequence[typing.Optional[base_type]], f"{cpp_type}?[]")) # type: ignore[valid-type]
|
| 821 |
+
if optional_list_base:
|
| 822 |
+
result.append((typing.Optional[typing.Sequence[base_type]], f"{cpp_type}[]?")) # type: ignore[valid-type]
|
| 823 |
+
return result
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
def get_supported_param_types():
|
| 827 |
+
data = [
|
| 828 |
+
# (python type, schema type, type[] variant, type?[] variant, type[]? variant
|
| 829 |
+
(torch.Tensor, "Tensor", True, True, False),
|
| 830 |
+
(int, "SymInt", True, False, True),
|
| 831 |
+
(float, "float", True, False, True),
|
| 832 |
+
(bool, "bool", True, False, True),
|
| 833 |
+
(str, "str", False, False, False),
|
| 834 |
+
(torch.types.Number, "Scalar", True, False, False),
|
| 835 |
+
(torch.dtype, "ScalarType", False, False, False),
|
| 836 |
+
(torch.device, "Device", False, False, False),
|
| 837 |
+
]
|
| 838 |
+
result = []
|
| 839 |
+
for line in data:
|
| 840 |
+
result.extend(derived_types(*line))
|
| 841 |
+
return dict(result)
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
SUPPORTED_RETURN_TYPES = {
|
| 845 |
+
torch.Tensor: "Tensor",
|
| 846 |
+
typing.List[torch.Tensor]: "Tensor[]",
|
| 847 |
+
int: "SymInt",
|
| 848 |
+
float: "float",
|
| 849 |
+
bool: "bool",
|
| 850 |
+
torch.types.Number: "Scalar",
|
| 851 |
+
}
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
def parse_return(annotation, error_fn):
|
| 855 |
+
origin = typing.get_origin(annotation)
|
| 856 |
+
if origin is not tuple:
|
| 857 |
+
if annotation not in SUPPORTED_RETURN_TYPES.keys():
|
| 858 |
+
error_fn(
|
| 859 |
+
f"Return has unsupported type {annotation}. "
|
| 860 |
+
f"The valid types are: {SUPPORTED_RETURN_TYPES}."
|
| 861 |
+
)
|
| 862 |
+
return SUPPORTED_RETURN_TYPES[annotation]
|
| 863 |
+
|
| 864 |
+
args = typing.get_args(annotation)
|
| 865 |
+
for arg in args:
|
| 866 |
+
if arg not in SUPPORTED_RETURN_TYPES:
|
| 867 |
+
error_fn(
|
| 868 |
+
f"Return has unsupported type {annotation}. "
|
| 869 |
+
f"The valid types are: {SUPPORTED_RETURN_TYPES}."
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")"
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
SUPPORTED_PARAM_TYPES = get_supported_param_types()
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
def report_error_callback(custom_op: typing.Any, key: str) -> None:
|
| 879 |
+
if key == "Undefined":
|
| 880 |
+
raise NotImplementedError(
|
| 881 |
+
f"{custom_op}: There were no Tensor inputs to this operator "
|
| 882 |
+
f"(e.g. you passed an empty list of Tensors). If your operator is a "
|
| 883 |
+
f"factory function (that is, it takes no Tensors and constructs "
|
| 884 |
+
f"a new one), then please use CustomOp.impl_factory to register "
|
| 885 |
+
f"an implementation for it"
|
| 886 |
+
)
|
| 887 |
+
if key == "Meta":
|
| 888 |
+
raise NotImplementedError(
|
| 889 |
+
f"{custom_op}: when running with device='Meta' tensors: there is no "
|
| 890 |
+
f"abstract impl registered for this CustomOp. Please register one via "
|
| 891 |
+
f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
|
| 892 |
+
)
|
| 893 |
+
if key in ("CPU", "CUDA"):
|
| 894 |
+
device = key.lower()
|
| 895 |
+
raise NotImplementedError(
|
| 896 |
+
f"{custom_op}: when running with device='{device}' tensors: there is no "
|
| 897 |
+
f"{device} impl registered for this CustomOp. Please register one via "
|
| 898 |
+
f"CustomOp.impl(device_type='{device}')"
|
| 899 |
+
)
|
| 900 |
+
raise NotImplementedError(
|
| 901 |
+
f"{custom_op}: No implementation for dispatch key {key}. It is likely "
|
| 902 |
+
f"that we have not added this functionality yet, please either open an "
|
| 903 |
+
f"issue or if you're feeling adventurous, use the low-level "
|
| 904 |
+
f"torch.library API"
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
def custom_op_from_existing(op):
|
| 909 |
+
ns = op.namespace
|
| 910 |
+
lib = torch.library.Library(ns, "FRAGMENT")
|
| 911 |
+
name = op.name().split("::")[-1]
|
| 912 |
+
schema_str = str(op._schema)
|
| 913 |
+
# CustomOp expects the schema string without the namespace
|
| 914 |
+
schema_str = schema_str.split("::")[-1]
|
| 915 |
+
schema = FunctionSchema.parse(schema_str)
|
| 916 |
+
return CustomOp(lib, ns, schema, name, op, _private_access=True)
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
def get_op(qualname):
|
| 920 |
+
def error_not_found():
|
| 921 |
+
raise ValueError(
|
| 922 |
+
f"Could not find the operator {qualname}. Please make sure you have "
|
| 923 |
+
f"already registered the operator and (if registered from C++) "
|
| 924 |
+
f"loaded it via torch.ops.load_library.")
|
| 925 |
+
|
| 926 |
+
ns, name = parse_qualname(qualname)
|
| 927 |
+
if not hasattr(torch.ops, ns):
|
| 928 |
+
error_not_found()
|
| 929 |
+
opnamespace = getattr(torch.ops, ns)
|
| 930 |
+
if not hasattr(opnamespace, name):
|
| 931 |
+
error_not_found()
|
| 932 |
+
packet = getattr(opnamespace, name)
|
| 933 |
+
if not hasattr(packet, 'default'):
|
| 934 |
+
error_not_found()
|
| 935 |
+
return packet.default
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
def _find_custom_op(qualname, also_check_torch_library=False):
|
| 939 |
+
if qualname in global_registry:
|
| 940 |
+
return global_registry[qualname]
|
| 941 |
+
if not also_check_torch_library:
|
| 942 |
+
raise RuntimeError(
|
| 943 |
+
f"Could not find custom op \"{qualname}\". Did you register it via "
|
| 944 |
+
f"the torch._custom_ops API?")
|
| 945 |
+
overload = get_op(qualname)
|
| 946 |
+
result = custom_op_from_existing(overload)
|
| 947 |
+
return result
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
def get_abstract_impl(qualname):
|
| 951 |
+
if qualname not in torch._custom_op.impl.global_registry:
|
| 952 |
+
return None
|
| 953 |
+
custom_op = torch._custom_op.impl.global_registry[qualname]
|
| 954 |
+
if custom_op is None:
|
| 955 |
+
return None
|
| 956 |
+
if not custom_op._has_impl("abstract"):
|
| 957 |
+
return None
|
| 958 |
+
return custom_op._get_impl("abstract").func
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True):
|
| 962 |
+
ns, name = qualname.split("::")
|
| 963 |
+
schema_str = f"{name}{schema}"
|
| 964 |
+
function_schema = FunctionSchema.parse(schema_str)
|
| 965 |
+
validate_schema(function_schema)
|
| 966 |
+
tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else []
|
| 967 |
+
lib = library.Library(ns, "FRAGMENT")
|
| 968 |
+
lib.define(schema_str, tags=tags)
|
| 969 |
+
ophandle = find_ophandle_or_throw(ns, function_schema.name)
|
| 970 |
+
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
|
| 971 |
+
result._register_autograd_kernel_indirection()
|
| 972 |
+
|
| 973 |
+
torch._C._dispatch_set_report_error_callback(
|
| 974 |
+
ophandle, functools.partial(report_error_callback, weakref.proxy(result))
|
| 975 |
+
)
|
| 976 |
+
return get_op(qualname)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/fft.py
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch._prims as prims
|
| 7 |
+
import torch._prims_common as utils
|
| 8 |
+
from torch._decomp import register_decomposition
|
| 9 |
+
from torch._prims_common import DimsType, ShapeType, TensorLikeType
|
| 10 |
+
from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
# Transforms
|
| 14 |
+
"fft",
|
| 15 |
+
"fft2",
|
| 16 |
+
"fftn",
|
| 17 |
+
"hfft",
|
| 18 |
+
"hfft2",
|
| 19 |
+
"hfftn",
|
| 20 |
+
"rfft",
|
| 21 |
+
"rfft2",
|
| 22 |
+
"rfftn",
|
| 23 |
+
"ifft",
|
| 24 |
+
"ifft2",
|
| 25 |
+
"ifftn",
|
| 26 |
+
"ihfft",
|
| 27 |
+
"ihfft2",
|
| 28 |
+
"ihfftn",
|
| 29 |
+
"irfft",
|
| 30 |
+
"irfft2",
|
| 31 |
+
"irfftn",
|
| 32 |
+
# Helpers
|
| 33 |
+
"fftshift",
|
| 34 |
+
"ifftshift",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
NormType = Union[None, Literal["forward", "backward", "ortho"]]
|
| 38 |
+
_NORM_VALUES = {None, "forward", "backward", "ortho"}
|
| 39 |
+
aten = torch._ops.ops.aten
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _apply_norm(
|
| 43 |
+
x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
|
| 44 |
+
) -> TensorLikeType:
|
| 45 |
+
"""Apply normalization to the un-normalized FFT result"""
|
| 46 |
+
torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
|
| 47 |
+
|
| 48 |
+
if norm == "ortho":
|
| 49 |
+
return x * (1 / math.sqrt(signal_numel))
|
| 50 |
+
|
| 51 |
+
normalize = (not forward and (norm is None or norm == "backward")) or (
|
| 52 |
+
forward and norm == "forward"
|
| 53 |
+
)
|
| 54 |
+
return x * (1 / signal_numel) if normalize else x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _promote_type_fft(
|
| 58 |
+
dtype: torch.dtype, require_complex: bool, device: torch.device
|
| 59 |
+
) -> torch.dtype:
|
| 60 |
+
"""Helper to promote a dtype to one supported by the FFT primitives"""
|
| 61 |
+
if dtype.is_complex:
|
| 62 |
+
return dtype
|
| 63 |
+
|
| 64 |
+
# Promote integral to default float type
|
| 65 |
+
if not dtype.is_floating_point:
|
| 66 |
+
dtype = torch.get_default_dtype()
|
| 67 |
+
|
| 68 |
+
allowed_types = [torch.float32, torch.float64]
|
| 69 |
+
maybe_support_half = device.type in ["cuda", "meta"]
|
| 70 |
+
|
| 71 |
+
if maybe_support_half:
|
| 72 |
+
allowed_types.append(torch.float16)
|
| 73 |
+
torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}")
|
| 74 |
+
|
| 75 |
+
if require_complex:
|
| 76 |
+
dtype = utils.corresponding_complex_dtype(dtype)
|
| 77 |
+
|
| 78 |
+
return dtype
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _maybe_promote_tensor_fft(
|
| 82 |
+
t: TensorLikeType, require_complex: bool = False
|
| 83 |
+
) -> TensorLikeType:
|
| 84 |
+
"""Helper to promote a tensor to a dtype supported by the FFT primitives"""
|
| 85 |
+
cur_type = t.dtype
|
| 86 |
+
new_type = _promote_type_fft(cur_type, require_complex, t.device)
|
| 87 |
+
return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _resize_fft_input(
|
| 91 |
+
x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...]
|
| 92 |
+
) -> TensorLikeType:
|
| 93 |
+
"""
|
| 94 |
+
Fixes the shape of x such that x.size(dims[i]) == sizes[i],
|
| 95 |
+
either by zero-padding, or by slicing x starting from 0.
|
| 96 |
+
"""
|
| 97 |
+
assert len(dims) == len(sizes)
|
| 98 |
+
must_copy = False
|
| 99 |
+
x_sizes = x.shape
|
| 100 |
+
pad_amount = [0] * len(x_sizes) * 2
|
| 101 |
+
for i in range(len(dims)):
|
| 102 |
+
if sizes[i] == -1:
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
if x_sizes[dims[i]] < sizes[i]:
|
| 106 |
+
must_copy = True
|
| 107 |
+
pad_idx = len(pad_amount) - 2 * dims[i] - 1
|
| 108 |
+
pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]
|
| 109 |
+
|
| 110 |
+
if x_sizes[dims[i]] > sizes[i]:
|
| 111 |
+
x = x.narrow(dims[i], 0, sizes[i])
|
| 112 |
+
|
| 113 |
+
return torch.constant_pad_nd(x, pad_amount) if must_copy else x
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _fft_c2r(
|
| 117 |
+
func_name: str,
|
| 118 |
+
input: TensorLikeType,
|
| 119 |
+
n: Optional[int],
|
| 120 |
+
dim: int,
|
| 121 |
+
norm: NormType,
|
| 122 |
+
forward: bool,
|
| 123 |
+
) -> TensorLikeType:
|
| 124 |
+
"""Common code for performing any complex to real FFT (irfft or hfft)"""
|
| 125 |
+
input = _maybe_promote_tensor_fft(input, require_complex=True)
|
| 126 |
+
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
|
| 127 |
+
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
|
| 128 |
+
torch._check(
|
| 129 |
+
last_dim_size >= 1,
|
| 130 |
+
lambda: f"Invalid number of data points ({last_dim_size}) specified",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
if n is not None:
|
| 134 |
+
input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
|
| 135 |
+
|
| 136 |
+
if forward:
|
| 137 |
+
input = torch.conj(input)
|
| 138 |
+
|
| 139 |
+
output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size)
|
| 140 |
+
return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _fft_r2c(
|
| 144 |
+
func_name: str,
|
| 145 |
+
input: TensorLikeType,
|
| 146 |
+
n: Optional[int],
|
| 147 |
+
dim: int,
|
| 148 |
+
norm: NormType,
|
| 149 |
+
forward: bool,
|
| 150 |
+
onesided: bool,
|
| 151 |
+
) -> TensorLikeType:
|
| 152 |
+
"""Common code for performing any real to complex FFT (rfft or ihfft)"""
|
| 153 |
+
torch._check(
|
| 154 |
+
not input.dtype.is_complex,
|
| 155 |
+
lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
|
| 156 |
+
)
|
| 157 |
+
input = _maybe_promote_tensor_fft(input)
|
| 158 |
+
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
|
| 159 |
+
dim_size = n if n is not None else input.shape[dim]
|
| 160 |
+
torch._check(
|
| 161 |
+
dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
if n is not None:
|
| 165 |
+
input = _resize_fft_input(input, dims, (n,))
|
| 166 |
+
|
| 167 |
+
ret = prims.fft_r2c(input, dim=dims, onesided=onesided)
|
| 168 |
+
ret = _apply_norm(ret, norm, dim_size, forward)
|
| 169 |
+
return ret if forward else torch.conj(ret)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _fft_c2c(
|
| 173 |
+
func_name: str,
|
| 174 |
+
input: TensorLikeType,
|
| 175 |
+
n: Optional[int],
|
| 176 |
+
dim: int,
|
| 177 |
+
norm: NormType,
|
| 178 |
+
forward: bool,
|
| 179 |
+
) -> TensorLikeType:
|
| 180 |
+
"""Common code for performing any complex to complex FFT (fft or ifft)"""
|
| 181 |
+
torch._check(
|
| 182 |
+
input.dtype.is_complex,
|
| 183 |
+
lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
|
| 184 |
+
)
|
| 185 |
+
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
|
| 186 |
+
dim_size = n if n is not None else input.shape[dim]
|
| 187 |
+
torch._check(
|
| 188 |
+
dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
if n is not None:
|
| 192 |
+
input = _resize_fft_input(input, dims, (n,))
|
| 193 |
+
|
| 194 |
+
ret = prims.fft_c2c(input, dim=dims, forward=forward)
|
| 195 |
+
return _apply_norm(ret, norm, dim_size, forward)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@register_decomposition(aten.fft_fft)
|
| 199 |
+
@out_wrapper()
|
| 200 |
+
def fft(
|
| 201 |
+
input: TensorLikeType,
|
| 202 |
+
n: Optional[int] = None,
|
| 203 |
+
dim: int = -1,
|
| 204 |
+
norm: NormType = None,
|
| 205 |
+
) -> TensorLikeType:
|
| 206 |
+
if input.dtype.is_complex:
|
| 207 |
+
return _fft_c2c("fft", input, n, dim, norm, forward=True)
|
| 208 |
+
else:
|
| 209 |
+
return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@register_decomposition(aten.fft_ifft)
|
| 213 |
+
@out_wrapper()
|
| 214 |
+
def ifft(
|
| 215 |
+
input: TensorLikeType,
|
| 216 |
+
n: Optional[int] = None,
|
| 217 |
+
dim: int = -1,
|
| 218 |
+
norm: NormType = None,
|
| 219 |
+
) -> TensorLikeType:
|
| 220 |
+
if input.dtype.is_complex:
|
| 221 |
+
return _fft_c2c("ifft", input, n, dim, norm, forward=False)
|
| 222 |
+
else:
|
| 223 |
+
return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
@register_decomposition(aten.fft_rfft)
|
| 227 |
+
@out_wrapper()
|
| 228 |
+
def rfft(
|
| 229 |
+
input: TensorLikeType,
|
| 230 |
+
n: Optional[int] = None,
|
| 231 |
+
dim: int = -1,
|
| 232 |
+
norm: NormType = None,
|
| 233 |
+
) -> TensorLikeType:
|
| 234 |
+
return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
@register_decomposition(aten.fft_irfft)
|
| 238 |
+
@out_wrapper()
|
| 239 |
+
def irfft(
|
| 240 |
+
input: TensorLikeType,
|
| 241 |
+
n: Optional[int] = None,
|
| 242 |
+
dim: int = -1,
|
| 243 |
+
norm: NormType = None,
|
| 244 |
+
) -> TensorLikeType:
|
| 245 |
+
return _fft_c2r("irfft", input, n, dim, norm, forward=False)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@register_decomposition(aten.fft_hfft)
|
| 249 |
+
@out_wrapper()
|
| 250 |
+
def hfft(
|
| 251 |
+
input: TensorLikeType,
|
| 252 |
+
n: Optional[int] = None,
|
| 253 |
+
dim: int = -1,
|
| 254 |
+
norm: NormType = None,
|
| 255 |
+
) -> TensorLikeType:
|
| 256 |
+
return _fft_c2r("hfft", input, n, dim, norm, forward=True)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
@register_decomposition(aten.fft_ihfft)
|
| 260 |
+
@out_wrapper()
|
| 261 |
+
def ihfft(
|
| 262 |
+
input: TensorLikeType,
|
| 263 |
+
n: Optional[int] = None,
|
| 264 |
+
dim: int = -1,
|
| 265 |
+
norm: NormType = None,
|
| 266 |
+
) -> TensorLikeType:
|
| 267 |
+
return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class _ShapeAndDims(NamedTuple):
|
| 271 |
+
shape: Tuple[int, ...]
|
| 272 |
+
dims: Tuple[int, ...]
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _canonicalize_fft_shape_and_dim_args(
|
| 276 |
+
input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType]
|
| 277 |
+
) -> _ShapeAndDims:
|
| 278 |
+
"""Convert the shape and dim arguments into a canonical form where neither are optional"""
|
| 279 |
+
input_dim = input.ndim
|
| 280 |
+
input_sizes = input.shape
|
| 281 |
+
|
| 282 |
+
if dim is not None:
|
| 283 |
+
if not isinstance(dim, Sequence):
|
| 284 |
+
dim = (dim,)
|
| 285 |
+
ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
|
| 286 |
+
|
| 287 |
+
# Check dims are unique
|
| 288 |
+
torch._check(
|
| 289 |
+
len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique"
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
if shape is not None:
|
| 293 |
+
if not isinstance(shape, Sequence):
|
| 294 |
+
shape = (shape,)
|
| 295 |
+
|
| 296 |
+
# Has shape, might have dim
|
| 297 |
+
torch._check(
|
| 298 |
+
dim is None or len(dim) == len(shape),
|
| 299 |
+
lambda: "When given, dim and shape arguments must have the same length",
|
| 300 |
+
)
|
| 301 |
+
transform_ndim = len(shape)
|
| 302 |
+
|
| 303 |
+
torch._check(
|
| 304 |
+
transform_ndim <= input_dim,
|
| 305 |
+
lambda: f"Got shape with {transform_ndim} values but input tensor "
|
| 306 |
+
f"only has {input_dim} dimensions.",
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# If shape is given, dims defaults to the last len(shape) dimensions
|
| 310 |
+
if dim is None:
|
| 311 |
+
ret_dims = tuple(range(input_dim - transform_ndim, input_dim))
|
| 312 |
+
|
| 313 |
+
# Translate any -1 values in shape to the default length
|
| 314 |
+
ret_shape = tuple(
|
| 315 |
+
s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined]
|
| 316 |
+
)
|
| 317 |
+
elif dim is None:
|
| 318 |
+
# No shape, no dim
|
| 319 |
+
ret_dims = tuple(range(input_dim))
|
| 320 |
+
ret_shape = tuple(input_sizes)
|
| 321 |
+
else:
|
| 322 |
+
# No shape, has dim
|
| 323 |
+
ret_shape = tuple(input_sizes[d] for d in ret_dims) # type: ignore[possibly-undefined]
|
| 324 |
+
|
| 325 |
+
for n in ret_shape:
|
| 326 |
+
torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
|
| 327 |
+
|
| 328 |
+
return _ShapeAndDims(shape=ret_shape, dims=ret_dims) # type: ignore[possibly-undefined]
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def _prod(xs: Iterable[int]) -> int:
|
| 332 |
+
"""Compute product of a list"""
|
| 333 |
+
prod = 1
|
| 334 |
+
for x in xs:
|
| 335 |
+
prod *= x
|
| 336 |
+
return prod
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def _fftn_c2c(
|
| 340 |
+
function_name: str,
|
| 341 |
+
input: TensorLikeType,
|
| 342 |
+
shape: Tuple[int, ...],
|
| 343 |
+
dim: Tuple[int, ...],
|
| 344 |
+
norm: NormType,
|
| 345 |
+
forward: bool,
|
| 346 |
+
) -> TensorLikeType:
|
| 347 |
+
"""Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
|
| 348 |
+
torch._check(
|
| 349 |
+
input.dtype.is_complex,
|
| 350 |
+
lambda: f"{function_name} expects a complex input tensor, "
|
| 351 |
+
f"but got {input.dtype}",
|
| 352 |
+
)
|
| 353 |
+
x = _resize_fft_input(input, dim, shape)
|
| 354 |
+
output = prims.fft_c2c(x, dim=dim, forward=forward)
|
| 355 |
+
return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
@register_decomposition(aten.fft_fftn)
|
| 359 |
+
@out_wrapper()
|
| 360 |
+
def fftn(
|
| 361 |
+
input: TensorLikeType,
|
| 362 |
+
s: Optional[ShapeType] = None,
|
| 363 |
+
dim: Optional[DimsType] = None,
|
| 364 |
+
norm: NormType = None,
|
| 365 |
+
) -> TensorLikeType:
|
| 366 |
+
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
| 367 |
+
x = _maybe_promote_tensor_fft(input, require_complex=True)
|
| 368 |
+
return _fftn_c2c("fftn", x, shape, dim, norm, forward=True)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@register_decomposition(aten.fft_ifftn)
|
| 372 |
+
@out_wrapper()
|
| 373 |
+
def ifftn(
|
| 374 |
+
input: TensorLikeType,
|
| 375 |
+
s: Optional[ShapeType] = None,
|
| 376 |
+
dim: Optional[DimsType] = None,
|
| 377 |
+
norm: NormType = None,
|
| 378 |
+
) -> TensorLikeType:
|
| 379 |
+
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
| 380 |
+
x = _maybe_promote_tensor_fft(input, require_complex=True)
|
| 381 |
+
return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
@register_decomposition(aten.fft_rfftn)
|
| 385 |
+
@out_wrapper()
|
| 386 |
+
def rfftn(
|
| 387 |
+
input: TensorLikeType,
|
| 388 |
+
s: Optional[ShapeType] = None,
|
| 389 |
+
dim: Optional[DimsType] = None,
|
| 390 |
+
norm: NormType = None,
|
| 391 |
+
) -> TensorLikeType:
|
| 392 |
+
torch._check(
|
| 393 |
+
not input.dtype.is_complex,
|
| 394 |
+
lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
|
| 395 |
+
)
|
| 396 |
+
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
| 397 |
+
input = _maybe_promote_tensor_fft(input, require_complex=False)
|
| 398 |
+
input = _resize_fft_input(input, dim, shape)
|
| 399 |
+
out = prims.fft_r2c(input, dim=dim, onesided=True)
|
| 400 |
+
return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
@register_decomposition(aten.fft_ihfftn)
|
| 404 |
+
@out_wrapper()
|
| 405 |
+
def ihfftn(
|
| 406 |
+
input: TensorLikeType,
|
| 407 |
+
s: Optional[ShapeType] = None,
|
| 408 |
+
dim: Optional[DimsType] = None,
|
| 409 |
+
norm: NormType = None,
|
| 410 |
+
) -> TensorLikeType:
|
| 411 |
+
torch._check(
|
| 412 |
+
not input.dtype.is_complex,
|
| 413 |
+
lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
|
| 414 |
+
)
|
| 415 |
+
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
| 416 |
+
torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
|
| 417 |
+
input = _maybe_promote_tensor_fft(input, require_complex=False)
|
| 418 |
+
input = _resize_fft_input(input, dim, shape)
|
| 419 |
+
|
| 420 |
+
tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True)
|
| 421 |
+
|
| 422 |
+
if len(dim) == 1:
|
| 423 |
+
tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False)
|
| 424 |
+
return prims.conj(tmp)
|
| 425 |
+
|
| 426 |
+
tmp = prims.conj_physical(tmp)
|
| 427 |
+
tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False)
|
| 428 |
+
return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class _CanonicalizeC2rReturn(NamedTuple):
|
| 432 |
+
shape: Tuple[int, ...]
|
| 433 |
+
dim: Tuple[int, ...]
|
| 434 |
+
last_dim_size: int
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def _canonicalize_fft_c2r_shape_and_dim_args(
|
| 438 |
+
fname: str,
|
| 439 |
+
input: TensorLikeType,
|
| 440 |
+
s: Optional[ShapeType],
|
| 441 |
+
dim: Optional[DimsType],
|
| 442 |
+
) -> _CanonicalizeC2rReturn:
|
| 443 |
+
"""Canonicalize shape and dim arguments for n-dimensional c2r transforms,
|
| 444 |
+
as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
|
| 445 |
+
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
| 446 |
+
torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
|
| 447 |
+
|
| 448 |
+
if s is None or s[-1] == -1:
|
| 449 |
+
last_dim_size = 2 * (input.shape[dim[-1]] - 1)
|
| 450 |
+
else:
|
| 451 |
+
last_dim_size = shape[-1]
|
| 452 |
+
|
| 453 |
+
torch._check(
|
| 454 |
+
last_dim_size >= 1,
|
| 455 |
+
lambda: f"Invalid number of data points ({last_dim_size}) specified",
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
shape_list = list(shape)
|
| 459 |
+
shape_list[-1] = last_dim_size // 2 + 1
|
| 460 |
+
return _CanonicalizeC2rReturn(
|
| 461 |
+
shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
@register_decomposition(aten.fft_irfftn)
|
| 466 |
+
@out_wrapper()
|
| 467 |
+
def irfftn(
|
| 468 |
+
input: TensorLikeType,
|
| 469 |
+
s: Optional[ShapeType] = None,
|
| 470 |
+
dim: Optional[DimsType] = None,
|
| 471 |
+
norm: NormType = None,
|
| 472 |
+
) -> TensorLikeType:
|
| 473 |
+
shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
|
| 474 |
+
"irfftn", input, s, dim
|
| 475 |
+
)
|
| 476 |
+
input = _maybe_promote_tensor_fft(input, require_complex=True)
|
| 477 |
+
input = _resize_fft_input(input, dim, shape)
|
| 478 |
+
out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size)
|
| 479 |
+
return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
@register_decomposition(aten.fft_hfftn)
|
| 483 |
+
@out_wrapper()
|
| 484 |
+
def hfftn(
|
| 485 |
+
input: TensorLikeType,
|
| 486 |
+
s: Optional[ShapeType] = None,
|
| 487 |
+
dim: Optional[DimsType] = None,
|
| 488 |
+
norm: NormType = None,
|
| 489 |
+
) -> TensorLikeType:
|
| 490 |
+
shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
|
| 491 |
+
"hfftn", input, s, dim
|
| 492 |
+
)
|
| 493 |
+
input = _maybe_promote_tensor_fft(input, require_complex=True)
|
| 494 |
+
input = _resize_fft_input(input, dim, shape)
|
| 495 |
+
|
| 496 |
+
tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input
|
| 497 |
+
tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True)
|
| 498 |
+
tmp = prims.conj_physical(tmp)
|
| 499 |
+
out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size)
|
| 500 |
+
return _apply_norm(out, norm, last_dim_size, forward=True)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
@register_decomposition(aten.fft_fft2)
|
| 504 |
+
@out_wrapper()
|
| 505 |
+
def fft2(
|
| 506 |
+
input: TensorLikeType,
|
| 507 |
+
s: Optional[ShapeType] = None,
|
| 508 |
+
dim: Optional[DimsType] = (-2, -1),
|
| 509 |
+
norm: NormType = None,
|
| 510 |
+
) -> TensorLikeType:
|
| 511 |
+
return torch.fft.fftn(input, s=s, dim=dim, norm=norm)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
@register_decomposition(aten.fft_ifft2)
|
| 515 |
+
@out_wrapper()
|
| 516 |
+
def ifft2(
|
| 517 |
+
input: TensorLikeType,
|
| 518 |
+
s: Optional[ShapeType] = None,
|
| 519 |
+
dim: Optional[DimsType] = (-2, -1),
|
| 520 |
+
norm: NormType = None,
|
| 521 |
+
) -> TensorLikeType:
|
| 522 |
+
return torch.fft.ifftn(input, s=s, dim=dim, norm=norm)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
@register_decomposition(aten.fft_rfft2)
|
| 526 |
+
@out_wrapper()
|
| 527 |
+
def rfft2(
|
| 528 |
+
input: TensorLikeType,
|
| 529 |
+
s: Optional[ShapeType] = None,
|
| 530 |
+
dim: Optional[DimsType] = (-2, -1),
|
| 531 |
+
norm: NormType = None,
|
| 532 |
+
) -> TensorLikeType:
|
| 533 |
+
return torch.fft.rfftn(input, s=s, dim=dim, norm=norm)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
@register_decomposition(aten.fft_irfft2)
|
| 537 |
+
@out_wrapper()
|
| 538 |
+
def irfft2(
|
| 539 |
+
input: TensorLikeType,
|
| 540 |
+
s: Optional[ShapeType] = None,
|
| 541 |
+
dim: Optional[DimsType] = (-2, -1),
|
| 542 |
+
norm: NormType = None,
|
| 543 |
+
) -> TensorLikeType:
|
| 544 |
+
return torch.fft.irfftn(input, s=s, dim=dim, norm=norm)
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
@register_decomposition(aten.fft_hfft2)
|
| 548 |
+
@out_wrapper()
|
| 549 |
+
def hfft2(
|
| 550 |
+
input: TensorLikeType,
|
| 551 |
+
s: Optional[ShapeType] = None,
|
| 552 |
+
dim: Optional[DimsType] = (-2, -1),
|
| 553 |
+
norm: NormType = None,
|
| 554 |
+
) -> TensorLikeType:
|
| 555 |
+
return torch.fft.hfftn(input, s=s, dim=dim, norm=norm)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
@register_decomposition(aten.fft_ihfft2)
|
| 559 |
+
@out_wrapper()
|
| 560 |
+
def ihfft2(
|
| 561 |
+
input: TensorLikeType,
|
| 562 |
+
s: Optional[ShapeType] = None,
|
| 563 |
+
dim: Optional[DimsType] = (-2, -1),
|
| 564 |
+
norm: NormType = None,
|
| 565 |
+
) -> TensorLikeType:
|
| 566 |
+
return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm)
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]:
|
| 570 |
+
"""Convert Optional[DimsType] to a simple list, defaulting to all dimensions"""
|
| 571 |
+
if dim is None:
|
| 572 |
+
return list(range(x.ndim))
|
| 573 |
+
elif not isinstance(dim, Sequence):
|
| 574 |
+
return [dim]
|
| 575 |
+
else:
|
| 576 |
+
return list(dim)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
@register_decomposition(aten.fft_fftshift)
|
| 580 |
+
def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
|
| 581 |
+
dims = _default_alldims(dim, input)
|
| 582 |
+
shift = [input.shape[d] // 2 for d in dims]
|
| 583 |
+
return torch.roll(input, shift, dims)
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
@register_decomposition(aten.fft_ifftshift)
|
| 587 |
+
def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
|
| 588 |
+
dims = _default_alldims(dim, input)
|
| 589 |
+
shift = [(input.shape[d] + 1) // 2 for d in dims]
|
| 590 |
+
return torch.roll(input, shift, dims)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (706 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (367 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .modules import * # noqa: F403
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (6.64 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .modules import * # noqa: F403
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
'Linear',
|
| 5 |
+
'Conv1d',
|
| 6 |
+
'Conv2d',
|
| 7 |
+
'Conv3d',
|
| 8 |
+
'ConvTranspose1d',
|
| 9 |
+
'ConvTranspose2d',
|
| 10 |
+
'ConvTranspose3d',
|
| 11 |
+
'RNNCell',
|
| 12 |
+
'LSTMCell',
|
| 13 |
+
'GRUCell',
|
| 14 |
+
'LSTM',
|
| 15 |
+
'GRU',
|
| 16 |
+
'Embedding',
|
| 17 |
+
'EmbeddingBag',
|
| 18 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (814 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/utils.py
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import enum
|
| 2 |
+
import operator
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.ao.nn.intrinsic.quantized as nniq
|
| 7 |
+
import torch.ao.nn.quantized as nnq
|
| 8 |
+
|
| 9 |
+
toq = torch.ops.quantized
|
| 10 |
+
from typing import Tuple, Callable, Dict, Set, List, Optional, Union
|
| 11 |
+
|
| 12 |
+
from torch.fx import GraphModule
|
| 13 |
+
from torch.fx.graph import Node
|
| 14 |
+
from torch.ao.quantization import (
|
| 15 |
+
ObserverBase,
|
| 16 |
+
FakeQuantizeBase,
|
| 17 |
+
)
|
| 18 |
+
from torch.ao.quantization.utils import getattr_from_fqn
|
| 19 |
+
from torch.ao.quantization.observer import _is_activation_post_process
|
| 20 |
+
|
| 21 |
+
from .ns_types import NSNodeTargetType, NSResultsType
|
| 22 |
+
|
| 23 |
+
# TODO(future PR): consider deleting this enum and using the torch types
|
| 24 |
+
# directly. This might be tricky because it is not a one to one mapping.
|
| 25 |
+
class NodeInputOrOutputType(enum.Enum):
|
| 26 |
+
FP32 = enum.auto() # torch.float
|
| 27 |
+
INT8 = enum.auto() # torch.qint8 or torch.quint8
|
| 28 |
+
FP16 = enum.auto() # torch.float16
|
| 29 |
+
UNKNOWN = enum.auto() # we cannot determine input/output dtype
|
| 30 |
+
# TODO(future PR): while these functions can support multiple dtypes,
|
| 31 |
+
# for the purposes of numerical debugging we want to get the actual
|
| 32 |
+
# dtype used in the model. We will likely need some kind of dtype
|
| 33 |
+
# propagation to estimate this.
|
| 34 |
+
FP32_OR_INT8 = enum.auto() # either torch.float or torch.quint8 or torch.qint8
|
| 35 |
+
# TODO(future PRs): dynamic quant, fake quant, etc
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_node_first_input_and_output_type(
|
| 39 |
+
node: Node,
|
| 40 |
+
gm: GraphModule,
|
| 41 |
+
logger_cls: Callable,
|
| 42 |
+
node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
|
| 43 |
+
) -> Tuple[NodeInputOrOutputType, NodeInputOrOutputType]:
|
| 44 |
+
|
| 45 |
+
# TODO(future PR): clean this up
|
| 46 |
+
FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"]
|
| 47 |
+
FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"]
|
| 48 |
+
FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"]
|
| 49 |
+
FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"]
|
| 50 |
+
MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"]
|
| 51 |
+
MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"]
|
| 52 |
+
MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
|
| 53 |
+
METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"]
|
| 54 |
+
|
| 55 |
+
if node.op == "call_function":
|
| 56 |
+
if node.target in FUNS_IO_TYPE_FP32:
|
| 57 |
+
return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
|
| 58 |
+
if node.target in FUNS_IO_TYPE_FP16:
|
| 59 |
+
return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16)
|
| 60 |
+
elif node.target in FUNS_IO_TYPE_INT8:
|
| 61 |
+
return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
|
| 62 |
+
elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
|
| 63 |
+
first_arg = get_normalized_nth_input(node, gm, 0)
|
| 64 |
+
assert isinstance(first_arg, Node)
|
| 65 |
+
(
|
| 66 |
+
_prev_node_input_type,
|
| 67 |
+
prev_node_output_type,
|
| 68 |
+
) = get_node_first_input_and_output_type(
|
| 69 |
+
first_arg, gm, logger_cls, node_type_to_io_type_map
|
| 70 |
+
)
|
| 71 |
+
return (prev_node_output_type, prev_node_output_type)
|
| 72 |
+
else:
|
| 73 |
+
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
|
| 74 |
+
|
| 75 |
+
elif node.op == "call_module":
|
| 76 |
+
assert node.op == "call_module"
|
| 77 |
+
assert isinstance(node.target, str)
|
| 78 |
+
mod = getattr_from_fqn(gm, node.target)
|
| 79 |
+
is_known_fp32_or_int8_input_module = any(
|
| 80 |
+
isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type]
|
| 81 |
+
)
|
| 82 |
+
if (
|
| 83 |
+
isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase)) # type: ignore[arg-type]
|
| 84 |
+
or is_known_fp32_or_int8_input_module
|
| 85 |
+
):
|
| 86 |
+
# A logger or observer's input and output type is the output
|
| 87 |
+
# type of the preceding node.
|
| 88 |
+
first_arg = get_normalized_nth_input(node, gm, 0)
|
| 89 |
+
assert isinstance(first_arg, Node)
|
| 90 |
+
(
|
| 91 |
+
_prev_node_input_type,
|
| 92 |
+
prev_node_output_type,
|
| 93 |
+
) = get_node_first_input_and_output_type(
|
| 94 |
+
first_arg, gm, logger_cls, node_type_to_io_type_map
|
| 95 |
+
)
|
| 96 |
+
return (prev_node_output_type, prev_node_output_type)
|
| 97 |
+
is_known_fp32_input_module = any(
|
| 98 |
+
isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32 # type: ignore[arg-type]
|
| 99 |
+
)
|
| 100 |
+
is_known_int8_input_module = any(
|
| 101 |
+
isinstance(mod, target_type) for target_type in MODS_IO_TYPE_INT8 # type: ignore[arg-type]
|
| 102 |
+
)
|
| 103 |
+
if is_known_fp32_input_module:
|
| 104 |
+
return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
|
| 105 |
+
elif is_known_int8_input_module:
|
| 106 |
+
return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
|
| 107 |
+
else:
|
| 108 |
+
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
|
| 109 |
+
|
| 110 |
+
elif node.op == "call_method":
|
| 111 |
+
if node.target == "dequantize":
|
| 112 |
+
# Dequantize is a special node because it allows multiple input types.
|
| 113 |
+
# So, we look up the output type of the previous node and return that
|
| 114 |
+
# as the input type of this node instance.
|
| 115 |
+
prev_node = get_normalized_nth_input(node, gm, 0)
|
| 116 |
+
assert isinstance(prev_node, Node)
|
| 117 |
+
(
|
| 118 |
+
_prev_node_input_type,
|
| 119 |
+
prev_node_output_type,
|
| 120 |
+
) = get_node_first_input_and_output_type(
|
| 121 |
+
prev_node, gm, logger_cls, node_type_to_io_type_map
|
| 122 |
+
)
|
| 123 |
+
return (prev_node_output_type, NodeInputOrOutputType.FP32)
|
| 124 |
+
|
| 125 |
+
elif node.target == "to":
|
| 126 |
+
# to is a special node because it allows multiple input types.
|
| 127 |
+
# So, we look up the output type of the previous node and return that
|
| 128 |
+
# as the input type of this node instance. We also look up the target
|
| 129 |
+
# of to and return the correct output type.
|
| 130 |
+
prev_node = get_normalized_nth_input(node, gm, 0)
|
| 131 |
+
assert isinstance(prev_node, Node)
|
| 132 |
+
(
|
| 133 |
+
_prev_node_input_type,
|
| 134 |
+
prev_node_output_type,
|
| 135 |
+
) = get_node_first_input_and_output_type(
|
| 136 |
+
prev_node, gm, logger_cls, node_type_to_io_type_map
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
cur_node_dtype_target = get_normalized_nth_input(node, gm, 1)
|
| 140 |
+
assert (
|
| 141 |
+
cur_node_dtype_target is torch.float16
|
| 142 |
+
), f"{cur_node_dtype_target} handling needs to be added"
|
| 143 |
+
|
| 144 |
+
return (prev_node_output_type, NodeInputOrOutputType.FP16)
|
| 145 |
+
|
| 146 |
+
elif node.target in METHS_IO_TYPE_FP32_OR_INT8:
|
| 147 |
+
first_arg = get_normalized_nth_input(node, gm, 0)
|
| 148 |
+
assert isinstance(first_arg, Node)
|
| 149 |
+
(
|
| 150 |
+
_prev_node_input_type,
|
| 151 |
+
prev_node_output_type,
|
| 152 |
+
) = get_node_first_input_and_output_type(
|
| 153 |
+
first_arg, gm, logger_cls, node_type_to_io_type_map
|
| 154 |
+
)
|
| 155 |
+
return (prev_node_output_type, prev_node_output_type)
|
| 156 |
+
|
| 157 |
+
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
|
| 158 |
+
else:
|
| 159 |
+
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def get_node_input_qparams(
|
| 163 |
+
node: Node,
|
| 164 |
+
gm: GraphModule,
|
| 165 |
+
node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
|
| 166 |
+
) -> Optional[Tuple[Union[torch.Tensor, float], Union[torch.Tensor, int]]]:
|
| 167 |
+
"""
|
| 168 |
+
Returns the qparams (scale, zero_point) of the first input to `node`,
|
| 169 |
+
if they can be inferred from the graph.
|
| 170 |
+
"""
|
| 171 |
+
prev_node = get_normalized_nth_input(node, gm, 0)
|
| 172 |
+
|
| 173 |
+
if not isinstance(prev_node, Node):
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
|
| 177 |
+
|
| 178 |
+
def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx):
|
| 179 |
+
scale_node = get_normalized_nth_input(node, gm, scale_arg_idx)
|
| 180 |
+
zp_node = get_normalized_nth_input(node, gm, zp_arg_idx)
|
| 181 |
+
assert isinstance(scale_node, Node) and isinstance(scale_node.target, str)
|
| 182 |
+
assert isinstance(zp_node, Node) and isinstance(zp_node.target, str)
|
| 183 |
+
scale_obj = getattr_from_fqn(gm, scale_node.target)
|
| 184 |
+
zp_obj = getattr_from_fqn(gm, zp_node.target)
|
| 185 |
+
return (scale_obj, zp_obj)
|
| 186 |
+
|
| 187 |
+
if prev_node.op == "call_function":
|
| 188 |
+
|
| 189 |
+
# quantize - read the args directly
|
| 190 |
+
if prev_node.target == torch.quantize_per_tensor:
|
| 191 |
+
return _get_scale_zp_from_function_args(prev_node, gm, 1, 2)
|
| 192 |
+
elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu):
|
| 193 |
+
return _get_scale_zp_from_function_args(prev_node, gm, 2, 3)
|
| 194 |
+
|
| 195 |
+
return None
|
| 196 |
+
# TODO(future PR): handle more functionals
|
| 197 |
+
# TODO(future PR): handle functional ops which inherit qparams from input
|
| 198 |
+
|
| 199 |
+
elif prev_node.op == "call_module":
|
| 200 |
+
|
| 201 |
+
# get type of the module
|
| 202 |
+
assert isinstance(prev_node.target, str)
|
| 203 |
+
module_obj = getattr_from_fqn(gm, prev_node.target)
|
| 204 |
+
if isinstance(
|
| 205 |
+
module_obj,
|
| 206 |
+
(
|
| 207 |
+
nnq.Linear,
|
| 208 |
+
nnq.Conv1d,
|
| 209 |
+
nnq.Conv2d,
|
| 210 |
+
nniq.ConvReLU2d,
|
| 211 |
+
nnq.Conv3d,
|
| 212 |
+
nnq.BatchNorm2d,
|
| 213 |
+
nnq.BatchNorm3d,
|
| 214 |
+
nnq.ConvTranspose1d,
|
| 215 |
+
nnq.ConvTranspose2d,
|
| 216 |
+
nnq.ELU,
|
| 217 |
+
nnq.GroupNorm,
|
| 218 |
+
nnq.InstanceNorm1d,
|
| 219 |
+
nnq.InstanceNorm2d,
|
| 220 |
+
nnq.InstanceNorm3d,
|
| 221 |
+
nnq.LayerNorm,
|
| 222 |
+
nnq.Hardswish,
|
| 223 |
+
nnq.LeakyReLU,
|
| 224 |
+
nnq.ReLU6,
|
| 225 |
+
nniq.BNReLU2d,
|
| 226 |
+
nniq.BNReLU3d,
|
| 227 |
+
nniq.ConvReLU1d,
|
| 228 |
+
nniq.ConvReLU2d,
|
| 229 |
+
nniq.ConvReLU3d,
|
| 230 |
+
nniq.LinearReLU,
|
| 231 |
+
),
|
| 232 |
+
):
|
| 233 |
+
return (module_obj.scale, module_obj.zero_point) # type: ignore[return-value]
|
| 234 |
+
|
| 235 |
+
is_known_fp32_or_int8_input_module = any(
|
| 236 |
+
isinstance(module_obj, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type]
|
| 237 |
+
)
|
| 238 |
+
if is_known_fp32_or_int8_input_module:
|
| 239 |
+
return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map)
|
| 240 |
+
|
| 241 |
+
return None
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def return_first_non_observer_node(
|
| 245 |
+
node: Node,
|
| 246 |
+
gm: GraphModule,
|
| 247 |
+
) -> Node:
|
| 248 |
+
"""
|
| 249 |
+
If node is not an observer, returns it. If node is an observer,
|
| 250 |
+
navigates up the graph and returns the first parent which is not an
|
| 251 |
+
observer. For example,
|
| 252 |
+
|
| 253 |
+
graph: (node_non_obs), node = node_non_obs : returns node_non_obs
|
| 254 |
+
graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs
|
| 255 |
+
graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs
|
| 256 |
+
"""
|
| 257 |
+
if node.op == "call_module":
|
| 258 |
+
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
|
| 259 |
+
if _is_activation_post_process(node_obj):
|
| 260 |
+
assert len(node.args) == 1
|
| 261 |
+
assert isinstance(node.args[0], Node)
|
| 262 |
+
node = node.args[0]
|
| 263 |
+
# code duplication intended, not worth refactoring
|
| 264 |
+
assert isinstance(node.target, str)
|
| 265 |
+
node_obj = getattr_from_fqn(gm, node.target)
|
| 266 |
+
if _is_activation_post_process(node_obj):
|
| 267 |
+
assert len(node.args) == 1
|
| 268 |
+
assert isinstance(node.args[0], Node)
|
| 269 |
+
node = node.args[0]
|
| 270 |
+
return node
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def get_number_of_non_param_args(
|
| 274 |
+
node: Node,
|
| 275 |
+
gm: GraphModule,
|
| 276 |
+
) -> int:
|
| 277 |
+
"""
|
| 278 |
+
Assumes that all non-param args occur first. Returns the number of
|
| 279 |
+
non-param args expected for a node. For example, for
|
| 280 |
+
|
| 281 |
+
F.linear(x, weight, bias)
|
| 282 |
+
|
| 283 |
+
Returns 1, because x is a non-param arg and weight and bias are params.
|
| 284 |
+
For
|
| 285 |
+
|
| 286 |
+
lstm_mod(x, hid)
|
| 287 |
+
|
| 288 |
+
Returns 2, because both x and hid are non-param args.
|
| 289 |
+
"""
|
| 290 |
+
if node.op == "call_module":
|
| 291 |
+
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
|
| 292 |
+
if isinstance(node_obj, nn.LSTM):
|
| 293 |
+
return 2
|
| 294 |
+
|
| 295 |
+
# default is 1
|
| 296 |
+
return 1
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]:
|
| 300 |
+
"""
|
| 301 |
+
Returns the indices of args of the node which we should attach
|
| 302 |
+
loggers to, if input logging is enabled.
|
| 303 |
+
|
| 304 |
+
For example,
|
| 305 |
+
* for (x + y), returns [0, 1]
|
| 306 |
+
* for (1 + y), returns [1]
|
| 307 |
+
* for (x + 1), returns [0]
|
| 308 |
+
* for (linear(x, w, b)) returns [0]
|
| 309 |
+
* by default, returns [0]
|
| 310 |
+
"""
|
| 311 |
+
if len(node.args) == 0:
|
| 312 |
+
return []
|
| 313 |
+
if node.op == "call_function" and (
|
| 314 |
+
# TODO(future PR): use relationship map instead of hardcoding
|
| 315 |
+
node.target in (torch.add, torch.ops.quantized.add, operator.add)
|
| 316 |
+
or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
|
| 317 |
+
):
|
| 318 |
+
result = []
|
| 319 |
+
for i in range(2):
|
| 320 |
+
if type(node.args[i]) == Node:
|
| 321 |
+
result.append(i)
|
| 322 |
+
return result
|
| 323 |
+
return [0]
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def get_target_type_str(node: Node, gm: GraphModule) -> str:
|
| 327 |
+
"""
|
| 328 |
+
Returns a string representation of the type of the function or module
|
| 329 |
+
pointed to by this node, or '' for other node types.
|
| 330 |
+
"""
|
| 331 |
+
target_type = ""
|
| 332 |
+
if node.op in ("call_function", "call_method"):
|
| 333 |
+
target_type = torch.typename(node.target)
|
| 334 |
+
elif node.op == "call_module":
|
| 335 |
+
assert isinstance(node.target, str)
|
| 336 |
+
target_mod = getattr_from_fqn(gm, node.target)
|
| 337 |
+
target_type = torch.typename(target_mod)
|
| 338 |
+
return target_type
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def rekey_logger_info_on_node_name_of_model(
|
| 342 |
+
results: NSResultsType,
|
| 343 |
+
model_name: str,
|
| 344 |
+
) -> NSResultsType:
|
| 345 |
+
"""
|
| 346 |
+
Rekeys the layer name of a results dictionary to use node names
|
| 347 |
+
from `model_name`.
|
| 348 |
+
|
| 349 |
+
For example, transforms
|
| 350 |
+
|
| 351 |
+
{'base_op_1_0': {'node_output': {'model_a':
|
| 352 |
+
[{'ref_node_name': 'linear1', ...}]}}}
|
| 353 |
+
|
| 354 |
+
into
|
| 355 |
+
|
| 356 |
+
{'linear1': {'node_output': {'model_a':
|
| 357 |
+
[{'ref_node_name': 'linear1', ...}]}}}
|
| 358 |
+
|
| 359 |
+
Note: we cannot use these node names directly because they are not
|
| 360 |
+
guaranteed to be consistent across models. This is why we extract
|
| 361 |
+
the results first and rekey afterwards.
|
| 362 |
+
"""
|
| 363 |
+
new_results = {}
|
| 364 |
+
for old_layer_name, result_type_to_results in results.items():
|
| 365 |
+
new_layer_name = None
|
| 366 |
+
for model_name_to_results in result_type_to_results.values():
|
| 367 |
+
for cur_model_name, list_of_results in model_name_to_results.items():
|
| 368 |
+
if cur_model_name == model_name:
|
| 369 |
+
assert len(list_of_results)
|
| 370 |
+
new_layer_name = list_of_results[0]["ref_node_name"]
|
| 371 |
+
else:
|
| 372 |
+
continue
|
| 373 |
+
if new_layer_name is not None:
|
| 374 |
+
new_results[new_layer_name] = result_type_to_results
|
| 375 |
+
else:
|
| 376 |
+
new_results[old_layer_name] = result_type_to_results
|
| 377 |
+
return new_results
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def maybe_add_missing_fqns(results: NSResultsType) -> None:
|
| 381 |
+
"""
|
| 382 |
+
If `fqn` entries are filled in for one of the models in `results`, copies
|
| 383 |
+
them over to any models which do not have them filled out.
|
| 384 |
+
|
| 385 |
+
A common use case benefitting from this is comparing a model prepared by
|
| 386 |
+
quantization to a quantized model. In this case, the model prepared by
|
| 387 |
+
quantization would have `fqn` entries, and the quantized model would not.
|
| 388 |
+
"""
|
| 389 |
+
|
| 390 |
+
# Check in the first result to find any model with fqn entries defined.
|
| 391 |
+
model_name_with_fqns = None
|
| 392 |
+
for result_type_to_results in results.values():
|
| 393 |
+
for model_name_to_results in result_type_to_results.values():
|
| 394 |
+
for model_name, model_results in model_name_to_results.items():
|
| 395 |
+
if len(model_results) > 0:
|
| 396 |
+
if model_results[0]["fqn"] is not None:
|
| 397 |
+
model_name_with_fqns = model_name
|
| 398 |
+
break
|
| 399 |
+
break
|
| 400 |
+
break
|
| 401 |
+
|
| 402 |
+
if model_name_with_fqns:
|
| 403 |
+
for result_type_to_results in results.values():
|
| 404 |
+
for model_name_to_results in result_type_to_results.values():
|
| 405 |
+
ref_model_results = model_name_to_results[model_name_with_fqns]
|
| 406 |
+
for model_name, model_results in model_name_to_results.items():
|
| 407 |
+
if model_name == model_name_with_fqns:
|
| 408 |
+
continue
|
| 409 |
+
for i in range(len(model_results)):
|
| 410 |
+
fqn = ref_model_results[i]["fqn"]
|
| 411 |
+
model_results[i]["fqn"] = fqn
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f):
|
| 415 |
+
def inner(*args, **kwargs):
|
| 416 |
+
a0, a1, *a_other = args
|
| 417 |
+
|
| 418 |
+
if (isinstance(a0, tuple) and isinstance(a1, tuple)) or (
|
| 419 |
+
isinstance(a0, list) and isinstance(a1, list)
|
| 420 |
+
):
|
| 421 |
+
results = []
|
| 422 |
+
for el0, el1 in zip(a0, a1):
|
| 423 |
+
new_args = (el0, el1, *a_other)
|
| 424 |
+
results.append(inner(*new_args, **kwargs))
|
| 425 |
+
return results
|
| 426 |
+
|
| 427 |
+
elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor):
|
| 428 |
+
if a0.is_quantized:
|
| 429 |
+
a0 = a0.dequantize()
|
| 430 |
+
if a1.is_quantized:
|
| 431 |
+
a1 = a1.dequantize()
|
| 432 |
+
|
| 433 |
+
# for the purposes of this util, only handle floats
|
| 434 |
+
if a0.dtype != torch.float or a1.dtype != torch.float:
|
| 435 |
+
return None
|
| 436 |
+
|
| 437 |
+
new_args = (a0, a1, *a_other)
|
| 438 |
+
return f(*new_args, **kwargs)
|
| 439 |
+
|
| 440 |
+
return inner
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
@maybe_dequantize_first_two_tensor_args_and_handle_tuples
|
| 444 |
+
def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 445 |
+
"""
|
| 446 |
+
Computes the SQNR between `x` and `y`.
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
x: Tensor or tuple of tensors
|
| 450 |
+
y: Tensor or tuple of tensors
|
| 451 |
+
|
| 452 |
+
Return:
|
| 453 |
+
float or tuple of floats
|
| 454 |
+
"""
|
| 455 |
+
Ps = torch.norm(x)
|
| 456 |
+
Pn = torch.norm(x - y)
|
| 457 |
+
return 20 * torch.log10(Ps / Pn)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
@maybe_dequantize_first_two_tensor_args_and_handle_tuples
|
| 461 |
+
def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 462 |
+
"""
|
| 463 |
+
Computes the normalized L2 error between `x` and `y`.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
x: Tensor or tuple of tensors
|
| 467 |
+
y: Tensor or tuple of tensors
|
| 468 |
+
|
| 469 |
+
Return:
|
| 470 |
+
float or tuple of floats
|
| 471 |
+
"""
|
| 472 |
+
return torch.sqrt(((x - y) ** 2).sum() / (x ** 2).sum())
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
@maybe_dequantize_first_two_tensor_args_and_handle_tuples
|
| 476 |
+
def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 477 |
+
"""
|
| 478 |
+
Computes the cosine similarity between `x` and `y`.
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
x: Tensor or tuple of tensors
|
| 482 |
+
y: Tensor or tuple of tensors
|
| 483 |
+
|
| 484 |
+
Return:
|
| 485 |
+
float or tuple of floats
|
| 486 |
+
"""
|
| 487 |
+
# For convolutions, the shape of the quantized weight has one additional
|
| 488 |
+
# dimension compared to the shape of the fp32 weight. Match the shapes
|
| 489 |
+
# to enable cosine similarity comparison.
|
| 490 |
+
x = x.reshape(1, -1)
|
| 491 |
+
y = y.reshape(1, -1)
|
| 492 |
+
return torch.nn.functional.cosine_similarity(x, y)
|
| 493 |
+
|
| 494 |
+
def op_type_supports_shadowing(node: Node) -> bool:
|
| 495 |
+
if node.op == 'call_function':
|
| 496 |
+
if node.target in (torch.add, torch.mul, operator.add, operator.mul, torch.cat, torch.stack):
|
| 497 |
+
# shadowing for ops with multiple tensor inputs is not implemented yet
|
| 498 |
+
return False
|
| 499 |
+
return True
|
| 500 |
+
|
| 501 |
+
def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
|
| 502 |
+
"""
|
| 503 |
+
Given a node, gets the n'th input to that node, normalizing
|
| 504 |
+
args and kwargs to the best of its ability.
|
| 505 |
+
"""
|
| 506 |
+
try:
|
| 507 |
+
norm_args_and_kwargs = node.normalized_arguments(
|
| 508 |
+
gm, normalize_to_only_use_kwargs=True)
|
| 509 |
+
if norm_args_and_kwargs is not None:
|
| 510 |
+
norm_args, norm_kwargs = norm_args_and_kwargs
|
| 511 |
+
assert len(norm_args) + len(norm_kwargs) > idx
|
| 512 |
+
if idx < len(norm_args):
|
| 513 |
+
return norm_args[idx]
|
| 514 |
+
else:
|
| 515 |
+
# note: in Python 3.7+ dicts are ordered
|
| 516 |
+
return list(norm_kwargs.values())[idx]
|
| 517 |
+
else:
|
| 518 |
+
assert len(node.args) + len(node.kwargs) > idx
|
| 519 |
+
if idx < len(node.args):
|
| 520 |
+
return node.args[idx] # type: ignore[return-value]
|
| 521 |
+
else:
|
| 522 |
+
kwargs_idx = idx + len(node.args)
|
| 523 |
+
return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value]
|
| 524 |
+
except RuntimeError:
|
| 525 |
+
# this RuntimeError happens when node argument normalization
|
| 526 |
+
# requires typehints to proceed, such as for torch.add where
|
| 527 |
+
# either the first, second or both arguments could be tensors
|
| 528 |
+
assert len(node.args) + len(node.kwargs) > idx
|
| 529 |
+
if idx < len(node.args):
|
| 530 |
+
return node.args[idx] # type: ignore[return-value]
|
| 531 |
+
else:
|
| 532 |
+
kwargs_idx = idx + len(node.args)
|
| 533 |
+
return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (6.95 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-311.pyc
ADDED
|
Binary file (7.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-311.pyc
ADDED
|
Binary file (9.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-311.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-311.pyc
ADDED
|
Binary file (1.43 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-311.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-311.pyc
ADDED
|
Binary file (4.07 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-311.pyc
ADDED
|
Binary file (7.07 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import operator
|
| 2 |
+
import torch
|
| 3 |
+
from torch.ao.quantization.backend_config import (
|
| 4 |
+
BackendConfig,
|
| 5 |
+
DTypeConfig,
|
| 6 |
+
ObservationType,
|
| 7 |
+
BackendPatternConfig,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
weighted_op_quint8_dtype_config = DTypeConfig(
|
| 11 |
+
input_dtype=torch.quint8,
|
| 12 |
+
output_dtype=torch.quint8,
|
| 13 |
+
weight_dtype=torch.qint8,
|
| 14 |
+
bias_dtype=torch.float,
|
| 15 |
+
)
|
| 16 |
+
from typing import List
|
| 17 |
+
|
| 18 |
+
def get_linear_configs():
|
| 19 |
+
linear_configs = []
|
| 20 |
+
observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
|
| 21 |
+
dtype_configs = [weighted_op_quint8_dtype_config]
|
| 22 |
+
|
| 23 |
+
# TODO: need to fix the way we insert observers for this pattern
|
| 24 |
+
# should be solved in the new fusion API
|
| 25 |
+
# reason that this doesn't work: the pattern is a bit complicated and we don't
|
| 26 |
+
# have a way to specify which input of the pattern we would like to observe
|
| 27 |
+
# pattern:
|
| 28 |
+
# bias input weight
|
| 29 |
+
# \ | /
|
| 30 |
+
# \ | t
|
| 31 |
+
# \ | /
|
| 32 |
+
# addmm
|
| 33 |
+
# we want to observe "weight" as weight, but there is not way to convey this
|
| 34 |
+
# information with current pattern language
|
| 35 |
+
#
|
| 36 |
+
# right now:
|
| 37 |
+
# original:
|
| 38 |
+
# weight - t \
|
| 39 |
+
# input - addmm
|
| 40 |
+
# observed (no hack):
|
| 41 |
+
# weight - t - observer \
|
| 42 |
+
# input - observer - addmm
|
| 43 |
+
# target:
|
| 44 |
+
# weight - observer - t \
|
| 45 |
+
# input - observer - addmm
|
| 46 |
+
|
| 47 |
+
# def root_node_getter(node_pattern):
|
| 48 |
+
# addmm, bias, act, weight = node_pattern
|
| 49 |
+
# return addmm
|
| 50 |
+
|
| 51 |
+
# linear_configs.append(
|
| 52 |
+
# BackendPatternConfig((torch.ops.aten.addmm.default, MatchAllNode, MatchAllNode, torch.ops.aten.t.default))
|
| 53 |
+
# .set_observation_type(observation_type) # noqa: E131
|
| 54 |
+
# .set_dtype_configs(dtype_configs)
|
| 55 |
+
# ._set_root_node_getter(root_node_getter))
|
| 56 |
+
|
| 57 |
+
linear_configs.append(
|
| 58 |
+
BackendPatternConfig(torch.ops.aten.addmm.default)
|
| 59 |
+
.set_observation_type(observation_type) # noqa: E131
|
| 60 |
+
.set_dtype_configs(dtype_configs)
|
| 61 |
+
._set_input_type_to_index({"weight": 2, "bias": 0})
|
| 62 |
+
)
|
| 63 |
+
# linear is decomposed to `t - mm` if bias is not present
|
| 64 |
+
linear_configs.append(
|
| 65 |
+
BackendPatternConfig(torch.ops.aten.mm.default)
|
| 66 |
+
.set_observation_type(observation_type) # noqa: E131
|
| 67 |
+
.set_dtype_configs(dtype_configs)
|
| 68 |
+
._set_input_type_to_index({"weight": 1})
|
| 69 |
+
)
|
| 70 |
+
return linear_configs
|
| 71 |
+
|
| 72 |
+
def get_conv_configs():
|
| 73 |
+
conv_configs = []
|
| 74 |
+
observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
|
| 75 |
+
dtype_configs = [weighted_op_quint8_dtype_config]
|
| 76 |
+
conv_configs.append(
|
| 77 |
+
BackendPatternConfig(torch.ops.aten.convolution.default)
|
| 78 |
+
.set_observation_type(observation_type) # noqa: E131
|
| 79 |
+
.set_dtype_configs(dtype_configs)
|
| 80 |
+
._set_input_type_to_index({"weight": 1, "bias": 2})
|
| 81 |
+
)
|
| 82 |
+
conv_configs.append(
|
| 83 |
+
BackendPatternConfig((torch.ops.aten.convolution.default, torch.ops.aten.relu.default))
|
| 84 |
+
.set_observation_type(observation_type) # noqa: E131
|
| 85 |
+
.set_dtype_configs(dtype_configs)
|
| 86 |
+
._set_input_type_to_index({"weight": 1, "bias": 2})
|
| 87 |
+
)
|
| 88 |
+
# TODO: remove when functionalization is supported in PT2 mode
|
| 89 |
+
conv_configs.append(
|
| 90 |
+
BackendPatternConfig((torch.ops.aten.convolution.default, torch.ops.aten.relu_.default))
|
| 91 |
+
.set_observation_type(observation_type) # noqa: E131
|
| 92 |
+
.set_dtype_configs(dtype_configs)
|
| 93 |
+
._set_input_type_to_index({"weight": 1, "bias": 2})
|
| 94 |
+
)
|
| 95 |
+
return conv_configs
|
| 96 |
+
|
| 97 |
+
def get_pooling_configs():
|
| 98 |
+
backend_pattern_configs = []
|
| 99 |
+
observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
|
| 100 |
+
dtype_configs = [weighted_op_quint8_dtype_config]
|
| 101 |
+
|
| 102 |
+
def root_node_getter(node_pattern):
|
| 103 |
+
getitem, maxpool, index = node_pattern
|
| 104 |
+
return maxpool
|
| 105 |
+
|
| 106 |
+
backend_pattern_configs.append(
|
| 107 |
+
BackendPatternConfig()
|
| 108 |
+
._set_pattern_complex_format((operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0))
|
| 109 |
+
.set_observation_type(observation_type) # noqa: E131
|
| 110 |
+
.set_dtype_configs(dtype_configs)
|
| 111 |
+
._set_root_node_getter(root_node_getter)
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return backend_pattern_configs
|
| 115 |
+
|
| 116 |
+
def get_relu_configs():
|
| 117 |
+
backend_pattern_configs = []
|
| 118 |
+
observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
|
| 119 |
+
dtype_configs = [weighted_op_quint8_dtype_config]
|
| 120 |
+
backend_pattern_configs.append(
|
| 121 |
+
BackendPatternConfig(torch.ops.aten.relu.default)
|
| 122 |
+
.set_observation_type(observation_type) # noqa: E131
|
| 123 |
+
.set_dtype_configs(dtype_configs))
|
| 124 |
+
return backend_pattern_configs
|
| 125 |
+
|
| 126 |
+
def get_binary_op_configs():
|
| 127 |
+
binary_op_configs: List[BackendPatternConfig] = []
|
| 128 |
+
dtype_configs = [weighted_op_quint8_dtype_config]
|
| 129 |
+
num_tensor_args_to_observation_type_mapping = {
|
| 130 |
+
# TODO: this is not used right now since we have extra check in prepare
|
| 131 |
+
# will need to change this to NO_OBSERVER later after we implemented
|
| 132 |
+
# Tensor dtype inference properly
|
| 133 |
+
0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
| 134 |
+
1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
|
| 135 |
+
2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
| 136 |
+
}
|
| 137 |
+
for op_with_quantized_bop_scalar_variant in [torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor]:
|
| 138 |
+
bop_patterns = [
|
| 139 |
+
(op_with_quantized_bop_scalar_variant, torch.ops.aten.relu.default),
|
| 140 |
+
op_with_quantized_bop_scalar_variant,
|
| 141 |
+
# TODO: remove when functionalization is supported in pt2_mode
|
| 142 |
+
(op_with_quantized_bop_scalar_variant, torch.ops.aten.relu_.default),
|
| 143 |
+
]
|
| 144 |
+
for bop_pattern in bop_patterns:
|
| 145 |
+
binary_op_configs.append(
|
| 146 |
+
BackendPatternConfig(bop_pattern)
|
| 147 |
+
.set_dtype_configs(dtype_configs) # noqa: E131
|
| 148 |
+
._set_num_tensor_args_to_observation_type(num_tensor_args_to_observation_type_mapping))
|
| 149 |
+
|
| 150 |
+
return binary_op_configs
|
| 151 |
+
|
| 152 |
+
def get_qnnpack_pt2e_backend_config():
|
| 153 |
+
return (
|
| 154 |
+
BackendConfig("qnnpack_pytorch_2.0_export")
|
| 155 |
+
.set_backend_pattern_configs(get_linear_configs())
|
| 156 |
+
.set_backend_pattern_configs(get_binary_op_configs())
|
| 157 |
+
.set_backend_pattern_configs(get_conv_configs())
|
| 158 |
+
.set_backend_pattern_configs(get_pooling_configs())
|
| 159 |
+
.set_backend_pattern_configs(get_relu_configs())
|
| 160 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/fbgemm.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from ._common_operator_config_utils import (
|
| 3 |
+
_get_binary_op_configs,
|
| 4 |
+
_get_bn_configs,
|
| 5 |
+
_get_cat_config,
|
| 6 |
+
_get_conv_configs,
|
| 7 |
+
_get_default_op_configs,
|
| 8 |
+
_get_embedding_op_configs,
|
| 9 |
+
_get_fixed_qparams_op_configs,
|
| 10 |
+
_get_linear_configs,
|
| 11 |
+
_get_rnn_op_configs,
|
| 12 |
+
_get_share_qparams_op_configs,
|
| 13 |
+
_get_tensor_info_op_configs,
|
| 14 |
+
)
|
| 15 |
+
from .backend_config import BackendConfig, DTypeConfig
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"get_fbgemm_backend_config",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
# ===================
|
| 22 |
+
# | DTYPE CONFIGS |
|
| 23 |
+
# ===================
|
| 24 |
+
|
| 25 |
+
# TODO: For now, these DTypeConfigs are identical to the ones defined in native.py
|
| 26 |
+
# In the future, once we support specifying quant_min/quant_max and scale_min/scale_max,
|
| 27 |
+
# these will diverge. In particular, for FBGEMM, we will restrict the activation quantized
|
| 28 |
+
# values to within [0, 127].
|
| 29 |
+
|
| 30 |
+
fbgemm_weighted_op_quint8_dtype_config = DTypeConfig(
|
| 31 |
+
input_dtype=torch.quint8,
|
| 32 |
+
output_dtype=torch.quint8,
|
| 33 |
+
weight_dtype=torch.qint8,
|
| 34 |
+
bias_dtype=torch.float,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
fbgemm_default_op_quint8_dtype_config = DTypeConfig(
|
| 38 |
+
input_dtype=torch.quint8,
|
| 39 |
+
output_dtype=torch.quint8,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
fbgemm_default_op_fp16_dtype_config = DTypeConfig(
|
| 43 |
+
input_dtype=torch.float16,
|
| 44 |
+
output_dtype=torch.float16,
|
| 45 |
+
weight_dtype=torch.float16,
|
| 46 |
+
bias_dtype=torch.float16,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
fbgemm_default_dynamic_int8_dtype_config = DTypeConfig(
|
| 50 |
+
input_dtype=torch.quint8,
|
| 51 |
+
output_dtype=torch.float,
|
| 52 |
+
weight_dtype=torch.qint8,
|
| 53 |
+
bias_dtype=torch.float,
|
| 54 |
+
is_dynamic=True,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
fbgemm_default_dynamic_float16_dtype_config = DTypeConfig(
|
| 58 |
+
input_dtype=torch.float16,
|
| 59 |
+
output_dtype=torch.float,
|
| 60 |
+
weight_dtype=torch.float16,
|
| 61 |
+
bias_dtype=torch.float,
|
| 62 |
+
is_dynamic=True,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
fbgemm_weight_only_quint8_dtype_config = DTypeConfig(
|
| 66 |
+
input_dtype=torch.float,
|
| 67 |
+
output_dtype=torch.float,
|
| 68 |
+
weight_dtype=torch.quint8,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
fbgemm_weight_only_quint4x2_dtype_config = DTypeConfig(
|
| 72 |
+
input_dtype=torch.float,
|
| 73 |
+
output_dtype=torch.float,
|
| 74 |
+
weight_dtype=torch.quint4x2,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# =====================
|
| 79 |
+
# | BACKEND CONFIGS |
|
| 80 |
+
# =====================
|
| 81 |
+
|
| 82 |
+
def get_fbgemm_backend_config() -> BackendConfig:
|
| 83 |
+
"""
|
| 84 |
+
Return the `BackendConfig` for PyTorch's native FBGEMM backend.
|
| 85 |
+
"""
|
| 86 |
+
conv_dtype_configs = [fbgemm_weighted_op_quint8_dtype_config]
|
| 87 |
+
linear_dtype_configs = [
|
| 88 |
+
fbgemm_weighted_op_quint8_dtype_config,
|
| 89 |
+
fbgemm_default_dynamic_int8_dtype_config,
|
| 90 |
+
fbgemm_default_dynamic_float16_dtype_config,
|
| 91 |
+
]
|
| 92 |
+
binary_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
|
| 93 |
+
default_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
|
| 94 |
+
fixed_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
|
| 95 |
+
share_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
|
| 96 |
+
tensor_info_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
|
| 97 |
+
rnn_op_dtype_configs = [
|
| 98 |
+
fbgemm_default_dynamic_int8_dtype_config,
|
| 99 |
+
fbgemm_default_dynamic_float16_dtype_config,
|
| 100 |
+
]
|
| 101 |
+
embedding_op_dtype_configs = [
|
| 102 |
+
fbgemm_weight_only_quint8_dtype_config,
|
| 103 |
+
fbgemm_weight_only_quint4x2_dtype_config,
|
| 104 |
+
]
|
| 105 |
+
return BackendConfig("fbgemm") \
|
| 106 |
+
.set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
|
| 107 |
+
.set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
|
| 108 |
+
.set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
|
| 109 |
+
.set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
|
| 110 |
+
.set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
|
| 111 |
+
.set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
|
| 112 |
+
.set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
|
| 113 |
+
.set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
|
| 114 |
+
.set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
|
| 115 |
+
.set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
|
| 116 |
+
.set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/native.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from ._common_operator_config_utils import (
|
| 3 |
+
_get_binary_op_configs,
|
| 4 |
+
_get_bn_configs,
|
| 5 |
+
_get_cat_config,
|
| 6 |
+
_get_conv_configs,
|
| 7 |
+
_get_default_op_configs,
|
| 8 |
+
_get_embedding_op_configs,
|
| 9 |
+
_get_fixed_qparams_op_configs,
|
| 10 |
+
_get_linear_configs,
|
| 11 |
+
_get_ln_configs,
|
| 12 |
+
_get_rnn_op_configs,
|
| 13 |
+
_get_share_qparams_op_configs,
|
| 14 |
+
_get_tensor_info_op_configs,
|
| 15 |
+
)
|
| 16 |
+
from .backend_config import BackendConfig, DTypeConfig
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"get_test_only_legacy_native_backend_config",
|
| 20 |
+
"default_op_quint8_dtype_config",
|
| 21 |
+
"default_op_fp16_dtype_config",
|
| 22 |
+
"default_dynamic_int8_dtype_config",
|
| 23 |
+
"default_dynamic_float16_dtype_config",
|
| 24 |
+
"input_output_only_quint8_dtype_config",
|
| 25 |
+
"weight_only_quint8_dtype_config",
|
| 26 |
+
"weight_only_quint4x2_dtype_config",
|
| 27 |
+
"get_native_backend_config",
|
| 28 |
+
"get_native_backend_config_dict",
|
| 29 |
+
"get_test_only_legacy_native_backend_config_dict",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
# ===================
|
| 33 |
+
# | DTYPE CONFIGS |
|
| 34 |
+
# ===================
|
| 35 |
+
|
| 36 |
+
# weighted op int8 dtype config
|
| 37 |
+
# this is config for ops that has quantized weights, like linear, conv
|
| 38 |
+
weighted_op_quint8_dtype_config = DTypeConfig(
|
| 39 |
+
input_dtype=torch.quint8,
|
| 40 |
+
output_dtype=torch.quint8,
|
| 41 |
+
weight_dtype=torch.qint8,
|
| 42 |
+
bias_dtype=torch.float,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
default_op_quint8_dtype_config = DTypeConfig(
|
| 46 |
+
input_dtype=torch.quint8,
|
| 47 |
+
output_dtype=torch.quint8,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
default_op_fp16_dtype_config = DTypeConfig(
|
| 51 |
+
input_dtype=torch.float16,
|
| 52 |
+
output_dtype=torch.float16,
|
| 53 |
+
weight_dtype=torch.float16,
|
| 54 |
+
bias_dtype=torch.float16,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
default_dynamic_int8_dtype_config = DTypeConfig(
|
| 58 |
+
input_dtype=torch.quint8,
|
| 59 |
+
output_dtype=torch.float,
|
| 60 |
+
weight_dtype=torch.qint8,
|
| 61 |
+
bias_dtype=torch.float,
|
| 62 |
+
# currently the dtype check is not yet enabled, so we provided the dtype_configs but
|
| 63 |
+
# it is not really used yet,
|
| 64 |
+
# we will enable it a bit later after we moved everything to backend_config_dict
|
| 65 |
+
is_dynamic=True,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
default_dynamic_float16_dtype_config = DTypeConfig(
|
| 69 |
+
input_dtype=torch.float16,
|
| 70 |
+
output_dtype=torch.float,
|
| 71 |
+
weight_dtype=torch.float16,
|
| 72 |
+
bias_dtype=torch.float,
|
| 73 |
+
# currently the dtype check is not yet enabled, so we provided the dtype_configs but
|
| 74 |
+
# it is not really used yet,
|
| 75 |
+
# we will enable it a bit later after we moved everything to backend_config_dict
|
| 76 |
+
is_dynamic=True,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Needed for LayerNorm and f.layer_norm, since currently the kernel only supports float weights
|
| 80 |
+
input_output_only_quint8_dtype_config = DTypeConfig(
|
| 81 |
+
input_dtype=torch.quint8,
|
| 82 |
+
output_dtype=torch.quint8,
|
| 83 |
+
weight_dtype=torch.float,
|
| 84 |
+
bias_dtype=torch.float,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
weight_only_quint8_dtype_config = DTypeConfig(
|
| 88 |
+
input_dtype=torch.float,
|
| 89 |
+
output_dtype=torch.float,
|
| 90 |
+
weight_dtype=torch.quint8,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
weight_only_quint4x2_dtype_config = DTypeConfig(
|
| 94 |
+
input_dtype=torch.float,
|
| 95 |
+
output_dtype=torch.float,
|
| 96 |
+
weight_dtype=torch.quint4x2,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# =====================
|
| 101 |
+
# | BACKEND CONFIGS |
|
| 102 |
+
# =====================
|
| 103 |
+
|
| 104 |
+
def get_test_only_legacy_native_backend_config() -> BackendConfig:
|
| 105 |
+
"""
|
| 106 |
+
Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional fp16 ops.
|
| 107 |
+
"""
|
| 108 |
+
conv_dtype_configs = [weighted_op_quint8_dtype_config]
|
| 109 |
+
linear_dtype_configs = [
|
| 110 |
+
weighted_op_quint8_dtype_config,
|
| 111 |
+
default_dynamic_int8_dtype_config,
|
| 112 |
+
default_dynamic_float16_dtype_config,
|
| 113 |
+
default_op_fp16_dtype_config,
|
| 114 |
+
]
|
| 115 |
+
binary_op_dtype_configs = [
|
| 116 |
+
default_op_quint8_dtype_config,
|
| 117 |
+
default_op_fp16_dtype_config,
|
| 118 |
+
]
|
| 119 |
+
default_op_dtype_configs = [default_op_quint8_dtype_config]
|
| 120 |
+
fixed_qparams_op_dtype_configs = [
|
| 121 |
+
default_op_quint8_dtype_config,
|
| 122 |
+
default_op_fp16_dtype_config,
|
| 123 |
+
]
|
| 124 |
+
share_qparams_op_dtype_configs = [
|
| 125 |
+
default_op_quint8_dtype_config,
|
| 126 |
+
default_op_fp16_dtype_config
|
| 127 |
+
]
|
| 128 |
+
tensor_info_op_dtype_configs = [
|
| 129 |
+
default_op_quint8_dtype_config,
|
| 130 |
+
]
|
| 131 |
+
rnn_op_dtype_configs = [
|
| 132 |
+
default_dynamic_int8_dtype_config,
|
| 133 |
+
default_dynamic_float16_dtype_config,
|
| 134 |
+
]
|
| 135 |
+
embedding_op_dtype_configs = [
|
| 136 |
+
weight_only_quint8_dtype_config,
|
| 137 |
+
weight_only_quint4x2_dtype_config,
|
| 138 |
+
]
|
| 139 |
+
layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config]
|
| 140 |
+
return BackendConfig("_native_and_fp16") \
|
| 141 |
+
.set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
|
| 142 |
+
.set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
|
| 143 |
+
.set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
|
| 144 |
+
.set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
|
| 145 |
+
.set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
|
| 146 |
+
.set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
|
| 147 |
+
.set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
|
| 148 |
+
.set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
|
| 149 |
+
.set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
|
| 150 |
+
.set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) \
|
| 151 |
+
.set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
|
| 152 |
+
.set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
|
| 153 |
+
|
| 154 |
+
def get_native_backend_config() -> BackendConfig:
|
| 155 |
+
"""
|
| 156 |
+
Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack).
|
| 157 |
+
"""
|
| 158 |
+
# TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK BackendConfigs
|
| 159 |
+
conv_dtype_configs = [weighted_op_quint8_dtype_config]
|
| 160 |
+
linear_dtype_configs = [
|
| 161 |
+
weighted_op_quint8_dtype_config,
|
| 162 |
+
default_dynamic_int8_dtype_config,
|
| 163 |
+
default_dynamic_float16_dtype_config,
|
| 164 |
+
]
|
| 165 |
+
binary_op_dtype_configs = [default_op_quint8_dtype_config]
|
| 166 |
+
default_op_dtype_configs = [default_op_quint8_dtype_config]
|
| 167 |
+
fixed_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
|
| 168 |
+
share_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
|
| 169 |
+
tensor_info_op_dtype_configs = [default_op_quint8_dtype_config]
|
| 170 |
+
rnn_op_dtype_configs = [
|
| 171 |
+
default_dynamic_int8_dtype_config,
|
| 172 |
+
default_dynamic_float16_dtype_config,
|
| 173 |
+
]
|
| 174 |
+
embedding_op_dtype_configs = [
|
| 175 |
+
weight_only_quint8_dtype_config,
|
| 176 |
+
weight_only_quint4x2_dtype_config,
|
| 177 |
+
]
|
| 178 |
+
layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config]
|
| 179 |
+
return BackendConfig("native") \
|
| 180 |
+
.set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
|
| 181 |
+
.set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
|
| 182 |
+
.set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
|
| 183 |
+
.set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
|
| 184 |
+
.set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
|
| 185 |
+
.set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
|
| 186 |
+
.set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
|
| 187 |
+
.set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
|
| 188 |
+
.set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
|
| 189 |
+
.set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) \
|
| 190 |
+
.set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
|
| 191 |
+
.set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
|
| 192 |
+
|
| 193 |
+
def get_native_backend_config_dict():
|
| 194 |
+
"""
|
| 195 |
+
Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) in dictionary form.
|
| 196 |
+
"""
|
| 197 |
+
return get_native_backend_config().to_dict()
|
| 198 |
+
|
| 199 |
+
def get_test_only_legacy_native_backend_config_dict():
|
| 200 |
+
"""
|
| 201 |
+
Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional
|
| 202 |
+
fp16 ops in dictionary form.
|
| 203 |
+
"""
|
| 204 |
+
return get_test_only_legacy_native_backend_config().to_dict()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fuser_method_mappings.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.ao.nn.intrinsic as nni
|
| 3 |
+
|
| 4 |
+
from typing import Any, Union, Callable, List, Tuple, Dict, Optional, Type
|
| 5 |
+
from torch.ao.quantization.utils import Pattern, get_combined_dict, MatchAllNode
|
| 6 |
+
import itertools
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"fuse_conv_bn",
|
| 10 |
+
"fuse_conv_bn_relu",
|
| 11 |
+
"fuse_linear_bn",
|
| 12 |
+
"fuse_convtranspose_bn",
|
| 13 |
+
"get_fuser_method",
|
| 14 |
+
"get_fuser_method_new",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
def fuse_conv_bn(is_qat, conv, bn):
|
| 18 |
+
r"""Return the fused the conv and bn modules.
|
| 19 |
+
Given the conv and bn modules, fuses them and returns the fused module
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
is_qat: a flag for whether we are using quantization aware training fusion
|
| 23 |
+
or post training quantization fusion
|
| 24 |
+
conv: Module instance of type conv2d/conv3d
|
| 25 |
+
bn: Spatial BN instance that needs to be fused with the conv
|
| 26 |
+
|
| 27 |
+
Examples::
|
| 28 |
+
|
| 29 |
+
>>> m1 = nn.Conv2d(10, 20, 3)
|
| 30 |
+
>>> b1 = nn.BatchNorm2d(20)
|
| 31 |
+
>>> # xdoctest: +SKIP
|
| 32 |
+
>>> m2 = fuse_conv_bn(m1, b1)
|
| 33 |
+
"""
|
| 34 |
+
assert conv.training == bn.training, \
|
| 35 |
+
"Conv and BN both must be in the same mode (train or eval)."
|
| 36 |
+
|
| 37 |
+
fused_module_class_map = {
|
| 38 |
+
nn.Conv1d: nni.ConvBn1d,
|
| 39 |
+
nn.Conv2d: nni.ConvBn2d,
|
| 40 |
+
nn.Conv3d: nni.ConvBn3d,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
if is_qat:
|
| 44 |
+
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
|
| 45 |
+
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
|
| 46 |
+
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
|
| 47 |
+
fused_module_class = fused_module_class_map.get((type(conv)), None)
|
| 48 |
+
if fused_module_class is not None:
|
| 49 |
+
return fused_module_class(conv, bn)
|
| 50 |
+
else:
|
| 51 |
+
raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn)}")
|
| 52 |
+
else:
|
| 53 |
+
return nn.utils.fuse_conv_bn_eval(conv, bn)
|
| 54 |
+
|
| 55 |
+
def fuse_conv_bn_relu(is_qat, conv, bn, relu):
|
| 56 |
+
r"""Return the fused conv and bv modules.
|
| 57 |
+
|
| 58 |
+
Given the conv and bn modules, fuses them and returns the fused module
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
is_qat: a flag for whether we are using quantization aware training fusion
|
| 62 |
+
or post training quantization fusion
|
| 63 |
+
conv: Module instance of type conv2d/conv3d
|
| 64 |
+
bn: Spatial BN instance that needs to be fused with the conv
|
| 65 |
+
|
| 66 |
+
Examples::
|
| 67 |
+
|
| 68 |
+
>>> m1 = nn.Conv2d(10, 20, 3)
|
| 69 |
+
>>> b1 = nn.BatchNorm2d(20)
|
| 70 |
+
>>> r1 = nn.ReLU(inplace=False)
|
| 71 |
+
>>> # xdoctest: +SKIP
|
| 72 |
+
>>> m2 = fuse_conv_bn_relu(m1, b1, r1)
|
| 73 |
+
"""
|
| 74 |
+
assert conv.training == bn.training == relu.training, \
|
| 75 |
+
"Conv and BN both must be in the same mode (train or eval)."
|
| 76 |
+
fused_module : Optional[Type[nn.Sequential]] = None
|
| 77 |
+
if is_qat:
|
| 78 |
+
map_to_fused_module_train = {
|
| 79 |
+
nn.Conv1d: nni.ConvBnReLU1d,
|
| 80 |
+
nn.Conv2d: nni.ConvBnReLU2d,
|
| 81 |
+
nn.Conv3d: nni.ConvBnReLU3d,
|
| 82 |
+
}
|
| 83 |
+
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
|
| 84 |
+
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
|
| 85 |
+
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
|
| 86 |
+
fused_module = map_to_fused_module_train.get(type(conv), None)
|
| 87 |
+
if fused_module is not None:
|
| 88 |
+
return fused_module(conv, bn, relu)
|
| 89 |
+
else:
|
| 90 |
+
raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, relu)}")
|
| 91 |
+
else:
|
| 92 |
+
map_to_fused_module_eval = {
|
| 93 |
+
nn.Conv1d: nni.ConvReLU1d,
|
| 94 |
+
nn.Conv2d: nni.ConvReLU2d,
|
| 95 |
+
nn.Conv3d: nni.ConvReLU3d,
|
| 96 |
+
}
|
| 97 |
+
fused_module = map_to_fused_module_eval.get(type(conv), None)
|
| 98 |
+
if fused_module is not None:
|
| 99 |
+
fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
|
| 100 |
+
return fused_module(fused_conv, relu)
|
| 101 |
+
else:
|
| 102 |
+
raise NotImplementedError(f"Cannot fuse eval modules: {(conv, bn, relu)}")
|
| 103 |
+
|
| 104 |
+
def fuse_linear_bn(is_qat, linear, bn):
|
| 105 |
+
r"""Return the fused linear and bn modules.
|
| 106 |
+
Given the linear and bn modules, fuses them and returns the fused module
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
is_qat: a flag for whether we are using quantization aware training fusion
|
| 110 |
+
or post training quantization fusion
|
| 111 |
+
linear: Module instance of type Linear
|
| 112 |
+
bn: BatchNorm1d instance that needs to be fused with the linear layer
|
| 113 |
+
|
| 114 |
+
Examples::
|
| 115 |
+
|
| 116 |
+
>>> m1 = nn.Linear(20, 10)
|
| 117 |
+
>>> b1 = nn.BatchNorm1d(10)
|
| 118 |
+
>>> # xdoctest: +SKIP
|
| 119 |
+
>>> m2 = fuse_linear_bn(m1, b1)
|
| 120 |
+
"""
|
| 121 |
+
assert linear.training == bn.training, \
|
| 122 |
+
"Linear and BN both must be in the same mode (train or eval)."
|
| 123 |
+
|
| 124 |
+
if is_qat:
|
| 125 |
+
assert bn.num_features == linear.out_features, \
|
| 126 |
+
"Output features of Linear must match num_features of BatchNorm1d"
|
| 127 |
+
assert bn.affine, "Only support fusing BatchNorm1d with affine set to True"
|
| 128 |
+
assert bn.track_running_stats, \
|
| 129 |
+
"Only support fusing BatchNorm1d with tracking_running_stats set to True"
|
| 130 |
+
return nni.LinearBn1d(linear, bn)
|
| 131 |
+
else:
|
| 132 |
+
return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
|
| 133 |
+
|
| 134 |
+
def fuse_convtranspose_bn(is_qat, convt, bn):
|
| 135 |
+
r"""Return the fused ConvTranspose and bn modules.
|
| 136 |
+
Given ConvTranspose and bn modules, fuses them and returns the fused module
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
convt: Module instance of type ConvTransposeNd
|
| 140 |
+
bn: BatchNormNd instance that needs to be fused with the linear layer.
|
| 141 |
+
batch norm N should match the ConvTranspose N
|
| 142 |
+
|
| 143 |
+
Examples::
|
| 144 |
+
|
| 145 |
+
>>> m1 = nn.ConvTranspose2d(10, 20, 3)
|
| 146 |
+
>>> b1 = nn.BatchNorm2d(20)
|
| 147 |
+
>>> # xdoctest: +SKIP
|
| 148 |
+
>>> m2 = fuse_convtranspose_bn(m1, b1)
|
| 149 |
+
"""
|
| 150 |
+
assert convt.training == bn.training, \
|
| 151 |
+
"ConvTranspose and BN both must be in the same mode (train or eval)."
|
| 152 |
+
|
| 153 |
+
if is_qat:
|
| 154 |
+
raise Exception("Fusing ConvTranspose+BatchNorm not yet supported in QAT.")
|
| 155 |
+
else:
|
| 156 |
+
return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True)
|
| 157 |
+
|
| 158 |
+
def _sequential_wrapper2(sequential):
|
| 159 |
+
"""Return a sequential wrapped that for is_qat and two modules.
|
| 160 |
+
Given a sequential class for two modules, return a function that takes
|
| 161 |
+
is_qat, and then two modules as argument, that ignores the is_qat flag
|
| 162 |
+
and always returns the sequential that combines the two input modules
|
| 163 |
+
"""
|
| 164 |
+
def fuser_method(is_qat, m1, m2):
|
| 165 |
+
return sequential(m1, m2)
|
| 166 |
+
return fuser_method
|
| 167 |
+
|
| 168 |
+
_DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
|
| 169 |
+
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
|
| 170 |
+
(nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
|
| 171 |
+
(nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
|
| 172 |
+
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
|
| 173 |
+
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
|
| 174 |
+
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
|
| 175 |
+
(nn.Conv1d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU1d),
|
| 176 |
+
(nn.Conv2d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU2d),
|
| 177 |
+
(nn.Conv3d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU3d),
|
| 178 |
+
(nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
|
| 179 |
+
(nn.Linear, nn.ReLU): _sequential_wrapper2(nni.LinearReLU),
|
| 180 |
+
(nn.BatchNorm2d, nn.ReLU): _sequential_wrapper2(nni.BNReLU2d),
|
| 181 |
+
(nn.BatchNorm3d, nn.ReLU): _sequential_wrapper2(nni.BNReLU3d),
|
| 182 |
+
(nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn,
|
| 183 |
+
(nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn,
|
| 184 |
+
(nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn,
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
def get_fuser_method(op_list, additional_fuser_method_mapping=None):
|
| 188 |
+
"""Get fuser method for the given list of module types.
|
| 189 |
+
|
| 190 |
+
Get fuser method for the given list of module types,
|
| 191 |
+
return None if fuser method does not exist
|
| 192 |
+
"""
|
| 193 |
+
if additional_fuser_method_mapping is None:
|
| 194 |
+
additional_fuser_method_mapping = {}
|
| 195 |
+
all_mappings = get_combined_dict(_DEFAULT_OP_LIST_TO_FUSER_METHOD,
|
| 196 |
+
additional_fuser_method_mapping)
|
| 197 |
+
fuser_method = all_mappings.get(op_list, None)
|
| 198 |
+
assert fuser_method is not None, f"did not find fuser method for: {op_list} "
|
| 199 |
+
return fuser_method
|
| 200 |
+
|
| 201 |
+
def _reverse2(f):
|
| 202 |
+
def reversed(is_qat, x, y):
|
| 203 |
+
return f(is_qat, y, x)
|
| 204 |
+
return reversed
|
| 205 |
+
|
| 206 |
+
def _reverse3(f):
|
| 207 |
+
def reversed(is_qat, x, w):
|
| 208 |
+
y, z = w
|
| 209 |
+
return f(is_qat, z, y, x)
|
| 210 |
+
return reversed
|
| 211 |
+
|
| 212 |
+
def _get_valid_patterns(op_pattern):
|
| 213 |
+
"""Return a list of valid patterns generated from the op_pattern.
|
| 214 |
+
|
| 215 |
+
Returns a list of valid patterns generated from the op_pattern,
|
| 216 |
+
since MatchAllNode can match all types of nodes,
|
| 217 |
+
e.g. pattern (torch.nn.Conv2d, torch.add) should also be able to match keys like
|
| 218 |
+
(MatchAllNode, torch.add) and (torch.nn.Conv2d, MatchAllNode)
|
| 219 |
+
|
| 220 |
+
Example Input:
|
| 221 |
+
(torch.add, (torch.nn.ReLU, torch.nn.Conv2d))
|
| 222 |
+
|
| 223 |
+
Example Output:
|
| 224 |
+
[(torch.add, (torch.nn.ReLU, torch.nn.Conv2d)),
|
| 225 |
+
(torch.add, (torch.nn.ReLU, MatchAllNode)),
|
| 226 |
+
(torch.add, (MatchAllNode, torch.nn.Conv2d)),
|
| 227 |
+
(torch.add, (MatchAllNode, MatchAllNode)),
|
| 228 |
+
(MatchAllNode, (torch.nn.ReLU, torch.nn.Conv2d)),
|
| 229 |
+
(MatchAllNode, (torch.nn.ReLU, MatchAllNode)),
|
| 230 |
+
(MatchAllNode, (MatchAllNode, torch.nn.Conv2d)),
|
| 231 |
+
(MatchAllNode, (MatchAllNode, MatchAllNode)),
|
| 232 |
+
]
|
| 233 |
+
"""
|
| 234 |
+
result: List[Any]
|
| 235 |
+
if isinstance(op_pattern, (tuple, list)):
|
| 236 |
+
sub_combs = []
|
| 237 |
+
for sub_pattern in op_pattern:
|
| 238 |
+
sub_combs.append(_get_valid_patterns(sub_pattern))
|
| 239 |
+
result = list(itertools.product(*sub_combs))
|
| 240 |
+
else:
|
| 241 |
+
result = [op_pattern, MatchAllNode]
|
| 242 |
+
return result
|
| 243 |
+
|
| 244 |
+
def get_fuser_method_new(
|
| 245 |
+
op_pattern: Pattern,
|
| 246 |
+
fuser_method_mapping: Dict[Pattern, Union[nn.Sequential, Callable]]):
|
| 247 |
+
"""Get fuser method.
|
| 248 |
+
|
| 249 |
+
This will be made default after we deprecate the get_fuser_method
|
| 250 |
+
Would like to implement this first and have a separate PR for deprecation
|
| 251 |
+
"""
|
| 252 |
+
op_patterns = _get_valid_patterns(op_pattern)
|
| 253 |
+
fuser_method = None
|
| 254 |
+
for op_pattern in op_patterns:
|
| 255 |
+
fuser_method = fuser_method_mapping.get(op_pattern, None)
|
| 256 |
+
if fuser_method is not None:
|
| 257 |
+
break
|
| 258 |
+
assert fuser_method is not None, f"did not find fuser method for: {op_pattern} "
|
| 259 |
+
return fuser_method
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .prepare import prepare
|
| 2 |
+
from .convert import convert
|
| 3 |
+
from .fuse import fuse
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-311.pyc
ADDED
|
Binary file (46.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-311.pyc
ADDED
|
Binary file (40.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-311.pyc
ADDED
|
Binary file (24.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-311.pyc
ADDED
|
Binary file (7.17 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-311.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-311.pyc
ADDED
|
Binary file (1.02 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-311.pyc
ADDED
|
Binary file (8.89 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-311.pyc
ADDED
|
Binary file (4.54 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-311.pyc
ADDED
|
Binary file (65.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-311.pyc
ADDED
|
Binary file (14.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-311.pyc
ADDED
|
Binary file (9.98 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_equalize.py
ADDED
|
@@ -0,0 +1,820 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.ao.nn.intrinsic as nni
|
| 10 |
+
from torch.fx import GraphModule
|
| 11 |
+
from torch.fx.graph import Node
|
| 12 |
+
from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
|
| 13 |
+
|
| 14 |
+
from ..observer import _with_args, ObserverBase, PerChannelMinMaxObserver
|
| 15 |
+
from ..utils import _parent_name, check_min_max_valid
|
| 16 |
+
|
| 17 |
+
from .utils import (
|
| 18 |
+
get_new_attr_name_with_prefix,
|
| 19 |
+
maybe_get_next_module,
|
| 20 |
+
node_arg_is_weight,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
CUSTOM_MODULE_SUPP_LIST: List[Any] = []
|
| 24 |
+
|
| 25 |
+
def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor:
|
| 26 |
+
"""Reshapes the scale so that we can multiply it to the input by the given axis.
|
| 27 |
+
"""
|
| 28 |
+
new_shape = [1] * input.ndim
|
| 29 |
+
new_shape[axis] = input.size(axis)
|
| 30 |
+
return scale.view(new_shape)
|
| 31 |
+
|
| 32 |
+
qsheme_mapping_per_tensor_to_per_channel = {
|
| 33 |
+
torch.per_tensor_affine: torch.per_channel_affine,
|
| 34 |
+
torch.per_tensor_symmetric: torch.per_channel_symmetric,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class _InputEqualizationObserver(nn.Module):
|
| 39 |
+
r"""Observer for tracking the running min/max values of input columns, and
|
| 40 |
+
computing the quantization parameters for the overall min/max input values.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
dtype: Quantized data type
|
| 44 |
+
qscheme: Quantization scheme
|
| 45 |
+
quant_min: Minimum quantization value. If unspecified, it will
|
| 46 |
+
follow the 8-bit setup.
|
| 47 |
+
quant_max: Maximum quantization value. If unspecified, it will
|
| 48 |
+
follow the 8-bit setup.
|
| 49 |
+
|
| 50 |
+
The running minimum/maximum :math:`x_\text{min/max}` are computed in the
|
| 51 |
+
same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`,
|
| 52 |
+
with the difference that the running min/max values are stored per column.
|
| 53 |
+
This observer is intended to be used along with a WeightEqualizationObserver
|
| 54 |
+
to calculate the equalization scale.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
|
| 58 |
+
quant_min=None, quant_max=None, factory_kwargs=None) -> None:
|
| 59 |
+
super().__init__()
|
| 60 |
+
|
| 61 |
+
if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
|
| 62 |
+
raise TypeError("Input qscheme must be per-tensor")
|
| 63 |
+
|
| 64 |
+
self.dtype = dtype
|
| 65 |
+
self.qscheme = qscheme
|
| 66 |
+
|
| 67 |
+
per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme]
|
| 68 |
+
self.input_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
|
| 69 |
+
qscheme=per_channel_qscheme,
|
| 70 |
+
quant_min=quant_min,
|
| 71 |
+
quant_max=quant_max,
|
| 72 |
+
factory_kwargs=factory_kwargs)
|
| 73 |
+
|
| 74 |
+
self.equalization_scale = torch.tensor(1)
|
| 75 |
+
self.equalization_shape: List[int] = []
|
| 76 |
+
|
| 77 |
+
def forward(self, x_orig):
|
| 78 |
+
if not (x_orig.ndim >= 2 and x_orig.ndim <= 5):
|
| 79 |
+
raise ValueError("InputEqualizationObserver only supports Linear and Conv layers")
|
| 80 |
+
|
| 81 |
+
# Calculate the shape needed to reshape the equalization scale later (needed for Conv layers)
|
| 82 |
+
self.equalization_shape = [1] * x_orig.ndim
|
| 83 |
+
self.equalization_shape[1] = x_orig.size(1)
|
| 84 |
+
|
| 85 |
+
return self.input_obs(x_orig)
|
| 86 |
+
|
| 87 |
+
def get_input_minmax(self):
|
| 88 |
+
return (self.input_obs.min_val, self.input_obs.max_val)
|
| 89 |
+
|
| 90 |
+
def set_equalization_scale(self, equalization_scale):
|
| 91 |
+
# Reshape the equalization scale along axis=1 so that it can be
|
| 92 |
+
# multiplied with the input along axis=1
|
| 93 |
+
if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
|
| 94 |
+
return
|
| 95 |
+
self.equalization_scale = torch.reshape(equalization_scale, self.equalization_shape)
|
| 96 |
+
|
| 97 |
+
def calculate_scaled_minmax(self):
|
| 98 |
+
r""" Returns the scaled min/max inputs
|
| 99 |
+
"""
|
| 100 |
+
if self.equalization_scale.nelement() == 1 and self.equalization_scale == torch.tensor(1):
|
| 101 |
+
warnings.warn(
|
| 102 |
+
"Must call calculate_equalization_scale before calling calculate_scaled_minmax. " +
|
| 103 |
+
"Will not scale the next quantization observer."
|
| 104 |
+
)
|
| 105 |
+
return None, None
|
| 106 |
+
|
| 107 |
+
# Calculate qparams for the scaled min/max inputs
|
| 108 |
+
# Scale the input by the equalization scale located at the same column
|
| 109 |
+
# index
|
| 110 |
+
(min_inputs, max_inputs) = self.get_input_minmax()
|
| 111 |
+
equalization_scale_reshaped = reshape_scale(self.equalization_scale, 0, min_inputs)
|
| 112 |
+
min_input_scaled = torch.min(torch.mul(min_inputs, equalization_scale_reshaped))
|
| 113 |
+
max_input_scaled = torch.max(torch.mul(max_inputs, equalization_scale_reshaped))
|
| 114 |
+
|
| 115 |
+
return min_input_scaled, max_input_scaled
|
| 116 |
+
|
| 117 |
+
with_args = classmethod(_with_args)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class _WeightEqualizationObserver(nn.Module):
|
| 121 |
+
r"""Observer for tracking the running min/max values of weight columns and
|
| 122 |
+
rows, and computing the quantization parameters for the weight rows.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
dtype: Quantized data type
|
| 126 |
+
qscheme: Quantization scheme
|
| 127 |
+
quant_min: Minimum quantization value. If unspecified, it will
|
| 128 |
+
follow the 8-bit setup.
|
| 129 |
+
quant_max: Maximum quantization value. If unspecified, it will
|
| 130 |
+
follow the 8-bit setup.
|
| 131 |
+
|
| 132 |
+
This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used
|
| 133 |
+
to record the running minimum and maximum of columns of incoming weight
|
| 134 |
+
tensors. This observer is intended to be used along with an
|
| 135 |
+
InputEqualizationObserver to calculate the equalization scale.
|
| 136 |
+
|
| 137 |
+
The running minimum/maximum :math:`w_\text{min/max}` are computed in the
|
| 138 |
+
same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None,
|
| 142 |
+
quant_max=None, factory_kwargs=None) -> None:
|
| 143 |
+
super().__init__()
|
| 144 |
+
|
| 145 |
+
self.dtype = dtype
|
| 146 |
+
self.qscheme = qscheme
|
| 147 |
+
self.ch_axis = 1
|
| 148 |
+
|
| 149 |
+
per_channel_qscheme = qscheme
|
| 150 |
+
if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
|
| 151 |
+
per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme]
|
| 152 |
+
self.weight_col_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
|
| 153 |
+
qscheme=per_channel_qscheme,
|
| 154 |
+
quant_min=quant_min,
|
| 155 |
+
quant_max=quant_max,
|
| 156 |
+
factory_kwargs=factory_kwargs)
|
| 157 |
+
|
| 158 |
+
self.equalization_scale = torch.tensor(1)
|
| 159 |
+
|
| 160 |
+
def forward(self, w_orig):
|
| 161 |
+
if not (w_orig.ndim >= 2 and w_orig.ndim <= 5):
|
| 162 |
+
raise ValueError("InputEqualizationObserver only supports Linear and Conv layers")
|
| 163 |
+
|
| 164 |
+
return self.weight_col_obs(w_orig)
|
| 165 |
+
|
| 166 |
+
def get_weight_col_minmax(self):
|
| 167 |
+
return (self.weight_col_obs.min_val, self.weight_col_obs.max_val)
|
| 168 |
+
|
| 169 |
+
def set_equalization_scale(self, equalization_scale):
|
| 170 |
+
self.equalization_scale = equalization_scale
|
| 171 |
+
|
| 172 |
+
with_args = classmethod(_with_args)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def calculate_equalization_scale(input_obs: _InputEqualizationObserver,
|
| 176 |
+
weight_obs: _WeightEqualizationObserver) -> torch.Tensor:
|
| 177 |
+
r""" Calculates the equalization scale and sets the equalization_scale value
|
| 178 |
+
in the observers.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
input_obs: Observer that tracks the ranges for the input columns
|
| 182 |
+
weight_obs: Observer that tracks the ranges for the weight columns
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
(min_inputs, max_inputs) = input_obs.get_input_minmax()
|
| 186 |
+
(min_weights, max_weights) = weight_obs.get_weight_col_minmax()
|
| 187 |
+
|
| 188 |
+
if not (check_min_max_valid(min_inputs, max_inputs) and check_min_max_valid(min_weights, max_weights)):
|
| 189 |
+
warnings.warn(
|
| 190 |
+
"Must run observer before calling calculate_equalization_scale. " +
|
| 191 |
+
"Returning default equalization scale torch.tensor(1)."
|
| 192 |
+
)
|
| 193 |
+
return torch.tensor(1)
|
| 194 |
+
|
| 195 |
+
if not (min_inputs.shape == min_weights.shape):
|
| 196 |
+
raise ValueError(
|
| 197 |
+
"Input and Weight must have the same column dimension. " +
|
| 198 |
+
f"Found {min_inputs.shape} and {min_weights.shape} shapes instead."
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
equalization_scale = torch.sqrt((max_weights - min_weights) / (max_inputs - min_inputs))
|
| 202 |
+
# Replace all 'inf', 'nan', 0's with 1s to prevent errors
|
| 203 |
+
equalization_scale[equalization_scale == 0.] = 1
|
| 204 |
+
equalization_scale = torch.nan_to_num(equalization_scale, nan=1, posinf=1, neginf=1)
|
| 205 |
+
return equalization_scale
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class EqualizationQConfig(namedtuple('EqualizationQConfig', ['input_activation', 'weight'])):
|
| 209 |
+
"""
|
| 210 |
+
Describes how to quantize a layer or a part of the network specifically for
|
| 211 |
+
input-weight equalization by providing settings (observer classes) for
|
| 212 |
+
inputs, outputs, and weights.
|
| 213 |
+
|
| 214 |
+
Note that EqualizationQConfig needs to contain observer **classes** (like
|
| 215 |
+
MinMaxObserver) or a callable that returns instances on invocation, not the
|
| 216 |
+
concrete observer instances themselves.
|
| 217 |
+
Quantization function will instantiate observers multiple times for each of
|
| 218 |
+
the layers.
|
| 219 |
+
|
| 220 |
+
Observer classes have usually reasonable default arguments, but they can be
|
| 221 |
+
overwritten with `with_args` method (that behaves like functools.partial):
|
| 222 |
+
|
| 223 |
+
my_qconfig = EqualizationQConfig(input_activation=_InputEqualizationObserver.with_args(dtype=torch.qint8),
|
| 224 |
+
weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8))
|
| 225 |
+
"""
|
| 226 |
+
def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity):
|
| 227 |
+
if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module):
|
| 228 |
+
raise ValueError("EqualizationQConfig received observer instance, please pass observer class instead. " +
|
| 229 |
+
"Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
|
| 230 |
+
self = super().__new__(cls, input_activation, weight)
|
| 231 |
+
return self
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
input_equalization_observer = _InputEqualizationObserver.with_args(
|
| 235 |
+
dtype=torch.quint8, qscheme=torch.per_tensor_symmetric)
|
| 236 |
+
weight_equalization_observer = _WeightEqualizationObserver.with_args(
|
| 237 |
+
dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
|
| 238 |
+
default_equalization_qconfig = EqualizationQConfig(input_activation=input_equalization_observer,
|
| 239 |
+
weight=weight_equalization_observer)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def fused_module_supports_equalization(module) -> bool:
|
| 243 |
+
""" Checks if the fused node supports equalization. """
|
| 244 |
+
return type(module) in [nni.LinearReLU, nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d]
|
| 245 |
+
|
| 246 |
+
def nn_module_supports_equalization(module) -> bool:
|
| 247 |
+
""" Checks if the torch.nn node supports equalization. """
|
| 248 |
+
return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]
|
| 249 |
+
|
| 250 |
+
def custom_module_supports_equalization(module) -> bool:
|
| 251 |
+
""" Checks if the custom node supports equalization. """
|
| 252 |
+
return type(module) in CUSTOM_MODULE_SUPP_LIST
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def node_supports_equalization(node: Node, modules) -> bool:
|
| 256 |
+
""" Checks if the current node supports equalization
|
| 257 |
+
Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers
|
| 258 |
+
"""
|
| 259 |
+
if node.op == 'call_module':
|
| 260 |
+
return nn_module_supports_equalization(modules[str(node.target)]) or \
|
| 261 |
+
fused_module_supports_equalization(modules[str(node.target)]) or \
|
| 262 |
+
custom_module_supports_equalization(modules[str(node.target)])
|
| 263 |
+
elif node.op == 'call_function':
|
| 264 |
+
return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d]
|
| 265 |
+
return False
|
| 266 |
+
|
| 267 |
+
def is_equalization_observer(observer: nn.Module) -> bool:
|
| 268 |
+
return (isinstance(observer, (_InputEqualizationObserver, _WeightEqualizationObserver)))
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
###############################################################################
|
| 272 |
+
# Functions for equalization during convert #
|
| 273 |
+
###############################################################################
|
| 274 |
+
|
| 275 |
+
def get_op_node_and_weight_eq_obs(
|
| 276 |
+
input_eq_obs_node: Node,
|
| 277 |
+
model: GraphModule,
|
| 278 |
+
modules: Dict[str, nn.Module]
|
| 279 |
+
) -> Tuple[Optional[Node], Optional[_WeightEqualizationObserver]]:
|
| 280 |
+
""" Gets the following weight equalization observer. There should always
|
| 281 |
+
exist a weight equalization observer after an input equalization observer.
|
| 282 |
+
|
| 283 |
+
Returns the operation node that follows the input equalization observer node
|
| 284 |
+
and the weight equalization observer
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
# Find the op node that comes directly after the input equalization observer
|
| 288 |
+
op_node = None
|
| 289 |
+
for user in input_eq_obs_node.users.keys():
|
| 290 |
+
if node_supports_equalization(user, modules):
|
| 291 |
+
op_node = user
|
| 292 |
+
break
|
| 293 |
+
|
| 294 |
+
assert op_node is not None
|
| 295 |
+
if op_node.op == 'call_module':
|
| 296 |
+
# If the op_node is a nn.Linear layer, then it must have a
|
| 297 |
+
# WeightEqualizationObserver configuration
|
| 298 |
+
maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(model, "equalization_node_name_to_qconfig")
|
| 299 |
+
assert maybe_equalization_node_name_to_config is not None
|
| 300 |
+
equalization_node_name_to_qconfig: Dict[str, Any] = maybe_equalization_node_name_to_config # type: ignore[assignment]
|
| 301 |
+
assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None
|
| 302 |
+
weight_eq_obs = equalization_node_name_to_qconfig.get(op_node.name, None).weight()
|
| 303 |
+
|
| 304 |
+
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
|
| 305 |
+
return op_node, weight_eq_obs
|
| 306 |
+
|
| 307 |
+
elif op_node.op == 'call_function':
|
| 308 |
+
weight_node = maybe_get_weight_eq_obs_node(op_node, modules)
|
| 309 |
+
if weight_node is not None:
|
| 310 |
+
weight_eq_obs = modules[str(weight_node.target)]
|
| 311 |
+
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
|
| 312 |
+
return op_node, weight_eq_obs
|
| 313 |
+
|
| 314 |
+
return None, None
|
| 315 |
+
|
| 316 |
+
def maybe_get_weight_eq_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> Optional[Node]:
|
| 317 |
+
""" Gets the weight equalization observer node if it exists.
|
| 318 |
+
"""
|
| 319 |
+
assert op_node.op == 'call_function'
|
| 320 |
+
for node_arg in op_node.args:
|
| 321 |
+
if node_arg_is_weight(op_node, node_arg):
|
| 322 |
+
assert (isinstance(node_arg, Node) and node_arg.op == 'call_module' and
|
| 323 |
+
isinstance(modules[str(node_arg.target)], _WeightEqualizationObserver))
|
| 324 |
+
return node_arg
|
| 325 |
+
return None
|
| 326 |
+
|
| 327 |
+
def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Optional[_InputEqualizationObserver]:
|
| 328 |
+
""" Gets the following input equalization observer if it exists.
|
| 329 |
+
|
| 330 |
+
For example, in the case of connecting linear layers:
|
| 331 |
+
x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
|
| 332 |
+
If the node being passed in is the linear1 node, then we want to return eq_obs2,
|
| 333 |
+
the following equalization observer for linear2.
|
| 334 |
+
|
| 335 |
+
However, if there are no connecting layers:
|
| 336 |
+
x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add
|
| 337 |
+
Then we want to return None.
|
| 338 |
+
|
| 339 |
+
In the case of an unfused linear-relu layer with a connecting linear layer:
|
| 340 |
+
linear1 -> relu -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
|
| 341 |
+
Since it is unfused, we want to skip over the relu layer and return eq_obs2,
|
| 342 |
+
the following equalization observer for linear2.
|
| 343 |
+
"""
|
| 344 |
+
|
| 345 |
+
assert node_supports_equalization(node, modules)
|
| 346 |
+
|
| 347 |
+
# Locate the following nn.ReLU or F.relu node if it exists
|
| 348 |
+
maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU)
|
| 349 |
+
if maybe_relu_node is None:
|
| 350 |
+
maybe_relu_node = maybe_get_next_module(node, modules, target_functional_type=F.relu)
|
| 351 |
+
|
| 352 |
+
# Locate the following output observer if it exists.
|
| 353 |
+
# We will skip the relu node if it exists.
|
| 354 |
+
maybe_obs_node = (
|
| 355 |
+
maybe_get_next_module(node, modules, ObserverBase)
|
| 356 |
+
if maybe_relu_node is None
|
| 357 |
+
else maybe_get_next_module(maybe_relu_node, modules, ObserverBase)
|
| 358 |
+
)
|
| 359 |
+
if maybe_obs_node is None:
|
| 360 |
+
return None
|
| 361 |
+
|
| 362 |
+
maybe_eq_obs_node = maybe_get_next_module(maybe_obs_node, modules, _InputEqualizationObserver)
|
| 363 |
+
if maybe_eq_obs_node is None:
|
| 364 |
+
return None
|
| 365 |
+
|
| 366 |
+
maybe_eq_obs = modules[str(maybe_eq_obs_node)]
|
| 367 |
+
assert isinstance(maybe_eq_obs, _InputEqualizationObserver)
|
| 368 |
+
return maybe_eq_obs
|
| 369 |
+
|
| 370 |
+
def maybe_get_next_equalization_scale(node: Node, modules: Dict[str, nn.Module]) -> Optional[torch.Tensor]:
|
| 371 |
+
""" If the next next node is an InputEqualizationObserver then we want to
|
| 372 |
+
return its equalization scale, else we return 1
|
| 373 |
+
|
| 374 |
+
This is used in the case where there are two connecting linear layers:
|
| 375 |
+
linear1 -> LinearOutObs -> InputEqObs -> linear2
|
| 376 |
+
In this case, the node given is linear1 and we want to locate the InputEqObs.
|
| 377 |
+
"""
|
| 378 |
+
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
|
| 379 |
+
if next_inp_eq_obs:
|
| 380 |
+
if next_inp_eq_obs.equalization_scale.nelement() == 1 and \
|
| 381 |
+
next_inp_eq_obs.equalization_scale == torch.tensor(1):
|
| 382 |
+
return None
|
| 383 |
+
return next_inp_eq_obs.equalization_scale
|
| 384 |
+
return None
|
| 385 |
+
|
| 386 |
+
def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None:
|
| 387 |
+
""" Scales the following input quantization observer's min/max values by
|
| 388 |
+
updating the values with the scaled min/max values calculated by the input
|
| 389 |
+
equalization observer
|
| 390 |
+
"""
|
| 391 |
+
input_eq_obs = modules[str(node.target)]
|
| 392 |
+
assert isinstance(input_eq_obs, _InputEqualizationObserver)
|
| 393 |
+
|
| 394 |
+
input_quant_obs_node = node.args[0]
|
| 395 |
+
assert isinstance(input_quant_obs_node, Node)
|
| 396 |
+
|
| 397 |
+
input_quant_obs = modules[str(input_quant_obs_node.target)]
|
| 398 |
+
if not isinstance(input_quant_obs, ObserverBase):
|
| 399 |
+
return
|
| 400 |
+
|
| 401 |
+
min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax()
|
| 402 |
+
if min_input_scaled is None and max_input_scaled is None:
|
| 403 |
+
return
|
| 404 |
+
input_quant_obs.min_val = min_input_scaled
|
| 405 |
+
input_quant_obs.max_val = max_input_scaled
|
| 406 |
+
|
| 407 |
+
def scale_weight_node(
|
| 408 |
+
node: Node,
|
| 409 |
+
modules: Dict[str, nn.Module],
|
| 410 |
+
equalization_scale: torch.Tensor,
|
| 411 |
+
next_equalization_scale: Optional[torch.Tensor],
|
| 412 |
+
) -> None:
|
| 413 |
+
""" Scale the weights for input-weight equalization by multiplying the
|
| 414 |
+
weight by 1/equalization_scale and next_equalization_scale
|
| 415 |
+
|
| 416 |
+
Args:
|
| 417 |
+
node: Current node whose weights we want to scale
|
| 418 |
+
equalization_scale: Current node's calculated equalization scale
|
| 419 |
+
next_equalization_scale: Next node's calculated equalization scale if
|
| 420 |
+
the following node needs to be equalized, 1 otherwise
|
| 421 |
+
"""
|
| 422 |
+
if equalization_scale is None:
|
| 423 |
+
return
|
| 424 |
+
|
| 425 |
+
if fused_module_supports_equalization(modules[str(node.target)]):
|
| 426 |
+
op_module = modules[str(node.target)][0] # type: ignore[index]
|
| 427 |
+
else:
|
| 428 |
+
op_module = modules[str(node.target)]
|
| 429 |
+
assert nn_module_supports_equalization(op_module) or custom_module_supports_equalization(op_module)
|
| 430 |
+
|
| 431 |
+
# Scale the weights for input-weight equalization
|
| 432 |
+
# If the following layer needs to be equalized then we will multiply its scale
|
| 433 |
+
weight = op_module.weight
|
| 434 |
+
assert isinstance(weight, torch.Tensor)
|
| 435 |
+
|
| 436 |
+
# Scale the weights by the reciprocal of the equalization scale
|
| 437 |
+
# Reshape the equalization scale so that we can multiply it to the weight along axis=1
|
| 438 |
+
equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight)
|
| 439 |
+
scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped))
|
| 440 |
+
|
| 441 |
+
if next_equalization_scale is None:
|
| 442 |
+
op_module.weight = nn.Parameter(scaled_weight)
|
| 443 |
+
return
|
| 444 |
+
|
| 445 |
+
# Multiply the weights row wise by the next equalization scale
|
| 446 |
+
# Reshape the equalization scale so that we can multiply it to the weight along axis=0
|
| 447 |
+
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, weight)
|
| 448 |
+
scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
|
| 449 |
+
|
| 450 |
+
op_module.weight = nn.Parameter(scaled_weight)
|
| 451 |
+
|
| 452 |
+
# Multiply the bias element wise by the next equalization scale
|
| 453 |
+
bias = op_module.bias
|
| 454 |
+
if bias is None:
|
| 455 |
+
return
|
| 456 |
+
assert isinstance(bias, torch.Tensor)
|
| 457 |
+
|
| 458 |
+
# Reshape the equalization scale so that we can multiply it element-wise to the bias
|
| 459 |
+
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
|
| 460 |
+
scaled_bias = torch.mul(bias, next_equalization_scale_reshaped)
|
| 461 |
+
op_module.bias = nn.Parameter(scaled_bias)
|
| 462 |
+
|
| 463 |
+
def scale_weight_functional(
|
| 464 |
+
op_node: Node,
|
| 465 |
+
model: GraphModule,
|
| 466 |
+
modules: Dict[str, nn.Module],
|
| 467 |
+
equalization_scale: torch.Tensor,
|
| 468 |
+
next_equalization_scale: Optional[torch.Tensor],
|
| 469 |
+
) -> None:
|
| 470 |
+
""" Scales the weight value for functional layers
|
| 471 |
+
"""
|
| 472 |
+
if equalization_scale is None:
|
| 473 |
+
return
|
| 474 |
+
|
| 475 |
+
# From the given op_node, the path looks like:
|
| 476 |
+
# get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node
|
| 477 |
+
# So we want to trace back from the op_node to get the equalization observer
|
| 478 |
+
# node, then the quantization observer node, and then finally the weight
|
| 479 |
+
# node which contains the weight values.
|
| 480 |
+
|
| 481 |
+
# Get the equalization observer node
|
| 482 |
+
weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules)
|
| 483 |
+
if weight_eq_obs_node is None:
|
| 484 |
+
return
|
| 485 |
+
|
| 486 |
+
# Get the quantization observer node
|
| 487 |
+
weight_quant_obs_node = weight_eq_obs_node.args[0]
|
| 488 |
+
if weight_quant_obs_node is None:
|
| 489 |
+
return
|
| 490 |
+
assert (isinstance(weight_quant_obs_node, Node) and
|
| 491 |
+
isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase))
|
| 492 |
+
|
| 493 |
+
# Get the get_attr(weight) node
|
| 494 |
+
weight_node = weight_quant_obs_node.args[0]
|
| 495 |
+
if weight_node is None:
|
| 496 |
+
return
|
| 497 |
+
assert isinstance(weight_node, Node) and weight_node.op == 'get_attr'
|
| 498 |
+
|
| 499 |
+
weight_parent_name, weight_name = _parent_name(weight_node.target)
|
| 500 |
+
weight = getattr(modules[weight_parent_name], weight_name)
|
| 501 |
+
|
| 502 |
+
# Scale the weights for input-weight equalization
|
| 503 |
+
# If the following layer needs to be equalized then we will multiply its scale
|
| 504 |
+
# Reshape the equalization scale so that we can multiply it to the weight along axis=1
|
| 505 |
+
equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight)
|
| 506 |
+
scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped))
|
| 507 |
+
|
| 508 |
+
if next_equalization_scale is None:
|
| 509 |
+
setattr(modules[weight_parent_name], weight_name, scaled_weight)
|
| 510 |
+
return
|
| 511 |
+
|
| 512 |
+
# Multiply the weights row wise by the next equalization scale
|
| 513 |
+
# Reshape the equalization scale so that we can multiply it to the weight along axis=1
|
| 514 |
+
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, scaled_weight)
|
| 515 |
+
scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
|
| 516 |
+
|
| 517 |
+
setattr(modules[weight_parent_name], weight_name, scaled_weight)
|
| 518 |
+
assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight)
|
| 519 |
+
|
| 520 |
+
# Multiply the bias element wise by the next equalization scale
|
| 521 |
+
bias_node = None
|
| 522 |
+
for node in op_node.args:
|
| 523 |
+
# Find the node containing the weight values
|
| 524 |
+
if isinstance(node, Node) and node.op == 'get_attr' and 'bias' in node.name:
|
| 525 |
+
bias_node = node
|
| 526 |
+
break
|
| 527 |
+
if bias_node is None:
|
| 528 |
+
return
|
| 529 |
+
|
| 530 |
+
bias_parent_name, bias_name = _parent_name(bias_node.target)
|
| 531 |
+
bias = getattr(modules[bias_parent_name], bias_name)
|
| 532 |
+
|
| 533 |
+
# Reshape the equalization scale so that we can multiply it element-wise to the bias
|
| 534 |
+
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
|
| 535 |
+
scaled_bias = torch.mul(bias, next_equalization_scale_reshaped)
|
| 536 |
+
setattr(modules[bias_parent_name], bias_name, scaled_bias)
|
| 537 |
+
|
| 538 |
+
def clear_weight_quant_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> None:
|
| 539 |
+
""" Given the operation node, we want find the corresponding quantization
|
| 540 |
+
observer and reset its min/max values
|
| 541 |
+
"""
|
| 542 |
+
weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules)
|
| 543 |
+
if weight_eq_obs_node is None:
|
| 544 |
+
return
|
| 545 |
+
|
| 546 |
+
weight_quant_obs_node = weight_eq_obs_node.args[0]
|
| 547 |
+
if weight_quant_obs_node is None:
|
| 548 |
+
return
|
| 549 |
+
assert isinstance(weight_quant_obs_node, Node)
|
| 550 |
+
|
| 551 |
+
weight_quant_obs = modules[str(weight_quant_obs_node.target)]
|
| 552 |
+
assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)
|
| 553 |
+
weight_quant_obs.reset_min_max_vals() # type: ignore[operator]
|
| 554 |
+
|
| 555 |
+
def remove_node(model: GraphModule, node: Node, prev_node: Node):
|
| 556 |
+
""" Removes the given node from the model by replacing all of its users with
|
| 557 |
+
the given previous node
|
| 558 |
+
"""
|
| 559 |
+
# For all of the current node's users, replace the current node with
|
| 560 |
+
# the input quantization observer node
|
| 561 |
+
orig_users = list(node.users.keys())
|
| 562 |
+
for user_node in orig_users:
|
| 563 |
+
user_node.replace_input_with(node, prev_node)
|
| 564 |
+
|
| 565 |
+
# Erase the InputEqualizationObserver node
|
| 566 |
+
model.graph.erase_node(node)
|
| 567 |
+
|
| 568 |
+
def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module]) -> Dict[str, _WeightEqualizationObserver]:
|
| 569 |
+
""" Update all of the observer's equalization scale. For each
|
| 570 |
+
InputEqualizationObserver, we will find the location of the next
|
| 571 |
+
WeightEqualizationObserver, create it, and calculate the equalization scale
|
| 572 |
+
based on the two observers.
|
| 573 |
+
|
| 574 |
+
We will then return a dictionary mapping operation node names to
|
| 575 |
+
the corresponding WeightEqualizationObservers for that operation.
|
| 576 |
+
"""
|
| 577 |
+
weight_eq_obs_dict = {}
|
| 578 |
+
for node in model.graph.nodes:
|
| 579 |
+
if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
|
| 580 |
+
input_eq_obs = modules[node.target]
|
| 581 |
+
assert isinstance(input_eq_obs, _InputEqualizationObserver)
|
| 582 |
+
op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
|
| 583 |
+
|
| 584 |
+
if op_node is None or weight_eq_obs is None:
|
| 585 |
+
continue
|
| 586 |
+
|
| 587 |
+
if op_node.op == 'call_module':
|
| 588 |
+
# Calibrate the weight equalization observer since it has just
|
| 589 |
+
# been created
|
| 590 |
+
if fused_module_supports_equalization(modules[str(op_node.target)]):
|
| 591 |
+
module = modules[str(op_node.target)][0] # type: ignore[index]
|
| 592 |
+
assert nn_module_supports_equalization(module)
|
| 593 |
+
weight_eq_obs(module.weight)
|
| 594 |
+
else:
|
| 595 |
+
weight_eq_obs(modules[str(op_node.target)].weight)
|
| 596 |
+
|
| 597 |
+
# Calculate and set the equalization scale values
|
| 598 |
+
equalization_scale = calculate_equalization_scale(input_eq_obs, weight_eq_obs)
|
| 599 |
+
input_eq_obs.set_equalization_scale(equalization_scale)
|
| 600 |
+
weight_eq_obs.set_equalization_scale(equalization_scale)
|
| 601 |
+
|
| 602 |
+
weight_eq_obs_dict[op_node.name] = weight_eq_obs
|
| 603 |
+
|
| 604 |
+
return weight_eq_obs_dict
|
| 605 |
+
|
| 606 |
+
def convert_eq_obs(
|
| 607 |
+
model: GraphModule,
|
| 608 |
+
modules: Dict[str, nn.Module],
|
| 609 |
+
weight_eq_obs_dict: Dict[str, _WeightEqualizationObserver],
|
| 610 |
+
) -> None:
|
| 611 |
+
""" Converts the equalization operations and updates the other nodes in the
|
| 612 |
+
following way:
|
| 613 |
+
- Removes the input equalization observers and inserts a mul operator
|
| 614 |
+
along with an equalization scale node wherever applicable (we do not
|
| 615 |
+
want to insert a mul operator between connecting linear layers).
|
| 616 |
+
- Updates the input quantization observers with the scaled input min/max
|
| 617 |
+
values.
|
| 618 |
+
- Scales the weights by the current and next equalization scales.
|
| 619 |
+
- Removes the weight equalization observer node if it exists.
|
| 620 |
+
|
| 621 |
+
Before (after prepare):
|
| 622 |
+
weight values
|
| 623 |
+
|
|
| 624 |
+
WeightQuantObs
|
| 625 |
+
|
|
| 626 |
+
WeightEqObs
|
| 627 |
+
|
|
| 628 |
+
x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs
|
| 629 |
+
|
| 630 |
+
After this function:
|
| 631 |
+
scaled weight values
|
| 632 |
+
|
|
| 633 |
+
equalization scale WeightQuantObs
|
| 634 |
+
| |
|
| 635 |
+
x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs
|
| 636 |
+
|
| 637 |
+
After convert:
|
| 638 |
+
equalization scale scaled weight values
|
| 639 |
+
| |
|
| 640 |
+
x -> mul -> quantize_per_tensor -> quantized::linear
|
| 641 |
+
|
| 642 |
+
Note that although the equalization observer appeared after the quantization
|
| 643 |
+
observer after prepare_fx, the mul node appears before the quantization node
|
| 644 |
+
after convert_fx. This is because placing the equalization observer after
|
| 645 |
+
the quantization observer in prepare_fx would allow us to keep the invariant
|
| 646 |
+
that the graph before the current node inserts its observers is not
|
| 647 |
+
modified.
|
| 648 |
+
|
| 649 |
+
Having the equalization observer before the quantization observer would also
|
| 650 |
+
cause some inconsistences between the ordering of the quantization and
|
| 651 |
+
equalization observers.
|
| 652 |
+
For example, a single linear layer would look like:
|
| 653 |
+
x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1
|
| 654 |
+
But between two connected linear layers, it would look like:
|
| 655 |
+
linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2
|
| 656 |
+
"""
|
| 657 |
+
for node in model.graph.nodes:
|
| 658 |
+
if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
|
| 659 |
+
inp_quant_obs_node = node.args[0]
|
| 660 |
+
prev_node = inp_quant_obs_node.args[0]
|
| 661 |
+
|
| 662 |
+
# If the previous node is a layer that needs to be equalized, then
|
| 663 |
+
# we will remove the current node because we do not need to add any
|
| 664 |
+
# equalization nodes between two layers that need to be equalized
|
| 665 |
+
|
| 666 |
+
# Before: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> input_eq_obs2 (node) -> linear2
|
| 667 |
+
# After: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> linear2
|
| 668 |
+
if node_supports_equalization(prev_node, modules) or "relu" in prev_node.name:
|
| 669 |
+
remove_node(model, node, inp_quant_obs_node)
|
| 670 |
+
continue
|
| 671 |
+
|
| 672 |
+
# Update the following input quantization observer's min/max values
|
| 673 |
+
scale_input_observer(node, modules)
|
| 674 |
+
|
| 675 |
+
# Remove the InputEqualization node and add a mul operator before
|
| 676 |
+
# the quantization observer node that appears before the equalization node
|
| 677 |
+
# Before: x -> input_quant_obs -> input_eq_obs -> linear
|
| 678 |
+
# After: x -> mul -> input_quant_obs -> linear
|
| 679 |
+
|
| 680 |
+
# Create a node containing the equalization scale
|
| 681 |
+
with model.graph.inserting_before(inp_quant_obs_node):
|
| 682 |
+
get_new_eq_scale_name = get_new_attr_name_with_prefix(prev_node.name + '_equalization_scale')
|
| 683 |
+
name = get_new_eq_scale_name(modules)
|
| 684 |
+
setattr(model, name, modules[node.target].equalization_scale)
|
| 685 |
+
eq_scale_node = model.graph.create_node('get_attr', name)
|
| 686 |
+
|
| 687 |
+
# Create a node multiplying the input with the equalization scale
|
| 688 |
+
with model.graph.inserting_after(eq_scale_node):
|
| 689 |
+
inputs = (prev_node, eq_scale_node)
|
| 690 |
+
mul_node = model.graph.create_node("call_function", torch.mul, inputs)
|
| 691 |
+
|
| 692 |
+
# Set the mul nod to be the input_quant_obs_node's input instead of
|
| 693 |
+
# the previous node
|
| 694 |
+
inp_quant_obs_node.replace_input_with(prev_node, mul_node)
|
| 695 |
+
remove_node(model, node, inp_quant_obs_node)
|
| 696 |
+
|
| 697 |
+
elif weight_eq_obs_dict.get(node.name, None) is not None:
|
| 698 |
+
weight_eq_obs = weight_eq_obs_dict.get(node.name)
|
| 699 |
+
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
|
| 700 |
+
equalization_scale = weight_eq_obs.equalization_scale
|
| 701 |
+
|
| 702 |
+
if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
|
| 703 |
+
equalization_scale = None # type: ignore[assignment]
|
| 704 |
+
maybe_next_equalization_scale = maybe_get_next_equalization_scale(node, modules)
|
| 705 |
+
|
| 706 |
+
# Scale the weight nodes
|
| 707 |
+
if node.op == 'call_module':
|
| 708 |
+
scale_weight_node(node, modules, equalization_scale, maybe_next_equalization_scale)
|
| 709 |
+
elif node.op == 'call_function':
|
| 710 |
+
scale_weight_functional(node, model, modules, equalization_scale, maybe_next_equalization_scale)
|
| 711 |
+
|
| 712 |
+
weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules)
|
| 713 |
+
if weight_eq_obs_node is None:
|
| 714 |
+
return
|
| 715 |
+
assert isinstance(modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver)
|
| 716 |
+
|
| 717 |
+
# Clear the quantization observer's min/max values so that they
|
| 718 |
+
# can get updated later based on the new scale values
|
| 719 |
+
clear_weight_quant_obs_node(node, modules)
|
| 720 |
+
|
| 721 |
+
# Erase the weight equalization observer node
|
| 722 |
+
prev_node = weight_eq_obs_node.args[0]
|
| 723 |
+
remove_node(model, weight_eq_obs_node, prev_node)
|
| 724 |
+
else:
|
| 725 |
+
raise ValueError("Expected operation node to be 'call_module' or 'call_function" +
|
| 726 |
+
f"Instead got node {node.name} as '{node.op}'.")
|
| 727 |
+
|
| 728 |
+
def _convert_equalization_ref(model: GraphModule):
|
| 729 |
+
""" Reference function which applies changes needed for equalization, but
|
| 730 |
+
does not quantize the nodes
|
| 731 |
+
"""
|
| 732 |
+
modules = dict(model.named_modules(remove_duplicate=False))
|
| 733 |
+
|
| 734 |
+
# Calculate the equalization scale, update the observers with the scaled
|
| 735 |
+
# inputs, and scale the weight
|
| 736 |
+
weight_eq_obs_dict = update_obs_for_equalization(model, modules)
|
| 737 |
+
convert_eq_obs(model, modules, weight_eq_obs_dict)
|
| 738 |
+
|
| 739 |
+
return GraphModule(model, model.graph)
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
###############################################################################
|
| 743 |
+
# Functions for running the equalized model on the Numeric Suite #
|
| 744 |
+
###############################################################################
|
| 745 |
+
|
| 746 |
+
def get_layer_sqnr_dict(model_a: nn.Module, model_b: nn.Module, x: torch.Tensor) -> Dict[str, float]:
|
| 747 |
+
""" Runs the Numeric Suite on model_a and model_b and returns a dictionary
|
| 748 |
+
containing the SQNR between layers in model_a and model_b.
|
| 749 |
+
|
| 750 |
+
Note: In order to support equalized models, this function has a hacky fix in
|
| 751 |
+
which we do not match any torch.mul operators. This is because equalized
|
| 752 |
+
models contain extra mul operators to scale the input by the equalization
|
| 753 |
+
scale, but this edge case has not been resolved yet within the numeric suite code.
|
| 754 |
+
|
| 755 |
+
Args:
|
| 756 |
+
model_a: A float model
|
| 757 |
+
model_b: A quantized model
|
| 758 |
+
x: Inputs to use during calibration
|
| 759 |
+
"""
|
| 760 |
+
import torch.ao.ns._numeric_suite_fx as ns
|
| 761 |
+
from torch.ao.ns.fx.mappings import get_unmatchable_types_map
|
| 762 |
+
|
| 763 |
+
unmatchable_types_map = get_unmatchable_types_map()
|
| 764 |
+
unmatchable_types_map["funs_unmatchable"].add(torch.mul)
|
| 765 |
+
|
| 766 |
+
model_a_ns, model_b_ns = ns.add_loggers(
|
| 767 |
+
'fp32', model_a,
|
| 768 |
+
'int8', model_b,
|
| 769 |
+
ns.OutputLogger,
|
| 770 |
+
unmatchable_types_map=unmatchable_types_map
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
model_a_ns(x)
|
| 774 |
+
model_b_ns(x)
|
| 775 |
+
|
| 776 |
+
activation_comparison_dict = ns.extract_logger_info(
|
| 777 |
+
model_a_ns,
|
| 778 |
+
model_b_ns,
|
| 779 |
+
ns.OutputLogger,
|
| 780 |
+
'int8')
|
| 781 |
+
ns.extend_logger_results_with_comparison(
|
| 782 |
+
activation_comparison_dict,
|
| 783 |
+
'fp32', 'int8',
|
| 784 |
+
torch.ao.ns.fx.utils.compute_sqnr, 'sqnr'
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
# Construct a dictionary mapping layer names to the SQNR values
|
| 788 |
+
layer_sqnr_dict = {}
|
| 789 |
+
for key in activation_comparison_dict:
|
| 790 |
+
layer = activation_comparison_dict[key]['node_output']['int8'][0]['fqn']
|
| 791 |
+
sqnr = activation_comparison_dict[key]['node_output']['int8'][0]['sqnr'][0]
|
| 792 |
+
layer_sqnr_dict[layer] = sqnr
|
| 793 |
+
|
| 794 |
+
return layer_sqnr_dict
|
| 795 |
+
|
| 796 |
+
def get_equalization_qconfig_dict(
|
| 797 |
+
layer_sqnr_dict: Dict[str, float],
|
| 798 |
+
num_layers_to_equalize: int
|
| 799 |
+
) -> Any:
|
| 800 |
+
""" Given the layer to SQNR dictionary, find the layers with the highest
|
| 801 |
+
quantization errors, and return an equalization_qconfig_dict
|
| 802 |
+
specifying to only equalize those top layers.
|
| 803 |
+
|
| 804 |
+
Args:
|
| 805 |
+
layer_sqnr_dict: Dictionary mapping layer names to SQNR values (found
|
| 806 |
+
when comparing an equalized model against a float model)
|
| 807 |
+
num_layers_to_equalize: Number of layers with the highest quantization
|
| 808 |
+
errors to equalize
|
| 809 |
+
"""
|
| 810 |
+
|
| 811 |
+
# Sort the layer_sqnr_dictionary values and get the layers with the lowest
|
| 812 |
+
# SQNR values (aka highest quantization errors)
|
| 813 |
+
layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=lambda item: item[1])
|
| 814 |
+
layers_to_equalize = layer_sqnr_sorted[:num_layers_to_equalize]
|
| 815 |
+
|
| 816 |
+
# Constructs an equalization_qconfig_dict that specifies to only equalize
|
| 817 |
+
# the layers with the highest quantization errors
|
| 818 |
+
module_to_qconfig_list = [(item[0], default_equalization_qconfig) for item in layers_to_equalize]
|
| 819 |
+
equalization_qconfig_dict = {"module_name": module_to_qconfig_list}
|
| 820 |
+
return equalization_qconfig_dict
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (239 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-311.pyc
ADDED
|
Binary file (66.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-311.pyc
ADDED
|
Binary file (26.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-311.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-311.pyc
ADDED
|
Binary file (29.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/detector.py
ADDED
|
@@ -0,0 +1,1539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Set, Tuple, Callable, List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.ao.nn.qat as nnqat
|
| 6 |
+
from abc import ABC, abstractmethod
|
| 7 |
+
from torch.ao.quantization.fake_quantize import FakeQuantize
|
| 8 |
+
from torch.ao.quantization.fx.graph_module import GraphModule
|
| 9 |
+
from torch.ao.quantization.fx._model_report.model_report_observer import ModelReportObserver
|
| 10 |
+
from torch.ao.quantization.qconfig import (
|
| 11 |
+
QConfig,
|
| 12 |
+
default_qconfig,
|
| 13 |
+
_assert_valid_qconfig,
|
| 14 |
+
)
|
| 15 |
+
from torch.ao.quantization.observer import (
|
| 16 |
+
ObserverBase,
|
| 17 |
+
default_dynamic_quant_observer,
|
| 18 |
+
default_per_channel_weight_observer,
|
| 19 |
+
default_observer,
|
| 20 |
+
default_weight_observer,
|
| 21 |
+
)
|
| 22 |
+
from torch.ao.quantization.fx._equalize import (
|
| 23 |
+
default_equalization_qconfig,
|
| 24 |
+
EqualizationQConfig,
|
| 25 |
+
)
|
| 26 |
+
from torch.ao.quantization.observer import _is_activation_post_process
|
| 27 |
+
|
| 28 |
+
# Names for observer insert keys
|
| 29 |
+
DETECTOR_TARGET_NODE_KEY = "target_node"
|
| 30 |
+
DETECTOR_OBS_TO_INSERT_KEY = "observer_to_insert"
|
| 31 |
+
DETECTOR_IS_POST_OBS_KEY = "is_post_observer"
|
| 32 |
+
DETECTOR_OBS_ARGS_KEY = "observer_args"
|
| 33 |
+
|
| 34 |
+
# Mapping related code
|
| 35 |
+
class DetectorQConfigInfo:
|
| 36 |
+
r"""
|
| 37 |
+
This class contains the QConfig information for a single module.
|
| 38 |
+
The list of variables / values this contains can grow depending on the
|
| 39 |
+
extensibility of the qconfig mapping feature set but this currently includes:
|
| 40 |
+
- if activation observer is dynamic
|
| 41 |
+
- if weight observer is per channel
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
module_fqn (str): The fully qualified name (fqn) of the module that this
|
| 46 |
+
information contains info relevant to qconfig for
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, module_fqn: str):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.module_fqn = module_fqn
|
| 52 |
+
|
| 53 |
+
# populate this section with all the variables we might find important
|
| 54 |
+
# change from none if your detector is actually using this
|
| 55 |
+
self.is_activation_dynamic = False
|
| 56 |
+
self.is_weight_per_channel = False
|
| 57 |
+
|
| 58 |
+
# equalization related options
|
| 59 |
+
self.is_equalization_recommended = False
|
| 60 |
+
|
| 61 |
+
def generate_quantization_qconfig(self, module: torch.nn.Module) -> QConfig:
|
| 62 |
+
r"""
|
| 63 |
+
Args:
|
| 64 |
+
module (torch.nn.Module) The module we are generating
|
| 65 |
+
the qconfig for
|
| 66 |
+
|
| 67 |
+
Returns the generated quantization QConfig according to what a valid configuration is
|
| 68 |
+
"""
|
| 69 |
+
# Apply suggestions to new qconfig
|
| 70 |
+
module_qconfig = default_qconfig
|
| 71 |
+
|
| 72 |
+
# keep track of dynamic and per_channel recommendations
|
| 73 |
+
recommendations_list = []
|
| 74 |
+
# append as if a list of combinations
|
| 75 |
+
recommendations_list.append((self.is_activation_dynamic, self.is_weight_per_channel))
|
| 76 |
+
recommendations_list.append((self.is_activation_dynamic, False)) # only trying dynamic rec
|
| 77 |
+
recommendations_list.append((False, self.is_weight_per_channel)) # only trying dynamic
|
| 78 |
+
|
| 79 |
+
# now we try each of the combinations
|
| 80 |
+
for rec in recommendations_list:
|
| 81 |
+
# rec[0] -> dynamic recommended
|
| 82 |
+
# rec[1] -> per channel recommended
|
| 83 |
+
activation = default_dynamic_quant_observer if rec[0] else default_observer
|
| 84 |
+
weight = default_per_channel_weight_observer if rec[1] else default_weight_observer
|
| 85 |
+
test_config = QConfig(activation, weight)
|
| 86 |
+
try:
|
| 87 |
+
_assert_valid_qconfig(test_config, module)
|
| 88 |
+
module_qconfig = test_config
|
| 89 |
+
break
|
| 90 |
+
except AssertionError:
|
| 91 |
+
# if not a valid configuration, we move on to the next one in priority
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
# return the QConfig chosen
|
| 95 |
+
return module_qconfig
|
| 96 |
+
|
| 97 |
+
def generate_equalization_qconfig(self) -> EqualizationQConfig:
|
| 98 |
+
r"""
|
| 99 |
+
This returns the equalization configuration for a module.
|
| 100 |
+
|
| 101 |
+
For now, it just returns the default, but as more equalization options become
|
| 102 |
+
possible, this method can get more fleshed out with more nuanced granularity.
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
Returns the generated equalization QConfig according to what a valid configuration is
|
| 106 |
+
"""
|
| 107 |
+
# in this case, we just return default equalization config
|
| 108 |
+
# we know this is valid because only valid modules would even
|
| 109 |
+
# have this option
|
| 110 |
+
return default_equalization_qconfig
|
| 111 |
+
|
| 112 |
+
# Adding base class for detectors
|
| 113 |
+
class DetectorBase(ABC):
|
| 114 |
+
r""" Base Detector Module
|
| 115 |
+
Any detector class should derive from this class.
|
| 116 |
+
|
| 117 |
+
Concrete detectors should follow the same general API, which includes:
|
| 118 |
+
- A method to calculate and return observer insertion points
|
| 119 |
+
- Should return both the fqns and the Observer class to insert
|
| 120 |
+
- A method to return a report based on the detector
|
| 121 |
+
- Should return a str-based report and dict info in Tuple[str,Dict] format
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(self):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.detector_config_info = None
|
| 127 |
+
|
| 128 |
+
@abstractmethod
|
| 129 |
+
def determine_observer_insert_points(self, model) -> Dict:
|
| 130 |
+
r"""
|
| 131 |
+
Args
|
| 132 |
+
model (nn.Module or subclass): model to find observer insertion points
|
| 133 |
+
|
| 134 |
+
Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict.
|
| 135 |
+
This dict maps string keys to detector specific information
|
| 136 |
+
"""
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
@abstractmethod
|
| 140 |
+
def get_detector_name(self) -> str:
|
| 141 |
+
r""" Returns the name of the current detector """
|
| 142 |
+
pass
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@abstractmethod
|
| 146 |
+
def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
|
| 147 |
+
r""" Returns the DetectorQConfigInfo for each module_fqn relevant
|
| 148 |
+
Args
|
| 149 |
+
model (nn.Module or subclass): model to find observer insertion points
|
| 150 |
+
|
| 151 |
+
Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
|
| 152 |
+
A DetectorQConfigInfo with the information to generate a QConfig for a specific module
|
| 153 |
+
"""
|
| 154 |
+
pass
|
| 155 |
+
|
| 156 |
+
def _get_targeting_node(self, prepared_fx_model: GraphModule, target_fqn: str) -> torch.fx.node.Node:
|
| 157 |
+
r"""
|
| 158 |
+
Takes in a GraphModule and the target_fqn and finds the node whose target is this fqn.
|
| 159 |
+
|
| 160 |
+
If it's not found, it means it is most likely inside a fused layer
|
| 161 |
+
We just go one layer up in terms of the fqn we are searching for until we find parent node
|
| 162 |
+
If we get to empty string, then we know that it doesn't exist
|
| 163 |
+
|
| 164 |
+
The reason for the recursion is that if the model that we are looking for got fused,
|
| 165 |
+
we will have module fqn as e.g. x.linear.0 but the graph will only have a node for the fused module,
|
| 166 |
+
which would have fqn as x.linear so they will not match.
|
| 167 |
+
To handle this, if we don't match, we then take off the last bit of the fqn e.g. x.linear.0 -> x.linear,
|
| 168 |
+
or more generally foo.bar.baz -> foo.bar and search again, this will allow us to locate the correct module
|
| 169 |
+
even in cases with fusion
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
prepared_fx_model (GraphModule): The prepared Fx GraphModule
|
| 173 |
+
target_fqn (str): The fqn of the layer we are trying to target
|
| 174 |
+
|
| 175 |
+
Returns the node object we are trying to add observers around
|
| 176 |
+
"""
|
| 177 |
+
for node in prepared_fx_model.graph.nodes:
|
| 178 |
+
# if the node's target is our target, return it
|
| 179 |
+
if node.target == target_fqn:
|
| 180 |
+
return node
|
| 181 |
+
|
| 182 |
+
# getting here means node not found
|
| 183 |
+
# if no "." we are already at base and failed
|
| 184 |
+
parent_fqn_sep_index = target_fqn.rfind(".")
|
| 185 |
+
if parent_fqn_sep_index == -1:
|
| 186 |
+
raise ValueError("passed in target_fqn not found in graph's targets.")
|
| 187 |
+
else:
|
| 188 |
+
# recursively call it with parent fqn
|
| 189 |
+
return self._get_targeting_node(prepared_fx_model, target_fqn[:parent_fqn_sep_index])
|
| 190 |
+
|
| 191 |
+
@abstractmethod
|
| 192 |
+
def generate_detector_report(self, model) -> Tuple[str, Dict[str, Any]]:
|
| 193 |
+
r"""
|
| 194 |
+
Args
|
| 195 |
+
model (nn.Module or subclass): model to find observer insertion points
|
| 196 |
+
|
| 197 |
+
Returns a Tuple of two elements:
|
| 198 |
+
Str: string report of the suggested improvements
|
| 199 |
+
Dict: contains useful data collected by the observer pertinent to this report
|
| 200 |
+
"""
|
| 201 |
+
pass
|
| 202 |
+
|
| 203 |
+
class PerChannelDetector(DetectorBase):
|
| 204 |
+
r""" This class is used to detect if any Linear or Conv layers in a model utilize per_channel quantization.
|
| 205 |
+
Only Linear and Conv layers can use per_channel as of now so only these two are currently checked.
|
| 206 |
+
|
| 207 |
+
per_channel quantization can lead to major benefits in the form of accuracy.
|
| 208 |
+
Therefore, if the backend used by the user supports it, it is recommended to use
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
backend (str, optional): the backend the user wishes to use in production
|
| 212 |
+
Default value is current torch.backends.quantized.engine
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
# Keys for return dictionary
|
| 216 |
+
BACKEND_KEY = "backend"
|
| 217 |
+
PER_CHAN_SUPPORTED_KEY = "per_channel_quantization_supported"
|
| 218 |
+
PER_CHAN_USED_KEY = "per_channel_quantization_used"
|
| 219 |
+
|
| 220 |
+
# Default map for representing supported per channel quantization modules for different backends
|
| 221 |
+
DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: Dict[str, Set[Any]] = {
|
| 222 |
+
"fbgemm": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
|
| 223 |
+
"qnnpack": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
|
| 224 |
+
"onednn": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
|
| 225 |
+
"x86": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
def __init__(self, backend: str = torch.backends.quantized.engine):
|
| 229 |
+
super().__init__()
|
| 230 |
+
|
| 231 |
+
# store the backend information
|
| 232 |
+
self.backend_chosen = backend
|
| 233 |
+
self.supported_modules = set()
|
| 234 |
+
if self.backend_chosen in self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES:
|
| 235 |
+
self.supported_modules = self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[self.backend_chosen]
|
| 236 |
+
else:
|
| 237 |
+
raise ValueError(f"Not configured to work with {self.backend_chosen}. Try a different default backend")
|
| 238 |
+
|
| 239 |
+
def get_detector_name(self) -> str:
|
| 240 |
+
r""" returns the string name of this detector"""
|
| 241 |
+
return "per_channel_detector"
|
| 242 |
+
|
| 243 |
+
def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
|
| 244 |
+
r""" Returns the DetectorQConfigInfo for each module_fqn relevant
|
| 245 |
+
Args
|
| 246 |
+
model (nn.Module or subclass): model to find observer insertion points
|
| 247 |
+
|
| 248 |
+
Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
|
| 249 |
+
A DetectorQConfigInfo with the information to generate a QConfig for a specific module
|
| 250 |
+
"""
|
| 251 |
+
# run the helper function to populate the dictionary
|
| 252 |
+
per_channel_info = self._detect_per_channel_helper(model)
|
| 253 |
+
|
| 254 |
+
# we actually have a qconfig info object we are populating
|
| 255 |
+
module_fqn_to_detector_qconfig_info = {}
|
| 256 |
+
|
| 257 |
+
for module_fqn in per_channel_info:
|
| 258 |
+
# create a detector info instance
|
| 259 |
+
detector_qconfig_info = DetectorQConfigInfo(module_fqn)
|
| 260 |
+
|
| 261 |
+
# see if per channel quantization is supported
|
| 262 |
+
per_chan_supported: bool = per_channel_info[module_fqn][self.PER_CHAN_SUPPORTED_KEY]
|
| 263 |
+
detector_qconfig_info.is_weight_per_channel = per_chan_supported
|
| 264 |
+
module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
|
| 265 |
+
|
| 266 |
+
return module_fqn_to_detector_qconfig_info
|
| 267 |
+
|
| 268 |
+
def determine_observer_insert_points(self, model: nn.Module) -> Dict:
|
| 269 |
+
r"""
|
| 270 |
+
There is no observers inserted for the PerChannelDetector.
|
| 271 |
+
|
| 272 |
+
Returns an empty dictionary since no observers are added or needed
|
| 273 |
+
"""
|
| 274 |
+
return {}
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def _detect_per_channel_helper(self, model: nn.Module):
|
| 278 |
+
r"""
|
| 279 |
+
determines if per_channel quantization is supported in modules and submodules.
|
| 280 |
+
|
| 281 |
+
Returns a dictionary in the higher level _detect_per_channel function.
|
| 282 |
+
Each entry maps the fully-qualified-name to information on whether per_channel quantization.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
model: The current module that is being checked to see if it is per_channel quantizable
|
| 286 |
+
|
| 287 |
+
Returns dictionary mapping fqns to if per_channel quantization is possible
|
| 288 |
+
"""
|
| 289 |
+
# create dict we will return
|
| 290 |
+
per_channel_info: Dict = {}
|
| 291 |
+
|
| 292 |
+
# get the fully qualified name and check if in list of modules to include and list of modules to ignore
|
| 293 |
+
for fqn, module in model.named_modules():
|
| 294 |
+
|
| 295 |
+
is_in_include_list = sum([isinstance(module, x) for x in self.supported_modules]) > 0
|
| 296 |
+
|
| 297 |
+
# check if the module per_channel is supported
|
| 298 |
+
# based on backend
|
| 299 |
+
per_channel_supported = False
|
| 300 |
+
|
| 301 |
+
if is_in_include_list:
|
| 302 |
+
per_channel_supported = True
|
| 303 |
+
|
| 304 |
+
# assert statement for MyPy
|
| 305 |
+
q_config_file = module.qconfig
|
| 306 |
+
assert isinstance(q_config_file, QConfig)
|
| 307 |
+
|
| 308 |
+
# this object should either be fake quant or observer
|
| 309 |
+
q_or_s_obj = module.qconfig.weight.p.func()
|
| 310 |
+
assert isinstance(q_or_s_obj, (FakeQuantize, ObserverBase))
|
| 311 |
+
|
| 312 |
+
per_channel_used = False # will be true if found in qconfig
|
| 313 |
+
|
| 314 |
+
if hasattr(q_or_s_obj, "ch_axis"): # then we know that per_channel quantization used
|
| 315 |
+
|
| 316 |
+
# all fake quants have channel axis so need to check is_per_channel
|
| 317 |
+
if isinstance(q_or_s_obj, FakeQuantize):
|
| 318 |
+
if hasattr(q_or_s_obj, "is_per_channel") and q_or_s_obj.is_per_channel:
|
| 319 |
+
per_channel_used = True
|
| 320 |
+
elif isinstance(q_or_s_obj, ObserverBase):
|
| 321 |
+
# should be an observer otherwise
|
| 322 |
+
per_channel_used = True
|
| 323 |
+
else:
|
| 324 |
+
raise ValueError("Should be either observer or fake quant")
|
| 325 |
+
|
| 326 |
+
per_channel_info[fqn] = {
|
| 327 |
+
self.PER_CHAN_SUPPORTED_KEY: per_channel_supported,
|
| 328 |
+
self.PER_CHAN_USED_KEY: per_channel_used,
|
| 329 |
+
self.BACKEND_KEY: self.backend_chosen
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
return per_channel_info
|
| 333 |
+
|
| 334 |
+
def generate_detector_report(self, model: nn.Module) -> Tuple[str, Dict[str, Any]]:
|
| 335 |
+
r"""Checks if any Linear or Conv layers in the model utilize per_channel quantization.
|
| 336 |
+
Only Linear and Conv layers can use per_channel as of now so only these two are currently checked.
|
| 337 |
+
|
| 338 |
+
Looks at q_config format and backend to determine if per_channel can be utilized.
|
| 339 |
+
Uses the DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES structure to determine support
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
model: The prepared and calibrated model we want to check if using per_channel
|
| 343 |
+
|
| 344 |
+
Returns a tuple with two elements:
|
| 345 |
+
String report of potential actions to improve model (if per_channel quantization is available in backend)
|
| 346 |
+
Dictionary mapping per_channel quantizable elements to:
|
| 347 |
+
whether per_channel quantization is supported by the backend
|
| 348 |
+
if it is being utilized in the current model
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
# run the helper function to populate the dictionary
|
| 352 |
+
per_channel_info = self._detect_per_channel_helper(model)
|
| 353 |
+
|
| 354 |
+
# String to let the user know of further optimizations
|
| 355 |
+
further_optims_str = f"Further Optimizations for backend {self.backend_chosen}: \n"
|
| 356 |
+
|
| 357 |
+
optimizations_possible = False
|
| 358 |
+
for fqn in per_channel_info:
|
| 359 |
+
fqn_dict = per_channel_info[fqn]
|
| 360 |
+
if fqn_dict[self.PER_CHAN_SUPPORTED_KEY] and not fqn_dict[self.PER_CHAN_USED_KEY]:
|
| 361 |
+
optimizations_possible = True
|
| 362 |
+
further_optims_str += f"Module {fqn} can be configured to use per_channel quantization.\n"
|
| 363 |
+
|
| 364 |
+
if optimizations_possible:
|
| 365 |
+
further_optims_str += (
|
| 366 |
+
"To use per_channel quantization, make sure the qconfig has a per_channel weight observer."
|
| 367 |
+
)
|
| 368 |
+
else:
|
| 369 |
+
further_optims_str += "No further per_channel optimizations possible."
|
| 370 |
+
|
| 371 |
+
# return the string and the dictionary form of same information
|
| 372 |
+
return (further_optims_str, per_channel_info)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class DynamicStaticDetector(DetectorBase):
|
| 376 |
+
r"""
|
| 377 |
+
Determines whether dynamic or static quantization is more appropriate for a given module.
|
| 378 |
+
|
| 379 |
+
Takes advantage of the ModelReportObserver that records range information.
|
| 380 |
+
Stationary distribution of data are strictly above tolerance level for the comparison statistic:
|
| 381 |
+
|
| 382 |
+
S = average_batch_activation_range/epoch_activation_range
|
| 383 |
+
|
| 384 |
+
Nonstationary distributions are below or at the tolerance level for this metric.
|
| 385 |
+
|
| 386 |
+
If the distribution of data right after the module is non-stationary, recommend dynamic quantization
|
| 387 |
+
Otherwise recommend static quantization
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
tolerance (float, optional): The threshold where S metric is stationary above and non-stationary otherwise. Default: 0.5
|
| 391 |
+
"""
|
| 392 |
+
# names for the pre and post observers that are inserted
|
| 393 |
+
DEFAULT_PRE_OBSERVER_NAME = "model_report_pre_observer"
|
| 394 |
+
DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer"
|
| 395 |
+
|
| 396 |
+
# naming conventions for stationary vs non-stationary data
|
| 397 |
+
STATIONARY_STR = "stationary"
|
| 398 |
+
NON_STATIONARY_STR = "non-stationary"
|
| 399 |
+
|
| 400 |
+
# naming for activation
|
| 401 |
+
INPUT_ACTIVATION_PREFIX = "input_activation_"
|
| 402 |
+
OUTPUT_ACTIVATION_PREFIX = "output_activation_"
|
| 403 |
+
|
| 404 |
+
# naming conventions for the keys of the return module info
|
| 405 |
+
TOLERANCE_KEY = "dynamic_static_tolerance"
|
| 406 |
+
DEFAULT_DYNAMIC_REC_KEY = "dynamic_recommended"
|
| 407 |
+
PRE_OBS_COMP_STAT_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat"
|
| 408 |
+
POST_OBS_COMP_STAT_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat"
|
| 409 |
+
PRE_OBS_DATA_DIST_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification"
|
| 410 |
+
POST_OBS_DATA_DIST_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification"
|
| 411 |
+
IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported"
|
| 412 |
+
|
| 413 |
+
# modules that are supported both dynamic and static for this report function
|
| 414 |
+
DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = {nn.Linear}
|
| 415 |
+
|
| 416 |
+
# modules that will be supported soon for both
|
| 417 |
+
DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = {nn.Conv1d, nn.Conv2d, nn.Conv3d}
|
| 418 |
+
|
| 419 |
+
def __init__(self, tolerance=0.5):
|
| 420 |
+
super().__init__()
|
| 421 |
+
|
| 422 |
+
# set tolerance level and initialize a set to keep track of useful fqn locations
|
| 423 |
+
self.tolerance = tolerance
|
| 424 |
+
self.useful_observer_fqns: Set[str] = set()
|
| 425 |
+
|
| 426 |
+
def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]:
|
| 427 |
+
r"""
|
| 428 |
+
Determines where observers need to be inserted for the Dynamic vs Static detector.
|
| 429 |
+
For this detector, we want to place observers on either side of linear layers in the model.
|
| 430 |
+
|
| 431 |
+
Currently inserts observers for:
|
| 432 |
+
linear layers
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
prepared_fx_model (GraphModule): The prepared Fx GraphModule
|
| 436 |
+
|
| 437 |
+
Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
|
| 438 |
+
key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
|
| 439 |
+
key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
|
| 440 |
+
key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
|
| 441 |
+
key "observer_args" -> The arguments that are meant to be passed into the observer
|
| 442 |
+
"""
|
| 443 |
+
|
| 444 |
+
# observer for this detector is ModelReportObserver
|
| 445 |
+
obs_ctr = ModelReportObserver
|
| 446 |
+
|
| 447 |
+
# return dict
|
| 448 |
+
obs_fqn_to_info: Dict[str, Dict[str, Any]] = {}
|
| 449 |
+
|
| 450 |
+
for fqn, module in prepared_fx_model.named_modules():
|
| 451 |
+
# make sure module is supported
|
| 452 |
+
if self._is_supported(module, insert=True):
|
| 453 |
+
# if it's a supported type, we want to get node and add observer insert locations
|
| 454 |
+
targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
|
| 455 |
+
|
| 456 |
+
# add entry for pre-observer
|
| 457 |
+
pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
|
| 458 |
+
|
| 459 |
+
obs_fqn_to_info[pre_obs_fqn] = {
|
| 460 |
+
DETECTOR_TARGET_NODE_KEY: targeted_node,
|
| 461 |
+
DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(),
|
| 462 |
+
DETECTOR_IS_POST_OBS_KEY: False,
|
| 463 |
+
DETECTOR_OBS_ARGS_KEY: targeted_node.args
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
# add entry for post-observer
|
| 467 |
+
post_obs_fqn = fqn + "." + self.DEFAULT_POST_OBSERVER_NAME
|
| 468 |
+
|
| 469 |
+
obs_fqn_to_info[post_obs_fqn] = {
|
| 470 |
+
DETECTOR_TARGET_NODE_KEY: targeted_node,
|
| 471 |
+
DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(),
|
| 472 |
+
DETECTOR_IS_POST_OBS_KEY: True,
|
| 473 |
+
DETECTOR_OBS_ARGS_KEY: (targeted_node,)
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
return obs_fqn_to_info
|
| 477 |
+
|
| 478 |
+
def get_detector_name(self) -> str:
|
| 479 |
+
r""" returns the string name of this detector"""
|
| 480 |
+
return "dynamic_vs_static_detector"
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
|
| 484 |
+
r""" Returns the DetectorQConfigInfo for each module_fqn relevant
|
| 485 |
+
Args
|
| 486 |
+
model (nn.Module or subclass): model to find observer insertion points
|
| 487 |
+
|
| 488 |
+
Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
|
| 489 |
+
A DetectorQConfigInfo with the information to generate a QConfig for a specific module
|
| 490 |
+
"""
|
| 491 |
+
# run the helper function to populate the dictionary
|
| 492 |
+
dynamic_static_info = self._generate_dict_info(model)
|
| 493 |
+
|
| 494 |
+
# we actually have a qconfig info object we are populating
|
| 495 |
+
module_fqn_to_detector_qconfig_info = {}
|
| 496 |
+
|
| 497 |
+
for module_fqn in dynamic_static_info:
|
| 498 |
+
# create a detector info instance
|
| 499 |
+
detector_qconfig_info = DetectorQConfigInfo(module_fqn)
|
| 500 |
+
|
| 501 |
+
# see if per channel quantization is supported
|
| 502 |
+
dynamic_static_recommended: bool = dynamic_static_info[module_fqn][self.DEFAULT_DYNAMIC_REC_KEY]
|
| 503 |
+
detector_qconfig_info.is_activation_dynamic = dynamic_static_recommended
|
| 504 |
+
module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
|
| 505 |
+
|
| 506 |
+
return module_fqn_to_detector_qconfig_info
|
| 507 |
+
|
| 508 |
+
def _is_supported(self, module: nn.Module, insert: bool = False) -> bool:
|
| 509 |
+
r"""Returns whether the given module is supported for observers
|
| 510 |
+
|
| 511 |
+
Args
|
| 512 |
+
module: The module to check and ensure is supported
|
| 513 |
+
insert: True if this is check for observer insertion, false if for report gen
|
| 514 |
+
|
| 515 |
+
Returns True if the module is supported by observer, False otherwise
|
| 516 |
+
"""
|
| 517 |
+
# check to see if module is of a supported type
|
| 518 |
+
is_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED]) > 0
|
| 519 |
+
|
| 520 |
+
# check if it will be supported
|
| 521 |
+
future_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED]) > 0
|
| 522 |
+
|
| 523 |
+
# supported
|
| 524 |
+
supported = is_supported_type or future_supported_type
|
| 525 |
+
|
| 526 |
+
# this is check for observer insertion
|
| 527 |
+
if insert:
|
| 528 |
+
return supported
|
| 529 |
+
else:
|
| 530 |
+
# this is for report gen and we also need to check if it contains observers
|
| 531 |
+
has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) and hasattr(module, self.DEFAULT_POST_OBSERVER_NAME)
|
| 532 |
+
return supported and has_obs
|
| 533 |
+
|
| 534 |
+
def _generate_dict_info(self, model: GraphModule) -> Dict[str, Any]:
|
| 535 |
+
r"""
|
| 536 |
+
Helper function for generate_detector_report that does the generation of the dictionary.
|
| 537 |
+
This process is done as specified in generate_detector_report documentation
|
| 538 |
+
|
| 539 |
+
Args:
|
| 540 |
+
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
|
| 541 |
+
|
| 542 |
+
Returns a Dictionary mapping modules with ModelReportObservers around them to:
|
| 543 |
+
whether dynamic quantization is recommended
|
| 544 |
+
their S metric of input to module
|
| 545 |
+
whether input to module is stationary or non-stationary
|
| 546 |
+
their S metric of output of module
|
| 547 |
+
whether output of module is stationary or non-stationary
|
| 548 |
+
the tolerance level to decided whether input/output is stationary or non-stationary
|
| 549 |
+
whether it is currently supported or planned for the future
|
| 550 |
+
"""
|
| 551 |
+
# store modules dynamic vs static information
|
| 552 |
+
module_dynamic_static_info = {}
|
| 553 |
+
|
| 554 |
+
# This for loop goes through the modules, and extracts all relevant information into module_dynamic_static_info
|
| 555 |
+
# This information primary includes whether the data distributions around a supported module is stationary or not
|
| 556 |
+
# Based on this, it is recorded whether dynamic or static quantization is recommended
|
| 557 |
+
|
| 558 |
+
# loop through all submodules included nested ones
|
| 559 |
+
for fqn, module in model.named_modules():
|
| 560 |
+
# if module is Linear has the ModelReportObserver attached to it
|
| 561 |
+
if self._is_supported(module):
|
| 562 |
+
# get pre and post observers for the module
|
| 563 |
+
pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
|
| 564 |
+
post_obs = getattr(module, self.DEFAULT_POST_OBSERVER_NAME)
|
| 565 |
+
|
| 566 |
+
# get the statistics for each module
|
| 567 |
+
pre_stat = pre_obs.get_batch_to_epoch_ratio()
|
| 568 |
+
post_stat = post_obs.get_batch_to_epoch_ratio()
|
| 569 |
+
|
| 570 |
+
# record module, pre and post stat, and whether to do dynamic or static based off it
|
| 571 |
+
# true if post observer data distribution is non-stationary, false if it's stationary
|
| 572 |
+
dynamic_recommended = post_stat <= self.tolerance
|
| 573 |
+
|
| 574 |
+
# specify the classifications for whether data distributions considered stationary or non-stationary
|
| 575 |
+
pre_obs_dist_classif = self.STATIONARY_STR if pre_stat > self.tolerance else self.NON_STATIONARY_STR
|
| 576 |
+
post_obs_dist_classif = self.STATIONARY_STR if post_stat > self.tolerance else self.NON_STATIONARY_STR
|
| 577 |
+
|
| 578 |
+
# check if current support or future support
|
| 579 |
+
is_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED]) > 0
|
| 580 |
+
|
| 581 |
+
# store the set of important information for this module
|
| 582 |
+
module_info = {
|
| 583 |
+
self.TOLERANCE_KEY: self.tolerance,
|
| 584 |
+
self.DEFAULT_DYNAMIC_REC_KEY: dynamic_recommended,
|
| 585 |
+
self.PRE_OBS_COMP_STAT_KEY: pre_stat,
|
| 586 |
+
self.PRE_OBS_DATA_DIST_KEY: pre_obs_dist_classif,
|
| 587 |
+
self.POST_OBS_COMP_STAT_KEY: post_stat,
|
| 588 |
+
self.POST_OBS_DATA_DIST_KEY: post_obs_dist_classif,
|
| 589 |
+
self.IS_CURRENTLY_SUPPORTED_KEY: is_supported_type,
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
module_dynamic_static_info[fqn] = module_info
|
| 593 |
+
|
| 594 |
+
return module_dynamic_static_info
|
| 595 |
+
|
| 596 |
+
def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]:
|
| 597 |
+
r"""
|
| 598 |
+
Determines whether dynamic or static quantization is more appropriate for a given module.
|
| 599 |
+
|
| 600 |
+
Takes advantage of the ModelReportObserver that records range information.
|
| 601 |
+
Stationary distribution of data are strictly above tolerance level for the comparison statistic:
|
| 602 |
+
|
| 603 |
+
S = average_batch_activation_range/epoch_activation_range
|
| 604 |
+
|
| 605 |
+
Nonstationary distributions are below or at the tolerance level for this metric.
|
| 606 |
+
|
| 607 |
+
If the distribution of data right after the module is non-stationary, recommend dynamic quantization
|
| 608 |
+
Otherwise recommend static quantization
|
| 609 |
+
|
| 610 |
+
This will then generate suggestions for dynamic vs static quantization focused around Linear.
|
| 611 |
+
|
| 612 |
+
Args:
|
| 613 |
+
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
|
| 614 |
+
|
| 615 |
+
Returns a tuple with two elements:
|
| 616 |
+
String report of of whether dynamic or static quantization is recommended for certain modules
|
| 617 |
+
Dictionary mapping modules with ModelReportObservers around them to:
|
| 618 |
+
whether dynamic quantization is recommended
|
| 619 |
+
their S metric of input to module
|
| 620 |
+
whether input to module is stationary or non-stationary
|
| 621 |
+
their S metric of output of module
|
| 622 |
+
whether output of module is stationary or non-stationary
|
| 623 |
+
the tolerance level to decided whether input/output is stationary or non-stationary
|
| 624 |
+
whether it is currently supported or planned for the future
|
| 625 |
+
"""
|
| 626 |
+
|
| 627 |
+
# get the dictionary of the information to format the string report
|
| 628 |
+
module_dynamic_static_info = self._generate_dict_info(model)
|
| 629 |
+
|
| 630 |
+
dynamic_vs_static_string = "Dynamic vs. Static Quantization suggestions: \n"
|
| 631 |
+
|
| 632 |
+
modules_added: bool = False # check to make sure at least 1 module added.
|
| 633 |
+
|
| 634 |
+
dynamic_benefit = " You will get more accurate results if you use dynamic quantization"
|
| 635 |
+
static_benefit = " You can increase model efficiency if you use static quantization"
|
| 636 |
+
future_support_str = ". This layer is not yet supported for dynamic quantization"
|
| 637 |
+
# This for loop goes through the information collected in module_dynamic_static_info and:
|
| 638 |
+
# Populates the string based report with the information from module_dynamic_static_info
|
| 639 |
+
# Compiles the complete report by appending relevant formatted strings
|
| 640 |
+
|
| 641 |
+
for module_fqn in module_dynamic_static_info.keys():
|
| 642 |
+
|
| 643 |
+
# there is at least 1 module for suggestion
|
| 644 |
+
modules_added = True
|
| 645 |
+
module_info = module_dynamic_static_info[module_fqn]
|
| 646 |
+
suggestion_string_template = "For module {} it is suggested to use {} quantization because {}.\n"
|
| 647 |
+
|
| 648 |
+
# decide what string formatting values will be
|
| 649 |
+
quantization_type = ""
|
| 650 |
+
quantization_reasoning = "the distribution of data before {} is {} and the distribution after is {}."
|
| 651 |
+
|
| 652 |
+
benefit_str = ""
|
| 653 |
+
|
| 654 |
+
# strings for if dynamic quantized per tensor is needed
|
| 655 |
+
recommend_per_tensor = ". We recommend to add a {} before this module if it is static."
|
| 656 |
+
rec_lay_to_add = "dynamic quantize per tensor layer"
|
| 657 |
+
dynamic_per_tensor_string = recommend_per_tensor.format(rec_lay_to_add)
|
| 658 |
+
dynamic_per_tensor_reasoning_string = (
|
| 659 |
+
" This is because the input to this module has a non-stationary distribution"
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
# start composing explanation
|
| 663 |
+
if module_info[self.DEFAULT_DYNAMIC_REC_KEY]:
|
| 664 |
+
quantization_type = "dynamic"
|
| 665 |
+
# check if currently supported or future supported
|
| 666 |
+
benefit_str = dynamic_benefit
|
| 667 |
+
if not module_info[self.IS_CURRENTLY_SUPPORTED_KEY]:
|
| 668 |
+
benefit_str += future_support_str
|
| 669 |
+
else:
|
| 670 |
+
quantization_type = "static"
|
| 671 |
+
benefit_str = static_benefit
|
| 672 |
+
|
| 673 |
+
# now set the quantization explanation string
|
| 674 |
+
quantization_reasoning = (
|
| 675 |
+
quantization_reasoning.format(
|
| 676 |
+
module_fqn, module_info[self.PRE_OBS_DATA_DIST_KEY], module_info[self.POST_OBS_DATA_DIST_KEY]
|
| 677 |
+
)
|
| 678 |
+
+ benefit_str
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
# if we have a non-stationary input -> linear -> stationary we suggested static
|
| 682 |
+
# however, we want to also recommend they add a dynamic quantize per tensor right if this change is made
|
| 683 |
+
if (
|
| 684 |
+
module_info[self.PRE_OBS_DATA_DIST_KEY] == self.NON_STATIONARY_STR
|
| 685 |
+
and module_info[self.POST_OBS_DATA_DIST_KEY] == self.STATIONARY_STR
|
| 686 |
+
):
|
| 687 |
+
quantization_reasoning = (
|
| 688 |
+
quantization_reasoning + dynamic_per_tensor_string + dynamic_per_tensor_reasoning_string
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
# format the overall suggestion string with the specific inputs
|
| 692 |
+
module_suggestion_string = suggestion_string_template.format(
|
| 693 |
+
module_fqn, quantization_type, quantization_reasoning
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
# append to overall suggestion
|
| 697 |
+
dynamic_vs_static_string += module_suggestion_string
|
| 698 |
+
|
| 699 |
+
if not modules_added:
|
| 700 |
+
dynamic_vs_static_string += "No applicable layers for suggestions. Only linear and conv are valid.\n"
|
| 701 |
+
|
| 702 |
+
# return the string as well as the dictionary of information
|
| 703 |
+
return (dynamic_vs_static_string, module_dynamic_static_info)
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
class InputWeightEqualizationDetector(DetectorBase):
|
| 707 |
+
r"""
|
| 708 |
+
Determines whether input-weight equalization can help improve quantization for certain modules.
|
| 709 |
+
|
| 710 |
+
Specifically, this list of modules includes:
|
| 711 |
+
linear
|
| 712 |
+
conv
|
| 713 |
+
|
| 714 |
+
Determines whether input-weight equalization is recommended based on the comp stat:
|
| 715 |
+
s_c = sqrt(w_c/W)/sqrt(i_c/I)
|
| 716 |
+
where:
|
| 717 |
+
w_c is range of weight for channel c, W is range of weight over all channels
|
| 718 |
+
i_c is range of input for channel c, I is range of input over all channels
|
| 719 |
+
|
| 720 |
+
if s_c >= threshold or <= 1 / threshold, recommends input-weight equalization
|
| 721 |
+
|
| 722 |
+
Args:
|
| 723 |
+
ratio_threshold (float): The threshold for s_c to determine if input-weight equalization is suggested
|
| 724 |
+
Should be between 0 and 1 (both non-inclusive)
|
| 725 |
+
ch_axis (int, optional): The channel axis being observed to determine input weight equalization
|
| 726 |
+
Default: 1
|
| 727 |
+
|
| 728 |
+
* :attr:`ratio_threshold`: The threshold for s_c to determine if input-weight equalization is suggested
|
| 729 |
+
Should be between 0 and 1
|
| 730 |
+
|
| 731 |
+
* :attr:`ch_axis`: The channel axis being observed to determine input weight equalization
|
| 732 |
+
|
| 733 |
+
* :attr:`SUPPORTED_MODULES`: This specifies the modules that are supported for input-weight equalization
|
| 734 |
+
|
| 735 |
+
* :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
|
| 736 |
+
"""
|
| 737 |
+
|
| 738 |
+
SUPPORTED_MODULES: Set[Callable] = {nn.Linear,
|
| 739 |
+
nn.Conv1d,
|
| 740 |
+
nn.Conv2d,
|
| 741 |
+
nn.Conv3d,
|
| 742 |
+
nnqat.Linear,
|
| 743 |
+
nnqat.Conv1d,
|
| 744 |
+
nnqat.Conv2d,
|
| 745 |
+
nnqat.Conv3d}
|
| 746 |
+
|
| 747 |
+
# names for the pre and post observers that are inserted
|
| 748 |
+
DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"
|
| 749 |
+
|
| 750 |
+
# weight / activation prefix for each of the below info
|
| 751 |
+
WEIGHT_PREFIX = "weight_"
|
| 752 |
+
ACTIVATION_PREFIX = "input_activation_"
|
| 753 |
+
|
| 754 |
+
# string names for keys of info dictionaries
|
| 755 |
+
PER_CHANNEL_MAX_KEY = "per_channel_max"
|
| 756 |
+
PER_CHANNEL_MIN_KEY = "per_channel_min"
|
| 757 |
+
GLOBAL_MAX_KEY = "global_max"
|
| 758 |
+
GLOBAL_MIN_KEY = "global_min"
|
| 759 |
+
|
| 760 |
+
# keys for return dict of recommendations
|
| 761 |
+
RECOMMENDED_KEY = "input_weight_equalization_recommended"
|
| 762 |
+
COMP_METRIC_KEY = "input_weight_channel_comparison_metrics"
|
| 763 |
+
THRESHOLD_KEY = "input_weight_threshold"
|
| 764 |
+
CHANNEL_KEY = "input_weight_channel_axis"
|
| 765 |
+
|
| 766 |
+
# default weight and info strings
|
| 767 |
+
WEIGHT_STR = "weight"
|
| 768 |
+
INPUT_STR = "input"
|
| 769 |
+
|
| 770 |
+
# default for what ratio we recommend input weight
|
| 771 |
+
DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO = 0.4
|
| 772 |
+
|
| 773 |
+
def __init__(self, ratio_threshold: float, ch_axis: int = 1):
|
| 774 |
+
# ensure passed in inputs are valid
|
| 775 |
+
if ratio_threshold <= 0 or ratio_threshold >= 1:
|
| 776 |
+
raise ValueError("Make sure threshold is > 0 and < 1")
|
| 777 |
+
|
| 778 |
+
# initialize attributes based on args
|
| 779 |
+
self.ratio_threshold: float = ratio_threshold
|
| 780 |
+
self.ch_axis: int = ch_axis
|
| 781 |
+
|
| 782 |
+
def _is_supported(self, module: nn.Module, insert: bool = False) -> bool:
|
| 783 |
+
r"""Returns whether the given module is supported for observers
|
| 784 |
+
|
| 785 |
+
Args
|
| 786 |
+
module: The module to check and ensure is supported
|
| 787 |
+
insert: True if this is check for observer insertion, false if for report gen
|
| 788 |
+
|
| 789 |
+
Returns True if the module is supported by observer, False otherwise
|
| 790 |
+
"""
|
| 791 |
+
# check to see if module is of a supported type
|
| 792 |
+
is_supported_type = sum([type(module) is x for x in self.SUPPORTED_MODULES]) > 0
|
| 793 |
+
|
| 794 |
+
# this is check for observer insertion
|
| 795 |
+
if insert:
|
| 796 |
+
return is_supported_type
|
| 797 |
+
else:
|
| 798 |
+
# this is for report gen and we also need to check if it contains observers
|
| 799 |
+
has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
|
| 800 |
+
return is_supported_type and has_obs
|
| 801 |
+
|
| 802 |
+
def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
|
| 803 |
+
r""" Returns the DetectorQConfigInfo for each module_fqn relevant
|
| 804 |
+
Args
|
| 805 |
+
model (nn.Module or subclass): model to find observer insertion points
|
| 806 |
+
|
| 807 |
+
Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
|
| 808 |
+
A DetectorQConfigInfo with the information to generate a QConfig for a specific module
|
| 809 |
+
"""
|
| 810 |
+
# run the helper function to populate the dictionary
|
| 811 |
+
# find the range of inputs
|
| 812 |
+
input_values: Dict[str, Dict] = self._extract_input_info(model)
|
| 813 |
+
|
| 814 |
+
# find the range of weights
|
| 815 |
+
weight_values: Dict[str, Dict] = self._extract_weight_info(model)
|
| 816 |
+
|
| 817 |
+
# calculate per_channel comparison statistic s_c
|
| 818 |
+
comp_stats: Dict[str, torch.Tensor] = self._generate_comparison_values(input_values, weight_values)
|
| 819 |
+
|
| 820 |
+
# generate the return dictionary
|
| 821 |
+
input_weight_equalization_info: Dict[str, Dict] = self._generate_dict_info(input_values, weight_values, comp_stats)
|
| 822 |
+
|
| 823 |
+
# we actually have a qconfig info object we are populating
|
| 824 |
+
module_fqn_to_detector_qconfig_info = {}
|
| 825 |
+
|
| 826 |
+
for module_fqn in input_weight_equalization_info:
|
| 827 |
+
# create a detector info instance
|
| 828 |
+
detector_qconfig_info = DetectorQConfigInfo(module_fqn)
|
| 829 |
+
|
| 830 |
+
# see if per channel quantization is supported
|
| 831 |
+
input_weight_recommended: bool = input_weight_equalization_info[module_fqn][self.RECOMMENDED_KEY]
|
| 832 |
+
detector_qconfig_info.is_equalization_recommended = input_weight_recommended
|
| 833 |
+
module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
|
| 834 |
+
|
| 835 |
+
return module_fqn_to_detector_qconfig_info
|
| 836 |
+
|
| 837 |
+
def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]:
|
| 838 |
+
r"""Determines where observers need to be inserted for the Input Weight Equalization Detector.
|
| 839 |
+
For this detector, we want to place observers in front of supported layers.
|
| 840 |
+
|
| 841 |
+
Currently inserts observers for:
|
| 842 |
+
linear layers
|
| 843 |
+
conv layers
|
| 844 |
+
|
| 845 |
+
Args:
|
| 846 |
+
prepared_fx_model (GraphModule): The prepared Fx GraphModule
|
| 847 |
+
|
| 848 |
+
Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
|
| 849 |
+
key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
|
| 850 |
+
key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
|
| 851 |
+
key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
|
| 852 |
+
key "observer_args" -> The arguments that are meant to be passed into the observer
|
| 853 |
+
"""
|
| 854 |
+
|
| 855 |
+
# observer for this detector is ModelReportObserver
|
| 856 |
+
obs_ctr = ModelReportObserver
|
| 857 |
+
|
| 858 |
+
# return dict
|
| 859 |
+
obs_fqn_to_info: Dict[str, Dict[str, Any]] = {}
|
| 860 |
+
|
| 861 |
+
for fqn, module in prepared_fx_model.named_modules():
|
| 862 |
+
# check to see if module is of a supported type
|
| 863 |
+
if self._is_supported(module, insert=True):
|
| 864 |
+
# if it's a supported type, we want to get node and add observer insert locations
|
| 865 |
+
targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
|
| 866 |
+
|
| 867 |
+
# add entry for pre-observer
|
| 868 |
+
pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
|
| 869 |
+
|
| 870 |
+
obs_fqn_to_info[pre_obs_fqn] = {
|
| 871 |
+
DETECTOR_TARGET_NODE_KEY: targeted_node,
|
| 872 |
+
DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis),
|
| 873 |
+
DETECTOR_IS_POST_OBS_KEY: False,
|
| 874 |
+
DETECTOR_OBS_ARGS_KEY: targeted_node.args,
|
| 875 |
+
}
|
| 876 |
+
|
| 877 |
+
return obs_fqn_to_info
|
| 878 |
+
|
| 879 |
+
def get_detector_name(self) -> str:
|
| 880 |
+
r"""Returns the name of this detector"""
|
| 881 |
+
return "input_weight_equalization_detector"
|
| 882 |
+
|
| 883 |
+
def _extract_input_info(self, model: GraphModule) -> Dict[str, Dict]:
|
| 884 |
+
r"""
|
| 885 |
+
Takes in a calibrated GraphModule and then finds the relevant observers.
|
| 886 |
+
It then extracts the input information for each observer returns it
|
| 887 |
+
|
| 888 |
+
Args
|
| 889 |
+
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
|
| 890 |
+
|
| 891 |
+
Returns a dict mapping relevant module fqns (str) to a dict with keys:
|
| 892 |
+
"input_activation_per_channel_max" : maps to the per_channel max values
|
| 893 |
+
"input_activation_per_channel_min" : maps to the per_channel min values
|
| 894 |
+
"input_activation_global_max" : maps to the global max recorded
|
| 895 |
+
"input_activation_global_min" : maps to the global min recorded
|
| 896 |
+
"""
|
| 897 |
+
|
| 898 |
+
# return dictionary mapping observer fqns to desired info
|
| 899 |
+
input_info: Dict[str, Dict] = {}
|
| 900 |
+
|
| 901 |
+
for fqn, module in model.named_modules():
|
| 902 |
+
# if module is supported and it has a pre-observer
|
| 903 |
+
if self._is_supported(module):
|
| 904 |
+
# get pre observer for the module
|
| 905 |
+
pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
|
| 906 |
+
|
| 907 |
+
input_info[fqn] = {
|
| 908 |
+
self.ACTIVATION_PREFIX + self.PER_CHANNEL_MAX_KEY: pre_obs.max_val,
|
| 909 |
+
self.ACTIVATION_PREFIX + self.PER_CHANNEL_MIN_KEY: pre_obs.min_val,
|
| 910 |
+
self.ACTIVATION_PREFIX + self.GLOBAL_MAX_KEY: max(pre_obs.max_val),
|
| 911 |
+
self.ACTIVATION_PREFIX + self.GLOBAL_MIN_KEY: min(pre_obs.min_val),
|
| 912 |
+
}
|
| 913 |
+
|
| 914 |
+
return input_info
|
| 915 |
+
|
| 916 |
+
def _extract_weight_info(self, model: GraphModule) -> Dict[str, Dict]:
|
| 917 |
+
r"""
|
| 918 |
+
Takes in a calibrated GraphModule and then finds the relevant observers.
|
| 919 |
+
It then extracts the weight information for each layer an observer is attached to.
|
| 920 |
+
|
| 921 |
+
Args
|
| 922 |
+
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
|
| 923 |
+
|
| 924 |
+
Returns a dict mapping module fqns (str) to a dict with keys:
|
| 925 |
+
"per_channel_max" : maps to the per_channel max values
|
| 926 |
+
"per_channel_min" : maps to the per_channel min values
|
| 927 |
+
"global_max" : maps to the global max recorded
|
| 928 |
+
"global_min" : maps to the global min recorded
|
| 929 |
+
"""
|
| 930 |
+
# return dictionary mapping observer fqns to desired info
|
| 931 |
+
weight_info: Dict[str, Dict] = {}
|
| 932 |
+
|
| 933 |
+
for fqn, module in model.named_modules():
|
| 934 |
+
# if module is supported and it has a pre-observer
|
| 935 |
+
if self._is_supported(module):
|
| 936 |
+
# we don't need actual observer, just the module weights
|
| 937 |
+
# calculate min and max vals
|
| 938 |
+
device = module.weight.device
|
| 939 |
+
min_val: torch.Tensor = torch.tensor([float('inf')], device=device)
|
| 940 |
+
max_val: torch.Tensor = torch.tensor([float('-inf')], device=device)
|
| 941 |
+
x_copy = module.weight
|
| 942 |
+
x_dim = x_copy.size()
|
| 943 |
+
|
| 944 |
+
new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
|
| 945 |
+
new_axis_list[self.ch_axis] = 0
|
| 946 |
+
new_axis_list[0] = self.ch_axis
|
| 947 |
+
y = x_copy.permute(new_axis_list)
|
| 948 |
+
|
| 949 |
+
# Need to match dtype of min/max because the updates to buffers
|
| 950 |
+
# are done in place and types need to match for comparisons
|
| 951 |
+
y = y.to(min_val.dtype)
|
| 952 |
+
y = torch.flatten(y, start_dim=1)
|
| 953 |
+
if min_val.numel() == 0 or max_val.numel() == 0:
|
| 954 |
+
min_val, max_val = torch.aminmax(y, dim=1)
|
| 955 |
+
else:
|
| 956 |
+
min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
|
| 957 |
+
min_val = torch.min(min_val_cur, min_val)
|
| 958 |
+
max_val = torch.max(max_val_cur, max_val)
|
| 959 |
+
|
| 960 |
+
weight_info[fqn] = {
|
| 961 |
+
self.WEIGHT_PREFIX + self.PER_CHANNEL_MAX_KEY: max_val,
|
| 962 |
+
self.WEIGHT_PREFIX + self.PER_CHANNEL_MIN_KEY: min_val,
|
| 963 |
+
self.WEIGHT_PREFIX + self.GLOBAL_MAX_KEY: max(max_val),
|
| 964 |
+
self.WEIGHT_PREFIX + self.GLOBAL_MIN_KEY: min(min_val),
|
| 965 |
+
}
|
| 966 |
+
|
| 967 |
+
return weight_info
|
| 968 |
+
|
| 969 |
+
def _calculate_range_ratio(self, info_dict: Dict, info_str: str, module_fqn: str) -> torch.Tensor:
|
| 970 |
+
r"""
|
| 971 |
+
Takes in an info dict and calculates the s_c matrix.
|
| 972 |
+
|
| 973 |
+
Args:
|
| 974 |
+
info_dict (dict): A dictionary of either input or weight range info
|
| 975 |
+
info_str (str): A str describing whether currently looking at weight or input info
|
| 976 |
+
Either "weight" or "input"
|
| 977 |
+
module_fqn (str): The fqn of the module we are looking at
|
| 978 |
+
|
| 979 |
+
Returns a tensor of values, where each value is the s_c stat for a different channel
|
| 980 |
+
"""
|
| 981 |
+
# calculate the ratios of the info
|
| 982 |
+
# get the prefix str
|
| 983 |
+
prefix_str = self.ACTIVATION_PREFIX if info_str == self.INPUT_STR else self.WEIGHT_PREFIX
|
| 984 |
+
|
| 985 |
+
per_channel_range = info_dict[prefix_str + self.PER_CHANNEL_MAX_KEY] - info_dict[prefix_str + self.PER_CHANNEL_MIN_KEY]
|
| 986 |
+
global_range = info_dict[prefix_str + self.GLOBAL_MAX_KEY] - info_dict[prefix_str + self.GLOBAL_MIN_KEY]
|
| 987 |
+
|
| 988 |
+
if global_range == 0:
|
| 989 |
+
range_zero_explanation = "We recommend removing this channel as it doesn't provide any useful information."
|
| 990 |
+
raise ValueError(
|
| 991 |
+
"The range of the {} data for module {} is 0, which means you have a constant value channel. {}".format(
|
| 992 |
+
info_str, module_fqn, range_zero_explanation
|
| 993 |
+
)
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
ratio = per_channel_range / global_range
|
| 997 |
+
|
| 998 |
+
return ratio
|
| 999 |
+
|
| 1000 |
+
def _generate_comparison_values(self, input_info: Dict, weight_info: Dict) -> Dict[str, torch.Tensor]:
|
| 1001 |
+
r"""
|
| 1002 |
+
Takes in the information on the min and max values of the inputs and weights and:
|
| 1003 |
+
Calculates the comp stat for each channel: s_c = sqrt(w_c/W)/sqrt(i_c/I)
|
| 1004 |
+
|
| 1005 |
+
Args:
|
| 1006 |
+
input_info (dict): A dict mapping each observer to input range information
|
| 1007 |
+
weight_info (dict): A dict mapping each observer to weight range information
|
| 1008 |
+
|
| 1009 |
+
Returns a dict mapping relevant observer fqns (str) to a 1-D tensor.
|
| 1010 |
+
Each value is a different s_c value for a different channel
|
| 1011 |
+
"""
|
| 1012 |
+
# create return dictionary for each observer
|
| 1013 |
+
module_fqn_to_channel: Dict[str, torch.Tensor] = {}
|
| 1014 |
+
|
| 1015 |
+
# for each module (both passed in dicts should have same keys)
|
| 1016 |
+
for module_fqn in input_info:
|
| 1017 |
+
|
| 1018 |
+
# raise error if not in weight info
|
| 1019 |
+
if module_fqn not in weight_info:
|
| 1020 |
+
raise KeyError(f"Unable to find weight range stats for module {module_fqn}")
|
| 1021 |
+
|
| 1022 |
+
# calculate the ratios of the weight info and input info
|
| 1023 |
+
weight_ratio = self._calculate_range_ratio(weight_info[module_fqn], self.WEIGHT_STR, module_fqn)
|
| 1024 |
+
input_ratio = self._calculate_range_ratio(input_info[module_fqn], self.INPUT_STR, module_fqn)
|
| 1025 |
+
|
| 1026 |
+
# if mismatched size, because of grouping, we want to replicate weight enough times
|
| 1027 |
+
weight_channels = len(weight_ratio)
|
| 1028 |
+
input_channels = len(input_ratio)
|
| 1029 |
+
if weight_channels != input_channels:
|
| 1030 |
+
# we try to replicate
|
| 1031 |
+
assert input_channels % weight_channels == 0, "input channels should be divisible by weight channels."
|
| 1032 |
+
# get replication factor
|
| 1033 |
+
rep_factor: int = input_channels // weight_channels
|
| 1034 |
+
|
| 1035 |
+
# weight ratio is (n,), input ratio is (k,), we just repeat weight ratio k // n
|
| 1036 |
+
weight_ratio = weight_ratio.repeat(rep_factor)
|
| 1037 |
+
|
| 1038 |
+
# calculate the s metric per channel
|
| 1039 |
+
s = torch.sqrt(weight_ratio) / torch.sqrt(input_ratio)
|
| 1040 |
+
module_fqn_to_channel[module_fqn] = s
|
| 1041 |
+
|
| 1042 |
+
# return compiled observer ratios
|
| 1043 |
+
return module_fqn_to_channel
|
| 1044 |
+
|
| 1045 |
+
def _generate_dict_info(self, input_info: Dict, weight_info: Dict, comp_stats: Dict) -> Dict[str, Dict]:
|
| 1046 |
+
r"""
|
| 1047 |
+
Helper function for generate_detector_report that does the generation of the dictionary.
|
| 1048 |
+
This process is done as specified in generate_detector_report documentation
|
| 1049 |
+
|
| 1050 |
+
Args:
|
| 1051 |
+
input_info (dict): A dict mapping each module to input range information
|
| 1052 |
+
weight_info (dict): A dict mapping each module to weight range information
|
| 1053 |
+
comp_stats (dict): A dict mapping each module to its corresponding comp stat
|
| 1054 |
+
|
| 1055 |
+
Returns a dictionary mapping each module with relevant ModelReportObservers around them to:
|
| 1056 |
+
whether input weight equalization is recommended
|
| 1057 |
+
their s_c metric compared to the threshold
|
| 1058 |
+
the threshold used to make the recommendation
|
| 1059 |
+
the channel used for recording data
|
| 1060 |
+
the input channel range info
|
| 1061 |
+
the weight channel range info
|
| 1062 |
+
"""
|
| 1063 |
+
# store modules input weight equalization info
|
| 1064 |
+
input_weight_equalization_info: Dict[str, Dict] = {}
|
| 1065 |
+
|
| 1066 |
+
# for each module we add separate set of suggestions
|
| 1067 |
+
for module_fqn in input_info:
|
| 1068 |
+
|
| 1069 |
+
# get relevant info for this module
|
| 1070 |
+
mod_input_info: Dict = input_info[module_fqn]
|
| 1071 |
+
mod_weight_info: Dict = weight_info[module_fqn]
|
| 1072 |
+
mod_comp_stat: Dict = comp_stats[module_fqn]
|
| 1073 |
+
|
| 1074 |
+
# decide if each channel should have input weight equalization or not
|
| 1075 |
+
channel_rec_vals: list = []
|
| 1076 |
+
|
| 1077 |
+
for val in mod_comp_stat:
|
| 1078 |
+
float_rep: float = val.item()
|
| 1079 |
+
|
| 1080 |
+
# decide if recommending input weight equalization
|
| 1081 |
+
recommended: bool = float_rep >= self.ratio_threshold and float_rep <= 1 / self.ratio_threshold
|
| 1082 |
+
channel_rec_vals.append(recommended)
|
| 1083 |
+
|
| 1084 |
+
# build the return dict input
|
| 1085 |
+
# also unpack input and weight dicts into it
|
| 1086 |
+
input_weight_equalization_info[module_fqn] = {
|
| 1087 |
+
self.RECOMMENDED_KEY: channel_rec_vals,
|
| 1088 |
+
self.COMP_METRIC_KEY: mod_comp_stat,
|
| 1089 |
+
self.THRESHOLD_KEY: self.ratio_threshold,
|
| 1090 |
+
self.CHANNEL_KEY: self.ch_axis,
|
| 1091 |
+
**mod_input_info,
|
| 1092 |
+
**mod_weight_info,
|
| 1093 |
+
}
|
| 1094 |
+
|
| 1095 |
+
# return our compiled info for each module
|
| 1096 |
+
return input_weight_equalization_info
|
| 1097 |
+
|
| 1098 |
+
def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]:
|
| 1099 |
+
r"""
|
| 1100 |
+
Determines whether input weight equalization is appropriate for a given module.
|
| 1101 |
+
|
| 1102 |
+
Takes advantage of the ModelReport Observer which records per channel information of input range
|
| 1103 |
+
It then uses the passed in weight info inconjunction to compute the desired ratio
|
| 1104 |
+
Finally, it gives suggestions based on this information for each module of interest
|
| 1105 |
+
|
| 1106 |
+
Args:
|
| 1107 |
+
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
|
| 1108 |
+
|
| 1109 |
+
Returns a tuple with two elements:
|
| 1110 |
+
String report of of whether input weight equalization is recommended for certain modules
|
| 1111 |
+
Dictionary mapping modules of interest to:
|
| 1112 |
+
whether input weight equalization is recommended
|
| 1113 |
+
their s_c metric compared to the threshold
|
| 1114 |
+
the threshold used to make the recommendation
|
| 1115 |
+
the channel used for recording data
|
| 1116 |
+
the input channel range info
|
| 1117 |
+
the weight channel range info
|
| 1118 |
+
"""
|
| 1119 |
+
|
| 1120 |
+
# find the range of inputs
|
| 1121 |
+
input_values: Dict[str, Dict] = self._extract_input_info(model)
|
| 1122 |
+
|
| 1123 |
+
# find the range of weights
|
| 1124 |
+
weight_values: Dict[str, Dict] = self._extract_weight_info(model)
|
| 1125 |
+
|
| 1126 |
+
# calculate per_channel comparison statistic s_c
|
| 1127 |
+
comp_stats: Dict[str, torch.Tensor] = self._generate_comparison_values(input_values, weight_values)
|
| 1128 |
+
|
| 1129 |
+
# generate the return dictionary
|
| 1130 |
+
input_weight_equalization_info: Dict[str, Dict] = self._generate_dict_info(input_values, weight_values, comp_stats)
|
| 1131 |
+
|
| 1132 |
+
# now we can generate report based on this information
|
| 1133 |
+
input_weight_string = "Input-Weight Equalization suggestions: \n"
|
| 1134 |
+
|
| 1135 |
+
# some strings to be formatted depending on module we are adding
|
| 1136 |
+
module_suggestion_str = "For Module {} looked at with axis {}: \n"
|
| 1137 |
+
channel_suggestion_str = "\tWe suggest {} input weight equalization because {}\n"
|
| 1138 |
+
use_str = "to use"
|
| 1139 |
+
no_use_str = "to not use"
|
| 1140 |
+
input_weight_benefit_str = "{}/{} channels would benefit and we expect significant reduction in quantization error."
|
| 1141 |
+
input_weight_non_benefit_reasoning = "{}/{} channels benefitting from input-weight equalization being applied."
|
| 1142 |
+
input_weight_non_benefit_str = "we don't expect much improvement from input-weight equalization based on {}"
|
| 1143 |
+
|
| 1144 |
+
# added module check
|
| 1145 |
+
added_module: bool = False
|
| 1146 |
+
|
| 1147 |
+
# compile the suggestion string
|
| 1148 |
+
for module_fqn in input_weight_equalization_info:
|
| 1149 |
+
# we added at least 1 module
|
| 1150 |
+
added_module = True
|
| 1151 |
+
# add the module level description
|
| 1152 |
+
input_weight_string += module_suggestion_str.format(module_fqn, self.ch_axis)
|
| 1153 |
+
|
| 1154 |
+
mod_info: Dict[str, Any] = input_weight_equalization_info[module_fqn]
|
| 1155 |
+
|
| 1156 |
+
# gather info on how many channels would benefit from input weight and
|
| 1157 |
+
recommendation_per_channel: torch.Tensor = mod_info[self.RECOMMENDED_KEY]
|
| 1158 |
+
num_recs = sum(recommendation_per_channel)
|
| 1159 |
+
|
| 1160 |
+
if num_recs / len(recommendation_per_channel) >= self.DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO:
|
| 1161 |
+
input_benefit_formatted = input_weight_benefit_str.format(num_recs, len(recommendation_per_channel))
|
| 1162 |
+
channel_str = channel_suggestion_str.format(use_str, input_benefit_formatted)
|
| 1163 |
+
input_weight_string += channel_str
|
| 1164 |
+
else:
|
| 1165 |
+
non_benefit_reason_formatted = input_weight_non_benefit_reasoning.format(num_recs, len(recommendation_per_channel))
|
| 1166 |
+
non_benefit_str = input_weight_non_benefit_str.format(non_benefit_reason_formatted)
|
| 1167 |
+
channel_str = channel_suggestion_str.format(no_use_str, non_benefit_str)
|
| 1168 |
+
input_weight_string += channel_str
|
| 1169 |
+
|
| 1170 |
+
# if no modules looked at, amend return string
|
| 1171 |
+
if not added_module:
|
| 1172 |
+
input_weight_string += "No applicable layers for suggestions. Only linear and conv valid.\n"
|
| 1173 |
+
|
| 1174 |
+
# return a tuple with the string explanation and the compiled dict info
|
| 1175 |
+
return (input_weight_string, input_weight_equalization_info)
|
| 1176 |
+
|
| 1177 |
+
|
| 1178 |
+
class OutlierDetector(DetectorBase):
|
| 1179 |
+
r"""
|
| 1180 |
+
Determines whether there are significant outliers in activation data around a certain layer.
|
| 1181 |
+
|
| 1182 |
+
This is ideally used in conjunction with information on stationary vs. non-stationary distribution:
|
| 1183 |
+
If the data is stationary, and there are significant outliers, then we want to flag them
|
| 1184 |
+
We want to do this on a per channel basis for detecting outliers
|
| 1185 |
+
|
| 1186 |
+
Determines whether activation data is flagged as outlier based on if data is stationary and:
|
| 1187 |
+
p_r = avg(100th percentile / "reference_percentile"th percentile)
|
| 1188 |
+
where:
|
| 1189 |
+
p_r is average percentile ratio across all batches in the epoch
|
| 1190 |
+
reference_percentile is a percentile values between 0 and 100 exclusive
|
| 1191 |
+
|
| 1192 |
+
if p_r is above some threshold, then we consider the activations to have significant outliers
|
| 1193 |
+
|
| 1194 |
+
Args:
|
| 1195 |
+
ratio_threshold (float, optional): The threshold for p_r to determine if there are outliers in activations
|
| 1196 |
+
Should be >= 1
|
| 1197 |
+
Default: 3.5
|
| 1198 |
+
reference_percentile (float, optional): The denominator to find the relative scale of the 100th percentile
|
| 1199 |
+
Should be between 0 and 1
|
| 1200 |
+
Default: 0.975
|
| 1201 |
+
fraction_batches_used_threshold (float, optional): Threshold of fraction of batches per channel to determine outlier
|
| 1202 |
+
If fraction is below this, we deem number of samples used to calculate outliers as insignificant and alert user
|
| 1203 |
+
regardless of whether we detected outliers or not in channel to take a closer look at channel results
|
| 1204 |
+
Should be between 0 and 1
|
| 1205 |
+
Default: 0.95
|
| 1206 |
+
ch_axis (int, optional): The channel axis being observed to determine input weight equalization
|
| 1207 |
+
Default: 1
|
| 1208 |
+
|
| 1209 |
+
* :attr:`ratio_threshold`: The threshold for p_r to determine if there are outliers in activations
|
| 1210 |
+
The p_r value (average ratio of 100th percentile/reference_percentile) is compared to ratio_threshold
|
| 1211 |
+
If it is significantly greater, then we consider it an outlier
|
| 1212 |
+
This threshold was calculated based on the ratio of the percentiles in a normal distribution
|
| 1213 |
+
The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing
|
| 1214 |
+
|
| 1215 |
+
* :attr:`reference_percentile`: The denominator of the top fraction to find the relative scale of the 100th percentile
|
| 1216 |
+
Should be between 0 and 1
|
| 1217 |
+
The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing
|
| 1218 |
+
|
| 1219 |
+
* :attr:`fraction_batches_used_threshold`: The fraction of batches to determine outliers for each channel should be above this
|
| 1220 |
+
Some batches may not be used because of 0-based errors, so this is to ensure a good amount of the total batches are used
|
| 1221 |
+
Should be between 0 and 1
|
| 1222 |
+
|
| 1223 |
+
* :attr:`ch_axis`: The channel axis being observed to determine outliers
|
| 1224 |
+
|
| 1225 |
+
* :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
|
| 1226 |
+
"""
|
| 1227 |
+
|
| 1228 |
+
# names for the pre observers that are inserted
|
| 1229 |
+
DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"
|
| 1230 |
+
|
| 1231 |
+
# pre activation prefix
|
| 1232 |
+
INPUT_ACTIVATION_PREFIX = "input_activation_"
|
| 1233 |
+
|
| 1234 |
+
# names for dict keys
|
| 1235 |
+
OUTLIER_KEY = "outliers_detected"
|
| 1236 |
+
NUM_BATCHES_KEY = "outlier_detection_batches_used"
|
| 1237 |
+
IS_SUFFICIENT_BATCHES_KEY = "outlier_detection_is_sufficient_batches"
|
| 1238 |
+
COMP_METRIC_KEY = "outlier_detection_percentile_ratios"
|
| 1239 |
+
RATIO_THRES_KEY = "outlier_detection_ratio_threshold"
|
| 1240 |
+
REF_PERCENTILE_KEY = "outlier_detection_reference_percentile"
|
| 1241 |
+
CHANNEL_AXIS_KEY = "outlier_detection_channel_axis"
|
| 1242 |
+
MAX_VALS_KEY = INPUT_ACTIVATION_PREFIX + "per_channel_max"
|
| 1243 |
+
CONSTANT_COUNTS_KEY = "constant_batch_counts"
|
| 1244 |
+
|
| 1245 |
+
def __init__(
|
| 1246 |
+
self,
|
| 1247 |
+
ratio_threshold: float = 3.5,
|
| 1248 |
+
reference_percentile: float = 0.975,
|
| 1249 |
+
fraction_batches_used_threshold: float = 0.95,
|
| 1250 |
+
ch_axis: int = 1,
|
| 1251 |
+
):
|
| 1252 |
+
# initialize the variables of interest
|
| 1253 |
+
self.ratio_threshold = ratio_threshold
|
| 1254 |
+
|
| 1255 |
+
# make sure passed in percentile is valid
|
| 1256 |
+
assert reference_percentile >= 0 and reference_percentile <= 1
|
| 1257 |
+
assert fraction_batches_used_threshold >= 0 and fraction_batches_used_threshold <= 1
|
| 1258 |
+
self.reference_percentile = reference_percentile
|
| 1259 |
+
self.fraction_batches_used_threshold = fraction_batches_used_threshold
|
| 1260 |
+
self.ch_axis = ch_axis
|
| 1261 |
+
|
| 1262 |
+
def get_detector_name(self) -> str:
|
| 1263 |
+
r"""Returns the name of this detector"""
|
| 1264 |
+
return "outlier_detector"
|
| 1265 |
+
|
| 1266 |
+
def _supports_insertion(self, module: nn.Module) -> bool:
|
| 1267 |
+
r"""Returns whether the given module is supported for observers insertion
|
| 1268 |
+
|
| 1269 |
+
Any module that doesn't have children and isn't an observer itself is supported
|
| 1270 |
+
|
| 1271 |
+
Args
|
| 1272 |
+
module: The module to check and ensure is supported
|
| 1273 |
+
|
| 1274 |
+
Returns True if the module is supported by observer, False otherwise
|
| 1275 |
+
"""
|
| 1276 |
+
# case for insertion of module
|
| 1277 |
+
# check if the module has any children and isn't observer
|
| 1278 |
+
num_children = len(list(module.children()))
|
| 1279 |
+
return num_children == 0 and not _is_activation_post_process(module)
|
| 1280 |
+
|
| 1281 |
+
def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
|
| 1282 |
+
r""" Returns the DetectorQConfigInfo for each module_fqn relevant
|
| 1283 |
+
Args
|
| 1284 |
+
model (nn.Module or subclass): model to find observer insertion points
|
| 1285 |
+
|
| 1286 |
+
Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
|
| 1287 |
+
A DetectorQConfigInfo with the information to generate a QConfig for a specific module
|
| 1288 |
+
"""
|
| 1289 |
+
# currently doesn't do anything for outlier detector
|
| 1290 |
+
return {}
|
| 1291 |
+
|
| 1292 |
+
def _supports_report_gen(self, module: nn.Module) -> bool:
|
| 1293 |
+
r"""Returns whether the given module is supported for report generation
|
| 1294 |
+
|
| 1295 |
+
Any module that has a model report pre-observer is supported
|
| 1296 |
+
|
| 1297 |
+
Args
|
| 1298 |
+
module: The module to check and ensure is supported
|
| 1299 |
+
|
| 1300 |
+
Returns True if the module is supported by observer, False otherwise
|
| 1301 |
+
"""
|
| 1302 |
+
return hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
|
| 1303 |
+
|
| 1304 |
+
def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]:
|
| 1305 |
+
r""" Determines where observers need to be inserted for the Outlier Detector.
|
| 1306 |
+
|
| 1307 |
+
For this detector, we want to place observers in front of supported layers.
|
| 1308 |
+
|
| 1309 |
+
Currently inserts observers for:
|
| 1310 |
+
all layers that do not have children (leaf level layers)
|
| 1311 |
+
|
| 1312 |
+
Args:
|
| 1313 |
+
prepared_fx_model (GraphModule): The prepared Fx GraphModule
|
| 1314 |
+
|
| 1315 |
+
Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
|
| 1316 |
+
key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
|
| 1317 |
+
key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
|
| 1318 |
+
key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
|
| 1319 |
+
key "observer_args" -> The arguments that are meant to be passed into the observer
|
| 1320 |
+
"""
|
| 1321 |
+
# observer for this detector is ModelReportObserver
|
| 1322 |
+
obs_ctr = ModelReportObserver
|
| 1323 |
+
|
| 1324 |
+
# return dict
|
| 1325 |
+
obs_fqn_to_info: Dict[str, Dict[str, Any]] = {}
|
| 1326 |
+
|
| 1327 |
+
for fqn, module in prepared_fx_model.named_modules():
|
| 1328 |
+
# check to see if module is of a supported type
|
| 1329 |
+
if self._supports_insertion(module):
|
| 1330 |
+
# if it's a supported type, we want to get node and add observer insert locations
|
| 1331 |
+
targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
|
| 1332 |
+
|
| 1333 |
+
# add entry for pre-observer
|
| 1334 |
+
pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
|
| 1335 |
+
|
| 1336 |
+
obs_fqn_to_info[pre_obs_fqn] = {
|
| 1337 |
+
DETECTOR_TARGET_NODE_KEY: targeted_node,
|
| 1338 |
+
DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis, comp_percentile=self.reference_percentile),
|
| 1339 |
+
DETECTOR_IS_POST_OBS_KEY: False,
|
| 1340 |
+
DETECTOR_OBS_ARGS_KEY: targeted_node.args,
|
| 1341 |
+
}
|
| 1342 |
+
|
| 1343 |
+
return obs_fqn_to_info
|
| 1344 |
+
|
| 1345 |
+
def _calculate_outlier_info(
|
| 1346 |
+
self,
|
| 1347 |
+
percentile_ratios: torch.Tensor,
|
| 1348 |
+
counted_batches: torch.Tensor,
|
| 1349 |
+
total_batches: int,
|
| 1350 |
+
) -> Dict[str, List[bool]]:
|
| 1351 |
+
r"""
|
| 1352 |
+
Gives info on whether the percentile ratios calculated would be considered outliers
|
| 1353 |
+
Also gives information on whether the collected data is statistically significant to make this claim
|
| 1354 |
+
|
| 1355 |
+
Args:
|
| 1356 |
+
percentile_ratios (torch.Tensor): The average percentile_ratios per channel calculated by the observer
|
| 1357 |
+
counted_batches (torch.Tensor): The number of batches used for average calculation per tensor
|
| 1358 |
+
total_batches (int): The total number of batches that passed through observer in this epoch
|
| 1359 |
+
|
| 1360 |
+
Returns a dictionary mapping:
|
| 1361 |
+
"outliers_detected" : list of bools per channel that are true if it is considered an outlier
|
| 1362 |
+
"is_sufficient_batches": if o_r was >= fraction_batches_used_threshold:
|
| 1363 |
+
where o_r = counted_batches / total_batches
|
| 1364 |
+
"""
|
| 1365 |
+
outlier_dict: Dict[str, List[bool]] = {self.OUTLIER_KEY: [], self.IS_SUFFICIENT_BATCHES_KEY: []}
|
| 1366 |
+
|
| 1367 |
+
# get both as flattened lists for easy mapping
|
| 1368 |
+
ratios_list: List = percentile_ratios.tolist()
|
| 1369 |
+
num_batches_list: List = counted_batches.tolist()
|
| 1370 |
+
|
| 1371 |
+
# calculate whether channels were statistically significant
|
| 1372 |
+
significant_size = [
|
| 1373 |
+
batch_size / total_batches >= self.fraction_batches_used_threshold for batch_size in num_batches_list
|
| 1374 |
+
]
|
| 1375 |
+
outlier_dict[self.IS_SUFFICIENT_BATCHES_KEY] = significant_size
|
| 1376 |
+
|
| 1377 |
+
# calculate for each channel whether it's an outlier or not based on ratio
|
| 1378 |
+
outlier_detected = [ratio > self.ratio_threshold for ratio in ratios_list]
|
| 1379 |
+
outlier_dict[self.OUTLIER_KEY] = outlier_detected
|
| 1380 |
+
|
| 1381 |
+
# return the dictionary with the two lists
|
| 1382 |
+
return outlier_dict
|
| 1383 |
+
|
| 1384 |
+
def _generate_info_dict(self, model: GraphModule) -> Dict[str, Dict]:
|
| 1385 |
+
r"""
|
| 1386 |
+
Helper function for generate_detector_report that does the generation of the dictionary.
|
| 1387 |
+
This process is done as specified in generate_detector_report documentation
|
| 1388 |
+
|
| 1389 |
+
Args:
|
| 1390 |
+
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
|
| 1391 |
+
|
| 1392 |
+
Returns a dict mapping relevant module fqns to:
|
| 1393 |
+
whether there were outliers found in activation before
|
| 1394 |
+
the number of batches used for each channel
|
| 1395 |
+
whether fraction of applicable batches used is above fraction_batches_used_threshold
|
| 1396 |
+
their p_r metric compared to the threshold
|
| 1397 |
+
the threshold used to make the recommendation
|
| 1398 |
+
the reference_percentile used to make the recommendation
|
| 1399 |
+
the channel axis used to determine individual channels
|
| 1400 |
+
the constant batch counts per channel
|
| 1401 |
+
the per channel max values
|
| 1402 |
+
"""
|
| 1403 |
+
# return dictionary mapping observer fqns to desired info
|
| 1404 |
+
info_dict: Dict[str, Dict] = {}
|
| 1405 |
+
|
| 1406 |
+
for fqn, module in model.named_modules():
|
| 1407 |
+
# if module is supported and it has a pre-observer
|
| 1408 |
+
if self._supports_report_gen(module):
|
| 1409 |
+
# get pre observer for the module
|
| 1410 |
+
pre_obs: ModelReportObserver = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
|
| 1411 |
+
|
| 1412 |
+
# get the number of batches and calculated ratio thresholds
|
| 1413 |
+
num_batches: torch.Tensor = pre_obs.percentile_batches_tracked
|
| 1414 |
+
average_ratios: torch.Tensor = pre_obs.average_percentile_ratio
|
| 1415 |
+
channel_batch_cnts: torch.Tensor = pre_obs.constant_channels
|
| 1416 |
+
total_batches: int = pre_obs.num_batches_tracked
|
| 1417 |
+
|
| 1418 |
+
# also get the max values
|
| 1419 |
+
max_vals: torch.Tensor = pre_obs.max_val
|
| 1420 |
+
|
| 1421 |
+
# we have to specifically modify how we are recording negative ratio for pre-relu layers
|
| 1422 |
+
for index, ratio_val in enumerate(average_ratios):
|
| 1423 |
+
# check if we have a negative ratio
|
| 1424 |
+
# a ratio might be negative if we have a situation where the 100th percentile is
|
| 1425 |
+
# > 0 while the nth percentile is < 0, in which case this would not be detected
|
| 1426 |
+
# as an outlier. Since we care more about magnitude, we make it positive.
|
| 1427 |
+
if ratio_val.item() < 0:
|
| 1428 |
+
# first make it positive
|
| 1429 |
+
average_ratios[index] = -ratio_val
|
| 1430 |
+
|
| 1431 |
+
if ratio_val.item() < 1:
|
| 1432 |
+
# if it's less than 1 we have the flip it as well
|
| 1433 |
+
average_ratios[index] = 1 / ratio_val
|
| 1434 |
+
|
| 1435 |
+
outlier_calcs = self._calculate_outlier_info(average_ratios, num_batches, total_batches)
|
| 1436 |
+
|
| 1437 |
+
# calculate whether ratios were outliers
|
| 1438 |
+
info_dict[fqn] = {
|
| 1439 |
+
self.CHANNEL_AXIS_KEY: self.ch_axis,
|
| 1440 |
+
self.REF_PERCENTILE_KEY: self.reference_percentile,
|
| 1441 |
+
self.RATIO_THRES_KEY: self.ratio_threshold,
|
| 1442 |
+
self.COMP_METRIC_KEY: average_ratios,
|
| 1443 |
+
self.NUM_BATCHES_KEY: num_batches,
|
| 1444 |
+
self.OUTLIER_KEY: outlier_calcs[self.OUTLIER_KEY],
|
| 1445 |
+
self.IS_SUFFICIENT_BATCHES_KEY: outlier_calcs[self.IS_SUFFICIENT_BATCHES_KEY],
|
| 1446 |
+
self.CONSTANT_COUNTS_KEY: channel_batch_cnts,
|
| 1447 |
+
self.MAX_VALS_KEY: max_vals
|
| 1448 |
+
}
|
| 1449 |
+
|
| 1450 |
+
return info_dict
|
| 1451 |
+
|
| 1452 |
+
def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]:
|
| 1453 |
+
r"""
|
| 1454 |
+
Determines whether input weight equalization is appropriate for a given module.
|
| 1455 |
+
|
| 1456 |
+
Takes advantage of the ModelReport Observer which records the relevant percentile information
|
| 1457 |
+
|
| 1458 |
+
Args:
|
| 1459 |
+
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
|
| 1460 |
+
|
| 1461 |
+
Returns a tuple with two elements:
|
| 1462 |
+
String report of of whether there are outliers in the activations around certain modules
|
| 1463 |
+
Dictionary mapping modules of interest to:
|
| 1464 |
+
whether there were outliers found in activation before
|
| 1465 |
+
the number of batches used for each channel
|
| 1466 |
+
whether fraction of applicable batches used is above fraction_batches_used_threshold
|
| 1467 |
+
their p_r metric compared to the threshold
|
| 1468 |
+
the threshold used to make the recommendation
|
| 1469 |
+
the reference_percentile used to make the recommendation
|
| 1470 |
+
the channel axis used to determine individual channels
|
| 1471 |
+
the constant batch counts per channel
|
| 1472 |
+
the per channel max values
|
| 1473 |
+
"""
|
| 1474 |
+
# generate the information dictionary of outlier information
|
| 1475 |
+
info_dict = self._generate_info_dict(model)
|
| 1476 |
+
|
| 1477 |
+
# now we can generate report based on this information
|
| 1478 |
+
outlier_string = "Outlier detection report: \n"
|
| 1479 |
+
|
| 1480 |
+
# added module check
|
| 1481 |
+
added_module: bool = False
|
| 1482 |
+
|
| 1483 |
+
# some strings to be formatted depending on module we are adding
|
| 1484 |
+
module_suggestion_str = "For Module {} looked at with axis {}: \n"
|
| 1485 |
+
channel_suggestion_str = "\tFor channel {}, we found outliers in the preceding activation data with {}.\n"
|
| 1486 |
+
channel_max_value_str = "a max value across all batches of {}"
|
| 1487 |
+
note_string = "Note: outlier detection is only reliable for {}. We recommend {} to ensure the most accurate results."
|
| 1488 |
+
note_distribution = "stationary distributions"
|
| 1489 |
+
note_rec = "running the static vs. dynamic detector to ensure activation data before modules above is stationary"
|
| 1490 |
+
|
| 1491 |
+
# suggestion for constant batch check since that can make it no outliers
|
| 1492 |
+
constant_str = "\tFor channel {}, we found {} constant value batches. {}\n"
|
| 1493 |
+
constant_suggestion = "We recommend taking a look at the dict and data to see how frequent this occurred and why."
|
| 1494 |
+
|
| 1495 |
+
# compile the suggestion string
|
| 1496 |
+
for module_fqn in info_dict:
|
| 1497 |
+
# get module specific info
|
| 1498 |
+
mod_info: Dict[str, Any] = info_dict[module_fqn]
|
| 1499 |
+
# check to see if we already added high level model desc
|
| 1500 |
+
added_model_desc = False
|
| 1501 |
+
# look at each individual channel and add a suggestion
|
| 1502 |
+
for index, outlier_detected in enumerate(mod_info[self.OUTLIER_KEY]):
|
| 1503 |
+
if outlier_detected:
|
| 1504 |
+
# we found at least 1 outlier
|
| 1505 |
+
if not added_model_desc:
|
| 1506 |
+
# add the module level description
|
| 1507 |
+
outlier_string += module_suggestion_str.format(module_fqn, self.ch_axis)
|
| 1508 |
+
added_model_desc = True
|
| 1509 |
+
|
| 1510 |
+
# we mark that we found at least one outlier
|
| 1511 |
+
added_module = True
|
| 1512 |
+
max_value_found_str = channel_max_value_str.format(mod_info[self.MAX_VALS_KEY][index])
|
| 1513 |
+
channel_str = channel_suggestion_str.format(index, max_value_found_str)
|
| 1514 |
+
outlier_string += channel_str
|
| 1515 |
+
|
| 1516 |
+
# also check if we found constant batch
|
| 1517 |
+
if mod_info[self.CONSTANT_COUNTS_KEY][index] != 0:
|
| 1518 |
+
# make sure we add a module level highlight.
|
| 1519 |
+
if not added_model_desc:
|
| 1520 |
+
# add the module level description
|
| 1521 |
+
outlier_string += module_suggestion_str.format(module_fqn, self.ch_axis)
|
| 1522 |
+
added_model_desc = True
|
| 1523 |
+
|
| 1524 |
+
constant_values_for_channel = mod_info[self.CONSTANT_COUNTS_KEY][index]
|
| 1525 |
+
formatted_str = constant_str.format(index, constant_values_for_channel, constant_suggestion)
|
| 1526 |
+
outlier_string += formatted_str
|
| 1527 |
+
# we also added at least one thing to description
|
| 1528 |
+
added_module = True
|
| 1529 |
+
|
| 1530 |
+
|
| 1531 |
+
# if found outlier, give suggestion, else give default response
|
| 1532 |
+
if added_module:
|
| 1533 |
+
# compose the note string
|
| 1534 |
+
note_composed = note_string.format(note_distribution, note_rec)
|
| 1535 |
+
outlier_string += note_composed
|
| 1536 |
+
else:
|
| 1537 |
+
outlier_string += "There were no outliers found in the activations.\n"
|
| 1538 |
+
|
| 1539 |
+
return (outlier_string, info_dict)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py
ADDED
|
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Any, Set, Dict, List, Tuple, OrderedDict
|
| 3 |
+
from collections import OrderedDict as OrdDict
|
| 4 |
+
|
| 5 |
+
# try to import tablate
|
| 6 |
+
got_tabulate = True
|
| 7 |
+
try:
|
| 8 |
+
from tabulate import tabulate
|
| 9 |
+
except ImportError:
|
| 10 |
+
got_tabulate = False
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# var to see if we could import matplotlib
|
| 14 |
+
got_matplotlib = True
|
| 15 |
+
try:
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
except ImportError:
|
| 18 |
+
got_matplotlib = False
|
| 19 |
+
|
| 20 |
+
class ModelReportVisualizer:
|
| 21 |
+
r"""
|
| 22 |
+
The ModelReportVisualizer class aims to provide users a way to visualize some of the statistics
|
| 23 |
+
that were generated by the ModelReport API. However, at a higher level, the class aims to provide
|
| 24 |
+
some level of visualization of statistics to PyTorch in order to make it easier to parse data and
|
| 25 |
+
diagnose any potential issues with data or a specific model. With respect to the visualizations,
|
| 26 |
+
the ModelReportVisualizer class currently supports several methods of visualizing data.
|
| 27 |
+
|
| 28 |
+
Supported Visualization Methods Include:
|
| 29 |
+
- Table format
|
| 30 |
+
- Plot format (line graph)
|
| 31 |
+
- Histogram format
|
| 32 |
+
|
| 33 |
+
For all of the existing visualization methods, there is the option to filter data based on:
|
| 34 |
+
- A module fqn prefix
|
| 35 |
+
- Feature [required for the plot and histogram]
|
| 36 |
+
|
| 37 |
+
* :attr:`generated_reports` The reports generated by the ModelReport class in the structure below
|
| 38 |
+
Ensure sure that features that are the same across different report contain the same name
|
| 39 |
+
Ensure that objects representing the same features are the same type / dimension (where applicable)
|
| 40 |
+
|
| 41 |
+
Note:
|
| 42 |
+
Currently, the ModelReportVisualizer class supports visualization of data generated by the
|
| 43 |
+
ModelReport class. However, this structure is extensible and should allow the visualization of
|
| 44 |
+
other information as long as the information is structured in the following general format:
|
| 45 |
+
|
| 46 |
+
Report Structure
|
| 47 |
+
-- module_fqn [module with attached detectors]
|
| 48 |
+
|
|
| 49 |
+
-- feature keys [not every detector extracts same information]
|
| 50 |
+
[same collected info has same keys, unless can be specific to detector]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
The goal behind the class is that the generated visualizations can be used in conjunction with the generated
|
| 54 |
+
report for people to get a better understanding of issues and what the fix might be. It is also just to provide
|
| 55 |
+
a good visualization platform, since it might be hard to parse through the ModelReport returned dictionary as
|
| 56 |
+
that grows in size.
|
| 57 |
+
|
| 58 |
+
General Use Flow Expected
|
| 59 |
+
1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects
|
| 60 |
+
2.) Prepare your model with prepare_fx
|
| 61 |
+
3.) Call model_report.prepare_detailed_calibration on your model to add relevant observers
|
| 62 |
+
4.) Callibrate your model with data
|
| 63 |
+
5.) Call model_report.generate_report on your model to generate report and optionally remove added observers
|
| 64 |
+
6.) Use output of model_report.generate_report to initialize ModelReportVisualizer instance
|
| 65 |
+
7.) Use instance to view different views of data as desired, applying filters as needed
|
| 66 |
+
8.) Either see the super detailed information or just the actual printed or shown table / plot / histogram
|
| 67 |
+
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
# keys for table dict
|
| 71 |
+
TABLE_TENSOR_KEY = "tensor_level_info"
|
| 72 |
+
TABLE_CHANNEL_KEY = "channel_level_info"
|
| 73 |
+
|
| 74 |
+
# Constants for header vals
|
| 75 |
+
NUM_NON_FEATURE_TENSOR_HEADERS = 2
|
| 76 |
+
NUM_NON_FEATURE_CHANNEL_HEADERS = 3
|
| 77 |
+
|
| 78 |
+
# Constants for row index in header
|
| 79 |
+
CHANNEL_NUM_INDEX = 2
|
| 80 |
+
|
| 81 |
+
def __init__(self, generated_reports: OrderedDict[str, Any]):
|
| 82 |
+
r"""
|
| 83 |
+
Initializes the ModelReportVisualizer instance with the necessary reports.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
generated_reports (Dict[str, Any]): The reports generated by the ModelReport class
|
| 87 |
+
can also be a dictionary generated in another manner, as long as format is same
|
| 88 |
+
"""
|
| 89 |
+
self.generated_reports = generated_reports
|
| 90 |
+
|
| 91 |
+
def get_all_unique_module_fqns(self) -> Set[str]:
|
| 92 |
+
r"""
|
| 93 |
+
The purpose of this method is to provide a user the set of all module_fqns so that if
|
| 94 |
+
they wish to use some of the filtering capabilities of the ModelReportVisualizer class,
|
| 95 |
+
they don't need to manually parse the generated_reports dictionary to get this information.
|
| 96 |
+
|
| 97 |
+
Returns all the unique module fqns present in the reports the ModelReportVisualizer
|
| 98 |
+
instance was initialized with.
|
| 99 |
+
"""
|
| 100 |
+
# returns the keys of the ordered dict
|
| 101 |
+
return set(self.generated_reports.keys())
|
| 102 |
+
|
| 103 |
+
def get_all_unique_feature_names(self, plottable_features_only: bool = True) -> Set[str]:
|
| 104 |
+
r"""
|
| 105 |
+
The purpose of this method is to provide a user the set of all feature names so that if
|
| 106 |
+
they wish to use the filtering capabilities of the generate_table_view(), or use either of
|
| 107 |
+
the generate_plot_view() or generate_histogram_view(), they don't need to manually parse
|
| 108 |
+
the generated_reports dictionary to get this information.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
plottable_features_only (bool): True if the user is only looking for plottable features,
|
| 112 |
+
False otherwise
|
| 113 |
+
plottable features are those that are tensor values
|
| 114 |
+
Default: True (only return those feature names that are plottable)
|
| 115 |
+
|
| 116 |
+
Returns all the unique module fqns present in the reports the ModelReportVisualizer
|
| 117 |
+
instance was initialized with.
|
| 118 |
+
"""
|
| 119 |
+
unique_feature_names = set()
|
| 120 |
+
for module_fqn in self.generated_reports:
|
| 121 |
+
# get dict of the features
|
| 122 |
+
feature_dict: Dict[str, Any] = self.generated_reports[module_fqn]
|
| 123 |
+
|
| 124 |
+
# loop through features
|
| 125 |
+
for feature_name in feature_dict:
|
| 126 |
+
# if we need plottable, ensure type of val is tensor
|
| 127 |
+
if not plottable_features_only or type(feature_dict[feature_name]) == torch.Tensor:
|
| 128 |
+
unique_feature_names.add(feature_name)
|
| 129 |
+
|
| 130 |
+
# return our compiled set of unique feature names
|
| 131 |
+
return unique_feature_names
|
| 132 |
+
|
| 133 |
+
def _get_filtered_data(self, feature_filter: str, module_fqn_filter: str) -> OrderedDict[str, Any]:
|
| 134 |
+
r"""
|
| 135 |
+
Filters the data and returns it in the same ordered dictionary format so the relevant views can be displayed.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
feature_filter (str): The feature filter, if we want to filter the set of data to only include
|
| 139 |
+
a certain set of features that include feature_filter
|
| 140 |
+
If feature = "", then we do not filter based on any features
|
| 141 |
+
module_fqn_filter (str): The filter on prefix for the module fqn. All modules that have fqn with
|
| 142 |
+
this prefix will be included
|
| 143 |
+
If module_fqn_filter = "" we do not filter based on module fqn, and include all modules
|
| 144 |
+
|
| 145 |
+
First, the data is filtered based on module_fqn, and then filtered based on feature
|
| 146 |
+
Returns an OrderedDict (sorted in order of model) mapping:
|
| 147 |
+
module_fqns -> feature_names -> values
|
| 148 |
+
"""
|
| 149 |
+
# create return dict
|
| 150 |
+
filtered_dict: OrderedDict[str, Any] = OrdDict()
|
| 151 |
+
|
| 152 |
+
for module_fqn in self.generated_reports:
|
| 153 |
+
# first filter based on module
|
| 154 |
+
if module_fqn_filter == "" or module_fqn_filter in module_fqn:
|
| 155 |
+
# create entry for module and loop through features
|
| 156 |
+
filtered_dict[module_fqn] = {}
|
| 157 |
+
module_reports = self.generated_reports[module_fqn]
|
| 158 |
+
for feature_name in module_reports:
|
| 159 |
+
# check if filtering on features and do so if desired
|
| 160 |
+
if feature_filter == "" or feature_filter in feature_name:
|
| 161 |
+
filtered_dict[module_fqn][feature_name] = module_reports[feature_name]
|
| 162 |
+
|
| 163 |
+
# we have populated the filtered dict, and must return it
|
| 164 |
+
|
| 165 |
+
return filtered_dict
|
| 166 |
+
|
| 167 |
+
def _generate_tensor_table(
|
| 168 |
+
self,
|
| 169 |
+
filtered_data: OrderedDict[str, Dict[str, Any]],
|
| 170 |
+
tensor_features: List[str]
|
| 171 |
+
) -> Tuple[List, List]:
|
| 172 |
+
r"""
|
| 173 |
+
Takes in the filtered data and features list and generates the tensor headers and table
|
| 174 |
+
|
| 175 |
+
Currently meant to generate the headers and table for both the tensor information.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
filtered_data (OrderedDict[str, Dict[str, Any]]): An OrderedDict (sorted in order of model) mapping:
|
| 179 |
+
module_fqns -> feature_names -> values
|
| 180 |
+
tensor_features (List[str]): A list of the tensor level features
|
| 181 |
+
|
| 182 |
+
Returns a tuple with:
|
| 183 |
+
A list of the headers of the tensor table
|
| 184 |
+
A list of lists containing the table information row by row
|
| 185 |
+
The 0th index row will contain the headers of the columns
|
| 186 |
+
The rest of the rows will contain data
|
| 187 |
+
"""
|
| 188 |
+
# now we compose the tensor information table
|
| 189 |
+
tensor_table: List[List[Any]] = []
|
| 190 |
+
tensor_headers: List[str] = []
|
| 191 |
+
|
| 192 |
+
# append the table row to the table only if we have features
|
| 193 |
+
if len(tensor_features) > 0:
|
| 194 |
+
# now we add all the data
|
| 195 |
+
for index, module_fqn in enumerate(filtered_data):
|
| 196 |
+
# we make a new row for the tensor table
|
| 197 |
+
tensor_table_row = [index, module_fqn]
|
| 198 |
+
for feature in tensor_features:
|
| 199 |
+
# we iterate in same order of added features
|
| 200 |
+
|
| 201 |
+
if feature in filtered_data[module_fqn]:
|
| 202 |
+
# add value if applicable to module
|
| 203 |
+
feature_val = filtered_data[module_fqn][feature]
|
| 204 |
+
else:
|
| 205 |
+
# add that it is not applicable
|
| 206 |
+
feature_val = "Not Applicable"
|
| 207 |
+
|
| 208 |
+
# if it's a tensor we want to extract val
|
| 209 |
+
if isinstance(feature_val, torch.Tensor):
|
| 210 |
+
feature_val = feature_val.item()
|
| 211 |
+
|
| 212 |
+
# we add to our list of values
|
| 213 |
+
tensor_table_row.append(feature_val)
|
| 214 |
+
|
| 215 |
+
tensor_table.append(tensor_table_row)
|
| 216 |
+
|
| 217 |
+
# add row of headers of we actually have something, otherwise just empty
|
| 218 |
+
if len(tensor_table) != 0:
|
| 219 |
+
tensor_headers = ["idx", "layer_fqn"] + tensor_features
|
| 220 |
+
|
| 221 |
+
return (tensor_headers, tensor_table)
|
| 222 |
+
|
| 223 |
+
def _generate_channels_table(
|
| 224 |
+
self,
|
| 225 |
+
filtered_data: OrderedDict[str, Any],
|
| 226 |
+
channel_features: List[str],
|
| 227 |
+
num_channels: int
|
| 228 |
+
) -> Tuple[List, List]:
|
| 229 |
+
r"""
|
| 230 |
+
Takes in the filtered data and features list and generates the channels headers and table
|
| 231 |
+
|
| 232 |
+
Currently meant to generate the headers and table for both the channels information.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
filtered_data (OrderedDict[str, Any]): An OrderedDict (sorted in order of model) mapping:
|
| 236 |
+
module_fqns -> feature_names -> values
|
| 237 |
+
channel_features (List[str]): A list of the channel level features
|
| 238 |
+
num_channels (int): Number of channels in the channel data
|
| 239 |
+
|
| 240 |
+
Returns a tuple with:
|
| 241 |
+
A list of the headers of the channel table
|
| 242 |
+
A list of lists containing the table information row by row
|
| 243 |
+
The 0th index row will contain the headers of the columns
|
| 244 |
+
The rest of the rows will contain data
|
| 245 |
+
"""
|
| 246 |
+
# now we compose the table for the channel information table
|
| 247 |
+
channel_table: List[List[Any]] = []
|
| 248 |
+
channel_headers: List[str] = []
|
| 249 |
+
|
| 250 |
+
# counter to keep track of number of entries in
|
| 251 |
+
channel_table_entry_counter: int = 0
|
| 252 |
+
|
| 253 |
+
if len(channel_features) > 0:
|
| 254 |
+
# now we add all channel data
|
| 255 |
+
for module_fqn in filtered_data:
|
| 256 |
+
# we iterate over all channels
|
| 257 |
+
for channel in range(num_channels):
|
| 258 |
+
# we make a new row for the channel
|
| 259 |
+
new_channel_row = [channel_table_entry_counter, module_fqn, channel]
|
| 260 |
+
for feature in channel_features:
|
| 261 |
+
if feature in filtered_data[module_fqn]:
|
| 262 |
+
# add value if applicable to module
|
| 263 |
+
feature_val = filtered_data[module_fqn][feature][channel]
|
| 264 |
+
else:
|
| 265 |
+
# add that it is not applicable
|
| 266 |
+
feature_val = "Not Applicable"
|
| 267 |
+
|
| 268 |
+
# if it's a tensor we want to extract val
|
| 269 |
+
if type(feature_val) is torch.Tensor:
|
| 270 |
+
feature_val = feature_val.item()
|
| 271 |
+
|
| 272 |
+
# add value to channel specific row
|
| 273 |
+
new_channel_row.append(feature_val)
|
| 274 |
+
|
| 275 |
+
# add to table and increment row index counter
|
| 276 |
+
channel_table.append(new_channel_row)
|
| 277 |
+
channel_table_entry_counter += 1
|
| 278 |
+
|
| 279 |
+
# add row of headers of we actually have something, otherwise just empty
|
| 280 |
+
if len(channel_table) != 0:
|
| 281 |
+
channel_headers = ["idx", "layer_fqn", "channel"] + channel_features
|
| 282 |
+
|
| 283 |
+
return (channel_headers, channel_table)
|
| 284 |
+
|
| 285 |
+
def generate_filtered_tables(self, feature_filter: str = "", module_fqn_filter: str = "") -> Dict[str, Tuple[List, List]]:
|
| 286 |
+
r"""
|
| 287 |
+
Takes in optional filter values and generates two tables with desired information.
|
| 288 |
+
|
| 289 |
+
The generated tables are presented in both a list-of-lists format
|
| 290 |
+
|
| 291 |
+
The reason for the two tables are that they handle different things:
|
| 292 |
+
1.) the first table handles all tensor level information
|
| 293 |
+
2.) the second table handles and displays all channel based information
|
| 294 |
+
|
| 295 |
+
The reasoning for this is that having all the info in one table can make it ambiguous which collected
|
| 296 |
+
statistics are global, and which are actually per-channel, so it's better to split it up into two
|
| 297 |
+
tables. This also makes the information much easier to digest given the plethora of statistics collected
|
| 298 |
+
|
| 299 |
+
Tensor table columns:
|
| 300 |
+
idx layer_fqn feature_1 feature_2 feature_3 .... feature_n
|
| 301 |
+
---- --------- --------- --------- --------- ---------
|
| 302 |
+
|
| 303 |
+
Per-Channel table columns:
|
| 304 |
+
idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n
|
| 305 |
+
---- --------- ------- --------- --------- --------- ---------
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
feature_filter (str, optional): Filters the features presented to only those that
|
| 309 |
+
contain this filter substring
|
| 310 |
+
Default = "", results in all the features being printed
|
| 311 |
+
module_fqn_filter (str, optional): Only includes modules that contains this string
|
| 312 |
+
Default = "", results in all the modules in the reports to be visible in the table
|
| 313 |
+
|
| 314 |
+
Returns a dictionary with two keys:
|
| 315 |
+
(Dict[str, Tuple[List, List]]) A dict containing two keys:
|
| 316 |
+
"tensor_level_info", "channel_level_info"
|
| 317 |
+
Each key maps to a tuple with:
|
| 318 |
+
A list of the headers of each table
|
| 319 |
+
A list of lists containing the table information row by row
|
| 320 |
+
The 0th index row will contain the headers of the columns
|
| 321 |
+
The rest of the rows will contain data
|
| 322 |
+
|
| 323 |
+
Example Use:
|
| 324 |
+
>>> # xdoctest: +SKIP("undefined variables")
|
| 325 |
+
>>> mod_report_visualizer.generate_filtered_tables(
|
| 326 |
+
... feature_filter = "per_channel_min",
|
| 327 |
+
... module_fqn_filter = "block1"
|
| 328 |
+
... ) # generates table with per_channel_min info for all modules in block 1 of the model
|
| 329 |
+
"""
|
| 330 |
+
# first get the filtered data
|
| 331 |
+
filtered_data: OrderedDict[str, Any] = self._get_filtered_data(feature_filter, module_fqn_filter)
|
| 332 |
+
|
| 333 |
+
# now we split into tensor and per-channel data
|
| 334 |
+
tensor_features: Set[str] = set()
|
| 335 |
+
channel_features: Set[str] = set()
|
| 336 |
+
|
| 337 |
+
# keep track of the number of channels we have
|
| 338 |
+
num_channels: int = 0
|
| 339 |
+
|
| 340 |
+
for module_fqn in filtered_data:
|
| 341 |
+
for feature_name in filtered_data[module_fqn]:
|
| 342 |
+
# get the data for that specific feature
|
| 343 |
+
feature_data = filtered_data[module_fqn][feature_name]
|
| 344 |
+
|
| 345 |
+
# check if not zero dim tensor
|
| 346 |
+
is_tensor: bool = isinstance(feature_data, torch.Tensor)
|
| 347 |
+
is_not_zero_dim: bool = is_tensor and len(feature_data.shape) != 0
|
| 348 |
+
|
| 349 |
+
if is_not_zero_dim or isinstance(feature_data, list):
|
| 350 |
+
# works means per channel
|
| 351 |
+
channel_features.add(feature_name)
|
| 352 |
+
num_channels = len(feature_data)
|
| 353 |
+
else:
|
| 354 |
+
# means is per-tensor
|
| 355 |
+
tensor_features.add(feature_name)
|
| 356 |
+
|
| 357 |
+
# we make them lists for iteration purposes
|
| 358 |
+
tensor_features_list: List[str] = sorted(tensor_features)
|
| 359 |
+
channel_features_list: List[str] = sorted(channel_features)
|
| 360 |
+
|
| 361 |
+
# get the tensor info
|
| 362 |
+
tensor_headers, tensor_table = self._generate_tensor_table(filtered_data, tensor_features_list)
|
| 363 |
+
|
| 364 |
+
# get the channel info
|
| 365 |
+
channel_headers, channel_table = self._generate_channels_table(
|
| 366 |
+
filtered_data, channel_features_list, num_channels
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# let's now create the dictionary to return
|
| 370 |
+
table_dict = {
|
| 371 |
+
self.TABLE_TENSOR_KEY : (tensor_headers, tensor_table),
|
| 372 |
+
self.TABLE_CHANNEL_KEY : (channel_headers, channel_table)
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
# return the two tables
|
| 376 |
+
return table_dict
|
| 377 |
+
|
| 378 |
+
def generate_table_visualization(self, feature_filter: str = "", module_fqn_filter: str = ""):
|
| 379 |
+
r"""
|
| 380 |
+
Takes in optional filter values and prints out formatted tables of the information.
|
| 381 |
+
|
| 382 |
+
The reason for the two tables printed out instead of one large one are that they handle different things:
|
| 383 |
+
1.) the first table handles all tensor level information
|
| 384 |
+
2.) the second table handles and displays all channel based information
|
| 385 |
+
|
| 386 |
+
The reasoning for this is that having all the info in one table can make it ambiguous which collected
|
| 387 |
+
statistics are global, and which are actually per-channel, so it's better to split it up into two
|
| 388 |
+
tables. This also makes the information much easier to digest given the plethora of statistics collected
|
| 389 |
+
|
| 390 |
+
Tensor table columns:
|
| 391 |
+
idx layer_fqn feature_1 feature_2 feature_3 .... feature_n
|
| 392 |
+
---- --------- --------- --------- --------- ---------
|
| 393 |
+
|
| 394 |
+
Per-Channel table columns:
|
| 395 |
+
|
| 396 |
+
idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n
|
| 397 |
+
---- --------- ------- --------- --------- --------- ---------
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
feature_filter (str, optional): Filters the features presented to only those that
|
| 401 |
+
contain this filter substring
|
| 402 |
+
Default = "", results in all the features being printed
|
| 403 |
+
module_fqn_filter (str, optional): Only includes modules that contains this string
|
| 404 |
+
Default = "", results in all the modules in the reports to be visible in the table
|
| 405 |
+
|
| 406 |
+
Example Use:
|
| 407 |
+
>>> # xdoctest: +SKIP("undefined variables")
|
| 408 |
+
>>> mod_report_visualizer.generate_table_visualization(
|
| 409 |
+
... feature_filter = "per_channel_min",
|
| 410 |
+
... module_fqn_filter = "block1"
|
| 411 |
+
... )
|
| 412 |
+
>>> # prints out neatly formatted table with per_channel_min info
|
| 413 |
+
>>> # for all modules in block 1 of the model
|
| 414 |
+
"""
|
| 415 |
+
# see if we got tabulate
|
| 416 |
+
if not got_tabulate:
|
| 417 |
+
print("Make sure to install tabulate and try again.")
|
| 418 |
+
return None
|
| 419 |
+
|
| 420 |
+
# get the table dict and the specific tables of interest
|
| 421 |
+
table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter)
|
| 422 |
+
tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY]
|
| 423 |
+
channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY]
|
| 424 |
+
|
| 425 |
+
# get the table string and print it out
|
| 426 |
+
# now we have populated the tables for each one
|
| 427 |
+
# let's create the strings to be returned
|
| 428 |
+
table_str = ""
|
| 429 |
+
# the tables will have some headers columns that are non-feature
|
| 430 |
+
# ex. table index, module name, channel index, etc.
|
| 431 |
+
# we want to look at header columns for features, that come after those headers
|
| 432 |
+
if len(tensor_headers) > self.NUM_NON_FEATURE_TENSOR_HEADERS:
|
| 433 |
+
# if we have at least one tensor level feature to be added we add tensor table
|
| 434 |
+
table_str += "Tensor Level Information \n"
|
| 435 |
+
table_str += tabulate(tensor_table, headers=tensor_headers)
|
| 436 |
+
if len(channel_headers) > self.NUM_NON_FEATURE_CHANNEL_HEADERS:
|
| 437 |
+
# if we have at least one channel level feature to be added we add tensor table
|
| 438 |
+
table_str += "\n\n Channel Level Information \n"
|
| 439 |
+
table_str += tabulate(channel_table, headers=channel_headers)
|
| 440 |
+
|
| 441 |
+
# if no features at all, let user know
|
| 442 |
+
if table_str == "":
|
| 443 |
+
table_str = "No data points to generate table with."
|
| 444 |
+
|
| 445 |
+
print(table_str)
|
| 446 |
+
|
| 447 |
+
def _get_plottable_data(self, feature_filter: str, module_fqn_filter: str) -> Tuple[List, List[List], bool]:
|
| 448 |
+
r"""
|
| 449 |
+
Takes in the feature filters and module filters and outputs the x and y data for plotting
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
feature_filter (str): Filters the features presented to only those that
|
| 453 |
+
contain this filter substring
|
| 454 |
+
module_fqn_filter (str): Only includes modules that contains this string
|
| 455 |
+
|
| 456 |
+
Returns a tuple of three elements
|
| 457 |
+
The first is a list containing relevant x-axis data
|
| 458 |
+
The second is a list containing the corresponding y-axis data
|
| 459 |
+
If the data is per channel
|
| 460 |
+
"""
|
| 461 |
+
# get the table dict and the specific tables of interest
|
| 462 |
+
table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter)
|
| 463 |
+
tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY]
|
| 464 |
+
channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY]
|
| 465 |
+
|
| 466 |
+
# make sure it is only 1 feature that is being plotted
|
| 467 |
+
# get the number of features in each of these
|
| 468 |
+
tensor_info_features_count = len(tensor_headers) - ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS
|
| 469 |
+
channel_info_features_count = len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS
|
| 470 |
+
|
| 471 |
+
# see if valid tensor or channel plot
|
| 472 |
+
is_valid_per_tensor_plot: bool = tensor_info_features_count == 1
|
| 473 |
+
is_valid_per_channel_plot: bool = channel_info_features_count == 1
|
| 474 |
+
|
| 475 |
+
# offset should either be one of tensor or channel table or neither
|
| 476 |
+
feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS
|
| 477 |
+
table = tensor_table
|
| 478 |
+
|
| 479 |
+
# if a per_channel plot, we have different offset and table
|
| 480 |
+
if is_valid_per_channel_plot:
|
| 481 |
+
feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS
|
| 482 |
+
table = channel_table
|
| 483 |
+
|
| 484 |
+
x_data: List = []
|
| 485 |
+
y_data: List[List] = []
|
| 486 |
+
# the feature will either be a tensor feature or channel feature
|
| 487 |
+
if is_valid_per_tensor_plot:
|
| 488 |
+
for table_row_num, row in enumerate(table):
|
| 489 |
+
# get x_value to append
|
| 490 |
+
x_val_to_append = table_row_num
|
| 491 |
+
# the index of the feature will the 0 + num non feature columns
|
| 492 |
+
tensor_feature_index = feature_column_offset
|
| 493 |
+
row_value = row[tensor_feature_index]
|
| 494 |
+
if not type(row_value) == str:
|
| 495 |
+
x_data.append(x_val_to_append)
|
| 496 |
+
y_data.append(row_value)
|
| 497 |
+
elif is_valid_per_channel_plot:
|
| 498 |
+
# gather the x_data and multiple y_data
|
| 499 |
+
# calculate the number of channels
|
| 500 |
+
num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in table) + 1
|
| 501 |
+
for channel in range(num_channels):
|
| 502 |
+
y_data.append([]) # separate data list per channel
|
| 503 |
+
|
| 504 |
+
for table_row_num, row in enumerate(table):
|
| 505 |
+
# get x_value to append
|
| 506 |
+
x_val_to_append = table_row_num
|
| 507 |
+
current_channel = row[self.CHANNEL_NUM_INDEX] # initially chose current channel
|
| 508 |
+
new_module_index: int = table_row_num // num_channels
|
| 509 |
+
x_val_to_append = new_module_index
|
| 510 |
+
|
| 511 |
+
# the index of the feature will the 0 + num non feature columns
|
| 512 |
+
tensor_feature_index = feature_column_offset
|
| 513 |
+
row_value = row[tensor_feature_index]
|
| 514 |
+
if not type(row_value) == str:
|
| 515 |
+
# only append if new index we are appending
|
| 516 |
+
if len(x_data) == 0 or x_data[-1] != x_val_to_append:
|
| 517 |
+
x_data.append(x_val_to_append)
|
| 518 |
+
|
| 519 |
+
# append value for that channel
|
| 520 |
+
y_data[current_channel].append(row_value)
|
| 521 |
+
else:
|
| 522 |
+
# more than one feature was chosen
|
| 523 |
+
error_str = "Make sure to pick only a single feature with your filter to plot a graph."
|
| 524 |
+
error_str += " We recommend calling get_all_unique_feature_names() to find unique feature names."
|
| 525 |
+
error_str += " Pick one of those features to plot."
|
| 526 |
+
raise ValueError(error_str)
|
| 527 |
+
|
| 528 |
+
# return x, y values, and if data is per-channel
|
| 529 |
+
return (x_data, y_data, is_valid_per_channel_plot)
|
| 530 |
+
|
| 531 |
+
def generate_plot_visualization(self, feature_filter: str, module_fqn_filter: str = ""):
|
| 532 |
+
r"""
|
| 533 |
+
Takes in a feature and optional module_filter and plots of the desired data.
|
| 534 |
+
|
| 535 |
+
For per channel features, it averages the value across the channels and plots a point
|
| 536 |
+
per module. The reason for this is that for models with hundreds of channels, it can
|
| 537 |
+
be hard to differentiate one channel line from another, and so the point of generating
|
| 538 |
+
a single average point per module is to give a sense of general trends that encourage
|
| 539 |
+
further deep dives.
|
| 540 |
+
|
| 541 |
+
Note:
|
| 542 |
+
Only features in the report that have tensor value data are plottable by this class
|
| 543 |
+
When the tensor information is plotted, it will plot:
|
| 544 |
+
idx as the x val, feature value as the y_val
|
| 545 |
+
When the channel information is plotted, it will plot:
|
| 546 |
+
the first idx of each module as the x val, feature value as the y_val [for each channel]
|
| 547 |
+
The reason for this is that we want to be able to compare values across the
|
| 548 |
+
channels for same layer, and it will be hard if values are staggered by idx
|
| 549 |
+
This means each module is represented by only 1 x value
|
| 550 |
+
Args:
|
| 551 |
+
feature_filter (str): Filters the features presented to only those that
|
| 552 |
+
contain this filter substring
|
| 553 |
+
module_fqn_filter (str, optional): Only includes modules that contains this string
|
| 554 |
+
Default = "", results in all the modules in the reports to be visible in the table
|
| 555 |
+
|
| 556 |
+
Example Use:
|
| 557 |
+
>>> # xdoctest: +SKIP("undefined variables")
|
| 558 |
+
>>> mod_report_visualizer.generate_plot_visualization(
|
| 559 |
+
... feature_filter = "per_channel_min",
|
| 560 |
+
... module_fqn_filter = "block1"
|
| 561 |
+
... )
|
| 562 |
+
>>> # outputs line plot of per_channel_min information for all
|
| 563 |
+
>>> # modules in block1 of model each channel gets it's own line,
|
| 564 |
+
>>> # and it's plotted across the in-order modules on the x-axis
|
| 565 |
+
"""
|
| 566 |
+
# checks if we have matplotlib and let's user know to install it if don't
|
| 567 |
+
if not got_matplotlib:
|
| 568 |
+
print("make sure to install matplotlib and try again.")
|
| 569 |
+
return None
|
| 570 |
+
|
| 571 |
+
# get the x and y data and if per channel
|
| 572 |
+
x_data, y_data, data_per_channel = self._get_plottable_data(feature_filter, module_fqn_filter)
|
| 573 |
+
|
| 574 |
+
# plot based on whether data is per channel or not
|
| 575 |
+
ax = plt.subplot()
|
| 576 |
+
ax.set_ylabel(feature_filter)
|
| 577 |
+
ax.set_title(feature_filter + " Plot")
|
| 578 |
+
plt.xticks(x_data) # only show ticks for actual points
|
| 579 |
+
|
| 580 |
+
if data_per_channel:
|
| 581 |
+
ax.set_xlabel("First idx of module")
|
| 582 |
+
# set the legend as well
|
| 583 |
+
# plot a single line that is average of the channel values
|
| 584 |
+
num_modules = len(y_data[0]) # all y_data have same length, so get num modules
|
| 585 |
+
num_channels = len(y_data) # we want num channels to be able to calculate average later
|
| 586 |
+
|
| 587 |
+
avg_vals = [sum(y_data[:][index]) / num_channels for index in range(num_modules)]
|
| 588 |
+
|
| 589 |
+
# plot the three things we measured
|
| 590 |
+
ax.plot(x_data, avg_vals, label=f"Average Value Across {num_channels} Channels")
|
| 591 |
+
ax.legend(loc='upper right')
|
| 592 |
+
else:
|
| 593 |
+
ax.set_xlabel("idx")
|
| 594 |
+
ax.plot(x_data, y_data)
|
| 595 |
+
|
| 596 |
+
# actually show the plot
|
| 597 |
+
plt.show()
|
| 598 |
+
|
| 599 |
+
def generate_histogram_visualization(self, feature_filter: str, module_fqn_filter: str = "", num_bins: int = 10):
|
| 600 |
+
r"""
|
| 601 |
+
Takes in a feature and optional module_filter and plots the histogram of desired data.
|
| 602 |
+
|
| 603 |
+
Note:
|
| 604 |
+
Only features in the report that have tensor value data can be viewed as a histogram
|
| 605 |
+
If you want to plot a histogram from all the channel values of a specific feature for
|
| 606 |
+
a specific model, make sure to specify both the model and the feature properly
|
| 607 |
+
in the filters and you should be able to see a distribution of the channel data
|
| 608 |
+
|
| 609 |
+
Args:
|
| 610 |
+
feature_filter (str, optional): Filters the features presented to only those that
|
| 611 |
+
contain this filter substring
|
| 612 |
+
Default = "", results in all the features being printed
|
| 613 |
+
module_fqn_filter (str, optional): Only includes modules that contains this string
|
| 614 |
+
Default = "", results in all the modules in the reports to be visible in the table
|
| 615 |
+
num_bins (int, optional): The number of bins to create the histogram with
|
| 616 |
+
Default = 10, the values will be split into 10 equal sized bins
|
| 617 |
+
|
| 618 |
+
Example Use:
|
| 619 |
+
>>> # xdoctest: +SKIP
|
| 620 |
+
>>> mod_report_visualizer.generategenerate_histogram_visualization_plot_visualization(
|
| 621 |
+
... feature_filter = "per_channel_min",
|
| 622 |
+
... module_fqn_filter = "block1"
|
| 623 |
+
... )
|
| 624 |
+
# outputs histogram of per_channel_min information for all modules in block1 of model
|
| 625 |
+
information is gathered across all channels for all modules in block 1 for the
|
| 626 |
+
per_channel_min and is displayed in a histogram of equally sized bins
|
| 627 |
+
"""
|
| 628 |
+
# checks if we have matplotlib and let's user know to install it if don't
|
| 629 |
+
if not got_matplotlib:
|
| 630 |
+
print("make sure to install matplotlib and try again.")
|
| 631 |
+
return None
|
| 632 |
+
|
| 633 |
+
# get the x and y data and if per channel
|
| 634 |
+
x_data, y_data, data_per_channel = self._get_plottable_data(feature_filter, module_fqn_filter)
|
| 635 |
+
|
| 636 |
+
# for histogram, we just care about plotting the y data
|
| 637 |
+
# plot based on whether data is per channel or not
|
| 638 |
+
ax = plt.subplot()
|
| 639 |
+
ax.set_xlabel(feature_filter)
|
| 640 |
+
ax.set_ylabel("Frequency")
|
| 641 |
+
ax.set_title(feature_filter + " Histogram")
|
| 642 |
+
|
| 643 |
+
if data_per_channel:
|
| 644 |
+
# set the legend as well
|
| 645 |
+
# combine all the data
|
| 646 |
+
all_data = []
|
| 647 |
+
for channel_info in y_data:
|
| 648 |
+
all_data.extend(channel_info)
|
| 649 |
+
|
| 650 |
+
val, bins, _ = plt.hist(
|
| 651 |
+
all_data,
|
| 652 |
+
bins=num_bins,
|
| 653 |
+
stacked=True,
|
| 654 |
+
rwidth=0.8,
|
| 655 |
+
)
|
| 656 |
+
plt.xticks(bins)
|
| 657 |
+
else:
|
| 658 |
+
val, bins, _ = plt.hist(
|
| 659 |
+
y_data,
|
| 660 |
+
bins=num_bins,
|
| 661 |
+
stacked=False,
|
| 662 |
+
rwidth=0.8,
|
| 663 |
+
)
|
| 664 |
+
plt.xticks(bins)
|
| 665 |
+
|
| 666 |
+
plt.show()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/quantize_handler.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
from typing import Callable, Dict, List, Optional, Type
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from torch.ao.quantization.backend_config import (
|
| 7 |
+
BackendConfig,
|
| 8 |
+
DTypeConfig,
|
| 9 |
+
ObservationType,
|
| 10 |
+
)
|
| 11 |
+
from torch.ao.quantization.utils import NodePattern, Pattern, QuantizerCls
|
| 12 |
+
from torch.fx.graph import Node
|
| 13 |
+
|
| 14 |
+
from .utils import all_node_args_have_no_tensors
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"QuantizeHandler",
|
| 19 |
+
"BinaryOpQuantizeHandler",
|
| 20 |
+
"CatQuantizeHandler",
|
| 21 |
+
"ConvReluQuantizeHandler",
|
| 22 |
+
"LinearReLUQuantizeHandler",
|
| 23 |
+
"BatchNormQuantizeHandler",
|
| 24 |
+
"EmbeddingQuantizeHandler",
|
| 25 |
+
"RNNDynamicQuantizeHandler",
|
| 26 |
+
"DefaultNodeQuantizeHandler",
|
| 27 |
+
"FixedQParamsOpQuantizeHandler",
|
| 28 |
+
"CopyNodeQuantizeHandler",
|
| 29 |
+
"GeneralTensorShapeOpQuantizeHandler",
|
| 30 |
+
"CustomModuleQuantizeHandler",
|
| 31 |
+
"StandaloneModuleQuantizeHandler",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
def _default_root_node_getter(node_pattern):
|
| 35 |
+
if node_pattern is None:
|
| 36 |
+
return node_pattern
|
| 37 |
+
while not isinstance(node_pattern, Node):
|
| 38 |
+
node_pattern = node_pattern[-1]
|
| 39 |
+
return node_pattern
|
| 40 |
+
|
| 41 |
+
# Base Pattern Handler
|
| 42 |
+
class QuantizeHandler(ABC): # noqa: B024
|
| 43 |
+
""" Base handler class for the quantizer patterns
|
| 44 |
+
"""
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
node_pattern: NodePattern,
|
| 48 |
+
modules: Dict[str, torch.nn.Module],
|
| 49 |
+
root_node_getter: Optional[Callable] = None,
|
| 50 |
+
is_custom_module=False,
|
| 51 |
+
is_standalone_module=False):
|
| 52 |
+
""" Records pattern information in __init__, which will be used
|
| 53 |
+
in convert
|
| 54 |
+
"""
|
| 55 |
+
self.node_pattern = node_pattern
|
| 56 |
+
self.modules = modules
|
| 57 |
+
if root_node_getter is None:
|
| 58 |
+
root_node_getter = _default_root_node_getter
|
| 59 |
+
self.root_node = root_node_getter(node_pattern)
|
| 60 |
+
self.is_custom_module_ = is_custom_module
|
| 61 |
+
self.is_standalone_module_ = is_standalone_module
|
| 62 |
+
self.num_tensor_args = 0
|
| 63 |
+
# determine how many of the first two args are Tensors (versus scalars)
|
| 64 |
+
# this distinguishes things like "x + y" from "x + 2" or "2 + x"
|
| 65 |
+
if isinstance(self.root_node, Node):
|
| 66 |
+
cache_for_no_tensor_check: Dict[Node, bool] = {}
|
| 67 |
+
for arg_idx in range(len(self.root_node.args)):
|
| 68 |
+
arg = self.root_node.args[arg_idx]
|
| 69 |
+
if isinstance(arg, Node) and (
|
| 70 |
+
not all_node_args_have_no_tensors(
|
| 71 |
+
arg, self.modules, cache_for_no_tensor_check)):
|
| 72 |
+
self.num_tensor_args += 1
|
| 73 |
+
|
| 74 |
+
def is_general_tensor_value_op(self) -> bool:
|
| 75 |
+
"""
|
| 76 |
+
Returns True if the operator works for both floating point and
|
| 77 |
+
quantized input, and does some computation based on the input Tensor,
|
| 78 |
+
or the ops that only re-arranges the Tensor values or query some metadata
|
| 79 |
+
about the Tensor
|
| 80 |
+
so we need to insert observer/fake_quant for the output of the
|
| 81 |
+
operator (same observer instance as input)
|
| 82 |
+
since the distribution of values is different for input and output
|
| 83 |
+
Tensors (for HistogramObserver) while they share the same quantization
|
| 84 |
+
parameters
|
| 85 |
+
Example operator: avgpool2d, reshape, transpose, maxpool2d
|
| 86 |
+
Example observed operator:
|
| 87 |
+
observer_0 - avgpool2d - observer_0 (same observer instance as input)
|
| 88 |
+
"""
|
| 89 |
+
return False
|
| 90 |
+
|
| 91 |
+
def is_custom_module(self):
|
| 92 |
+
return self.is_custom_module_
|
| 93 |
+
|
| 94 |
+
def is_standalone_module(self):
|
| 95 |
+
return self.is_standalone_module_
|
| 96 |
+
|
| 97 |
+
def _get_quantize_handler_cls(
|
| 98 |
+
observation_type: ObservationType,
|
| 99 |
+
dtype_configs: List[DTypeConfig],
|
| 100 |
+
num_tensor_args_to_observation_type: Dict[int, ObservationType]) -> Type[QuantizeHandler]:
|
| 101 |
+
"""
|
| 102 |
+
Return a configurable QuantizeHandler that matches the given specifications from the backend.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
class ConfigurableQuantizeHandler(QuantizeHandler):
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
node_pattern: NodePattern,
|
| 109 |
+
modules: Dict[str, torch.nn.Module],
|
| 110 |
+
root_node_getter: Optional[Callable] = None):
|
| 111 |
+
super().__init__(node_pattern, modules, root_node_getter)
|
| 112 |
+
if num_tensor_args_to_observation_type:
|
| 113 |
+
assert self.num_tensor_args in num_tensor_args_to_observation_type, \
|
| 114 |
+
f"Must provide observation_type config for tensor number {self.num_tensor_args}" \
|
| 115 |
+
f" in num_tensor_args_to_observation_type for {node_pattern}"
|
| 116 |
+
self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args]
|
| 117 |
+
else:
|
| 118 |
+
self.observation_type = observation_type
|
| 119 |
+
self.dtype_configs = dtype_configs
|
| 120 |
+
|
| 121 |
+
def is_general_tensor_value_op(self) -> bool:
|
| 122 |
+
return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
|
| 123 |
+
|
| 124 |
+
return ConfigurableQuantizeHandler
|
| 125 |
+
|
| 126 |
+
def _get_pattern_to_quantize_handlers(backend_config: BackendConfig) -> Dict[Pattern, QuantizerCls]:
|
| 127 |
+
"""
|
| 128 |
+
Note: Quantize handler is just a holder for some check methods like
|
| 129 |
+
(should_insert_observer_for_output), maybe this can be a enum as well,
|
| 130 |
+
we can refactor this after we convert the path for fbgemm/qnnpack fully to the
|
| 131 |
+
new path, this is not exposed to backend developers
|
| 132 |
+
"""
|
| 133 |
+
pattern_to_quantize_handlers = {}
|
| 134 |
+
for pattern, config in backend_config._pattern_complex_format_to_config.items():
|
| 135 |
+
observation_type = config.observation_type
|
| 136 |
+
dtype_configs = config.dtype_configs
|
| 137 |
+
num_tensor_args_to_observation_type = config._num_tensor_args_to_observation_type
|
| 138 |
+
pattern_to_quantize_handlers[pattern] = \
|
| 139 |
+
_get_quantize_handler_cls(
|
| 140 |
+
observation_type,
|
| 141 |
+
dtype_configs,
|
| 142 |
+
num_tensor_args_to_observation_type)
|
| 143 |
+
return pattern_to_quantize_handlers
|
| 144 |
+
|
| 145 |
+
# TODO: remove this class, this is still exposed in torch.ao.quantization
|
| 146 |
+
# but we should be able to break bc
|
| 147 |
+
class BinaryOpQuantizeHandler(QuantizeHandler):
|
| 148 |
+
pass
|
| 149 |
+
|
| 150 |
+
class CatQuantizeHandler(QuantizeHandler):
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
# TODO: remove this class
|
| 154 |
+
class ConvReluQuantizeHandler(QuantizeHandler):
|
| 155 |
+
pass
|
| 156 |
+
|
| 157 |
+
# TODO: remove this class
|
| 158 |
+
class LinearReLUQuantizeHandler(QuantizeHandler):
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
# TODO: remove this class
|
| 162 |
+
class BatchNormQuantizeHandler(QuantizeHandler):
|
| 163 |
+
pass
|
| 164 |
+
|
| 165 |
+
# TODO: remove this class
|
| 166 |
+
class EmbeddingQuantizeHandler(QuantizeHandler):
|
| 167 |
+
pass
|
| 168 |
+
|
| 169 |
+
# TODO: remove this class
|
| 170 |
+
class RNNDynamicQuantizeHandler(QuantizeHandler):
|
| 171 |
+
pass
|
| 172 |
+
|
| 173 |
+
# TODO: remove this class
|
| 174 |
+
class DefaultNodeQuantizeHandler(QuantizeHandler):
|
| 175 |
+
""" Common quantized op, first input and first output will be quantized
|
| 176 |
+
"""
|
| 177 |
+
pass
|
| 178 |
+
|
| 179 |
+
# TODO: remove this class
|
| 180 |
+
class FixedQParamsOpQuantizeHandler(QuantizeHandler):
|
| 181 |
+
pass
|
| 182 |
+
|
| 183 |
+
# TODO: remove
|
| 184 |
+
class CopyNodeQuantizeHandler(QuantizeHandler):
|
| 185 |
+
pass
|
| 186 |
+
|
| 187 |
+
# TODO: remove
|
| 188 |
+
class GeneralTensorShapeOpQuantizeHandler(QuantizeHandler):
|
| 189 |
+
pass
|
| 190 |
+
|
| 191 |
+
# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated
|
| 192 |
+
class CustomModuleQuantizeHandler(QuantizeHandler):
|
| 193 |
+
pass
|
| 194 |
+
|
| 195 |
+
# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated
|
| 196 |
+
class StandaloneModuleQuantizeHandler(QuantizeHandler):
|
| 197 |
+
pass
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/utils.py
ADDED
|
@@ -0,0 +1,885 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.ao.quantization import (
|
| 5 |
+
QConfigAny,
|
| 6 |
+
QuantType,
|
| 7 |
+
)
|
| 8 |
+
from torch.ao.quantization.backend_config import (
|
| 9 |
+
DTypeWithConstraints,
|
| 10 |
+
)
|
| 11 |
+
from torch.ao.quantization.fake_quantize import (
|
| 12 |
+
FakeQuantizeBase,
|
| 13 |
+
FixedQParamsFakeQuantize,
|
| 14 |
+
)
|
| 15 |
+
from torch.ao.quantization.observer import (
|
| 16 |
+
FixedQParamsObserver,
|
| 17 |
+
ObserverBase,
|
| 18 |
+
)
|
| 19 |
+
from torch.ao.quantization.qconfig import (
|
| 20 |
+
float16_static_qconfig,
|
| 21 |
+
float16_dynamic_qconfig,
|
| 22 |
+
qconfig_equals,
|
| 23 |
+
)
|
| 24 |
+
from torch.ao.quantization.stubs import DeQuantStub
|
| 25 |
+
from torch.ao.quantization.utils import (
|
| 26 |
+
activation_is_statically_quantized,
|
| 27 |
+
)
|
| 28 |
+
from torch.ao.quantization.observer import _is_activation_post_process
|
| 29 |
+
from torch.ao.quantization.qconfig_mapping import QConfigMapping
|
| 30 |
+
|
| 31 |
+
from torch.fx import GraphModule, map_arg
|
| 32 |
+
|
| 33 |
+
from torch.fx.graph import (
|
| 34 |
+
Graph,
|
| 35 |
+
Node,
|
| 36 |
+
)
|
| 37 |
+
from .custom_config import PrepareCustomConfig
|
| 38 |
+
# importing the lib so that the quantized_decomposed ops are registered
|
| 39 |
+
from ._decomposed import quantized_decomposed_lib # noqa: F401
|
| 40 |
+
|
| 41 |
+
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type
|
| 42 |
+
from dataclasses import dataclass
|
| 43 |
+
from collections import namedtuple
|
| 44 |
+
import operator
|
| 45 |
+
import warnings
|
| 46 |
+
|
| 47 |
+
# TODO: revisit this list. Many helper methods shouldn't be public
|
| 48 |
+
__all__ = [
|
| 49 |
+
"all_node_args_except_first",
|
| 50 |
+
"all_node_args_have_no_tensors",
|
| 51 |
+
"assert_and_get_unique_device",
|
| 52 |
+
"collect_producer_nodes",
|
| 53 |
+
"create_getattr_from_value",
|
| 54 |
+
"create_node_from_old_node_preserve_meta",
|
| 55 |
+
"EMPTY_ARG_DICT",
|
| 56 |
+
"get_custom_module_class_keys",
|
| 57 |
+
"get_linear_prepack_op_for_dtype",
|
| 58 |
+
"get_new_attr_name_with_prefix",
|
| 59 |
+
"get_non_observable_arg_indexes_and_types",
|
| 60 |
+
"get_qconv_prepack_op",
|
| 61 |
+
"get_skipped_module_name_and_classes",
|
| 62 |
+
"graph_module_from_producer_nodes",
|
| 63 |
+
"maybe_get_next_module",
|
| 64 |
+
"NodeInfo",
|
| 65 |
+
"node_arg_is_bias",
|
| 66 |
+
"node_arg_is_weight",
|
| 67 |
+
"NON_OBSERVABLE_ARG_DICT",
|
| 68 |
+
"NON_QUANTIZABLE_WEIGHT_OPS",
|
| 69 |
+
"return_arg_list",
|
| 70 |
+
"ObservedGraphModuleAttrs",
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
NON_QUANTIZABLE_WEIGHT_OPS = {torch.nn.functional.layer_norm, torch.nn.functional.group_norm, torch.nn.functional.instance_norm}
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class ObservedGraphModuleAttrs:
|
| 77 |
+
node_name_to_qconfig: Dict[str, QConfigAny]
|
| 78 |
+
node_name_to_scope: Dict[str, Tuple[str, type]]
|
| 79 |
+
prepare_custom_config: PrepareCustomConfig
|
| 80 |
+
equalization_node_name_to_qconfig: Dict[str, Any]
|
| 81 |
+
qconfig_mapping: QConfigMapping
|
| 82 |
+
is_qat: bool
|
| 83 |
+
observed_node_names: Set[str]
|
| 84 |
+
is_observed_standalone_module: bool = False
|
| 85 |
+
standalone_module_input_quantized_idxs: Optional[List[int]] = None
|
| 86 |
+
standalone_module_output_quantized_idxs: Optional[List[int]] = None
|
| 87 |
+
|
| 88 |
+
def node_arg_is_weight(node: Node, arg: Any) -> bool:
|
| 89 |
+
"""Returns if node arg is weight"""
|
| 90 |
+
weight_index = None
|
| 91 |
+
if "target_dtype_info" in node.meta:
|
| 92 |
+
weight_index = node.meta["target_dtype_info"].get("weight_index", None)
|
| 93 |
+
if weight_index is not None and weight_index < len(node.args) and node.args[weight_index] is arg:
|
| 94 |
+
return True
|
| 95 |
+
return node.kwargs.get("weight") is arg
|
| 96 |
+
|
| 97 |
+
def node_arg_is_bias(node: Node, arg: Any) -> bool:
|
| 98 |
+
"""Returns if node arg is bias"""
|
| 99 |
+
bias_index = None
|
| 100 |
+
if "target_dtype_info" in node.meta:
|
| 101 |
+
bias_index = node.meta["target_dtype_info"].get("bias_index", None)
|
| 102 |
+
if bias_index is not None and bias_index < len(node.args) and node.args[bias_index] is arg:
|
| 103 |
+
return True
|
| 104 |
+
return node.kwargs.get("bias") is arg
|
| 105 |
+
|
| 106 |
+
def get_custom_module_class_keys(custom_module_mapping: Dict[QuantType, Dict[Type, Type]]) -> List[Any]:
|
| 107 |
+
r""" Get all the unique custom module keys in the custom config dict
|
| 108 |
+
e.g.
|
| 109 |
+
Input:
|
| 110 |
+
{
|
| 111 |
+
QuantType.STATIC: {
|
| 112 |
+
CustomModule1: ObservedCustomModule
|
| 113 |
+
},
|
| 114 |
+
QuantType.DYNAMIC: {
|
| 115 |
+
CustomModule2: DynamicObservedCustomModule
|
| 116 |
+
},
|
| 117 |
+
QuantType.WEIGHT_ONLY: {
|
| 118 |
+
CustomModule3: WeightOnlyObservedCustomModule
|
| 119 |
+
},
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
Output:
|
| 123 |
+
# extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts
|
| 124 |
+
[CustomModule1, CustomModule2, CustomModule3]
|
| 125 |
+
"""
|
| 126 |
+
# using set to dedup
|
| 127 |
+
float_custom_module_classes : Set[Any] = set()
|
| 128 |
+
for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]:
|
| 129 |
+
quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {})
|
| 130 |
+
quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys())
|
| 131 |
+
float_custom_module_classes |= quant_mode_custom_module_classes
|
| 132 |
+
return list(float_custom_module_classes)
|
| 133 |
+
|
| 134 |
+
def get_linear_prepack_op_for_dtype(dtype):
|
| 135 |
+
if dtype == torch.float16:
|
| 136 |
+
return torch.ops.quantized.linear_prepack_fp16
|
| 137 |
+
elif dtype == torch.qint8:
|
| 138 |
+
return torch.ops.quantized.linear_prepack
|
| 139 |
+
else:
|
| 140 |
+
raise Exception("can't get linear prepack op for dtype:", dtype)
|
| 141 |
+
|
| 142 |
+
def get_qconv_prepack_op(conv_op: Callable) -> Callable:
|
| 143 |
+
prepack_ops = {
|
| 144 |
+
torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack,
|
| 145 |
+
torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack,
|
| 146 |
+
torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack,
|
| 147 |
+
torch.nn.functional.conv_transpose1d: torch.ops.quantized.conv_transpose1d_prepack,
|
| 148 |
+
torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack,
|
| 149 |
+
torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack,
|
| 150 |
+
}
|
| 151 |
+
prepack_op = prepack_ops.get(conv_op, None)
|
| 152 |
+
assert prepack_op, f"Didn't find prepack op for {conv_op}"
|
| 153 |
+
return prepack_op
|
| 154 |
+
|
| 155 |
+
# Returns a function that can get a new attribute name for module with given
|
| 156 |
+
# prefix, for example,
|
| 157 |
+
# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
|
| 158 |
+
# >> new_name = get_new_observer_name(module)
|
| 159 |
+
# new_name will be an unused attribute name on module, e.g. `_observer_1`
|
| 160 |
+
def get_new_attr_name_with_prefix(prefix: str) -> Callable:
|
| 161 |
+
prefix = prefix.replace(".", "_")
|
| 162 |
+
|
| 163 |
+
def get_new_attr_name(module: torch.nn.Module):
|
| 164 |
+
def get_attr_name(i: int):
|
| 165 |
+
return prefix + str(i)
|
| 166 |
+
i = 0
|
| 167 |
+
attr_name = get_attr_name(i)
|
| 168 |
+
while hasattr(module, attr_name):
|
| 169 |
+
i += 1
|
| 170 |
+
attr_name = get_attr_name(i)
|
| 171 |
+
return attr_name
|
| 172 |
+
return get_new_attr_name
|
| 173 |
+
|
| 174 |
+
def collect_producer_nodes(node: Node) -> Optional[List[Node]]:
|
| 175 |
+
r''' Starting from a target node, trace back until we hit inpu or
|
| 176 |
+
getattr node. This is used to extract the chain of operators
|
| 177 |
+
starting from getattr to the target node, for example
|
| 178 |
+
def forward(self, x):
|
| 179 |
+
observed = self.observer(self.weight)
|
| 180 |
+
return F.linear(x, observed)
|
| 181 |
+
collect_producer_nodes(observed) will either return a list of nodes that
|
| 182 |
+
produces the observed node or None if we can't extract a self contained
|
| 183 |
+
graph without free variables(inputs of the forward function).
|
| 184 |
+
'''
|
| 185 |
+
nodes = [node]
|
| 186 |
+
frontier = [node]
|
| 187 |
+
while frontier:
|
| 188 |
+
node = frontier.pop()
|
| 189 |
+
all_args = list(node.args) + list(node.kwargs.values())
|
| 190 |
+
for arg in all_args:
|
| 191 |
+
if not isinstance(arg, Node):
|
| 192 |
+
continue
|
| 193 |
+
if arg.op == 'placeholder':
|
| 194 |
+
# hit input, can't fold in this case
|
| 195 |
+
return None
|
| 196 |
+
nodes.append(arg)
|
| 197 |
+
if not (arg.op == 'call_function' and arg.target == getattr):
|
| 198 |
+
frontier.append(arg)
|
| 199 |
+
return nodes
|
| 200 |
+
|
| 201 |
+
def graph_module_from_producer_nodes(
|
| 202 |
+
root: GraphModule, producer_nodes: List[Node]) -> GraphModule:
|
| 203 |
+
r''' Construct a graph module from extracted producer nodes
|
| 204 |
+
from `collect_producer_nodes` function
|
| 205 |
+
Args:
|
| 206 |
+
root: the root module for the original graph
|
| 207 |
+
producer_nodes: a list of nodes we use to construct the graph
|
| 208 |
+
Return:
|
| 209 |
+
A graph module constructed from the producer nodes
|
| 210 |
+
'''
|
| 211 |
+
assert len(producer_nodes) > 0, 'list of producer nodes can not be empty'
|
| 212 |
+
# since we traced back from node to getattr
|
| 213 |
+
producer_nodes.reverse()
|
| 214 |
+
graph = Graph()
|
| 215 |
+
env: Dict[Any, Any] = {}
|
| 216 |
+
|
| 217 |
+
def load_arg(a):
|
| 218 |
+
return map_arg(a, lambda node: env[node])
|
| 219 |
+
for producer_node in producer_nodes:
|
| 220 |
+
env[producer_node] = graph.node_copy(producer_node, load_arg)
|
| 221 |
+
graph.output(load_arg(producer_nodes[-1]))
|
| 222 |
+
graph_module = GraphModule(root, graph)
|
| 223 |
+
return graph_module
|
| 224 |
+
|
| 225 |
+
def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
|
| 226 |
+
"""
|
| 227 |
+
Returns the unique device for a module, or None if no device is found.
|
| 228 |
+
Throws an error if multiple devices are detected.
|
| 229 |
+
"""
|
| 230 |
+
devices = {p.device for p in module.parameters()} | \
|
| 231 |
+
{p.device for p in module.buffers()}
|
| 232 |
+
"""
|
| 233 |
+
As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564
|
| 234 |
+
"""
|
| 235 |
+
if {torch.device("cpu"), torch.device("meta")} == devices:
|
| 236 |
+
warnings.warn("Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'.")
|
| 237 |
+
devices = {torch.device("cpu")}
|
| 238 |
+
""
|
| 239 |
+
assert len(devices) <= 1, (
|
| 240 |
+
"prepare only works with cpu or single-device CUDA modules, "
|
| 241 |
+
f"but got devices {devices}"
|
| 242 |
+
)
|
| 243 |
+
device = next(iter(devices)) if len(devices) > 0 else None
|
| 244 |
+
return device
|
| 245 |
+
|
| 246 |
+
def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str, value: Any) -> Node:
|
| 247 |
+
"""
|
| 248 |
+
Given a value of any type, creates a getattr node corresponding to the value and
|
| 249 |
+
registers the value as a buffer to the module.
|
| 250 |
+
"""
|
| 251 |
+
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
|
| 252 |
+
attr_name = get_new_attr_name(module)
|
| 253 |
+
device = assert_and_get_unique_device(module)
|
| 254 |
+
new_value = value.clone().detach() if isinstance(value, torch.Tensor) \
|
| 255 |
+
else torch.tensor(value, device=device)
|
| 256 |
+
module.register_buffer(attr_name, new_value)
|
| 257 |
+
# Create get_attr with value
|
| 258 |
+
attr_node = graph.create_node("get_attr", attr_name)
|
| 259 |
+
return attr_node
|
| 260 |
+
|
| 261 |
+
def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module], cache: Dict[Node, bool]) -> bool:
|
| 262 |
+
"""
|
| 263 |
+
If we know for sure that all of this node's args have no
|
| 264 |
+
tensors (are primitives), return True. If we either
|
| 265 |
+
find a tensor or are not sure, return False. Note: this
|
| 266 |
+
function is not exact.
|
| 267 |
+
"""
|
| 268 |
+
if cache and node in cache:
|
| 269 |
+
return cache[node]
|
| 270 |
+
|
| 271 |
+
result = False # will be overwritten
|
| 272 |
+
if not isinstance(node, Node):
|
| 273 |
+
result = True
|
| 274 |
+
elif node.op == 'placeholder':
|
| 275 |
+
result = False
|
| 276 |
+
elif node.op == 'call_module':
|
| 277 |
+
assert isinstance(node.target, str)
|
| 278 |
+
if _is_activation_post_process(modules[node.target]):
|
| 279 |
+
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
|
| 280 |
+
elif node.op == 'call_module':
|
| 281 |
+
result = False
|
| 282 |
+
elif node.op == 'call_function' and node.target is operator.getitem:
|
| 283 |
+
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
|
| 284 |
+
elif node.op == 'get_attr':
|
| 285 |
+
result = False
|
| 286 |
+
elif node.target is getattr and node.args[1] in ['ndim', 'shape']:
|
| 287 |
+
# x1 = x0.ndim
|
| 288 |
+
result = True
|
| 289 |
+
elif node.op == 'call_method' and node.target == 'size':
|
| 290 |
+
# x1 = x0.size(0)
|
| 291 |
+
result = True
|
| 292 |
+
else:
|
| 293 |
+
found_one_tensor = False
|
| 294 |
+
for arg in node.args:
|
| 295 |
+
if isinstance(arg, list):
|
| 296 |
+
for list_el in arg:
|
| 297 |
+
if isinstance(list_el, Node):
|
| 298 |
+
this_list_el_args_have_no_tensors = \
|
| 299 |
+
all_node_args_have_no_tensors(list_el, modules, cache)
|
| 300 |
+
found_one_tensor = found_one_tensor or \
|
| 301 |
+
(not this_list_el_args_have_no_tensors)
|
| 302 |
+
# If found_one_tensor is True, there is no point in
|
| 303 |
+
# recursing further as the end result will always
|
| 304 |
+
# be True.
|
| 305 |
+
# TODO(future PR): remove this entire function and
|
| 306 |
+
# change to dtype inference without recursion.
|
| 307 |
+
if found_one_tensor:
|
| 308 |
+
result = not found_one_tensor
|
| 309 |
+
if cache:
|
| 310 |
+
cache[node] = result
|
| 311 |
+
return result
|
| 312 |
+
elif isinstance(arg, int):
|
| 313 |
+
pass
|
| 314 |
+
else:
|
| 315 |
+
if isinstance(arg, Node):
|
| 316 |
+
this_arg_args_have_no_tensors = all_node_args_have_no_tensors(arg, modules, cache)
|
| 317 |
+
found_one_tensor = found_one_tensor or \
|
| 318 |
+
(not this_arg_args_have_no_tensors)
|
| 319 |
+
# If found_one_tensor is True, there is no point in
|
| 320 |
+
# recursing further as the end result will always
|
| 321 |
+
# be True.
|
| 322 |
+
# TODO(future PR): remove this entire function and
|
| 323 |
+
# change to dtype inference without recursion.
|
| 324 |
+
if found_one_tensor:
|
| 325 |
+
result = not found_one_tensor
|
| 326 |
+
if cache:
|
| 327 |
+
cache[node] = result
|
| 328 |
+
return result
|
| 329 |
+
else:
|
| 330 |
+
found_one_tensor = True
|
| 331 |
+
result = not found_one_tensor
|
| 332 |
+
if cache:
|
| 333 |
+
cache[node] = result
|
| 334 |
+
return result
|
| 335 |
+
|
| 336 |
+
def all_node_args_except_first(node: Node) -> List[int]:
|
| 337 |
+
"""
|
| 338 |
+
Returns all node arg indices after first
|
| 339 |
+
"""
|
| 340 |
+
return list(range(1, len(node.args)))
|
| 341 |
+
|
| 342 |
+
def return_arg_list(arg_indices: List[int]) -> Callable[[Node], List[int]]:
|
| 343 |
+
"""
|
| 344 |
+
Constructs a function that takes a node as arg and returns the arg_indices
|
| 345 |
+
that are valid for node.args
|
| 346 |
+
"""
|
| 347 |
+
def arg_indices_func(node: Node) -> List[int]:
|
| 348 |
+
return [i for i in arg_indices if i < len(node.args)]
|
| 349 |
+
return arg_indices_func
|
| 350 |
+
|
| 351 |
+
NodeInfo = namedtuple("NodeInfo", "op target")
|
| 352 |
+
|
| 353 |
+
# this dict identifies which indices of a node are non tensors
|
| 354 |
+
# so that they can be propagated correctly since inserting observers
|
| 355 |
+
# for them would cause errors
|
| 356 |
+
|
| 357 |
+
NON_OBSERVABLE_ARG_DICT: Dict[NodeInfo, Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]] = {
|
| 358 |
+
NodeInfo("call_method", "masked_fill") : {
|
| 359 |
+
torch.bool: return_arg_list([1]),
|
| 360 |
+
float: return_arg_list([2])
|
| 361 |
+
},
|
| 362 |
+
NodeInfo("call_method", "permute") : {
|
| 363 |
+
int: all_node_args_except_first
|
| 364 |
+
},
|
| 365 |
+
NodeInfo("call_method", "repeat") : {
|
| 366 |
+
int: all_node_args_except_first
|
| 367 |
+
},
|
| 368 |
+
NodeInfo("call_method", "reshape") : {
|
| 369 |
+
int: all_node_args_except_first
|
| 370 |
+
},
|
| 371 |
+
NodeInfo("call_method", "size") : {
|
| 372 |
+
int: return_arg_list([1])
|
| 373 |
+
},
|
| 374 |
+
NodeInfo("call_method", "transpose") : {
|
| 375 |
+
int: all_node_args_except_first
|
| 376 |
+
},
|
| 377 |
+
NodeInfo("call_method", torch.transpose) : {
|
| 378 |
+
int: all_node_args_except_first
|
| 379 |
+
},
|
| 380 |
+
NodeInfo("call_method", "unsqueeze") : {
|
| 381 |
+
int: return_arg_list([1])
|
| 382 |
+
},
|
| 383 |
+
NodeInfo("call_method", "unsqueeze_") : {
|
| 384 |
+
int: return_arg_list([1])
|
| 385 |
+
},
|
| 386 |
+
NodeInfo("call_method", torch.unsqueeze) : {
|
| 387 |
+
int: return_arg_list([1])
|
| 388 |
+
},
|
| 389 |
+
NodeInfo("call_method", "view") : {
|
| 390 |
+
int: all_node_args_except_first
|
| 391 |
+
},
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
EMPTY_ARG_DICT: Dict[Union[type, torch.dtype], Callable[[Node], List[int]]] = {}
|
| 395 |
+
|
| 396 |
+
def get_non_observable_arg_indexes_and_types(node: Node) -> Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]:
|
| 397 |
+
"""
|
| 398 |
+
Returns a dict with of non float tensor types as keys and values which correspond to a
|
| 399 |
+
function to retrieve the list (which takes the node as an argument)
|
| 400 |
+
"""
|
| 401 |
+
info = NodeInfo(node.op, node.target)
|
| 402 |
+
|
| 403 |
+
return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT)
|
| 404 |
+
|
| 405 |
+
def maybe_get_next_module(
|
| 406 |
+
node: Node,
|
| 407 |
+
modules: Dict[str, nn.Module],
|
| 408 |
+
target_module_type: Optional[Type[nn.Module]] = None,
|
| 409 |
+
target_functional_type: Any = None,
|
| 410 |
+
) -> Optional[Node]:
|
| 411 |
+
""" Gets the next module that matches what is needed in
|
| 412 |
+
is_target_module_type if it exists
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
node: The node whose users we want to look at
|
| 416 |
+
target_module_type: Module type that we want to check
|
| 417 |
+
target_functional_type: Functional type that we want to check
|
| 418 |
+
"""
|
| 419 |
+
|
| 420 |
+
for user in node.users.keys():
|
| 421 |
+
if user.op == 'call_module' and target_module_type is not None and \
|
| 422 |
+
isinstance(modules[str(user.target)], target_module_type):
|
| 423 |
+
return user
|
| 424 |
+
elif (user.op == 'call_function' and target_functional_type is not None and
|
| 425 |
+
user.target == target_functional_type):
|
| 426 |
+
return user
|
| 427 |
+
|
| 428 |
+
return None
|
| 429 |
+
|
| 430 |
+
def create_node_from_old_node_preserve_meta(
|
| 431 |
+
quantized_graph: Graph,
|
| 432 |
+
create_node_args: Tuple[Any, ...],
|
| 433 |
+
old_node: Node,
|
| 434 |
+
) -> Node:
|
| 435 |
+
"""
|
| 436 |
+
Creates `new_node` and copies the necessary metadata to it from `old_node`.
|
| 437 |
+
"""
|
| 438 |
+
new_node = quantized_graph.create_node(*create_node_args)
|
| 439 |
+
new_node.stack_trace = old_node.stack_trace
|
| 440 |
+
return new_node
|
| 441 |
+
|
| 442 |
+
def get_skipped_module_name_and_classes(
|
| 443 |
+
prepare_custom_config: PrepareCustomConfig,
|
| 444 |
+
is_standalone_module: bool) -> Tuple[List[str], List[Type[Any]]]:
|
| 445 |
+
skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names)
|
| 446 |
+
skipped_module_classes = copy.copy(prepare_custom_config.non_traceable_module_classes)
|
| 447 |
+
if not is_standalone_module:
|
| 448 |
+
# standalone module and custom module config are applied in top level module
|
| 449 |
+
skipped_module_names += list(prepare_custom_config.standalone_module_names.keys())
|
| 450 |
+
skipped_module_classes += list(prepare_custom_config.standalone_module_classes.keys())
|
| 451 |
+
skipped_module_classes += get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
|
| 452 |
+
|
| 453 |
+
return skipped_module_names, skipped_module_classes
|
| 454 |
+
|
| 455 |
+
def _is_custom_module_lstm(
|
| 456 |
+
node: Node,
|
| 457 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 458 |
+
qconfig: QConfigAny = None,
|
| 459 |
+
# QuantizeHandler, but we cannot include the type here due to circular imports
|
| 460 |
+
qhandler: Optional[Any] = None,
|
| 461 |
+
) -> bool:
|
| 462 |
+
"""
|
| 463 |
+
Return whether this refers to the custom module LSTM flow.
|
| 464 |
+
"""
|
| 465 |
+
mod = _get_module(node, named_modules)
|
| 466 |
+
if qconfig is not None and qhandler is not None:
|
| 467 |
+
assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined]
|
| 468 |
+
return isinstance(mod, torch.nn.LSTM) and \
|
| 469 |
+
activation_is_statically_quantized(qconfig) and \
|
| 470 |
+
qhandler.is_custom_module()
|
| 471 |
+
else:
|
| 472 |
+
return isinstance(mod, torch.ao.nn.quantizable.LSTM)
|
| 473 |
+
|
| 474 |
+
def _is_custom_module_mha(
|
| 475 |
+
node: Node,
|
| 476 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 477 |
+
qconfig: QConfigAny = None,
|
| 478 |
+
# QuantizeHandler, but we cannot include the type here due to circular imports
|
| 479 |
+
qhandler: Optional[Any] = None,
|
| 480 |
+
) -> bool:
|
| 481 |
+
"""
|
| 482 |
+
Return whether this refers to the custom module MultiheadAttention flow.
|
| 483 |
+
"""
|
| 484 |
+
mod = _get_module(node, named_modules)
|
| 485 |
+
if qconfig is not None and qhandler is not None:
|
| 486 |
+
assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined]
|
| 487 |
+
return isinstance(mod, torch.nn.MultiheadAttention) and \
|
| 488 |
+
activation_is_statically_quantized(qconfig) and \
|
| 489 |
+
qhandler.is_custom_module()
|
| 490 |
+
else:
|
| 491 |
+
return isinstance(mod, torch.ao.nn.quantizable.MultiheadAttention)
|
| 492 |
+
|
| 493 |
+
def _get_module(node: Node, named_modules: Dict[str, torch.nn.Module]) -> Optional[torch.nn.Module]:
|
| 494 |
+
"""
|
| 495 |
+
If `node` refers to a call_module node, return the module, else None.
|
| 496 |
+
"""
|
| 497 |
+
if node.op == "call_module" and str(node.target) in named_modules:
|
| 498 |
+
return named_modules[str(node.target)]
|
| 499 |
+
else:
|
| 500 |
+
return None
|
| 501 |
+
|
| 502 |
+
def _insert_dequant_stub(
|
| 503 |
+
node: Node,
|
| 504 |
+
model: torch.nn.Module,
|
| 505 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 506 |
+
graph: Graph,
|
| 507 |
+
) -> Node:
|
| 508 |
+
"""
|
| 509 |
+
Attach a `DeQuantStub` to the model and create a node that calls this
|
| 510 |
+
`DeQuantStub` on the output of `node`, similar to how observers are inserted.
|
| 511 |
+
"""
|
| 512 |
+
prefix = "dequant_stub_"
|
| 513 |
+
get_new_dequant_stub_name = get_new_attr_name_with_prefix(prefix)
|
| 514 |
+
dequant_stub_name = get_new_dequant_stub_name(model)
|
| 515 |
+
dequant_stub = DeQuantStub()
|
| 516 |
+
setattr(model, dequant_stub_name, dequant_stub)
|
| 517 |
+
named_modules[dequant_stub_name] = dequant_stub
|
| 518 |
+
with graph.inserting_after(node):
|
| 519 |
+
return graph.call_module(dequant_stub_name, (node,))
|
| 520 |
+
|
| 521 |
+
def _insert_dequant_stubs_for_custom_module_lstm_output(
|
| 522 |
+
node: Node,
|
| 523 |
+
model: torch.nn.Module,
|
| 524 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 525 |
+
graph: Graph,
|
| 526 |
+
) -> Node:
|
| 527 |
+
"""
|
| 528 |
+
Insert DeQuantStubs after each internal output node of custom module LSTM.
|
| 529 |
+
|
| 530 |
+
Custom module LSTM outputs are nested tuples of the structure (output, (hidden0, hidden1)),
|
| 531 |
+
Since we cannot dequantize a tuple as a whole, we must first break down the tuple into its
|
| 532 |
+
components through `getitem`. This function transforms the graph as follows:
|
| 533 |
+
|
| 534 |
+
(1) Split the LSTM node into (output, (hidden0, hidden1))
|
| 535 |
+
(2) Insert a DeQuantStub after each internal node
|
| 536 |
+
(3) Recombine the DeQuantStubs into the same structure as before
|
| 537 |
+
(4) Reroute all consumers of the original LSTM node and its sub-nodes
|
| 538 |
+
(e.g. lstm[0])
|
| 539 |
+
|
| 540 |
+
Before:
|
| 541 |
+
lstm_output
|
| 542 |
+
|
|
| 543 |
+
v
|
| 544 |
+
original_user(s)
|
| 545 |
+
After:
|
| 546 |
+
lstm_output
|
| 547 |
+
/ \\
|
| 548 |
+
/ (getitem) \\
|
| 549 |
+
/ \\
|
| 550 |
+
v v
|
| 551 |
+
output hidden
|
| 552 |
+
| / \\
|
| 553 |
+
(DeQuantStub) (getitem)
|
| 554 |
+
| / \\
|
| 555 |
+
v v v
|
| 556 |
+
output_dq hidden0 hidden1
|
| 557 |
+
| | |
|
| 558 |
+
| (DeQuantStub) (DeQuantStub)
|
| 559 |
+
| | |
|
| 560 |
+
| v v
|
| 561 |
+
| hidden0_dq hidden1_dq
|
| 562 |
+
| \\ /
|
| 563 |
+
| (tuple)
|
| 564 |
+
| \\ /
|
| 565 |
+
| v v
|
| 566 |
+
| hidden_dq
|
| 567 |
+
\\ /
|
| 568 |
+
\\ (tuple) /
|
| 569 |
+
v v
|
| 570 |
+
lstm_output_dq
|
| 571 |
+
|
|
| 572 |
+
v
|
| 573 |
+
original_user(s)
|
| 574 |
+
|
| 575 |
+
For step (4), reroute all users of the original LSTM node(s) as follows:
|
| 576 |
+
lstm_output -> lstm_output_dq
|
| 577 |
+
lstm_output[0] -> output_dq
|
| 578 |
+
lstm_output[1] -> hidden_dq
|
| 579 |
+
lstm_output[1][0] -> hidden0_dq
|
| 580 |
+
lstm_output[1][1] -> hidden1_dq
|
| 581 |
+
|
| 582 |
+
Return the node `lstm_output_dq`.
|
| 583 |
+
"""
|
| 584 |
+
# (1) Split the LSTM node into (output, (hidden0, hidden1))
|
| 585 |
+
# (2) Insert a DeQuantStub after each internal node
|
| 586 |
+
with graph.inserting_after(node):
|
| 587 |
+
output = graph.call_function(operator.getitem, (node, 0))
|
| 588 |
+
output_dq = _insert_dequant_stub(output, model, named_modules, graph)
|
| 589 |
+
with graph.inserting_after(output_dq):
|
| 590 |
+
hidden = graph.call_function(operator.getitem, (node, 1))
|
| 591 |
+
with graph.inserting_after(hidden):
|
| 592 |
+
hidden0 = graph.call_function(operator.getitem, (hidden, 0))
|
| 593 |
+
hidden0_dq = _insert_dequant_stub(hidden0, model, named_modules, graph)
|
| 594 |
+
with graph.inserting_after(hidden0_dq):
|
| 595 |
+
hidden1 = graph.call_function(operator.getitem, (hidden, 1))
|
| 596 |
+
hidden1_dq = _insert_dequant_stub(hidden1, model, named_modules, graph)
|
| 597 |
+
|
| 598 |
+
# (3) Recombine the DeQuantStubs into the same structure as before
|
| 599 |
+
with graph.inserting_after(hidden1_dq):
|
| 600 |
+
hidden_dq = graph.call_function(tuple, ([hidden0_dq, hidden1_dq],))
|
| 601 |
+
with graph.inserting_after(hidden_dq):
|
| 602 |
+
lstm_output_dq = graph.call_function(tuple, ([output_dq, hidden_dq],))
|
| 603 |
+
|
| 604 |
+
# (4) Reroute all consumers of the original LSTM node and its sub-nodes
|
| 605 |
+
for user in list(node.users.keys()):
|
| 606 |
+
if user != output and user != hidden:
|
| 607 |
+
user.replace_input_with(node, lstm_output_dq)
|
| 608 |
+
# The getitem and tuple nodes we added here may interfere with reference quantized
|
| 609 |
+
# pattern matching, so we need to redirect the consumers of internal nodes to the
|
| 610 |
+
# corresponding nodes with DeQuantStubs (e.g. lstm_output_dq[0] -> output_dq) attached,
|
| 611 |
+
# in order to preserve reference patterns like "dequantize - consumer - quantize".
|
| 612 |
+
_reroute_tuple_getitem_pattern(graph)
|
| 613 |
+
return lstm_output_dq
|
| 614 |
+
|
| 615 |
+
def _maybe_get_custom_module_lstm_from_node_arg(
|
| 616 |
+
arg: Node,
|
| 617 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 618 |
+
) -> Optional[Node]:
|
| 619 |
+
"""
|
| 620 |
+
Given an argument of a node, if the argument refers to the path through which the node
|
| 621 |
+
is a consumer of custom module LSTM, return the custom module LSTM node, or None otherwise.
|
| 622 |
+
|
| 623 |
+
This is used to determine whether a node is a consumer of custom module LSTM, and, if so,
|
| 624 |
+
skip inserting input observers for this node. This is because custom module LSTM produces
|
| 625 |
+
quantized outputs, so inserting an input observer for the consumer of custom module LSTM
|
| 626 |
+
would unnecessarily quantize the outputs again.
|
| 627 |
+
|
| 628 |
+
lstm -> consumer
|
| 629 |
+
|
| 630 |
+
In practice, however, custom module LSTM outputs a tuple (output, (hidden0, hidden1)) with
|
| 631 |
+
DeQuantStubs attached to each internal node (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
|
| 632 |
+
This tuple can be consumed in one of four ways:
|
| 633 |
+
|
| 634 |
+
lstm -> getitem -> DeQuantStub -> consumer # consume lstm[0]
|
| 635 |
+
lstm -> getitem -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm[1]
|
| 636 |
+
lstm -> getitem -> getitem -> DeQuantStub -> consumer # consume lstm[1][0] or lstm[1][1]
|
| 637 |
+
lstm -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm
|
| 638 |
+
|
| 639 |
+
Thus, we must match against the above patterns instead of simply checking the parent node
|
| 640 |
+
to determine whether this node is a consumer of a custom module LSTM.
|
| 641 |
+
"""
|
| 642 |
+
def match_dq(a):
|
| 643 |
+
return isinstance(_get_module(a, named_modules), DeQuantStub)
|
| 644 |
+
|
| 645 |
+
def match_lstm(a):
|
| 646 |
+
return _is_custom_module_lstm(a, named_modules)
|
| 647 |
+
|
| 648 |
+
def match_getitem(a):
|
| 649 |
+
return a.op == "call_function" and a.target == operator.getitem
|
| 650 |
+
|
| 651 |
+
def match_tuple(a):
|
| 652 |
+
return a.op == "call_function" and a.target == tuple
|
| 653 |
+
|
| 654 |
+
def _match_pattern(match_pattern: List[Callable]) -> Optional[Node]:
|
| 655 |
+
"""
|
| 656 |
+
Traverse up the graph and match the args one by one.
|
| 657 |
+
If there is a match, return the last matched node, or None otherwise.
|
| 658 |
+
"""
|
| 659 |
+
a = arg
|
| 660 |
+
for i, match in enumerate(match_pattern):
|
| 661 |
+
if not match(a):
|
| 662 |
+
return None
|
| 663 |
+
# Match next arg, for tuple the arg is a tuple of a list, e.g. ([dq_1, other_node],)
|
| 664 |
+
if i < len(match_pattern) - 1:
|
| 665 |
+
if match == match_tuple:
|
| 666 |
+
a = a.args[0][0] # type: ignore[assignment,index]
|
| 667 |
+
else:
|
| 668 |
+
a = a.args[0] # type: ignore[assignment]
|
| 669 |
+
return a
|
| 670 |
+
|
| 671 |
+
all_match_patterns = [
|
| 672 |
+
[match_dq, match_getitem, match_lstm],
|
| 673 |
+
[match_tuple, match_dq, match_getitem, match_getitem, match_lstm],
|
| 674 |
+
[match_dq, match_getitem, match_getitem, match_lstm],
|
| 675 |
+
[match_tuple, match_dq, match_getitem, match_lstm],
|
| 676 |
+
]
|
| 677 |
+
|
| 678 |
+
for p in all_match_patterns:
|
| 679 |
+
matched_node = _match_pattern(p)
|
| 680 |
+
if matched_node is not None:
|
| 681 |
+
return matched_node
|
| 682 |
+
return None
|
| 683 |
+
|
| 684 |
+
def _reroute_tuple_getitem_pattern(graph: Graph):
|
| 685 |
+
"""
|
| 686 |
+
Search for patterns where N consecutive `tuple` call_function nodes are followed by
|
| 687 |
+
N consecutive `getitem` call_function nodes that are "reverses" of the `tuple` nodes.
|
| 688 |
+
If we find this pattern, reroute the consumers of the last `getitem` to skip these
|
| 689 |
+
N `tuple` and `getitem` nodes.
|
| 690 |
+
|
| 691 |
+
Before:
|
| 692 |
+
|
| 693 |
+
a b c
|
| 694 |
+
| \\ /
|
| 695 |
+
\\ tuple
|
| 696 |
+
\\ /
|
| 697 |
+
tuple
|
| 698 |
+
|
|
| 699 |
+
getitem(1)
|
| 700 |
+
|
|
| 701 |
+
getitem(0)
|
| 702 |
+
|
|
| 703 |
+
d
|
| 704 |
+
|
| 705 |
+
After:
|
| 706 |
+
|
| 707 |
+
b
|
| 708 |
+
|
|
| 709 |
+
d
|
| 710 |
+
"""
|
| 711 |
+
def find_patterns(
|
| 712 |
+
node: Node,
|
| 713 |
+
index_stack: List[int],
|
| 714 |
+
current_pattern: List[Node],
|
| 715 |
+
matched_patterns: List[List[Node]],
|
| 716 |
+
seen: Set[Tuple[Node, Tuple[int, ...]]]):
|
| 717 |
+
"""
|
| 718 |
+
Traverse the graph recursively to match for the N-tuple - N-getitem patterns,
|
| 719 |
+
starting at the given node.
|
| 720 |
+
|
| 721 |
+
We use a stack to keep track of the expected `getitem` indices, since these are
|
| 722 |
+
reversed from the `tuple` indices. In the above example, the stack after
|
| 723 |
+
(b -> tuple -> tuple) will be [0, 1], which will be popped by getitem(1) first
|
| 724 |
+
and then by getitem(0).
|
| 725 |
+
|
| 726 |
+
TODO: traverse upwards from the output and handle the case when tuple is not a
|
| 727 |
+
separate node, e.g. graph.call_function(operator.getitem, args=(a, (b, c)))
|
| 728 |
+
"""
|
| 729 |
+
if len(index_stack) == 0 and len(current_pattern) > 0:
|
| 730 |
+
matched_patterns.append(copy.copy(current_pattern))
|
| 731 |
+
current_pattern.clear()
|
| 732 |
+
|
| 733 |
+
# Avoid duplicating work
|
| 734 |
+
state = (node, tuple(index_stack))
|
| 735 |
+
if state in seen:
|
| 736 |
+
return
|
| 737 |
+
seen.add(state)
|
| 738 |
+
|
| 739 |
+
# Iterate through users of this node to find tuple/getitem nodes to match
|
| 740 |
+
for user in node.users:
|
| 741 |
+
if user.op == "call_function" and user.target == tuple:
|
| 742 |
+
for i, user_arg in enumerate(user.args[0]): # type: ignore[arg-type]
|
| 743 |
+
if user_arg == node:
|
| 744 |
+
index_stack.append(i)
|
| 745 |
+
current_pattern.append(user)
|
| 746 |
+
find_patterns(user, index_stack, current_pattern, matched_patterns, seen)
|
| 747 |
+
elif user.op == "call_function" and user.target == operator.getitem:
|
| 748 |
+
if len(index_stack) > 0:
|
| 749 |
+
if user.args[1] == index_stack[-1]:
|
| 750 |
+
index_stack.pop()
|
| 751 |
+
current_pattern.append(user)
|
| 752 |
+
find_patterns(user, index_stack, current_pattern, matched_patterns, seen)
|
| 753 |
+
return matched_patterns
|
| 754 |
+
|
| 755 |
+
# Collect all matched patterns
|
| 756 |
+
matched_patterns: List[List[Node]] = []
|
| 757 |
+
seen: Set[Tuple[Node, Tuple[int, ...]]] = set() # (node, index_stack)
|
| 758 |
+
for node in graph.nodes:
|
| 759 |
+
find_patterns(node, [], [], matched_patterns, seen)
|
| 760 |
+
|
| 761 |
+
# For each pattern, redirect all consumers of the last getitem node to the correct input
|
| 762 |
+
# of the first tuple node
|
| 763 |
+
for pattern in matched_patterns:
|
| 764 |
+
first_tuple = pattern[0]
|
| 765 |
+
last_getitem = pattern[-1]
|
| 766 |
+
assert first_tuple.op == "call_function" and first_tuple.target == tuple
|
| 767 |
+
assert last_getitem.op == "call_function" and last_getitem.target == operator.getitem
|
| 768 |
+
last_getitem_index = last_getitem.args[1]
|
| 769 |
+
new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index]
|
| 770 |
+
for user in list(last_getitem.users.keys()):
|
| 771 |
+
user.replace_input_with(last_getitem, new_input)
|
| 772 |
+
|
| 773 |
+
def _get_observer_from_activation_post_process(
|
| 774 |
+
activation_post_process: Union[ObserverBase, FakeQuantizeBase],
|
| 775 |
+
) -> ObserverBase:
|
| 776 |
+
"""
|
| 777 |
+
If `activation_post_process` is an observer, return the observer.
|
| 778 |
+
If `activation_post_process` is a fake quantize, return the internal observer.
|
| 779 |
+
"""
|
| 780 |
+
if isinstance(activation_post_process, ObserverBase):
|
| 781 |
+
return activation_post_process
|
| 782 |
+
else:
|
| 783 |
+
assert isinstance(activation_post_process, FakeQuantizeBase)
|
| 784 |
+
return activation_post_process.activation_post_process # type: ignore[return-value]
|
| 785 |
+
|
| 786 |
+
def _qconfig_satisfies_dtype_config_constraints(
|
| 787 |
+
qconfig: QConfigAny,
|
| 788 |
+
dtype_with_constraints: DTypeWithConstraints,
|
| 789 |
+
is_activation: bool = True) -> bool:
|
| 790 |
+
"""
|
| 791 |
+
Return whether `qconfig` satisfies the following constraints from the backend,
|
| 792 |
+
specified through the activation and weight DTypeWithConstraints.
|
| 793 |
+
|
| 794 |
+
1. QConfig specified a quantization range that falls within the backend's, if any
|
| 795 |
+
2. QConfig specified a min scale value that is >= the backend's, if any
|
| 796 |
+
3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has
|
| 797 |
+
scale and zero point that match the backend's, if any
|
| 798 |
+
|
| 799 |
+
If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`.
|
| 800 |
+
If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True.
|
| 801 |
+
"""
|
| 802 |
+
# TODO: log warnings only when the user enabled a debug flag
|
| 803 |
+
def _activation_post_process_satisfies_dtype_config_constraints(
|
| 804 |
+
activation_post_process: Union[ObserverBase, FakeQuantizeBase],
|
| 805 |
+
dtype_with_constraints: DTypeWithConstraints,
|
| 806 |
+
debug_string: str) -> bool:
|
| 807 |
+
observer = _get_observer_from_activation_post_process(activation_post_process)
|
| 808 |
+
app_quant_min = getattr(observer, "quant_min", None)
|
| 809 |
+
app_quant_max = getattr(observer, "quant_max", None)
|
| 810 |
+
# TODO: for now, just use the existing eps value as scale_min. In the future, we should
|
| 811 |
+
# resolve the differences between the two, either by renaming eps or some other way
|
| 812 |
+
app_scale_min = getattr(observer, "eps", None)
|
| 813 |
+
backend_quant_min = dtype_with_constraints.quant_min_lower_bound
|
| 814 |
+
backend_quant_max = dtype_with_constraints.quant_max_upper_bound
|
| 815 |
+
backend_scale_min = dtype_with_constraints.scale_min_lower_bound
|
| 816 |
+
backend_scale_exact_match = dtype_with_constraints.scale_exact_match
|
| 817 |
+
backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match
|
| 818 |
+
# check quantization ranges
|
| 819 |
+
if backend_quant_min is not None and backend_quant_max is not None:
|
| 820 |
+
if app_quant_min is None or app_quant_max is None:
|
| 821 |
+
warnings.warn(f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}")
|
| 822 |
+
return False
|
| 823 |
+
elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max:
|
| 824 |
+
warnings.warn(
|
| 825 |
+
f"QConfig {debug_string} quantization range must fall within the backend's:\n"
|
| 826 |
+
f"QConfig range = ({app_quant_min}, {app_quant_max}), "
|
| 827 |
+
f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), "
|
| 828 |
+
f"ignoring {qconfig}"
|
| 829 |
+
)
|
| 830 |
+
return False
|
| 831 |
+
# check scale min
|
| 832 |
+
if backend_scale_min is not None:
|
| 833 |
+
if app_scale_min is None:
|
| 834 |
+
warnings.warn(f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}")
|
| 835 |
+
return False
|
| 836 |
+
if app_scale_min < backend_scale_min:
|
| 837 |
+
warnings.warn(
|
| 838 |
+
f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to "
|
| 839 |
+
f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}"
|
| 840 |
+
)
|
| 841 |
+
return False
|
| 842 |
+
# check fixed scale and zero point
|
| 843 |
+
if backend_scale_exact_match is not None and backend_zero_point_exact_match is not None:
|
| 844 |
+
# For tests only, accept the following qconfigs for now
|
| 845 |
+
# TODO: handle fp16 qconfigs properly
|
| 846 |
+
for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]:
|
| 847 |
+
if qconfig_equals(qconfig, accepted_qconfig):
|
| 848 |
+
return True
|
| 849 |
+
suggestion_str = (
|
| 850 |
+
"Please use torch.ao.quantization.get_default_qconfig_mapping or "
|
| 851 |
+
"torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n"
|
| 852 |
+
" qconfig_mapping = get_default_qconfig_mapping(\"fbgemm\")\n"
|
| 853 |
+
" model = prepare_fx(model, qconfig_mapping, example_inputs)"
|
| 854 |
+
)
|
| 855 |
+
if not isinstance(activation_post_process, FixedQParamsObserver) and \
|
| 856 |
+
not isinstance(activation_post_process, FixedQParamsFakeQuantize):
|
| 857 |
+
warnings.warn(
|
| 858 |
+
f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize "
|
| 859 |
+
f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}"
|
| 860 |
+
)
|
| 861 |
+
return False
|
| 862 |
+
if observer.scale != backend_scale_exact_match or observer.zero_point != backend_zero_point_exact_match:
|
| 863 |
+
warnings.warn(
|
| 864 |
+
f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) "
|
| 865 |
+
f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), "
|
| 866 |
+
f"ignoring {qconfig}.\n{suggestion_str}"
|
| 867 |
+
)
|
| 868 |
+
return False
|
| 869 |
+
return True
|
| 870 |
+
|
| 871 |
+
if qconfig is None or dtype_with_constraints.dtype is None:
|
| 872 |
+
return True
|
| 873 |
+
|
| 874 |
+
activation_post_process_ctr = qconfig.activation if is_activation else qconfig.weight
|
| 875 |
+
debug_string = "activation" if is_activation else "weight"
|
| 876 |
+
satisfies_constraints = True
|
| 877 |
+
if activation_post_process_ctr is not None:
|
| 878 |
+
activation_post_process = activation_post_process_ctr()
|
| 879 |
+
assert _is_activation_post_process(activation_post_process)
|
| 880 |
+
# If dtypes don't match, don't check the activation_post_process and return True early
|
| 881 |
+
if activation_post_process.dtype != dtype_with_constraints.dtype:
|
| 882 |
+
return True
|
| 883 |
+
satisfies_constraints = _activation_post_process_satisfies_dtype_config_constraints(
|
| 884 |
+
activation_post_process, dtype_with_constraints, debug_string)
|
| 885 |
+
return satisfies_constraints
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-311.pyc
ADDED
|
Binary file (9.73 kB). View file
|
|
|