Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 +3 -0
- .venv/lib/python3.11/site-packages/torch/_export/converter.py +1584 -0
- .venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py +523 -0
- .venv/lib/python3.11/site-packages/torch/_export/pass_base.py +441 -0
- .venv/lib/python3.11/site-packages/torch/_export/tools.py +146 -0
- .venv/lib/python3.11/site-packages/torch/_export/verifier.py +456 -0
- .venv/lib/python3.11/site-packages/torch/_export/wrappers.py +121 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/__init__.py +55 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/computation.py +27 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/config.py +17 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/debug.py +22 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/device_context.py +26 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/extract_compiled_graph.py +225 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/metrics.py +22 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/ts_backend.py +7 -0
- .venv/lib/python3.11/site-packages/torch/multiprocessing/_atfork.py +35 -0
- .venv/lib/python3.11/site-packages/torch/multiprocessing/pool.py +52 -0
- .venv/lib/python3.11/site-packages/torch/multiprocessing/queue.py +43 -0
- .venv/lib/python3.11/site-packages/torch/multiprocessing/reductions.py +647 -0
- .venv/lib/python3.11/site-packages/torch/multiprocessing/spawn.py +328 -0
- .venv/lib/python3.11/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__init__.py +9 -0
- .venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__init__.py +39 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/fusion.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/init.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/prune.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/rnn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/stateless.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/nn/utils/_deprecation_utils.py +54 -0
.gitattributes
CHANGED
|
@@ -123,3 +123,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 123 |
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 filter=lfs diff=lfs merge=lfs -text
|
| 124 |
.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 filter=lfs diff=lfs merge=lfs -text
|
| 125 |
.venv/lib/python3.11/site-packages/opencv_python_headless.libs/libopenblas-r0-f650aae0.3.3.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 123 |
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 filter=lfs diff=lfs merge=lfs -text
|
| 124 |
.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 filter=lfs diff=lfs merge=lfs -text
|
| 125 |
.venv/lib/python3.11/site-packages/opencv_python_headless.libs/libopenblas-r0-f650aae0.3.3.so filter=lfs diff=lfs merge=lfs -text
|
| 126 |
+
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 filter=lfs diff=lfs merge=lfs -text
|
| 127 |
+
.venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:94fab98c15040558c3c80f2c1a2f5fda9baa72afc39a88bdcc82185f49d241c3
|
| 3 |
+
size 86326864
|
.venv/lib/python3.11/site-packages/torch/_export/converter.py
ADDED
|
@@ -0,0 +1,1584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import builtins
|
| 3 |
+
import logging
|
| 4 |
+
import operator
|
| 5 |
+
import typing
|
| 6 |
+
import warnings
|
| 7 |
+
from contextlib import contextmanager
|
| 8 |
+
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.export._trace
|
| 12 |
+
from torch import _C
|
| 13 |
+
from torch._export.passes.replace_quantized_ops_with_standard_ops_pass import (
|
| 14 |
+
replace_quantized_ops_with_standard_ops,
|
| 15 |
+
)
|
| 16 |
+
from torch.export.exported_program import ExportedProgram
|
| 17 |
+
from torch.export.graph_signature import (
|
| 18 |
+
ConstantArgument,
|
| 19 |
+
CustomObjArgument,
|
| 20 |
+
InputKind,
|
| 21 |
+
InputSpec,
|
| 22 |
+
OutputKind,
|
| 23 |
+
OutputSpec,
|
| 24 |
+
TensorArgument,
|
| 25 |
+
)
|
| 26 |
+
from torch.fx import subgraph_rewriter
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
log = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _get_param_count_list(method_graph, args_params):
|
| 33 |
+
param_count_list = []
|
| 34 |
+
for input_, arg_params_ in zip(method_graph.inputs(), args_params):
|
| 35 |
+
if "PackedParams" in str(input_.type()):
|
| 36 |
+
in_vars, _ = torch.jit._flatten(arg_params_)
|
| 37 |
+
param_count_list.append(len(in_vars))
|
| 38 |
+
else:
|
| 39 |
+
param_count_list.append(arg_params_ is not None)
|
| 40 |
+
|
| 41 |
+
return param_count_list
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _trace_and_get_graph_from_model(model, args):
|
| 45 |
+
# A basic sanity check: make sure the state_dict keys are the same
|
| 46 |
+
# before and after running the model. Fail fast!
|
| 47 |
+
orig_state_dict_keys = torch.jit._unique_state_dict(model).keys()
|
| 48 |
+
|
| 49 |
+
# Disable Autocast cache because it replaces kernel's weight and bias
|
| 50 |
+
# by (undesired) constants.
|
| 51 |
+
# No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665
|
| 52 |
+
prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
|
| 53 |
+
torch.set_autocast_cache_enabled(False)
|
| 54 |
+
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
|
| 55 |
+
model,
|
| 56 |
+
args,
|
| 57 |
+
strict=False,
|
| 58 |
+
_force_outplace=False,
|
| 59 |
+
_return_inputs_states=True,
|
| 60 |
+
)
|
| 61 |
+
torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
|
| 62 |
+
|
| 63 |
+
if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys():
|
| 64 |
+
raise RuntimeError(
|
| 65 |
+
"state_dict changed after running the tracer; "
|
| 66 |
+
"something weird is happening in your model!"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
return trace_graph, torch_out
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _create_jit_graph(
|
| 73 |
+
model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any]
|
| 74 |
+
) -> Tuple[torch.Graph, List["_C.IValue"], Any, Optional[torch.ScriptModule]]:
|
| 75 |
+
if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
|
| 76 |
+
flattened_args = tuple(torch.jit._flatten(tuple(args))[0])
|
| 77 |
+
torch_out = None
|
| 78 |
+
|
| 79 |
+
if isinstance(model, torch.jit.ScriptModule):
|
| 80 |
+
try:
|
| 81 |
+
graph = model.forward.graph # type: ignore[attr-defined]
|
| 82 |
+
except AttributeError as e:
|
| 83 |
+
raise RuntimeError("'forward' method must be a script method") from e
|
| 84 |
+
_C._jit_pass_onnx_function_substitution(graph)
|
| 85 |
+
freezed_module = _C._freeze_module(
|
| 86 |
+
typing.cast(_C.ScriptModule, model._c), preserveParameters=True
|
| 87 |
+
)
|
| 88 |
+
module, params = _C._jit_onnx_list_model_parameters(freezed_module)
|
| 89 |
+
method_graph = module._get_method("forward").graph
|
| 90 |
+
args_params = tuple(args) + tuple(params)
|
| 91 |
+
param_count_list = _get_param_count_list(method_graph, args_params)
|
| 92 |
+
in_vars, _ = torch.jit._flatten(args_params)
|
| 93 |
+
graph = _C._propagate_and_assign_input_shapes(
|
| 94 |
+
method_graph, tuple(in_vars), param_count_list, False, False
|
| 95 |
+
)
|
| 96 |
+
return graph, params, torch_out, module
|
| 97 |
+
|
| 98 |
+
# torch.jit.ScriptFunction
|
| 99 |
+
params = []
|
| 100 |
+
graph = model.graph
|
| 101 |
+
_C._jit_pass_onnx_function_substitution(graph)
|
| 102 |
+
param_count_list = _get_param_count_list(graph, args)
|
| 103 |
+
graph = _C._propagate_and_assign_input_shapes(
|
| 104 |
+
graph, flattened_args, param_count_list, False, False
|
| 105 |
+
)
|
| 106 |
+
return graph, params, torch_out, None
|
| 107 |
+
|
| 108 |
+
graph, torch_out = _trace_and_get_graph_from_model(model, args)
|
| 109 |
+
_C._jit_pass_onnx_lint(graph)
|
| 110 |
+
state_dict = torch.jit._unique_state_dict(model)
|
| 111 |
+
params = list(state_dict.values())
|
| 112 |
+
graph_inputs = list(graph.inputs())
|
| 113 |
+
user_input_num = len(graph_inputs) - len(state_dict)
|
| 114 |
+
param_names = list(state_dict.keys())
|
| 115 |
+
for i, inp in enumerate(graph_inputs):
|
| 116 |
+
if i >= user_input_num:
|
| 117 |
+
inp.setDebugName(param_names[i - user_input_num])
|
| 118 |
+
_C._jit_pass_onnx_function_substitution(graph)
|
| 119 |
+
return graph, params, torch_out, None
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def list_add(a, b):
|
| 123 |
+
return a + b
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def list_append(container, element):
|
| 127 |
+
return container + [element]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def execute_subgraph_from_prim_loop(
|
| 131 |
+
subgraph, iter_idx, len_loop_local_arguments, *args, **kwargs
|
| 132 |
+
):
|
| 133 |
+
"""
|
| 134 |
+
subgraph: GraphModule from sub-block.
|
| 135 |
+
iter_idx: The index of interation.
|
| 136 |
+
len_loop_local_arguments: The number of loop local arguments in args.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
# Loop local variables. TS graph create those as inputs because their values
|
| 140 |
+
# are updated inside the loop.
|
| 141 |
+
loop_local_args = args[:len_loop_local_arguments]
|
| 142 |
+
# Global variables that are not passed in as inputs to the loop sub-blocks
|
| 143 |
+
# but are directly used. Most of time, their values are not updated, but
|
| 144 |
+
# the only exception is when there are some operations that perform inplace
|
| 145 |
+
# updates.
|
| 146 |
+
global_args = args[len_loop_local_arguments:]
|
| 147 |
+
return subgraph(*global_args, iter_idx, *loop_local_args, **kwargs)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule):
|
| 151 |
+
def pattern(im, dim, scale):
|
| 152 |
+
sym_size_int = torch.ops.aten.sym_size.int(im, dim)
|
| 153 |
+
scalar_tensor = torch.ops.aten.scalar_tensor(sym_size_int)
|
| 154 |
+
div_scalar_mode = torch.ops.aten.div.Scalar_mode(
|
| 155 |
+
scalar_tensor, scale, rounding_mode="trunc"
|
| 156 |
+
)
|
| 157 |
+
int_tensor = torch.ops.aten.Int.Tensor(div_scalar_mode)
|
| 158 |
+
return int_tensor
|
| 159 |
+
|
| 160 |
+
def replacement(im, dim, scale):
|
| 161 |
+
sym_size_int = torch.ops.aten.sym_size.int(im, dim)
|
| 162 |
+
return sym_size_int // scale
|
| 163 |
+
|
| 164 |
+
replaced_patterns = subgraph_rewriter.replace_pattern(gm, pattern, replacement)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def is_valid_for_codegen(name):
|
| 168 |
+
if len(name) == 0:
|
| 169 |
+
raise RuntimeError("Empty argument name for codegen")
|
| 170 |
+
if name[0].isdigit():
|
| 171 |
+
return False
|
| 172 |
+
return True
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def normalize_name(name: str, prefix: str = "rename") -> str:
|
| 176 |
+
name = name.replace(".", "_")
|
| 177 |
+
if is_valid_for_codegen(name):
|
| 178 |
+
return name
|
| 179 |
+
return f"{prefix}_{name}"
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def ir_name_to_func_name(name: str) -> str:
|
| 183 |
+
"""prim::If -> convert_prim_If"""
|
| 184 |
+
name_list = name.split("::")
|
| 185 |
+
return "convert_" + "_".join(name_list)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_node_as_placeholder_or_get_attr(fx_graph, name, is_top_level_graph):
|
| 189 |
+
if is_top_level_graph:
|
| 190 |
+
return fx_graph.get_attr(name)
|
| 191 |
+
return fx_graph.placeholder(name)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
_TORCH_DTYPE_TO_ENUM = {
|
| 195 |
+
torch.uint8: 0,
|
| 196 |
+
torch.int8: 1,
|
| 197 |
+
torch.int16: 2,
|
| 198 |
+
torch.int32: 3,
|
| 199 |
+
torch.int64: 4,
|
| 200 |
+
torch.float16: 5,
|
| 201 |
+
torch.float32: 6,
|
| 202 |
+
torch.float64: 7,
|
| 203 |
+
torch.complex32: 8,
|
| 204 |
+
torch.complex64: 9,
|
| 205 |
+
torch.complex128: 10,
|
| 206 |
+
torch.bool: 11,
|
| 207 |
+
torch.qint8: 12,
|
| 208 |
+
torch.quint8: 13,
|
| 209 |
+
torch.bfloat16: 15,
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
_TORCH_ENUM_TO_DTYPE = {value: key for key, value in _TORCH_DTYPE_TO_ENUM.items()}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def get_dtype_as_int(tensor):
|
| 216 |
+
"""
|
| 217 |
+
prim::dtype has the signature "Tensor a) -> int", where it gets the dtype of
|
| 218 |
+
the tensor and returns the integer corresponding to this dtype based on the
|
| 219 |
+
enum in ScalarType.h
|
| 220 |
+
"""
|
| 221 |
+
dtype = tensor.dtype
|
| 222 |
+
if dtype not in _TORCH_DTYPE_TO_ENUM:
|
| 223 |
+
raise RuntimeError(f"Unsupported dtype {dtype}")
|
| 224 |
+
return _TORCH_DTYPE_TO_ENUM[dtype]
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# Those operators will be automatically populated to a instance method
|
| 228 |
+
# of TS2FXGraphConverter with name convert_<namespace>_<opname>().
|
| 229 |
+
# Please check __init__ for method population implementations.
|
| 230 |
+
kind_to_standard_operators = {
|
| 231 |
+
"prim::max": builtins.max,
|
| 232 |
+
"prim::min": builtins.min,
|
| 233 |
+
"prim::TupleIndex": operator.getitem,
|
| 234 |
+
"aten::__is__": operator.is_,
|
| 235 |
+
"aten::__isnot__": operator.is_not,
|
| 236 |
+
"aten::__not__": operator.not_,
|
| 237 |
+
"aten::__contains__": operator.contains,
|
| 238 |
+
"prim::dtype": get_dtype_as_int,
|
| 239 |
+
"aten::len": len,
|
| 240 |
+
# Mapping from specialized op to its symbolic counterpart.
|
| 241 |
+
# They currently do not have any other overrides.
|
| 242 |
+
"aten::numel": torch.ops.aten.sym_numel,
|
| 243 |
+
"aten::size": torch.ops.aten.sym_size,
|
| 244 |
+
"aten::storage_offset": torch.ops.aten.sym_storage_offset,
|
| 245 |
+
"aten::stride": torch.ops.aten.sym_stride,
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def get_ir_value_parent_name_and_attr_name(node):
|
| 250 |
+
irv_parent_name, irv_name = node.input().debugName(), node.output().debugName()
|
| 251 |
+
attr_name = node.s("name")
|
| 252 |
+
return irv_name, irv_parent_name, attr_name
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def construct_fqn(ir, ref_map, name_map):
|
| 256 |
+
name_list = []
|
| 257 |
+
while ir in ref_map:
|
| 258 |
+
name_list.append(name_map[ir])
|
| 259 |
+
ir = ref_map[ir]
|
| 260 |
+
return ".".join(reversed(name_list))
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def get_block_to_lifted_attrs(graph: torch._C.Graph) -> Dict[torch._C.Block, Set[str]]:
|
| 264 |
+
"""
|
| 265 |
+
Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes.
|
| 266 |
+
When a graph has control flow, the graph will be divided into multiple blocks. We want to convert
|
| 267 |
+
each block to a graph which will be passed into torch.cond. A restriction for torch.cond is that model
|
| 268 |
+
parameters/buffers are expected to be lifted as inputs to the subgraphs. Before converting the model,
|
| 269 |
+
we will run this pass which will:
|
| 270 |
+
1. Figure out which params/buffers are used within blocks through tracing the GetAttr calls.
|
| 271 |
+
2. Process the graph bottom up to find the lifted attributes of each block by taking the union
|
| 272 |
+
of the attributes used in the current block, and the lifted attributes of all its child blocks.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
A mapping of blocks to a set of FQNs of its lifted attributes.
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
# A map from a block to its expected to be lifted arguments.
|
| 279 |
+
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]] = {}
|
| 280 |
+
|
| 281 |
+
# Reference map stores the input (i.e., src) and output (i.e., dest) IR of a
|
| 282 |
+
# GetAttr node. By traversing this reference map, we can figure out the
|
| 283 |
+
# full IR aliasing pass and figure out the FQN of an attribute.
|
| 284 |
+
# E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1"
|
| 285 |
+
node_to_parent_map: Dict[str, str] = {}
|
| 286 |
+
|
| 287 |
+
# Used for reconstructing the FQN of an attribute based on the reference map.
|
| 288 |
+
# In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR
|
| 289 |
+
# This name map stores which attribute name is called for a src IR --> dest IR action.
|
| 290 |
+
# E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear"
|
| 291 |
+
node_to_attr_name: Dict[str, str] = {}
|
| 292 |
+
|
| 293 |
+
def _dfs_get_attr_dependency(entry):
|
| 294 |
+
"""
|
| 295 |
+
First DFS path to construct reference map and name map.
|
| 296 |
+
"""
|
| 297 |
+
for node in entry.nodes():
|
| 298 |
+
if node.kind() == "prim::GetAttr":
|
| 299 |
+
(
|
| 300 |
+
irv_name,
|
| 301 |
+
irv_parent_name,
|
| 302 |
+
attr_name,
|
| 303 |
+
) = get_ir_value_parent_name_and_attr_name(node)
|
| 304 |
+
node_to_parent_map[irv_name] = irv_parent_name
|
| 305 |
+
node_to_attr_name[irv_name] = attr_name
|
| 306 |
+
for block in node.blocks():
|
| 307 |
+
_dfs_get_attr_dependency(block)
|
| 308 |
+
|
| 309 |
+
def _map_blocks_to_lifted_attrs(entry):
|
| 310 |
+
"""
|
| 311 |
+
Walk the graph in a bottom-up fashion to build the expected to be
|
| 312 |
+
lifted arguments for each block.
|
| 313 |
+
"""
|
| 314 |
+
arguments: Set[str] = set()
|
| 315 |
+
for node in entry.nodes():
|
| 316 |
+
for block in node.blocks():
|
| 317 |
+
# Recursively build.
|
| 318 |
+
arguments = arguments.union(_map_blocks_to_lifted_attrs(block))
|
| 319 |
+
if node.kind() == "prim::GetAttr":
|
| 320 |
+
irv_name = node.output().debugName()
|
| 321 |
+
# Skip for intermediate GetAttr, which will anyway not result a FQN.
|
| 322 |
+
# E.g., node_to_parent_name: {"%3": "%2", "%2": "%1"}
|
| 323 |
+
# node_to_attr_name: {"%3": "weight", "%2": "linear", "%1": "self"}
|
| 324 |
+
# There is only one FQN %3-->%2-->%1: self.linear.weight
|
| 325 |
+
# %2-->%1 is not a FQN: self.linear
|
| 326 |
+
if irv_name not in set(node_to_parent_map.values()):
|
| 327 |
+
arguments.add(
|
| 328 |
+
construct_fqn(irv_name, node_to_parent_map, node_to_attr_name)
|
| 329 |
+
)
|
| 330 |
+
if not isinstance(entry, torch._C.Graph): # Skip the top level.
|
| 331 |
+
blocks_to_lifted_attrs[entry] = arguments
|
| 332 |
+
return arguments
|
| 333 |
+
|
| 334 |
+
_dfs_get_attr_dependency(graph)
|
| 335 |
+
_map_blocks_to_lifted_attrs(graph)
|
| 336 |
+
|
| 337 |
+
return blocks_to_lifted_attrs
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def get_attribute_fqn_from_ts_node(
|
| 341 |
+
name_to_attribute_fqn: Dict[str, str], node: torch._C.Node
|
| 342 |
+
) -> str:
|
| 343 |
+
def get_attr(name: str):
|
| 344 |
+
if name in name_to_attribute_fqn:
|
| 345 |
+
return name_to_attribute_fqn[name]
|
| 346 |
+
else:
|
| 347 |
+
raise ValueError(f"Attribute {name} not found")
|
| 348 |
+
|
| 349 |
+
if node.kind() == "prim::SetAttr":
|
| 350 |
+
input_name = next(node.inputs()).debugName()
|
| 351 |
+
elif node.kind() == "prim::GetAttr":
|
| 352 |
+
input_name = node.input().debugName()
|
| 353 |
+
else:
|
| 354 |
+
raise RuntimeError(
|
| 355 |
+
f"Unexpected node kind when getting attribute fqn. node: {node} "
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
attr_name = node.s("name")
|
| 359 |
+
root_attr_name = get_attr(input_name)
|
| 360 |
+
attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name
|
| 361 |
+
|
| 362 |
+
return attr_fqn
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def get_op_overload(node: torch._C.Node):
|
| 366 |
+
schema_str = node.schema()
|
| 367 |
+
assert schema_str != "(no schema)", f"got empty schema for {node}"
|
| 368 |
+
schema: torch._C.FunctionSchema = torch._C.parse_schema(schema_str)
|
| 369 |
+
ns, op_name = str(schema.name).split("::")
|
| 370 |
+
override = schema.overload_name
|
| 371 |
+
|
| 372 |
+
try:
|
| 373 |
+
op_overload_mod = getattr(torch.ops, ns)
|
| 374 |
+
op_overload_packet = getattr(op_overload_mod, op_name)
|
| 375 |
+
if override:
|
| 376 |
+
op_overload = getattr(op_overload_packet, override)
|
| 377 |
+
else:
|
| 378 |
+
op_overload = op_overload_packet.default
|
| 379 |
+
except Exception as e:
|
| 380 |
+
raise RuntimeError(
|
| 381 |
+
f"Unable to find operator {node.kind()} with schema {node.schema()}"
|
| 382 |
+
) from e
|
| 383 |
+
|
| 384 |
+
return op_overload
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class TS2FXGraphConverter:
|
| 388 |
+
def __init__(
|
| 389 |
+
self,
|
| 390 |
+
ts_graph: Union[torch._C.Graph, torch._C.Block],
|
| 391 |
+
name_to_param: Dict[str, torch.Tensor],
|
| 392 |
+
name_to_buffer: Dict[str, torch.Tensor],
|
| 393 |
+
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]],
|
| 394 |
+
name_to_non_tensor_attribute: Dict[str, Any],
|
| 395 |
+
name_to_constant: Dict[str, Any],
|
| 396 |
+
):
|
| 397 |
+
self.ts_graph = ts_graph
|
| 398 |
+
self.name_to_param = name_to_param
|
| 399 |
+
self.name_to_buffer = name_to_buffer
|
| 400 |
+
|
| 401 |
+
self.fx_graph: torch.fx.Graph = torch.fx.Graph()
|
| 402 |
+
self.input_specs: List[InputSpec] = []
|
| 403 |
+
self.output_specs: List[OutputSpec] = []
|
| 404 |
+
|
| 405 |
+
self.name_to_node: Dict[
|
| 406 |
+
str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]]
|
| 407 |
+
] = {}
|
| 408 |
+
self.name_to_constant: Dict[str, Any] = name_to_constant
|
| 409 |
+
|
| 410 |
+
# Mapping from torchscript node output name to attribute fully qualified name
|
| 411 |
+
self.name_to_attribute_fqn: Dict[str, str] = {}
|
| 412 |
+
|
| 413 |
+
# Mapping from fully qualified name to real values or a fx graph node
|
| 414 |
+
# During convert, this represents the current value of a non-tensor attribute
|
| 415 |
+
# One use case is:
|
| 416 |
+
# def forward(self, x):
|
| 417 |
+
# c1 = self.count
|
| 418 |
+
# self.count += 1
|
| 419 |
+
# c2 = self.count
|
| 420 |
+
# return x + c1 + c2
|
| 421 |
+
self.name_to_non_tensor_attribute_node: Dict[str, Any] = {}
|
| 422 |
+
|
| 423 |
+
# Mapping from fully qualified name to initial real values inputs
|
| 424 |
+
# We separate it from self.name_to_non_tensor_attribute_node since
|
| 425 |
+
# we need initial real value input when we construct fx.GraphModule
|
| 426 |
+
self.name_to_non_tensor_attribute: Dict[str, Any] = name_to_non_tensor_attribute
|
| 427 |
+
|
| 428 |
+
self.subgraphs: Dict[str, torch.fx.GraphModule] = {}
|
| 429 |
+
|
| 430 |
+
self.blocks_to_lifted_attrs = blocks_to_lifted_attrs
|
| 431 |
+
|
| 432 |
+
# Populate methods for the standard operators.
|
| 433 |
+
for k in kind_to_standard_operators.keys():
|
| 434 |
+
handler_func_name = ir_name_to_func_name(k)
|
| 435 |
+
# Create an indirect function call:
|
| 436 |
+
# convert_<namespace>_<opname> --> lambda node: _convert_standard_operator(node)
|
| 437 |
+
setattr(
|
| 438 |
+
self,
|
| 439 |
+
handler_func_name,
|
| 440 |
+
lambda node: self._convert_standard_operators(node),
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
# This stores a list of return results that do not appear in the original TS
|
| 444 |
+
# graph's outputs. The reason we maintain this is because some operations in the sub-block
|
| 445 |
+
# might have inplace updates to the variable defined in the parent fx graph. After
|
| 446 |
+
# the execution of that sub-block, the variable defined in the parent fx graph also
|
| 447 |
+
# needs to be updated.
|
| 448 |
+
self.name_update_from_subblock_to_parent: Set[str] = set()
|
| 449 |
+
|
| 450 |
+
def _is_get_attr_node(self, fqn):
|
| 451 |
+
return (
|
| 452 |
+
fqn in self.name_to_buffer
|
| 453 |
+
or fqn in self.name_to_param
|
| 454 |
+
or (
|
| 455 |
+
fqn in self.name_to_constant
|
| 456 |
+
and isinstance(self.name_to_constant[fqn], torch.ScriptObject)
|
| 457 |
+
)
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: List[str]):
|
| 461 |
+
subgraph_nodes, subgraph_converters = [], []
|
| 462 |
+
for block in node.blocks():
|
| 463 |
+
subgraph_converter = TS2FXGraphConverter(
|
| 464 |
+
block,
|
| 465 |
+
self.name_to_param,
|
| 466 |
+
self.name_to_buffer,
|
| 467 |
+
self.blocks_to_lifted_attrs,
|
| 468 |
+
{},
|
| 469 |
+
self.name_to_constant,
|
| 470 |
+
)
|
| 471 |
+
subgraph_converter.name_to_attribute_fqn = self.name_to_attribute_fqn
|
| 472 |
+
|
| 473 |
+
for block_arg in arguments:
|
| 474 |
+
normalized_block_arg_name = normalize_name(block_arg)
|
| 475 |
+
placeholder_node = subgraph_converter.fx_graph.placeholder(
|
| 476 |
+
normalized_block_arg_name
|
| 477 |
+
)
|
| 478 |
+
subgraph_converter.name_to_node[block_arg] = placeholder_node
|
| 479 |
+
|
| 480 |
+
subgraph = subgraph_converter.convert()
|
| 481 |
+
subgraph_name = self.add_subgraph(subgraph)
|
| 482 |
+
subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name))
|
| 483 |
+
subgraph_converters.append(subgraph_converter)
|
| 484 |
+
return subgraph_nodes, subgraph_converters
|
| 485 |
+
|
| 486 |
+
def _identify_inputs_as_arguments(self, entry):
|
| 487 |
+
"""
|
| 488 |
+
Identify inputs from the innermost sub-block. This is needed
|
| 489 |
+
for nested sub-blocks when the input is hidden in the nested sub-block.
|
| 490 |
+
E.g., example IR of input is hidden in the nested sub-block.
|
| 491 |
+
Graph[x.1]
|
| 492 |
+
%1 = ...
|
| 493 |
+
Block[]
|
| 494 |
+
Block[x.1]
|
| 495 |
+
%2 = x.1 ...
|
| 496 |
+
"""
|
| 497 |
+
arguments: Set[str] = set()
|
| 498 |
+
for block in entry.blocks():
|
| 499 |
+
for block_node in block.nodes():
|
| 500 |
+
for block_node_in in block_node.inputs():
|
| 501 |
+
if (
|
| 502 |
+
block_node_in.debugName() in self.name_to_node
|
| 503 |
+
and block_node_in.debugName() not in self.name_to_attribute_fqn
|
| 504 |
+
):
|
| 505 |
+
arguments.add(block_node_in.debugName())
|
| 506 |
+
arguments = arguments.union(
|
| 507 |
+
self._identify_inputs_as_arguments(block_node)
|
| 508 |
+
)
|
| 509 |
+
return arguments
|
| 510 |
+
|
| 511 |
+
def is_top_level_graph(self):
|
| 512 |
+
return isinstance(self.ts_graph, torch._C.Graph)
|
| 513 |
+
|
| 514 |
+
def add_subgraph(self, subgraph) -> str:
|
| 515 |
+
name = f"subgraph_{len(self.subgraphs)}"
|
| 516 |
+
self.subgraphs[name] = subgraph
|
| 517 |
+
return name
|
| 518 |
+
|
| 519 |
+
def get_args_kwargs(self, node: torch._C.Node, schema):
|
| 520 |
+
args = []
|
| 521 |
+
kwargs = {}
|
| 522 |
+
for input, schema_arg in zip(node.inputs(), schema.arguments):
|
| 523 |
+
if schema_arg.kwarg_only:
|
| 524 |
+
kwargs[schema_arg.name] = self.get_fx_value_by_ir_value(input)
|
| 525 |
+
else:
|
| 526 |
+
args.append(self.get_fx_value_by_ir_value(input))
|
| 527 |
+
|
| 528 |
+
return tuple(args), kwargs
|
| 529 |
+
|
| 530 |
+
def get_fx_value_by_ir_value(self, value: torch._C.Value):
|
| 531 |
+
value_name = value.debugName()
|
| 532 |
+
|
| 533 |
+
if value_name in self.name_to_node:
|
| 534 |
+
input_node = self.name_to_node[value_name]
|
| 535 |
+
return input_node
|
| 536 |
+
elif value_name in self.name_to_constant:
|
| 537 |
+
if isinstance(self.name_to_constant[value_name], torch.ScriptObject):
|
| 538 |
+
return self.fx_graph.get_attr(value_name)
|
| 539 |
+
return self.name_to_constant[value_name]
|
| 540 |
+
else:
|
| 541 |
+
raise ValueError(f"Input {value_name} not found")
|
| 542 |
+
|
| 543 |
+
def get_fx_value_by_fqn(self, name):
|
| 544 |
+
if name in self.name_to_node:
|
| 545 |
+
fx_node = self.name_to_node[name]
|
| 546 |
+
elif name in self.name_to_constant:
|
| 547 |
+
fx_node = self.name_to_constant[name]
|
| 548 |
+
elif name in self.name_to_non_tensor_attribute_node:
|
| 549 |
+
fx_node = self.name_to_non_tensor_attribute_node[name]
|
| 550 |
+
elif name in self.name_to_non_tensor_attribute:
|
| 551 |
+
fx_node = self.name_to_non_tensor_attribute[name]
|
| 552 |
+
else:
|
| 553 |
+
raise ValueError(f"Attribute {name} not found")
|
| 554 |
+
return fx_node
|
| 555 |
+
|
| 556 |
+
def convert(self) -> torch.fx.GraphModule:
|
| 557 |
+
self.convert_graph_inputs()
|
| 558 |
+
|
| 559 |
+
for node in self.ts_graph.nodes():
|
| 560 |
+
self.convert_node(node)
|
| 561 |
+
|
| 562 |
+
self.convert_graph_outputs()
|
| 563 |
+
|
| 564 |
+
# Pass parameter and buffer to the root for lookup.
|
| 565 |
+
gm = torch.fx.GraphModule(
|
| 566 |
+
{
|
| 567 |
+
**self.subgraphs,
|
| 568 |
+
**self.name_to_param,
|
| 569 |
+
**self.name_to_buffer,
|
| 570 |
+
**self.name_to_non_tensor_attribute,
|
| 571 |
+
**self.name_to_constant,
|
| 572 |
+
},
|
| 573 |
+
self.fx_graph,
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
inplace_optimize_sym_size_div(gm)
|
| 577 |
+
|
| 578 |
+
gm.graph.lint()
|
| 579 |
+
|
| 580 |
+
return gm
|
| 581 |
+
|
| 582 |
+
def convert_graph_inputs(self):
|
| 583 |
+
for graph_input in self.ts_graph.inputs():
|
| 584 |
+
name = graph_input.debugName()
|
| 585 |
+
|
| 586 |
+
if name in self.name_to_param:
|
| 587 |
+
normalized_name = normalize_name(name)
|
| 588 |
+
self.input_specs.append(
|
| 589 |
+
InputSpec(
|
| 590 |
+
InputKind.PARAMETER,
|
| 591 |
+
arg=TensorArgument(name=normalized_name),
|
| 592 |
+
target=name,
|
| 593 |
+
)
|
| 594 |
+
)
|
| 595 |
+
fx_node = get_node_as_placeholder_or_get_attr(
|
| 596 |
+
self.fx_graph, name, self.is_top_level_graph()
|
| 597 |
+
)
|
| 598 |
+
elif name in self.name_to_buffer:
|
| 599 |
+
normalized_name = normalize_name(name)
|
| 600 |
+
self.input_specs.append(
|
| 601 |
+
InputSpec(
|
| 602 |
+
InputKind.BUFFER,
|
| 603 |
+
arg=TensorArgument(name=normalized_name),
|
| 604 |
+
target=name,
|
| 605 |
+
persistent=True,
|
| 606 |
+
)
|
| 607 |
+
)
|
| 608 |
+
fx_node = get_node_as_placeholder_or_get_attr(
|
| 609 |
+
self.fx_graph, name, self.is_top_level_graph()
|
| 610 |
+
)
|
| 611 |
+
elif name in self.name_to_constant:
|
| 612 |
+
assert isinstance(
|
| 613 |
+
self.name_to_constant[name], torch.ScriptObject
|
| 614 |
+
), "Input conversion only handles ScriptObject"
|
| 615 |
+
normalized_name = normalize_name(name)
|
| 616 |
+
self.input_specs.append(
|
| 617 |
+
InputSpec(
|
| 618 |
+
InputKind.CUSTOM_OBJ,
|
| 619 |
+
arg=CustomObjArgument(
|
| 620 |
+
name=normalized_name, class_fqn=normalized_name
|
| 621 |
+
),
|
| 622 |
+
target=name,
|
| 623 |
+
persistent=False,
|
| 624 |
+
)
|
| 625 |
+
)
|
| 626 |
+
fx_node = get_node_as_placeholder_or_get_attr(
|
| 627 |
+
self.fx_graph, name, self.is_top_level_graph()
|
| 628 |
+
)
|
| 629 |
+
elif isinstance(graph_input.type(), torch.ClassType):
|
| 630 |
+
# Directly skip inputs that are ScriptObject but not used in the graph.
|
| 631 |
+
continue
|
| 632 |
+
else:
|
| 633 |
+
normalized_name = normalize_name(name, prefix="input")
|
| 634 |
+
self.input_specs.append(
|
| 635 |
+
InputSpec(
|
| 636 |
+
InputKind.USER_INPUT,
|
| 637 |
+
arg=TensorArgument(name=normalized_name),
|
| 638 |
+
target=name,
|
| 639 |
+
)
|
| 640 |
+
)
|
| 641 |
+
fx_node = self.fx_graph.placeholder(normalized_name)
|
| 642 |
+
|
| 643 |
+
self.name_to_node[name] = fx_node
|
| 644 |
+
|
| 645 |
+
def convert_aten_Float(self, node: torch._C.Node):
|
| 646 |
+
def to_float_tensor(t):
|
| 647 |
+
return t.to(dtype=torch.float).item()
|
| 648 |
+
|
| 649 |
+
inp_list = [
|
| 650 |
+
self.get_fx_value_by_ir_value(inp) for inp in node.inputs()
|
| 651 |
+
] # noqa: C416
|
| 652 |
+
fx_node = self.fx_graph.call_function(
|
| 653 |
+
to_float_tensor,
|
| 654 |
+
tuple(inp_list),
|
| 655 |
+
)
|
| 656 |
+
self.name_to_node[node.output().debugName()] = fx_node
|
| 657 |
+
|
| 658 |
+
def convert_aten_tensor(self, node: torch._C.Node):
|
| 659 |
+
"""aten::tensor creates a constant tensor ad-hoc --> GetAttr"""
|
| 660 |
+
args, kwargs = self.get_args_kwargs(node, torch.ops.aten.tensor.default._schema)
|
| 661 |
+
|
| 662 |
+
for k in kwargs:
|
| 663 |
+
if k == "requires_grad":
|
| 664 |
+
kwargs[k] = bool(kwargs[k]) # 0 -> False, 1 -> True
|
| 665 |
+
|
| 666 |
+
to_tensor = (
|
| 667 |
+
torch.tensor
|
| 668 |
+
if all(isinstance(a, int) for a in args)
|
| 669 |
+
else torch._refs.tensor
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
def target(*args, **kwargs):
|
| 673 |
+
if "dtype" in kwargs and kwargs["dtype"] is not None:
|
| 674 |
+
kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]]
|
| 675 |
+
return to_tensor(*args, **kwargs)
|
| 676 |
+
|
| 677 |
+
# def to_dynamic_tensor(*args, **kwargs):
|
| 678 |
+
# if "dtype" in kwargs and kwargs["dtype"] is not None:
|
| 679 |
+
# kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]]
|
| 680 |
+
# return torch._refs.tensor(*args, **kwargs)
|
| 681 |
+
|
| 682 |
+
output_name = node.output().debugName()
|
| 683 |
+
fx_node = self.fx_graph.call_function(target, args, kwargs)
|
| 684 |
+
self.name_to_node[output_name] = fx_node
|
| 685 |
+
|
| 686 |
+
def convert_aten_append(self, node: torch._C.Node):
|
| 687 |
+
# special handle python list append: "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)"
|
| 688 |
+
|
| 689 |
+
# inplace append to the list!! This is kinda crazy, as we are inplace mutating the list
|
| 690 |
+
# This makes the converter "non-functional", and the result depends on the order of the nodes being converter
|
| 691 |
+
# In a sense, the converter now becomes an stateful interpreter
|
| 692 |
+
warnings.warn(
|
| 693 |
+
"Converting aten::append.t, which is a inplace mutation of the list. "
|
| 694 |
+
"This makes the converter non-functional: the result depends on the order of the append nodes being converter!"
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
args = tuple(self.get_fx_value_by_ir_value(inp) for inp in node.inputs())
|
| 698 |
+
fx_node = self.fx_graph.call_function(list_append, args)
|
| 699 |
+
self.name_to_node[node.output().debugName()] = fx_node
|
| 700 |
+
|
| 701 |
+
# inplace mutate arg[0], which is the python list
|
| 702 |
+
self.name_to_node[node.inputsAt(0).debugName()] = fx_node
|
| 703 |
+
|
| 704 |
+
# Variables that need to be updated to parent module.
|
| 705 |
+
if not self.is_top_level_graph() and args[0].op == "placeholder":
|
| 706 |
+
self.name_update_from_subblock_to_parent.add(node.inputsAt(0).debugName())
|
| 707 |
+
|
| 708 |
+
def convert_prim_Constant(self, node: torch._C.Node):
|
| 709 |
+
name = node.output().debugName()
|
| 710 |
+
|
| 711 |
+
value: Any = None
|
| 712 |
+
if node.hasAttribute("value"):
|
| 713 |
+
constant_kind = node.kindOf("value")
|
| 714 |
+
if constant_kind == "i":
|
| 715 |
+
value = node.i("value")
|
| 716 |
+
elif constant_kind == "f":
|
| 717 |
+
value = node.f("value")
|
| 718 |
+
elif constant_kind == "s":
|
| 719 |
+
value = node.s("value")
|
| 720 |
+
elif constant_kind == "t":
|
| 721 |
+
alias_name = (
|
| 722 |
+
f"lifted_tensor_{name}" # Follow naming convention from EP tracing.
|
| 723 |
+
)
|
| 724 |
+
fx_node = self.fx_graph.get_attr(alias_name)
|
| 725 |
+
self.name_to_node[name] = fx_node
|
| 726 |
+
name, value = alias_name, node.t("value")
|
| 727 |
+
elif constant_kind == "ival":
|
| 728 |
+
value = node.ival("value")
|
| 729 |
+
else:
|
| 730 |
+
raise ValueError(f"Unsupported constant type: {node.kindOf('value')}")
|
| 731 |
+
else:
|
| 732 |
+
value = None
|
| 733 |
+
|
| 734 |
+
self.name_to_constant[name] = value
|
| 735 |
+
|
| 736 |
+
def convert_prim_CallMethod(self, node: torch._C.Node):
|
| 737 |
+
inp_list = [
|
| 738 |
+
self.get_fx_value_by_ir_value(inp) for inp in node.inputs()
|
| 739 |
+
] # noqa: C416
|
| 740 |
+
fx_node = self.fx_graph.call_method(
|
| 741 |
+
node.s("name"),
|
| 742 |
+
tuple(inp_list),
|
| 743 |
+
)
|
| 744 |
+
self.name_to_node[node.output().debugName()] = fx_node
|
| 745 |
+
|
| 746 |
+
def convert_prim_device(self, node: torch._C.Node):
|
| 747 |
+
input_type = node.input().type()
|
| 748 |
+
if input_type.isSubtypeOf(torch._C.TensorType.get()):
|
| 749 |
+
device = input_type.device() # type: ignore[attr-defined]
|
| 750 |
+
output_name = node.output().debugName()
|
| 751 |
+
self.name_to_constant[output_name] = device
|
| 752 |
+
else:
|
| 753 |
+
raise ValueError(f"Unsupported JitType ({input_type}) when get device")
|
| 754 |
+
|
| 755 |
+
def convert_prim_GetAttr(self, node: torch._C.Node):
|
| 756 |
+
# Build fully qulified name
|
| 757 |
+
attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node)
|
| 758 |
+
output_name = node.output().debugName()
|
| 759 |
+
self.name_to_attribute_fqn[output_name] = attr_fqn
|
| 760 |
+
|
| 761 |
+
if self.is_top_level_graph():
|
| 762 |
+
if self._is_get_attr_node(attr_fqn):
|
| 763 |
+
# We insert a get_attr node due to two reasons.
|
| 764 |
+
# First, ts graph does not lift tensor constants as input nodes. So
|
| 765 |
+
# tensor constants may be ignored by in convert_graph_inputs().
|
| 766 |
+
# Second, attr_fqn may have been written to via SetAttr. Two
|
| 767 |
+
# GetAttr may give different values.
|
| 768 |
+
self.name_to_node[output_name] = self.fx_graph.get_attr(attr_fqn)
|
| 769 |
+
else:
|
| 770 |
+
if attr_fqn not in self.name_to_non_tensor_attribute_node:
|
| 771 |
+
self.name_to_non_tensor_attribute_node[
|
| 772 |
+
attr_fqn
|
| 773 |
+
] = self.name_to_non_tensor_attribute[attr_fqn]
|
| 774 |
+
self.name_to_node[output_name] = self.name_to_non_tensor_attribute_node[
|
| 775 |
+
attr_fqn
|
| 776 |
+
]
|
| 777 |
+
else:
|
| 778 |
+
# Special support for if blocks which do not allow SetAttr TorchScript
|
| 779 |
+
# node and get_attr FX Graph Node.
|
| 780 |
+
if self._is_get_attr_node(attr_fqn):
|
| 781 |
+
self.name_to_node[output_name] = self.name_to_node[attr_fqn]
|
| 782 |
+
|
| 783 |
+
def convert_prim_SetAttr(self, node: torch._C.Node):
|
| 784 |
+
attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node)
|
| 785 |
+
attr_value = tuple(node.inputs())[1]
|
| 786 |
+
ts_graph_tensor_input = self.get_fx_value_by_ir_value(attr_value)
|
| 787 |
+
if self._is_get_attr_node(attr_fqn):
|
| 788 |
+
fx_attr_node = self.fx_graph.get_attr(attr_fqn)
|
| 789 |
+
self.fx_graph.call_function(
|
| 790 |
+
torch.Tensor.copy_, (fx_attr_node, ts_graph_tensor_input)
|
| 791 |
+
)
|
| 792 |
+
else:
|
| 793 |
+
self.name_to_non_tensor_attribute_node[attr_fqn] = ts_graph_tensor_input
|
| 794 |
+
|
| 795 |
+
def convert_call_function_op(self, node: torch._C.Node):
|
| 796 |
+
target = get_op_overload(node)
|
| 797 |
+
|
| 798 |
+
args, kwargs = self.get_args_kwargs(node, target._schema)
|
| 799 |
+
|
| 800 |
+
fx_node = self.fx_graph.call_function(target, args, kwargs)
|
| 801 |
+
|
| 802 |
+
# TODO: covnert sourceRange() into stack_trace
|
| 803 |
+
# fx_node.meta["stack_trace"] = node.sourceRange()
|
| 804 |
+
|
| 805 |
+
if node.outputsSize() == 1:
|
| 806 |
+
output_name = node.output().debugName()
|
| 807 |
+
self.name_to_node[output_name] = fx_node
|
| 808 |
+
else:
|
| 809 |
+
for i, outp in enumerate(node.outputs()):
|
| 810 |
+
output_name = outp.debugName()
|
| 811 |
+
next_fx_node = self.fx_graph.call_function(
|
| 812 |
+
operator.getitem, (fx_node, i)
|
| 813 |
+
)
|
| 814 |
+
self.name_to_node[output_name] = next_fx_node
|
| 815 |
+
|
| 816 |
+
def convert_prim_TupleConstruct(self, node: torch._C.Node):
|
| 817 |
+
self._convert_prim_iterator(node)
|
| 818 |
+
|
| 819 |
+
def convert_prim_ListConstruct(self, node: torch._C.Node):
|
| 820 |
+
self._convert_prim_iterator(node)
|
| 821 |
+
|
| 822 |
+
def _convert_prim_iterator(self, node: torch._C.Node):
|
| 823 |
+
output_list = []
|
| 824 |
+
for inp in node.inputs():
|
| 825 |
+
output_list.append(self.get_fx_value_by_ir_value(inp))
|
| 826 |
+
|
| 827 |
+
output_name = node.output().debugName()
|
| 828 |
+
self.name_to_node[output_name] = output_list
|
| 829 |
+
|
| 830 |
+
def convert_prim_DictConstruct(self, node: torch._C.Node):
|
| 831 |
+
output_dict = {}
|
| 832 |
+
k, v = None, None
|
| 833 |
+
for i, inp in enumerate(node.inputs()):
|
| 834 |
+
# We assume key value are stored in pair in the DictConstruct.
|
| 835 |
+
# The first element is the key and the following is the value.
|
| 836 |
+
if i % 2 == 0:
|
| 837 |
+
k = self.get_fx_value_by_ir_value(inp)
|
| 838 |
+
else:
|
| 839 |
+
v = self.get_fx_value_by_ir_value(inp)
|
| 840 |
+
assert (
|
| 841 |
+
k is not None and v is not None
|
| 842 |
+
), "DictConstruct has an empty key value pair."
|
| 843 |
+
output_dict[k] = v
|
| 844 |
+
k, v = None, None
|
| 845 |
+
|
| 846 |
+
assert (
|
| 847 |
+
k is None and v is None
|
| 848 |
+
), "DictConstruct has an odd number of elements (violating our assumption)."
|
| 849 |
+
|
| 850 |
+
output_name = node.output().debugName()
|
| 851 |
+
self.name_to_node[output_name] = output_dict
|
| 852 |
+
|
| 853 |
+
def convert_prim_ListUnpack(self, node: torch._C.Node):
|
| 854 |
+
self._convert_prim_unpack_iterator(node)
|
| 855 |
+
|
| 856 |
+
def convert_prim_TupleUnpack(self, node: torch._C.Node):
|
| 857 |
+
self._convert_prim_unpack_iterator(node)
|
| 858 |
+
|
| 859 |
+
def _convert_prim_unpack_iterator(self, node: torch._C.Node):
|
| 860 |
+
# Single input and multiple outputs for unpacking.
|
| 861 |
+
for i, outp in enumerate(node.outputs()):
|
| 862 |
+
outp_name = outp.debugName()
|
| 863 |
+
inp = self.get_fx_value_by_ir_value(node.input())
|
| 864 |
+
fx_node = self.fx_graph.call_function(operator.getitem, (inp, i))
|
| 865 |
+
self.name_to_node[outp_name] = fx_node
|
| 866 |
+
|
| 867 |
+
def convert_aten_Int(self, node: torch._C.Node):
|
| 868 |
+
# converts aten::Int as aten._to_copy + aten::_local_scalar_dense
|
| 869 |
+
target = torch.ops.aten._to_copy.default
|
| 870 |
+
args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
|
| 871 |
+
to_copy_node = self.fx_graph.call_function(target, args, {"dtype": torch.int32})
|
| 872 |
+
|
| 873 |
+
fx_node = self.fx_graph.call_function(
|
| 874 |
+
torch.ops.aten._local_scalar_dense.default, (to_copy_node,)
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
# TODO: covnert sourceRange() into stack_trace
|
| 878 |
+
# fx_node.meta["stack_trace"] = node.sourceRange()
|
| 879 |
+
|
| 880 |
+
output_name = node.output().debugName()
|
| 881 |
+
self.name_to_node[output_name] = fx_node
|
| 882 |
+
|
| 883 |
+
def convert_prim_NumToTensor(self, node: torch._C.Node):
|
| 884 |
+
# Converts prim::NumToTensor as aten.scalar_tensor.
|
| 885 |
+
# prim::NumToTensor IRs are currently triggered by:
|
| 886 |
+
# .size() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L950
|
| 887 |
+
# .numel() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L971
|
| 888 |
+
# For both of those APIs, torch.jit.trace implicitly sets the output tensor type
|
| 889 |
+
# to be LongTensor.
|
| 890 |
+
target = torch.ops.aten.scalar_tensor
|
| 891 |
+
args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
|
| 892 |
+
|
| 893 |
+
fx_node = self.fx_graph.call_function(target, args, {"dtype": torch.long})
|
| 894 |
+
output_name = node.output().debugName()
|
| 895 |
+
self.name_to_node[output_name] = fx_node
|
| 896 |
+
|
| 897 |
+
def convert_prim_CreateObject(self, node: torch._C.Node):
|
| 898 |
+
output_name = node.output().debugName()
|
| 899 |
+
self.name_to_attribute_fqn[output_name] = ""
|
| 900 |
+
|
| 901 |
+
def convert_aten__convolution(self, node: torch._C.Node):
|
| 902 |
+
# converts aten::_convolution as aten.convolution, since aten::_convolution
|
| 903 |
+
# doesn't have a meta function
|
| 904 |
+
target = torch.ops.aten.convolution.default
|
| 905 |
+
args, kwargs = self.get_args_kwargs(node, target._schema)
|
| 906 |
+
|
| 907 |
+
fx_node = self.fx_graph.call_function(target, args, kwargs)
|
| 908 |
+
|
| 909 |
+
output_name = node.output().debugName()
|
| 910 |
+
self.name_to_node[output_name] = fx_node
|
| 911 |
+
|
| 912 |
+
def convert_aten_div(self, node: torch._C.Node):
|
| 913 |
+
target = get_op_overload(node)
|
| 914 |
+
schema = target._schema
|
| 915 |
+
|
| 916 |
+
args, kwargs = self.get_args_kwargs(node, schema)
|
| 917 |
+
|
| 918 |
+
# converts aten::div.Tensor_mode(x, tensor_constant)
|
| 919 |
+
# as aten.div.Scalar_mode(x, tensor_constant.item())
|
| 920 |
+
if schema.overload_name == "Tensor_mode":
|
| 921 |
+
arg1_name = args[1].name
|
| 922 |
+
if arg1_name in self.name_to_constant and isinstance(
|
| 923 |
+
self.name_to_constant[arg1_name], torch.Tensor
|
| 924 |
+
):
|
| 925 |
+
tensor_constant = self.name_to_constant[arg1_name]
|
| 926 |
+
if tensor_constant.numel() == 1:
|
| 927 |
+
updated_args = list(args)
|
| 928 |
+
updated_args[1] = self.name_to_constant[arg1_name].item()
|
| 929 |
+
|
| 930 |
+
fx_node = self.fx_graph.call_function(
|
| 931 |
+
torch.ops.aten.div.Scalar_mode,
|
| 932 |
+
tuple(updated_args),
|
| 933 |
+
kwargs,
|
| 934 |
+
)
|
| 935 |
+
|
| 936 |
+
# TODO: covnert sourceRange() into stack_trace
|
| 937 |
+
# fx_node.meta["stack_trace"] = node.sourceRange()
|
| 938 |
+
|
| 939 |
+
output_name = node.output().debugName()
|
| 940 |
+
self.name_to_node[output_name] = fx_node
|
| 941 |
+
return
|
| 942 |
+
|
| 943 |
+
self.convert_call_function_op(node)
|
| 944 |
+
|
| 945 |
+
def convert_aten___getitem__(self, node: torch._C.Node):
|
| 946 |
+
input_container, index = tuple(
|
| 947 |
+
self.get_fx_value_by_ir_value(input) for input in node.inputs()
|
| 948 |
+
)
|
| 949 |
+
fx_node = self.fx_graph.call_function(
|
| 950 |
+
operator.getitem, (input_container, index)
|
| 951 |
+
)
|
| 952 |
+
output_name = node.output().debugName()
|
| 953 |
+
self.name_to_node[output_name] = fx_node
|
| 954 |
+
|
| 955 |
+
def convert_aten_to(self, node: torch._C.Node):
|
| 956 |
+
target = get_op_overload(node)
|
| 957 |
+
args, kwargs = self.get_args_kwargs(node, target._schema)
|
| 958 |
+
|
| 959 |
+
# special handle aten.to.dtype and aten.to.prim_dtype followed by inplace_mutation_op
|
| 960 |
+
# coz aten.to + inplace_mutation_op pattern would trigger
|
| 961 |
+
# "cannot mutate tensors with frozen storage" functionalization error.
|
| 962 |
+
# To work around the issue, we override the copy to be True, so that the output
|
| 963 |
+
# is for sure not an alias of input
|
| 964 |
+
if target == torch.ops.aten.to.dtype or target == torch.ops.aten.to.prim_dtype:
|
| 965 |
+
user_nodes = [use.user for use in node.output().uses()]
|
| 966 |
+
user_targets = [
|
| 967 |
+
get_op_overload(user_node)
|
| 968 |
+
for user_node in user_nodes
|
| 969 |
+
if user_node.schema() != "(no schema)"
|
| 970 |
+
]
|
| 971 |
+
has_mutable_target = any(
|
| 972 |
+
target._schema.is_mutable for target in user_targets
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
if has_mutable_target:
|
| 976 |
+
assert len(args) >= 4
|
| 977 |
+
new_args = list(args)
|
| 978 |
+
new_args[3] = True # copy, override to True
|
| 979 |
+
fx_node = self.fx_graph.call_function(
|
| 980 |
+
torch.ops.aten.to.dtype, tuple(new_args)
|
| 981 |
+
)
|
| 982 |
+
# temp hack to work around the issue https://github.com/pytorch/pytorch/issues/131679
|
| 983 |
+
# When this issue is fixed, the clone node would be no longer needed
|
| 984 |
+
clone_node = self.fx_graph.call_function(
|
| 985 |
+
torch.ops.aten.clone.default, (fx_node,)
|
| 986 |
+
)
|
| 987 |
+
output_name = node.output().debugName()
|
| 988 |
+
self.name_to_node[output_name] = clone_node
|
| 989 |
+
return
|
| 990 |
+
|
| 991 |
+
self.convert_call_function_op(node)
|
| 992 |
+
|
| 993 |
+
def convert_aten_add(self, node: torch._C.Node):
|
| 994 |
+
if node.schema() == "(no schema)":
|
| 995 |
+
if isinstance(node.inputsAt(0).type(), torch.ListType) and isinstance(
|
| 996 |
+
node.inputsAt(1).type(), torch.ListType
|
| 997 |
+
):
|
| 998 |
+
target = torch.ops.aten.add.t
|
| 999 |
+
else:
|
| 1000 |
+
raise RuntimeError(f"unable to determind the target for {node}")
|
| 1001 |
+
else:
|
| 1002 |
+
target = get_op_overload(node)
|
| 1003 |
+
|
| 1004 |
+
if target == torch.ops.aten.add.t:
|
| 1005 |
+
# special handle python list/tuple add: "aten::add.t(t[] a, t[] b) -> t[]" for
|
| 1006 |
+
# RuntimeError: aten::add() Expected a value of type 'List[t]' for argument 'a' but instead found type 'immutable_list'.
|
| 1007 |
+
args, kwargs = self.get_args_kwargs(node, target._schema)
|
| 1008 |
+
output_name = node.output().debugName()
|
| 1009 |
+
self.name_to_node[output_name] = self.fx_graph.call_function(list_add, args)
|
| 1010 |
+
else:
|
| 1011 |
+
self.convert_call_function_op(node)
|
| 1012 |
+
|
| 1013 |
+
def _check_prim_loop_support(self, node):
|
| 1014 |
+
inputs = list(node.inputs())
|
| 1015 |
+
|
| 1016 |
+
# TODO: (1/N) stage.
|
| 1017 |
+
if inputs[0].debugName() not in self.name_to_constant:
|
| 1018 |
+
raise RuntimeError(
|
| 1019 |
+
"prim::Loop currently cannot run with dynamic value of number of iterations."
|
| 1020 |
+
)
|
| 1021 |
+
|
| 1022 |
+
# Make sure the condition is not updated in the subblock.
|
| 1023 |
+
subblock = next(node.blocks())
|
| 1024 |
+
condition_output_name = next(subblock.outputs()).debugName()
|
| 1025 |
+
for node in subblock.nodes():
|
| 1026 |
+
if (
|
| 1027 |
+
node.outputsSize() == 1
|
| 1028 |
+
and node.output().debugName() == condition_output_name
|
| 1029 |
+
):
|
| 1030 |
+
raise RuntimeError(
|
| 1031 |
+
"prim::Loop currently cannot run with dynamic value of condition."
|
| 1032 |
+
)
|
| 1033 |
+
if node.outputsSize() >= 2:
|
| 1034 |
+
for outp in node.outputs():
|
| 1035 |
+
if outp.debugName() == condition_output_name:
|
| 1036 |
+
raise RuntimeError(
|
| 1037 |
+
"prim::Loop currently cannot run with dynamic value of condition."
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
def convert_prim_Loop(self, node: torch._C.Node):
|
| 1041 |
+
inputs = list(node.inputs())
|
| 1042 |
+
self._check_prim_loop_support(node)
|
| 1043 |
+
|
| 1044 |
+
num_iterations = self.get_fx_value_by_ir_value(inputs[0])
|
| 1045 |
+
|
| 1046 |
+
# Find inputs.
|
| 1047 |
+
loop_local_arguments = [inp.debugName() for inp in inputs[2:]]
|
| 1048 |
+
|
| 1049 |
+
global_arguments = self._identify_inputs_as_arguments(node)
|
| 1050 |
+
|
| 1051 |
+
# Lift parameters as inputs.
|
| 1052 |
+
for block in node.blocks():
|
| 1053 |
+
global_arguments = global_arguments.union(
|
| 1054 |
+
self.blocks_to_lifted_attrs[block]
|
| 1055 |
+
)
|
| 1056 |
+
|
| 1057 |
+
global_arguments = list(global_arguments)
|
| 1058 |
+
|
| 1059 |
+
subgraph_nodes, subgraph_converters = self._convert_block_to_subgraph(
|
| 1060 |
+
node, global_arguments
|
| 1061 |
+
)
|
| 1062 |
+
|
| 1063 |
+
assert len(subgraph_nodes) == 1
|
| 1064 |
+
subgraph_converter = subgraph_converters[0]
|
| 1065 |
+
if not self.is_top_level_graph():
|
| 1066 |
+
self.name_update_from_subblock_to_parent = (
|
| 1067 |
+
self.name_update_from_subblock_to_parent.union(
|
| 1068 |
+
subgraph_converter.name_update_from_subblock_to_parent
|
| 1069 |
+
)
|
| 1070 |
+
)
|
| 1071 |
+
|
| 1072 |
+
fx_block_args = [
|
| 1073 |
+
self.get_fx_value_by_fqn(name)
|
| 1074 |
+
for name in loop_local_arguments + global_arguments
|
| 1075 |
+
]
|
| 1076 |
+
for iter_idx in range(num_iterations):
|
| 1077 |
+
loop_node = self.fx_graph.call_function(
|
| 1078 |
+
execute_subgraph_from_prim_loop,
|
| 1079 |
+
# Check execute_node function for the expected arguments order.
|
| 1080 |
+
(
|
| 1081 |
+
subgraph_nodes[0],
|
| 1082 |
+
iter_idx,
|
| 1083 |
+
len(loop_local_arguments),
|
| 1084 |
+
*fx_block_args,
|
| 1085 |
+
),
|
| 1086 |
+
{},
|
| 1087 |
+
)
|
| 1088 |
+
|
| 1089 |
+
# Update the value of loop local variables.
|
| 1090 |
+
if node.outputsSize() >= 1:
|
| 1091 |
+
for i, outp in enumerate(node.outputs()):
|
| 1092 |
+
output_name = outp.debugName()
|
| 1093 |
+
self.name_to_node[output_name] = self.fx_graph.call_function(
|
| 1094 |
+
operator.getitem,
|
| 1095 |
+
(
|
| 1096 |
+
loop_node,
|
| 1097 |
+
i + 1,
|
| 1098 |
+
), # + 1 because the 0th element is the condition.
|
| 1099 |
+
)
|
| 1100 |
+
fx_block_args[i] = self.name_to_node[output_name]
|
| 1101 |
+
|
| 1102 |
+
# Update the value of global variables, whose values are modified inplace.
|
| 1103 |
+
for i, name in enumerate(
|
| 1104 |
+
subgraph_converter.name_update_from_subblock_to_parent
|
| 1105 |
+
):
|
| 1106 |
+
self.name_to_node[name] = self.fx_graph.call_function(
|
| 1107 |
+
operator.getitem,
|
| 1108 |
+
(
|
| 1109 |
+
loop_node,
|
| 1110 |
+
i + node.outputsSize() + 1,
|
| 1111 |
+
), # + 1 because the 0th element is the condition.
|
| 1112 |
+
)
|
| 1113 |
+
global_argument_index = global_arguments.index(name)
|
| 1114 |
+
fx_block_args[
|
| 1115 |
+
i + node.outputsSize() + global_argument_index
|
| 1116 |
+
] = self.name_to_node[name]
|
| 1117 |
+
|
| 1118 |
+
def _check_set_attr_in_if_block(self, if_node: torch._C.Node):
|
| 1119 |
+
for block in if_node.blocks():
|
| 1120 |
+
for node in block.nodes():
|
| 1121 |
+
if node.kind() == "prim::SetAttr":
|
| 1122 |
+
raise RuntimeError(
|
| 1123 |
+
"During converting prim::If to torch.cond, found prim::SetAttr op"
|
| 1124 |
+
" which is not supported yet. Please file an issue if you come "
|
| 1125 |
+
"across this error."
|
| 1126 |
+
)
|
| 1127 |
+
|
| 1128 |
+
def convert_prim_If(self, node: torch._C.Node):
|
| 1129 |
+
self._check_set_attr_in_if_block(node)
|
| 1130 |
+
|
| 1131 |
+
inputs = list(node.inputs())
|
| 1132 |
+
assert len(inputs) == 1
|
| 1133 |
+
predicate = self.get_fx_value_by_ir_value(inputs[0])
|
| 1134 |
+
|
| 1135 |
+
# Find inputs.
|
| 1136 |
+
arguments = self._identify_inputs_as_arguments(node)
|
| 1137 |
+
|
| 1138 |
+
# Lift parameters as inputs.
|
| 1139 |
+
for block in node.blocks():
|
| 1140 |
+
arguments = arguments.union(self.blocks_to_lifted_attrs[block])
|
| 1141 |
+
|
| 1142 |
+
arguments = list(arguments)
|
| 1143 |
+
subgraph_nodes, _ = self._convert_block_to_subgraph(node, arguments)
|
| 1144 |
+
|
| 1145 |
+
assert len(subgraph_nodes) == 2
|
| 1146 |
+
|
| 1147 |
+
fx_block_args = [self.get_fx_value_by_fqn(name) for name in arguments]
|
| 1148 |
+
|
| 1149 |
+
args = (
|
| 1150 |
+
predicate,
|
| 1151 |
+
subgraph_nodes[0],
|
| 1152 |
+
subgraph_nodes[1],
|
| 1153 |
+
tuple(fx_block_args),
|
| 1154 |
+
)
|
| 1155 |
+
|
| 1156 |
+
cond_node = self.fx_graph.call_function(torch.cond, args, {})
|
| 1157 |
+
|
| 1158 |
+
# prim::If can also have zero output.
|
| 1159 |
+
if node.outputsSize() == 1:
|
| 1160 |
+
output_name = node.output().debugName()
|
| 1161 |
+
self.name_to_node[output_name] = cond_node
|
| 1162 |
+
elif node.outputsSize() > 1:
|
| 1163 |
+
for i, output in enumerate(node.outputs()):
|
| 1164 |
+
output_name = output.debugName()
|
| 1165 |
+
getitem = self.fx_graph.call_function(operator.getitem, (cond_node, i))
|
| 1166 |
+
self.name_to_node[output_name] = getitem
|
| 1167 |
+
|
| 1168 |
+
def convert_aten_Bool(self, node: torch._C.Node):
|
| 1169 |
+
self._convert_as_noop(node)
|
| 1170 |
+
|
| 1171 |
+
def convert_prim_Enter(self, node: torch._C.Node):
|
| 1172 |
+
# export generally treats prim::Enter as noop
|
| 1173 |
+
# The only context manager export supports is aten::enable_grad.
|
| 1174 |
+
# Unfortunately, TorchScript does not support aten::enable_grad yet.
|
| 1175 |
+
# TODO: support aten::enable_grad in both TorchScript and Converter.
|
| 1176 |
+
return
|
| 1177 |
+
|
| 1178 |
+
def convert_prim_Exit(self, node: torch._C.Node):
|
| 1179 |
+
# export treats prim::Exit as noop
|
| 1180 |
+
return
|
| 1181 |
+
|
| 1182 |
+
def _convert_as_noop(self, node: torch._C.Node):
|
| 1183 |
+
# Converts the node as a no-op by mapping its output node as arg[0]
|
| 1184 |
+
|
| 1185 |
+
target = get_op_overload(node)
|
| 1186 |
+
schema = target._schema
|
| 1187 |
+
|
| 1188 |
+
args, kwargs = self.get_args_kwargs(node, schema)
|
| 1189 |
+
|
| 1190 |
+
output_name = node.output().debugName()
|
| 1191 |
+
self.name_to_node[output_name] = args[0]
|
| 1192 |
+
|
| 1193 |
+
def convert_profiler__record_function_exit(self, node: torch._C.Node):
|
| 1194 |
+
# _record_function_exit has side effect so we keep it in fx.graph
|
| 1195 |
+
# currently, _record_function_enter_new and _record_function_exit are
|
| 1196 |
+
# discarded during `retrace_as_exported_program`.
|
| 1197 |
+
target = torch.ops.profiler._record_function_exit
|
| 1198 |
+
args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
|
| 1199 |
+
self.fx_graph.call_function(target, args)
|
| 1200 |
+
|
| 1201 |
+
def convert_prim_tolist(self, node: torch._C.Node):
|
| 1202 |
+
# prim::tolist cannot be supported by `_convert_standard_operators`
|
| 1203 |
+
# since it requires call_method instead of call_function.
|
| 1204 |
+
target = "tolist"
|
| 1205 |
+
args = (self.get_fx_value_by_ir_value(next(node.inputs())),)
|
| 1206 |
+
fx_node = self.fx_graph.call_method(target, args)
|
| 1207 |
+
output_name = node.output().debugName()
|
| 1208 |
+
self.name_to_node[output_name] = fx_node
|
| 1209 |
+
|
| 1210 |
+
def convert_prim_Uninitialized(self, node: torch._C.Node):
|
| 1211 |
+
# `prim::Uninitialized` is inserted by the compiler when it can prove
|
| 1212 |
+
# the value will never be used. It can be introduced by exceptions,
|
| 1213 |
+
# breaks, continues, and returns.
|
| 1214 |
+
# So we add a dummy constant to the graph.
|
| 1215 |
+
output_name = node.output().debugName()
|
| 1216 |
+
self.name_to_constant[output_name] = torch.Tensor()
|
| 1217 |
+
|
| 1218 |
+
def _convert_standard_operators(self, node: torch._C.Node):
|
| 1219 |
+
target = kind_to_standard_operators[node.kind()]
|
| 1220 |
+
args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
|
| 1221 |
+
fx_node = self.fx_graph.call_function(target, args)
|
| 1222 |
+
output_name = node.output().debugName()
|
| 1223 |
+
self.name_to_node[output_name] = fx_node
|
| 1224 |
+
|
| 1225 |
+
def convert_node(self, node: torch._C.Node):
|
| 1226 |
+
node_kind = node.kind()
|
| 1227 |
+
|
| 1228 |
+
# Get handler based on namespace and operator name.
|
| 1229 |
+
# Provide a default node handler as well in case we don't find
|
| 1230 |
+
# matching converter for that.
|
| 1231 |
+
handler_func_name = ir_name_to_func_name(node_kind)
|
| 1232 |
+
handler_func = getattr(self, handler_func_name, self.convert_call_function_op)
|
| 1233 |
+
|
| 1234 |
+
# str calls print function implemented in CPP. To avoid repeating
|
| 1235 |
+
# the entire logic here, we simply keep first line from node string (getting rid
|
| 1236 |
+
# of sub-blocks IR prints).
|
| 1237 |
+
node_str = "".join(str(node).split("\n")[:1])
|
| 1238 |
+
log.debug("[%s] converts [%s]", handler_func.__name__, node_str)
|
| 1239 |
+
try:
|
| 1240 |
+
handler_func(node)
|
| 1241 |
+
except Exception as e:
|
| 1242 |
+
raise RuntimeError(f"TS2EPConverter failed for node {node_kind}") from e
|
| 1243 |
+
|
| 1244 |
+
def convert_graph_outputs(self):
|
| 1245 |
+
args = []
|
| 1246 |
+
outp_name_list = [outp.debugName() for outp in self.ts_graph.outputs()] + list(
|
| 1247 |
+
self.name_update_from_subblock_to_parent
|
| 1248 |
+
)
|
| 1249 |
+
for output_name in outp_name_list:
|
| 1250 |
+
if output_name in self.name_to_node:
|
| 1251 |
+
fx_node = self.name_to_node[output_name]
|
| 1252 |
+
# TODO: Revisit this later after HigherOrderOp design changes.
|
| 1253 |
+
# Currently, we cannot directly return input as output.
|
| 1254 |
+
if (
|
| 1255 |
+
not self.is_top_level_graph()
|
| 1256 |
+
and isinstance(fx_node, torch.fx.Node)
|
| 1257 |
+
and fx_node.op == "placeholder"
|
| 1258 |
+
):
|
| 1259 |
+
fx_node = self.fx_graph.call_function(torch.clone, (fx_node,))
|
| 1260 |
+
args.append(fx_node)
|
| 1261 |
+
self.output_specs.append(
|
| 1262 |
+
OutputSpec(
|
| 1263 |
+
OutputKind.USER_OUTPUT,
|
| 1264 |
+
arg=TensorArgument(name=output_name),
|
| 1265 |
+
target=output_name,
|
| 1266 |
+
)
|
| 1267 |
+
)
|
| 1268 |
+
elif output_name in self.name_to_constant:
|
| 1269 |
+
args.append(self.name_to_constant[output_name])
|
| 1270 |
+
self.output_specs.append(
|
| 1271 |
+
OutputSpec(
|
| 1272 |
+
OutputKind.USER_OUTPUT,
|
| 1273 |
+
arg=ConstantArgument(
|
| 1274 |
+
name=output_name, value=self.name_to_constant[output_name]
|
| 1275 |
+
),
|
| 1276 |
+
target=output_name,
|
| 1277 |
+
)
|
| 1278 |
+
)
|
| 1279 |
+
else:
|
| 1280 |
+
raise ValueError(f"Output {output_name} not found")
|
| 1281 |
+
|
| 1282 |
+
if len(args) == 0:
|
| 1283 |
+
# Sub-block of prim::If can have zero output.
|
| 1284 |
+
self.fx_graph.output([])
|
| 1285 |
+
elif len(args) == 1:
|
| 1286 |
+
self.fx_graph.output(
|
| 1287 |
+
args[0]
|
| 1288 |
+
) # Get rid of an extra list wrapped around final output.
|
| 1289 |
+
elif len(args) > 1:
|
| 1290 |
+
self.fx_graph.output(
|
| 1291 |
+
args
|
| 1292 |
+
) # For prim::Loop and prim::If with multiple outputs.
|
| 1293 |
+
else:
|
| 1294 |
+
# Sub-block of prim::Loop can have multiple outputs.
|
| 1295 |
+
self.fx_graph.output(args)
|
| 1296 |
+
|
| 1297 |
+
|
| 1298 |
+
class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
|
| 1299 |
+
"""
|
| 1300 |
+
Run TS2FXGraphConverter in an explain mode. It collects all failed operators conversions
|
| 1301 |
+
and provide that information to users. In order to collect all failed conversions, it
|
| 1302 |
+
also mocks some internal attributes (e.g., name_to_node).
|
| 1303 |
+
"""
|
| 1304 |
+
|
| 1305 |
+
class _DictMock(dict):
|
| 1306 |
+
def __init__(self, dict_data, mock_value):
|
| 1307 |
+
super().__init__(dict_data)
|
| 1308 |
+
self.mock_value = mock_value
|
| 1309 |
+
|
| 1310 |
+
def __getitem__(self, key):
|
| 1311 |
+
# If the original dictionary has the key, return its value.
|
| 1312 |
+
# Otherwise, return the mock value.
|
| 1313 |
+
if not super().__contains__(key):
|
| 1314 |
+
return self.mock_value
|
| 1315 |
+
return super().__getitem__(key)
|
| 1316 |
+
|
| 1317 |
+
def __contains__(self, key):
|
| 1318 |
+
return True
|
| 1319 |
+
|
| 1320 |
+
def __init__(
|
| 1321 |
+
self,
|
| 1322 |
+
ts_graph: Union[torch._C.Graph, torch._C.Block],
|
| 1323 |
+
name_to_param: Dict[str, torch.Tensor],
|
| 1324 |
+
name_to_buffer: Dict[str, torch.Tensor],
|
| 1325 |
+
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]],
|
| 1326 |
+
name_to_non_tensor_attribute: Dict[str, Any],
|
| 1327 |
+
name_to_constant: Dict[str, Any],
|
| 1328 |
+
):
|
| 1329 |
+
super().__init__(
|
| 1330 |
+
ts_graph,
|
| 1331 |
+
name_to_param,
|
| 1332 |
+
name_to_buffer,
|
| 1333 |
+
blocks_to_lifted_attrs,
|
| 1334 |
+
name_to_non_tensor_attribute,
|
| 1335 |
+
name_to_constant,
|
| 1336 |
+
)
|
| 1337 |
+
|
| 1338 |
+
# Data to keep track of unsupported nodes.
|
| 1339 |
+
self.unsupported_node_list: List[torch._C.Node] = []
|
| 1340 |
+
|
| 1341 |
+
# Add mock to needed attributes.
|
| 1342 |
+
self.name_to_node = ExplainTS2FXGraphConverter._DictMock(
|
| 1343 |
+
self.name_to_node,
|
| 1344 |
+
# Dummy node.
|
| 1345 |
+
torch.fx.Node(
|
| 1346 |
+
None, # type: ignore[arg-type]
|
| 1347 |
+
"mock",
|
| 1348 |
+
"call_function",
|
| 1349 |
+
lambda: None,
|
| 1350 |
+
(),
|
| 1351 |
+
{},
|
| 1352 |
+
),
|
| 1353 |
+
)
|
| 1354 |
+
|
| 1355 |
+
def explain(self):
|
| 1356 |
+
self.convert_graph_inputs()
|
| 1357 |
+
for node in self.ts_graph.nodes():
|
| 1358 |
+
self.convert_node(node)
|
| 1359 |
+
self.convert_graph_outputs()
|
| 1360 |
+
|
| 1361 |
+
def convert_node(self, node):
|
| 1362 |
+
try:
|
| 1363 |
+
super().convert_node(node)
|
| 1364 |
+
except Exception as e:
|
| 1365 |
+
self.unsupported_node_list.append(node)
|
| 1366 |
+
|
| 1367 |
+
|
| 1368 |
+
@contextmanager
|
| 1369 |
+
def disable_logging(log):
|
| 1370 |
+
disabled = log.disabled
|
| 1371 |
+
log.disabled = True
|
| 1372 |
+
try:
|
| 1373 |
+
yield
|
| 1374 |
+
finally:
|
| 1375 |
+
log.disabled = disabled
|
| 1376 |
+
|
| 1377 |
+
|
| 1378 |
+
class TS2EPConverter:
|
| 1379 |
+
# TorchScript model to ExportedProgram converter
|
| 1380 |
+
def __init__(
|
| 1381 |
+
self,
|
| 1382 |
+
ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction],
|
| 1383 |
+
sample_args: Tuple[Any, ...],
|
| 1384 |
+
sample_kwargs: Optional[Dict[str, Any]] = None,
|
| 1385 |
+
):
|
| 1386 |
+
self.ts_model = ts_model
|
| 1387 |
+
self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args)
|
| 1388 |
+
|
| 1389 |
+
self.sample_args = sample_args
|
| 1390 |
+
self.sample_kwargs = sample_kwargs
|
| 1391 |
+
|
| 1392 |
+
self.name_to_param: Dict[str, torch.Tensor] = {}
|
| 1393 |
+
self.name_to_buffer: Dict[str, torch.Tensor] = {}
|
| 1394 |
+
param_list = (
|
| 1395 |
+
list(self.ts_model.parameters())
|
| 1396 |
+
if not isinstance(self.ts_model, torch._C.ScriptFunction)
|
| 1397 |
+
else []
|
| 1398 |
+
)
|
| 1399 |
+
if not isinstance(self.ts_model, torch._C.ScriptFunction):
|
| 1400 |
+
for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr]
|
| 1401 |
+
# Check if tensor belongs to any parameter.
|
| 1402 |
+
if any(
|
| 1403 |
+
(tensor == param).all()
|
| 1404 |
+
for param in param_list
|
| 1405 |
+
if tensor.shape == param.shape
|
| 1406 |
+
):
|
| 1407 |
+
self.name_to_param[k] = tensor
|
| 1408 |
+
else:
|
| 1409 |
+
self.name_to_buffer[k] = tensor
|
| 1410 |
+
|
| 1411 |
+
self.name_to_non_tensor_attributes: Dict[str, Any] = {}
|
| 1412 |
+
self.name_to_constant: Dict[str, Any] = {}
|
| 1413 |
+
|
| 1414 |
+
self.lift_get_attr()
|
| 1415 |
+
|
| 1416 |
+
def convert(self) -> ExportedProgram:
|
| 1417 |
+
log.info(
|
| 1418 |
+
"""
|
| 1419 |
+
TS2EPConverter logging starts from here.
|
| 1420 |
+
|
| 1421 |
+
INFO: (TORCH_LOGS="export" <cmd>)
|
| 1422 |
+
* Log TorchScript IR.
|
| 1423 |
+
|
| 1424 |
+
DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
|
| 1425 |
+
* Log conversion IR by IR in a format of [<conversion handler name>] converts [<IR>].
|
| 1426 |
+
"""
|
| 1427 |
+
)
|
| 1428 |
+
log.info("TorchScript graph\n\n%s\n", self.ts_graph)
|
| 1429 |
+
|
| 1430 |
+
blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph)
|
| 1431 |
+
|
| 1432 |
+
graph_converter = TS2FXGraphConverter(
|
| 1433 |
+
self.ts_graph,
|
| 1434 |
+
self.name_to_param,
|
| 1435 |
+
self.name_to_buffer,
|
| 1436 |
+
blocks_to_lifted_attrs,
|
| 1437 |
+
self.name_to_non_tensor_attributes,
|
| 1438 |
+
self.name_to_constant,
|
| 1439 |
+
)
|
| 1440 |
+
gm = graph_converter.convert()
|
| 1441 |
+
|
| 1442 |
+
# Post-proccessing step to deal with quantized operators.
|
| 1443 |
+
replace_quantized_ops_with_standard_ops(gm)
|
| 1444 |
+
log.info("GraphModule: %s", gm.print_readable(print_output=False))
|
| 1445 |
+
|
| 1446 |
+
ep = self.retrace_as_exported_program(
|
| 1447 |
+
gm,
|
| 1448 |
+
graph_converter.name_to_constant,
|
| 1449 |
+
)
|
| 1450 |
+
log.info("%s", ep)
|
| 1451 |
+
|
| 1452 |
+
# Post-processing step to ensure ExportedProgram has the same state_dict as
|
| 1453 |
+
# the original TorchScript model. Throw warnings for additionally populated
|
| 1454 |
+
# state_dict entries.
|
| 1455 |
+
if not isinstance(self.ts_model, torch._C.ScriptFunction):
|
| 1456 |
+
for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr]
|
| 1457 |
+
if k not in ep.state_dict:
|
| 1458 |
+
warnings.warn(
|
| 1459 |
+
f"Manually populate {k} into state_dict ExportedProgram, but it is never used by the ExportedProgram."
|
| 1460 |
+
)
|
| 1461 |
+
ep.state_dict[k] = tensor
|
| 1462 |
+
|
| 1463 |
+
return ep
|
| 1464 |
+
|
| 1465 |
+
@disable_logging(log)
|
| 1466 |
+
def explain(self, print_output=True):
|
| 1467 |
+
blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph)
|
| 1468 |
+
|
| 1469 |
+
graph_converter = ExplainTS2FXGraphConverter(
|
| 1470 |
+
self.ts_graph,
|
| 1471 |
+
self.name_to_param,
|
| 1472 |
+
self.name_to_buffer,
|
| 1473 |
+
blocks_to_lifted_attrs,
|
| 1474 |
+
self.name_to_non_tensor_attributes,
|
| 1475 |
+
self.name_to_constant,
|
| 1476 |
+
)
|
| 1477 |
+
graph_converter.explain()
|
| 1478 |
+
if len(graph_converter.unsupported_node_list) > 0:
|
| 1479 |
+
explain_str = "Unsupported nodes are found in the following list:"
|
| 1480 |
+
for i, n in enumerate(graph_converter.unsupported_node_list):
|
| 1481 |
+
node_str = "".join(str(n).split("\n")[:1])
|
| 1482 |
+
explain_str += f"\n\n {i}. {n.kind()} [{node_str}]"
|
| 1483 |
+
else:
|
| 1484 |
+
explain_str = "Success!"
|
| 1485 |
+
if print_output:
|
| 1486 |
+
print(explain_str)
|
| 1487 |
+
return explain_str
|
| 1488 |
+
|
| 1489 |
+
def retrace_as_exported_program(
|
| 1490 |
+
self,
|
| 1491 |
+
gm: torch.fx.GraphModule,
|
| 1492 |
+
name_to_constant: Dict[str, Any],
|
| 1493 |
+
):
|
| 1494 |
+
# TODO: adjust input orders to match GraphSignature convention
|
| 1495 |
+
ep = torch.export._trace._export(
|
| 1496 |
+
gm,
|
| 1497 |
+
self.sample_args,
|
| 1498 |
+
strict=False,
|
| 1499 |
+
pre_dispatch=True,
|
| 1500 |
+
)
|
| 1501 |
+
|
| 1502 |
+
# Post-processing to make sure the ExportedProgram states are correct.
|
| 1503 |
+
# Because during conversion, we set tensor constants as GetAttr,
|
| 1504 |
+
# retracing cannot recognize them as tensor constants but instead
|
| 1505 |
+
# treat them as buffers. We need to set them again here.
|
| 1506 |
+
ep._constants.update(
|
| 1507 |
+
{
|
| 1508 |
+
k: v
|
| 1509 |
+
for k, v in name_to_constant.items()
|
| 1510 |
+
if isinstance(v, (torch.Tensor, torch.ScriptObject))
|
| 1511 |
+
}
|
| 1512 |
+
)
|
| 1513 |
+
for k in name_to_constant:
|
| 1514 |
+
ep.state_dict.pop(k, None)
|
| 1515 |
+
|
| 1516 |
+
for i, spec in enumerate(ep.graph_signature.input_specs):
|
| 1517 |
+
# Mark as constant tensors for erroneously traced buffers.
|
| 1518 |
+
if spec.kind == InputKind.BUFFER and spec.target in name_to_constant:
|
| 1519 |
+
assert isinstance(
|
| 1520 |
+
name_to_constant[spec.target], torch.Tensor
|
| 1521 |
+
), f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer"
|
| 1522 |
+
spec.kind = InputKind.CONSTANT_TENSOR
|
| 1523 |
+
ep.verifier().check(ep)
|
| 1524 |
+
|
| 1525 |
+
return ep
|
| 1526 |
+
|
| 1527 |
+
def lift_get_attr(self):
|
| 1528 |
+
# This function lifts multiple data types.
|
| 1529 |
+
|
| 1530 |
+
# 1. Tensor constants attributes (e.g., self.data = torch.tensor([2,3]))
|
| 1531 |
+
# to buffers. Currently, when there are tensor constants, export
|
| 1532 |
+
# would error and ask users to register tensor constants as buffers.
|
| 1533 |
+
# Since it is hard to manually do so for TorchScript models
|
| 1534 |
+
# (e.g., source code is missing), this function automatically
|
| 1535 |
+
# lifts tensor constants to be buffers.
|
| 1536 |
+
|
| 1537 |
+
# 2. ScriptObbject to constant. It will then be converted to getattr in
|
| 1538 |
+
# in the fx graph.
|
| 1539 |
+
#
|
| 1540 |
+
# This function should happen in TS2EPConverter instead of
|
| 1541 |
+
# TS2FXGraphConverter since it gets attributes from self.ts_model
|
| 1542 |
+
# which is not accessable in TS2FXGraphConverter. It is similar to where
|
| 1543 |
+
# we collect self.name_to_param and self.name_to_buffer.
|
| 1544 |
+
name_to_attribute_fqn: Dict[str, str] = {}
|
| 1545 |
+
|
| 1546 |
+
def get_attr(fqn: str):
|
| 1547 |
+
name = fqn.split(".")
|
| 1548 |
+
v = self.ts_model
|
| 1549 |
+
for n in name:
|
| 1550 |
+
v = getattr(v, n)
|
| 1551 |
+
return v
|
| 1552 |
+
|
| 1553 |
+
def get_fqn(node: torch._C.Node):
|
| 1554 |
+
attr_name = node.s("name")
|
| 1555 |
+
input_name = node.input().debugName()
|
| 1556 |
+
root_attr_name = name_to_attribute_fqn[input_name]
|
| 1557 |
+
attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name
|
| 1558 |
+
return attr_fqn
|
| 1559 |
+
|
| 1560 |
+
def _dfs_get_attr(block):
|
| 1561 |
+
for node in block.nodes():
|
| 1562 |
+
if node.kind() == "prim::CreateObject":
|
| 1563 |
+
output_name = node.output().debugName()
|
| 1564 |
+
name_to_attribute_fqn[output_name] = ""
|
| 1565 |
+
|
| 1566 |
+
if node.kind() == "prim::GetAttr":
|
| 1567 |
+
attr_fqn = get_fqn(node)
|
| 1568 |
+
value = get_attr(attr_fqn)
|
| 1569 |
+
output_name = node.output().debugName()
|
| 1570 |
+
name_to_attribute_fqn[output_name] = attr_fqn
|
| 1571 |
+
if isinstance(value, torch.Tensor):
|
| 1572 |
+
if attr_fqn not in self.name_to_buffer:
|
| 1573 |
+
# Lift tensor constants to be a buffer
|
| 1574 |
+
self.name_to_buffer[attr_fqn] = value
|
| 1575 |
+
elif isinstance(value, torch.ScriptObject):
|
| 1576 |
+
if attr_fqn not in self.name_to_constant:
|
| 1577 |
+
self.name_to_constant[attr_fqn] = value
|
| 1578 |
+
else:
|
| 1579 |
+
self.name_to_non_tensor_attributes[attr_fqn] = value
|
| 1580 |
+
|
| 1581 |
+
for subblock in node.blocks():
|
| 1582 |
+
_dfs_get_attr(subblock)
|
| 1583 |
+
|
| 1584 |
+
_dfs_get_attr(self.ts_graph)
|
.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
import inspect
|
| 4 |
+
import logging
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from typing import Any, Callable, Dict, List, Tuple, TYPE_CHECKING, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils._pytree as pytree
|
| 10 |
+
from torch._dynamo.source import (
|
| 11 |
+
AttrSource,
|
| 12 |
+
GetItemSource,
|
| 13 |
+
LocalSource,
|
| 14 |
+
TensorProperty,
|
| 15 |
+
TensorPropertySource,
|
| 16 |
+
)
|
| 17 |
+
from torch._dynamo.variables.builder import TrackedFake
|
| 18 |
+
from torch._export.passes.add_runtime_assertions_for_constraints_pass import InputDim
|
| 19 |
+
from torch._export.passes.lift_constants_pass import ConstantAttrMap
|
| 20 |
+
from torch._guards import Source
|
| 21 |
+
from torch._library.fake_class_registry import FakeScriptObject
|
| 22 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 23 |
+
from torch.export import Constraint
|
| 24 |
+
from torch.export.dynamic_shapes import (
|
| 25 |
+
_check_dynamic_shapes,
|
| 26 |
+
_combine_args,
|
| 27 |
+
_DimHint,
|
| 28 |
+
_process_dynamic_shapes,
|
| 29 |
+
_transform_shapes_for_default_dynamic,
|
| 30 |
+
_tree_map_with_path,
|
| 31 |
+
)
|
| 32 |
+
from torch.export.graph_signature import CustomObjArgument
|
| 33 |
+
from torch.fx.experimental import _config as config
|
| 34 |
+
from torch.fx.experimental.symbolic_shapes import (
|
| 35 |
+
_find_user_code_frame,
|
| 36 |
+
_suggest_fixes_for_data_dependent_error_non_strict,
|
| 37 |
+
ConstraintViolationError,
|
| 38 |
+
DimDynamic,
|
| 39 |
+
EqualityConstraint,
|
| 40 |
+
GuardOnDataDependentSymNode,
|
| 41 |
+
ShapeEnv,
|
| 42 |
+
StatelessSymbolicContext,
|
| 43 |
+
ValueRanges,
|
| 44 |
+
)
|
| 45 |
+
from torch.utils._pytree import (
|
| 46 |
+
GetAttrKey,
|
| 47 |
+
KeyPath,
|
| 48 |
+
MappingKey,
|
| 49 |
+
SequenceKey,
|
| 50 |
+
tree_map_with_path,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if TYPE_CHECKING:
|
| 55 |
+
from sympy import Symbol
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
log = logging.getLogger(__name__)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def key_path_to_source(kp: KeyPath) -> Source:
|
| 62 |
+
"""
|
| 63 |
+
Given a key path, return the source for the key path.
|
| 64 |
+
"""
|
| 65 |
+
source: Source = LocalSource("args")
|
| 66 |
+
for k in kp:
|
| 67 |
+
if isinstance(k, SequenceKey):
|
| 68 |
+
source = GetItemSource(source, k.idx)
|
| 69 |
+
elif isinstance(k, MappingKey):
|
| 70 |
+
source = GetItemSource(source, k.key)
|
| 71 |
+
elif isinstance(k, GetAttrKey):
|
| 72 |
+
source = AttrSource(source, k.name)
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError(f"Unknown KeyEntry {k}")
|
| 75 |
+
|
| 76 |
+
return source
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _is_constant_argument(t):
|
| 80 |
+
return t is None or isinstance(t, (int, float, bool, str))
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def fakify(
|
| 84 |
+
mode: FakeTensorMode,
|
| 85 |
+
kp: KeyPath,
|
| 86 |
+
t: Any,
|
| 87 |
+
t_constraints: Dict[int, Dict[int, Constraint]],
|
| 88 |
+
sources: Dict[Tuple[int, int], List[Source]],
|
| 89 |
+
):
|
| 90 |
+
source = key_path_to_source(kp)
|
| 91 |
+
if _is_constant_argument(t) or isinstance(t, torch.ScriptObject):
|
| 92 |
+
return t
|
| 93 |
+
|
| 94 |
+
if not isinstance(t, torch.Tensor):
|
| 95 |
+
raise ValueError(f"Unsupported input type {type(t)}")
|
| 96 |
+
n_dims = len(t.shape)
|
| 97 |
+
symbolic_context = StatelessSymbolicContext(
|
| 98 |
+
dynamic_sizes=[DimDynamic.DYNAMIC] * n_dims,
|
| 99 |
+
constraint_sizes=[None] * n_dims,
|
| 100 |
+
)
|
| 101 |
+
t_id = id(t)
|
| 102 |
+
assert mode.shape_env is not None
|
| 103 |
+
if t_id in t_constraints:
|
| 104 |
+
for i, constraint in t_constraints[t_id].items():
|
| 105 |
+
symbolic_context.constraint_sizes[i] = constraint.constraint_range
|
| 106 |
+
src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i)
|
| 107 |
+
sources[(t_id, i)].append(src)
|
| 108 |
+
mode.shape_env.source_name_to_debug_name[src.name()] = constraint.name # type: ignore[assignment]
|
| 109 |
+
fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context)
|
| 110 |
+
mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context)) # type: ignore[union-attr]
|
| 111 |
+
return fake
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def make_fake_inputs(
|
| 115 |
+
nn_module,
|
| 116 |
+
args,
|
| 117 |
+
kwargs,
|
| 118 |
+
dynamic_shapes,
|
| 119 |
+
_is_torch_jit_trace=False,
|
| 120 |
+
allow_complex_guards_as_runtime_asserts=False,
|
| 121 |
+
):
|
| 122 |
+
"""
|
| 123 |
+
Given an nn module, example inputs, and constraints, return a new fake mode,
|
| 124 |
+
fake inputs created in that mode whose dynamic shape dimensions are constrained
|
| 125 |
+
by the given ranges, and sources for pairs of dynamic shape dimensions that are
|
| 126 |
+
constrained to be equal.
|
| 127 |
+
"""
|
| 128 |
+
# TODO(avik): refactor Dynamo to avoid duplication of the following code
|
| 129 |
+
# between non-strict and strict.
|
| 130 |
+
# Specifically, here (non-strict) we do the following pre-tracing steps:
|
| 131 |
+
# - Fakify inputs.
|
| 132 |
+
# - Process input shape equalities.
|
| 133 |
+
# In strict, these steps are spread across multiple files:
|
| 134 |
+
# - output_graph.py fakifies inputs.
|
| 135 |
+
# - [post-tracing] guards.py processes input shape equalities.
|
| 136 |
+
|
| 137 |
+
combined_args = _combine_args(nn_module, args, kwargs)
|
| 138 |
+
_check_dynamic_shapes(combined_args, dynamic_shapes)
|
| 139 |
+
transformed_dynamic_shapes = _transform_shapes_for_default_dynamic(
|
| 140 |
+
combined_args, dynamic_shapes
|
| 141 |
+
)
|
| 142 |
+
constraints = _process_dynamic_shapes(combined_args, transformed_dynamic_shapes)
|
| 143 |
+
t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict)
|
| 144 |
+
for constraint in constraints:
|
| 145 |
+
t_constraints[constraint.t_id][constraint.dim] = constraint
|
| 146 |
+
|
| 147 |
+
context = torch._guards.TracingContext.try_get()
|
| 148 |
+
if context is not None:
|
| 149 |
+
# This occurs when we are exporting within dynamo. There already exists
|
| 150 |
+
# a toplevel TracingContext with a fake mode, so we do not want to
|
| 151 |
+
# create another fake mode.
|
| 152 |
+
fake_mode = context.fake_mode
|
| 153 |
+
elif not _is_torch_jit_trace:
|
| 154 |
+
code = nn_module.forward.__code__
|
| 155 |
+
co_fields = {
|
| 156 |
+
"co_name": code.co_name,
|
| 157 |
+
"co_filename": code.co_filename,
|
| 158 |
+
"co_firstlineno": code.co_firstlineno,
|
| 159 |
+
}
|
| 160 |
+
fake_mode = FakeTensorMode(
|
| 161 |
+
shape_env=ShapeEnv(
|
| 162 |
+
tracked_fakes=[],
|
| 163 |
+
co_fields=co_fields,
|
| 164 |
+
prefer_deferred_runtime_asserts_over_guards=True,
|
| 165 |
+
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
| 166 |
+
),
|
| 167 |
+
allow_non_fake_inputs=True,
|
| 168 |
+
export=True,
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
fake_mode = FakeTensorMode(
|
| 172 |
+
shape_env=ShapeEnv(
|
| 173 |
+
tracked_fakes=[],
|
| 174 |
+
prefer_deferred_runtime_asserts_over_guards=True,
|
| 175 |
+
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
| 176 |
+
),
|
| 177 |
+
allow_non_fake_inputs=True,
|
| 178 |
+
)
|
| 179 |
+
if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None:
|
| 180 |
+
raise ValueError(
|
| 181 |
+
"Detected fake_mode does not have a shape_env with tracked fakes. "
|
| 182 |
+
"If you constructed the module under a FakeTensorMode, "
|
| 183 |
+
"please initialize it like: FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[]))"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
with fake_mode:
|
| 187 |
+
# FIXME(ycao) ScriptMethod doesn't have signature, I am using an empty one to unblock
|
| 188 |
+
if not _is_torch_jit_trace:
|
| 189 |
+
original_signature = inspect.signature(nn_module.forward)
|
| 190 |
+
else:
|
| 191 |
+
original_signature = None
|
| 192 |
+
sources: Dict[Tuple[int, int], List[Source]] = defaultdict(list)
|
| 193 |
+
fake_args, fake_kwargs = tree_map_with_path(
|
| 194 |
+
lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
|
| 195 |
+
(args, kwargs),
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
names: Dict[str, Tuple[int, int]] = {}
|
| 199 |
+
source_pairs: List[Tuple[Source, Source]] = []
|
| 200 |
+
derived_equalities: List[Tuple[Source, Union[Source, Symbol], Callable]] = []
|
| 201 |
+
phantom_symbols: Dict[str, Symbol] = {}
|
| 202 |
+
for constraint in constraints:
|
| 203 |
+
torch.export.dynamic_shapes._process_equalities(
|
| 204 |
+
constraint,
|
| 205 |
+
lambda t_id, dim: sources[(t_id, dim)],
|
| 206 |
+
fake_mode.shape_env,
|
| 207 |
+
names,
|
| 208 |
+
source_pairs,
|
| 209 |
+
derived_equalities,
|
| 210 |
+
phantom_symbols,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
equalities_inputs = EqualityConstraint(
|
| 214 |
+
source_pairs=source_pairs,
|
| 215 |
+
derived_equalities=derived_equalities,
|
| 216 |
+
phantom_symbols=list(phantom_symbols.values()),
|
| 217 |
+
warn_only=False,
|
| 218 |
+
)
|
| 219 |
+
return (
|
| 220 |
+
fake_mode,
|
| 221 |
+
fake_args,
|
| 222 |
+
fake_kwargs,
|
| 223 |
+
equalities_inputs,
|
| 224 |
+
original_signature,
|
| 225 |
+
transformed_dynamic_shapes,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _flatten_dynamic_shapes(
|
| 230 |
+
combined_args: Dict[str, Any],
|
| 231 |
+
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]],
|
| 232 |
+
) -> List[Any]:
|
| 233 |
+
flat_shapes = []
|
| 234 |
+
|
| 235 |
+
def _tree_map_helper(path, t, shape):
|
| 236 |
+
nonlocal flat_shapes
|
| 237 |
+
flat_shapes.append(shape)
|
| 238 |
+
|
| 239 |
+
_tree_map_with_path(_tree_map_helper, combined_args, dynamic_shapes)
|
| 240 |
+
return flat_shapes
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def produce_guards_and_solve_constraints(
|
| 244 |
+
fake_mode: FakeTensorMode,
|
| 245 |
+
gm: torch.fx.GraphModule,
|
| 246 |
+
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
|
| 247 |
+
equalities_inputs: EqualityConstraint,
|
| 248 |
+
original_signature: inspect.Signature,
|
| 249 |
+
_is_torch_jit_trace=False,
|
| 250 |
+
):
|
| 251 |
+
"""
|
| 252 |
+
Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions,
|
| 253 |
+
and a graph module, produce guards on the fake mode's shape env (raising constraint
|
| 254 |
+
violations if any), solve (to suggest simplifications or fixes).
|
| 255 |
+
Dynamo already performs this, so this is for non-strict mode.
|
| 256 |
+
|
| 257 |
+
Additional inputs:
|
| 258 |
+
equalities_inputs: the equality constraints to use for guards
|
| 259 |
+
original_signature: the signature of the forward method
|
| 260 |
+
"""
|
| 261 |
+
shape_env = fake_mode.shape_env
|
| 262 |
+
assert shape_env is not None
|
| 263 |
+
assert shape_env.tracked_fakes is not None
|
| 264 |
+
|
| 265 |
+
placeholders = [tf.fake for tf in shape_env.tracked_fakes]
|
| 266 |
+
sources = [tf.source for tf in shape_env.tracked_fakes]
|
| 267 |
+
input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes]
|
| 268 |
+
constraint_violation_error = None
|
| 269 |
+
try:
|
| 270 |
+
shape_env.produce_guards(
|
| 271 |
+
placeholders,
|
| 272 |
+
sources,
|
| 273 |
+
input_contexts=input_contexts,
|
| 274 |
+
equalities_inputs=equalities_inputs,
|
| 275 |
+
ignore_static=False,
|
| 276 |
+
)
|
| 277 |
+
except ConstraintViolationError as e:
|
| 278 |
+
constraint_violation_error = e
|
| 279 |
+
|
| 280 |
+
shape_env.frozen = True
|
| 281 |
+
dim_constraints = shape_env.dim_constraints
|
| 282 |
+
if dim_constraints is None:
|
| 283 |
+
# Expected when shape_env.produce_guards throws an early constraint violation error.
|
| 284 |
+
# There is nothing to solve for in this case.
|
| 285 |
+
# TODO(avik): Maybe record the constraint violation error instead and replay later?
|
| 286 |
+
assert constraint_violation_error
|
| 287 |
+
raise constraint_violation_error
|
| 288 |
+
dim_constraints.solve()
|
| 289 |
+
forced_specializations = dim_constraints.forced_specializations()
|
| 290 |
+
if not _is_torch_jit_trace:
|
| 291 |
+
msg = dim_constraints.prettify_results(
|
| 292 |
+
original_signature,
|
| 293 |
+
dynamic_shapes,
|
| 294 |
+
constraint_violation_error,
|
| 295 |
+
forced_specializations,
|
| 296 |
+
)
|
| 297 |
+
else:
|
| 298 |
+
# FIXME(ycao): This is a hack to get around missing signature from ScriptMethod
|
| 299 |
+
msg = "dummy constraint violation message"
|
| 300 |
+
if constraint_violation_error:
|
| 301 |
+
constraint_violation_error.args = (constraint_violation_error.args[0] + msg,)
|
| 302 |
+
elif forced_specializations:
|
| 303 |
+
constraint_violation_error = ConstraintViolationError(msg)
|
| 304 |
+
if constraint_violation_error:
|
| 305 |
+
raise constraint_violation_error
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def make_constraints(
|
| 309 |
+
fake_mode: FakeTensorMode,
|
| 310 |
+
gm: torch.fx.GraphModule,
|
| 311 |
+
combined_args: Dict[str, Any],
|
| 312 |
+
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
|
| 313 |
+
num_lifted_inputs: int,
|
| 314 |
+
):
|
| 315 |
+
"""
|
| 316 |
+
Given a fake mode's shape env and user-specified dynamic shapes,
|
| 317 |
+
return the resulting range constraints and equality constraints.
|
| 318 |
+
|
| 319 |
+
Additional args:
|
| 320 |
+
num_lifted_inputs: the number of non-user-input placeholder nodes in the graph
|
| 321 |
+
(used only to enumerate the user-input nodes)
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
shape_env = fake_mode.shape_env
|
| 325 |
+
assert shape_env is not None
|
| 326 |
+
inline_constraints = gm.meta.get("inline_constraints", [])
|
| 327 |
+
range_constraints = {
|
| 328 |
+
symbol: inline_constraints[symbol] for symbol in inline_constraints
|
| 329 |
+
}
|
| 330 |
+
if not dynamic_shapes:
|
| 331 |
+
return range_constraints
|
| 332 |
+
|
| 333 |
+
# get individual dynamic shapes spec for each input
|
| 334 |
+
if not isinstance(dynamic_shapes, dict):
|
| 335 |
+
assert isinstance(dynamic_shapes, (tuple, list))
|
| 336 |
+
combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
|
| 337 |
+
flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes)
|
| 338 |
+
|
| 339 |
+
# check number of shapes vs. number of inputs
|
| 340 |
+
num_placeholders = [node.op == "placeholder" for node in gm.graph.nodes].count(True)
|
| 341 |
+
assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs
|
| 342 |
+
|
| 343 |
+
input_dims = defaultdict(list)
|
| 344 |
+
free_symbols = set()
|
| 345 |
+
for input_index, node in enumerate(gm.graph.nodes):
|
| 346 |
+
if input_index < num_lifted_inputs or node.op != "placeholder":
|
| 347 |
+
continue
|
| 348 |
+
if _is_constant_argument(node.meta["val"]) or isinstance(
|
| 349 |
+
node.meta["val"], CustomObjArgument
|
| 350 |
+
):
|
| 351 |
+
continue
|
| 352 |
+
shape_spec = flat_dynamic_shapes[input_index - num_lifted_inputs]
|
| 353 |
+
for i, d in enumerate(node.meta["val"].shape):
|
| 354 |
+
if isinstance(d, torch.SymInt) and not d.node.expr.is_number:
|
| 355 |
+
# Look up the range constraint for the symbol corresponding to this shape dimension
|
| 356 |
+
# and store it indexed by the symbolic expression corresponding to it.
|
| 357 |
+
# NOTE(avik): Use node._expr instead of node.expr for the lookup here because
|
| 358 |
+
# we want the symbol, not its replacement, which could be an expression. Maybe
|
| 359 |
+
# there's a better way to do this, e.g., by (re)computing value ranges for expressions?
|
| 360 |
+
dim = shape_spec[i] if shape_spec else None
|
| 361 |
+
if dim is None or isinstance(dim, _DimHint):
|
| 362 |
+
range_constraints[d.node.expr] = shape_env.var_to_range[
|
| 363 |
+
d.node._expr
|
| 364 |
+
]
|
| 365 |
+
else:
|
| 366 |
+
range_constraints[d.node.expr] = ValueRanges(
|
| 367 |
+
lower=dim.min, upper=dim.max
|
| 368 |
+
)
|
| 369 |
+
input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i))
|
| 370 |
+
free_symbols.update(d.node.expr.free_symbols)
|
| 371 |
+
|
| 372 |
+
for symbol in free_symbols:
|
| 373 |
+
if symbol not in range_constraints:
|
| 374 |
+
# Placeholders can have symbolic shapes that are derived expressions.
|
| 375 |
+
# The above code will record direct range constraints for them
|
| 376 |
+
# so that we can do runtime assertions. In addition, for serde checks
|
| 377 |
+
# we want to record range constraints for their root symbols.
|
| 378 |
+
range_constraints[symbol] = shape_env.var_to_range[symbol]
|
| 379 |
+
|
| 380 |
+
return range_constraints
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap:
|
| 384 |
+
"""Search the module hierarchy, gathering up all tensor and ScriptObject constants.
|
| 385 |
+
|
| 386 |
+
Returns a dictionary mapping hash(value) to the name of the constant. We
|
| 387 |
+
have to abuse `hash` here unfortunately, see: [ScriptObject hash].
|
| 388 |
+
"""
|
| 389 |
+
constants = ConstantAttrMap()
|
| 390 |
+
buffers_parameters = set(m.buffers())
|
| 391 |
+
buffers_parameters.update(m.parameters())
|
| 392 |
+
|
| 393 |
+
def inner(m: torch.nn.Module, prefix_atoms: List[str], constants):
|
| 394 |
+
for k, v in m.__dict__.items():
|
| 395 |
+
if isinstance(
|
| 396 |
+
v,
|
| 397 |
+
(
|
| 398 |
+
torch.Tensor,
|
| 399 |
+
torch.ScriptObject,
|
| 400 |
+
FakeScriptObject,
|
| 401 |
+
),
|
| 402 |
+
):
|
| 403 |
+
if v in buffers_parameters:
|
| 404 |
+
# filter out buffers and parameters, leaving only constants
|
| 405 |
+
continue
|
| 406 |
+
|
| 407 |
+
fqn = ".".join(prefix_atoms + [k])
|
| 408 |
+
constants.add(v, fqn)
|
| 409 |
+
for k, v in m.named_children():
|
| 410 |
+
inner(v, prefix_atoms + [k], constants)
|
| 411 |
+
|
| 412 |
+
inner(m, [], constants)
|
| 413 |
+
return constants
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
@contextlib.contextmanager
|
| 417 |
+
def _fakify_script_objects(
|
| 418 |
+
mod: torch.nn.Module,
|
| 419 |
+
args: Tuple[Any],
|
| 420 |
+
kwargs: Dict[Any, Any],
|
| 421 |
+
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
|
| 422 |
+
):
|
| 423 |
+
# This context manager is used to fakify script objects into FakeScriptObject.
|
| 424 |
+
# Inputs:
|
| 425 |
+
# mod: the module to be exported, it (and its recursive submodules)'s script object attrs haven't been fakified.
|
| 426 |
+
# args, kwargs: the args and kwargs inputs for mod, script object inputs haven't been fakified.
|
| 427 |
+
# fake_mode: the fake mode to be used for fakifying script objects. It's the same mode that fakify input tensors.
|
| 428 |
+
#
|
| 429 |
+
# Returns:
|
| 430 |
+
# mod: the patched module, its (and its recursive submodules) script object attrs have been fakified.
|
| 431 |
+
# fake_args, fake_kwargs: new fakified args and kwargs.
|
| 432 |
+
# Script object inputs have been fakified. Don't touch the tensors.
|
| 433 |
+
# fake_constant_attrs: a new map from FakeScriptObject to the fqn of the original script object.
|
| 434 |
+
# fake_to_real: a mapping between FakeScriptObject and the original script object in order to un-do the patching.
|
| 435 |
+
|
| 436 |
+
constant_attrs: ConstantAttrMap = _gather_constant_attrs(mod)
|
| 437 |
+
assert not any(
|
| 438 |
+
isinstance(obj, FakeScriptObject) for obj in constant_attrs.values()
|
| 439 |
+
), "Mod shouldn't contain any FakeScriptObject."
|
| 440 |
+
assert not pytree.tree_any(
|
| 441 |
+
lambda obj: isinstance(obj, FakeScriptObject), (args, kwargs)
|
| 442 |
+
), "args and kwargs shouldn't contain any FakeScriptObject."
|
| 443 |
+
|
| 444 |
+
patched_attr = {}
|
| 445 |
+
fake_constant_attrs = ConstantAttrMap()
|
| 446 |
+
fake_to_real = {}
|
| 447 |
+
|
| 448 |
+
def _maybe_fakify_obj(obj):
|
| 449 |
+
fake_obj = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, obj)
|
| 450 |
+
fake_to_real[fake_obj] = obj
|
| 451 |
+
return fake_obj
|
| 452 |
+
|
| 453 |
+
def _leaf_mod_and_attr(
|
| 454 |
+
mod: torch.nn.Module, attr_fqn: str
|
| 455 |
+
) -> Tuple[torch.nn.Module, str]:
|
| 456 |
+
*prefix_attr, last_attr = attr_fqn.split(".")
|
| 457 |
+
cur_mod = mod
|
| 458 |
+
for attr in prefix_attr:
|
| 459 |
+
cur_mod = getattr(cur_mod, attr)
|
| 460 |
+
return cur_mod, last_attr
|
| 461 |
+
|
| 462 |
+
try:
|
| 463 |
+
for obj, fqns in constant_attrs.items():
|
| 464 |
+
if isinstance(obj, torch.ScriptObject):
|
| 465 |
+
fake_script_obj = _maybe_fakify_obj(obj)
|
| 466 |
+
for fqn in fqns:
|
| 467 |
+
cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
|
| 468 |
+
assert obj is getattr(cur_mod, attr)
|
| 469 |
+
setattr(cur_mod, attr, fake_script_obj)
|
| 470 |
+
fake_constant_attrs.add(fake_script_obj, fqn)
|
| 471 |
+
patched_attr[fqn] = obj
|
| 472 |
+
else:
|
| 473 |
+
for fqn in fqns:
|
| 474 |
+
fake_constant_attrs.add(obj, fqn)
|
| 475 |
+
|
| 476 |
+
fake_args, fake_kwargs = pytree.tree_map_only(
|
| 477 |
+
torch.ScriptObject, _maybe_fakify_obj, (args, kwargs)
|
| 478 |
+
)
|
| 479 |
+
yield (mod, fake_args, fake_kwargs, fake_constant_attrs, fake_to_real)
|
| 480 |
+
finally:
|
| 481 |
+
for fqn, orig_obj in patched_attr.items():
|
| 482 |
+
cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
|
| 483 |
+
setattr(cur_mod, attr, orig_obj)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
|
| 487 |
+
"""
|
| 488 |
+
1. Handles data-dependent errors raised by torch function calls in non-strict.
|
| 489 |
+
|
| 490 |
+
Any data-dependent error is due to some condition on unbacked symints
|
| 491 |
+
that cannot be resolved. A mechanical way of fixing the error is to use
|
| 492 |
+
a torch._check() call to assert either that condition or its negation.
|
| 493 |
+
The handler suggests these options as code and points to the location
|
| 494 |
+
of the torch function call that raised the error as part of the error
|
| 495 |
+
message shown to the user, who can then simply select and copy-paste
|
| 496 |
+
a suggested fix at that location.
|
| 497 |
+
|
| 498 |
+
NOTE: Not all data-dependent errors are raised by torch function calls.
|
| 499 |
+
In particular, conditions on unbacked symints can appear outside such
|
| 500 |
+
calls, and as such are not handled here.
|
| 501 |
+
|
| 502 |
+
2. Handles line-of-code logging for each torch function call in non-strict.
|
| 503 |
+
|
| 504 |
+
Usage: TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC=1 TORCH_LOGS="+export" ...
|
| 505 |
+
"""
|
| 506 |
+
|
| 507 |
+
def __torch_function__(self, func, types, args=(), kwargs=None):
|
| 508 |
+
kwargs = kwargs or {}
|
| 509 |
+
if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc:
|
| 510 |
+
frame = _find_user_code_frame()
|
| 511 |
+
if frame is not None:
|
| 512 |
+
log.debug(
|
| 513 |
+
"%s called at %s:%s in %s",
|
| 514 |
+
func.__qualname__,
|
| 515 |
+
frame.f_code.co_filename,
|
| 516 |
+
frame.f_lineno,
|
| 517 |
+
frame.f_code.co_name,
|
| 518 |
+
)
|
| 519 |
+
try:
|
| 520 |
+
return func(*args, **kwargs)
|
| 521 |
+
except GuardOnDataDependentSymNode as e:
|
| 522 |
+
_suggest_fixes_for_data_dependent_error_non_strict(e)
|
| 523 |
+
raise
|
.venv/lib/python3.11/site-packages/torch/_export/pass_base.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import operator
|
| 3 |
+
import traceback
|
| 4 |
+
import typing
|
| 5 |
+
from contextlib import nullcontext
|
| 6 |
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from functorch.experimental.control_flow import _unstack_pytree
|
| 10 |
+
from torch import fx
|
| 11 |
+
from torch._dispatch.python import enable_python_dispatcher
|
| 12 |
+
from torch._export.pass_infra.node_metadata import NodeMetadata
|
| 13 |
+
from torch._export.pass_infra.proxy_value import ProxyValue
|
| 14 |
+
from torch._subclasses import FakeTensor, UnsupportedFakeTensorException
|
| 15 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 16 |
+
from torch.fx import traceback as fx_traceback
|
| 17 |
+
from torch.fx.experimental.proxy_tensor import PythonKeyTracer
|
| 18 |
+
from torch.fx.graph import CodeGen
|
| 19 |
+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
| 20 |
+
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
|
| 21 |
+
from torch.utils import _pytree as pytree
|
| 22 |
+
from torch.fx.experimental.symbolic_shapes import PropagateUnbackedSymInts, compute_unbacked_bindings
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
__all__ = ["_ExportPassBaseDeprecatedDoNotUse"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
Argument = Any
|
| 29 |
+
Value = Any
|
| 30 |
+
Fn = Callable[..., Any]
|
| 31 |
+
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
_TORCH_SYM_OPS: Set[Callable] = {
|
| 35 |
+
torch.sym_int,
|
| 36 |
+
torch.sym_float,
|
| 37 |
+
torch.sym_ite,
|
| 38 |
+
torch.sym_max,
|
| 39 |
+
torch.sym_min,
|
| 40 |
+
torch.sym_not,
|
| 41 |
+
torch.sym_sqrt,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ExportPassBaseError(RuntimeError):
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
| 50 |
+
"""
|
| 51 |
+
Interpreter-based pass class to help users maintain the IR spec while writing
|
| 52 |
+
transformations.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def _create_dummy_node_metadata():
|
| 57 |
+
return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ExportTracer(PythonKeyTracer):
|
| 61 |
+
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None:
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.callback = callback
|
| 64 |
+
self.root = torch.nn.Module()
|
| 65 |
+
self.graph = torch.fx.Graph()
|
| 66 |
+
self.graph.set_codegen(codegen)
|
| 67 |
+
self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment]
|
| 68 |
+
self.fake_tensor_mode: Optional[FakeTensorMode] = None
|
| 69 |
+
self.submodules: Dict[torch.nn.Module, str] = {}
|
| 70 |
+
|
| 71 |
+
def trace(self) -> None: # type: ignore[override]
|
| 72 |
+
raise ExportPassBaseError("ExportTracer doesn't support trace().")
|
| 73 |
+
|
| 74 |
+
def create_arg(self, a: Argument) -> torch.fx.Node:
|
| 75 |
+
if isinstance(a, torch.nn.Module):
|
| 76 |
+
if a not in self.submodules:
|
| 77 |
+
name_submodule = f"submodule_{len(self.submodules)}"
|
| 78 |
+
self.root.add_module(name_submodule, a)
|
| 79 |
+
self.submodules[a] = name_submodule
|
| 80 |
+
elif isinstance(a, FakeTensor):
|
| 81 |
+
if not hasattr(a, "constant") or a.constant is None:
|
| 82 |
+
raise ExportPassBaseError(f"Cannot add {a} to graph.")
|
| 83 |
+
a = a.constant
|
| 84 |
+
node = super().create_arg(a)
|
| 85 |
+
if (
|
| 86 |
+
isinstance(a, torch.Tensor)
|
| 87 |
+
and isinstance(node, torch.fx.Node)
|
| 88 |
+
and node.op == "get_attr"
|
| 89 |
+
):
|
| 90 |
+
self.set_metadata(node, a)
|
| 91 |
+
self.callback.on_attr(ProxyValue(a, node))
|
| 92 |
+
return node
|
| 93 |
+
|
| 94 |
+
def set_metadata(
|
| 95 |
+
self, node: torch.fx.Node, value: Argument,
|
| 96 |
+
) -> None:
|
| 97 |
+
# propagate the fake tensor or sym nodes
|
| 98 |
+
def make_val(
|
| 99 |
+
x: Argument,
|
| 100 |
+
) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]:
|
| 101 |
+
if isinstance(x, FakeTensor):
|
| 102 |
+
return x
|
| 103 |
+
elif isinstance(x, torch.Tensor):
|
| 104 |
+
if x.is_quantized:
|
| 105 |
+
# TODO (tmanlaibaatar) properly support Quantized FakeTensor
|
| 106 |
+
x = torch.dequantize(x)
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
assert self.fake_tensor_mode is not None
|
| 110 |
+
# TODO we should allocate static shapes
|
| 111 |
+
# for param/buffer values
|
| 112 |
+
if isinstance(x, torch.nn.Parameter):
|
| 113 |
+
fake_tensor = self.fake_tensor_mode.from_tensor(
|
| 114 |
+
x, static_shapes=True
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
fake_tensor = self.fake_tensor_mode.from_tensor(x)
|
| 118 |
+
except UnsupportedFakeTensorException:
|
| 119 |
+
# TODO: This is just a workaround to get over the
|
| 120 |
+
# x.as_subclass error
|
| 121 |
+
print(
|
| 122 |
+
"Fakeifying a Tensor subclass is not supported \
|
| 123 |
+
right now. Instead a TensorMetadata is used."
|
| 124 |
+
)
|
| 125 |
+
fake_tensor = None
|
| 126 |
+
return fake_tensor
|
| 127 |
+
elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)):
|
| 128 |
+
return x
|
| 129 |
+
else:
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
node.meta["val"] = pytree.tree_map(make_val, value)
|
| 133 |
+
|
| 134 |
+
# Set the tensor_metadata for values that do not have a corresponding FakeTensor
|
| 135 |
+
def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]:
|
| 136 |
+
if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor):
|
| 137 |
+
if x.is_quantized:
|
| 138 |
+
# TODO (tmanlaibaatar) properly support Quantized FakeTensor
|
| 139 |
+
x = torch.dequantize(x)
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
assert self.fake_tensor_mode is not None
|
| 143 |
+
_ = self.fake_tensor_mode.from_tensor(x)
|
| 144 |
+
tensor_meta = None
|
| 145 |
+
except UnsupportedFakeTensorException:
|
| 146 |
+
# TODO: This is just a workaround to get over the
|
| 147 |
+
# x.as_subclass error
|
| 148 |
+
tensor_meta = _extract_tensor_metadata(x)
|
| 149 |
+
return tensor_meta
|
| 150 |
+
else:
|
| 151 |
+
return None
|
| 152 |
+
|
| 153 |
+
node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value)
|
| 154 |
+
|
| 155 |
+
class ExportInterpreter(fx.Interpreter):
|
| 156 |
+
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None:
|
| 157 |
+
super().__init__(gm)
|
| 158 |
+
self.callback = callback
|
| 159 |
+
self.node: torch.fx.Node = next(iter(gm.graph.nodes))
|
| 160 |
+
|
| 161 |
+
def placeholder(
|
| 162 |
+
self,
|
| 163 |
+
target: str, # type: ignore[override]
|
| 164 |
+
args: Tuple[Argument, ...],
|
| 165 |
+
kwargs: Dict[str, Argument],
|
| 166 |
+
) -> ProxyValue:
|
| 167 |
+
arg = super().placeholder(target, args, kwargs)
|
| 168 |
+
return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta))
|
| 169 |
+
|
| 170 |
+
def output(
|
| 171 |
+
self,
|
| 172 |
+
target: torch.fx.node.Target,
|
| 173 |
+
args: Tuple[Argument, ...],
|
| 174 |
+
kwargs: Dict[str, Argument],
|
| 175 |
+
) -> ProxyValue:
|
| 176 |
+
return self.callback.output(args[0], NodeMetadata(self.node.meta)).data
|
| 177 |
+
|
| 178 |
+
def call_function(
|
| 179 |
+
self,
|
| 180 |
+
target: torch.fx.node.Target,
|
| 181 |
+
args: Tuple[Argument, ...],
|
| 182 |
+
kwargs: Dict[str, Argument],
|
| 183 |
+
) -> ProxyValue:
|
| 184 |
+
meta = NodeMetadata(self.node.meta)
|
| 185 |
+
|
| 186 |
+
if target == operator.getitem:
|
| 187 |
+
value, key = args
|
| 188 |
+
return self.callback.call_getitem(value, key, meta)
|
| 189 |
+
elif getattr(target, "__module__", None) in {"_operator", "math"}:
|
| 190 |
+
assert callable(target)
|
| 191 |
+
return self.callback.call_sym(target, args, meta)
|
| 192 |
+
elif target in _TORCH_SYM_OPS:
|
| 193 |
+
assert callable(target)
|
| 194 |
+
return self.callback.call_sym(target, args, meta)
|
| 195 |
+
elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
|
| 196 |
+
return self.callback.call_operator(
|
| 197 |
+
target,
|
| 198 |
+
args,
|
| 199 |
+
kwargs,
|
| 200 |
+
meta,
|
| 201 |
+
)
|
| 202 |
+
elif target == torch.ops.higher_order.cond:
|
| 203 |
+
pred, true_fn, false_fn, inputs = args
|
| 204 |
+
return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
|
| 205 |
+
elif target == torch.ops.higher_order.map_impl:
|
| 206 |
+
f, mapped_args, operands = args # type: ignore[assignment]
|
| 207 |
+
return self.callback.call_map(f, mapped_args, operands, meta)
|
| 208 |
+
# For other unregistered HigherOrderOps, just interpret them blindly
|
| 209 |
+
elif isinstance(target, torch._ops.HigherOrderOperator):
|
| 210 |
+
return self.callback._fx(
|
| 211 |
+
"call_function",
|
| 212 |
+
target,
|
| 213 |
+
args,
|
| 214 |
+
kwargs,
|
| 215 |
+
meta,
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
raise ExportPassBaseError(f"Unsupported target type: {target}")
|
| 219 |
+
|
| 220 |
+
def get_attr(
|
| 221 |
+
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
|
| 222 |
+
) -> Argument:
|
| 223 |
+
return super().get_attr(target, args, kwargs)
|
| 224 |
+
|
| 225 |
+
def call_module(
|
| 226 |
+
self,
|
| 227 |
+
target: torch.fx.node.Target,
|
| 228 |
+
args: Tuple[Argument, ...],
|
| 229 |
+
kwargs: Dict[str, Argument],
|
| 230 |
+
) -> None:
|
| 231 |
+
raise ExportPassBaseError("call_module is not supported.")
|
| 232 |
+
|
| 233 |
+
def call_method(
|
| 234 |
+
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
|
| 235 |
+
) -> None:
|
| 236 |
+
raise ExportPassBaseError("call_method is not supported.")
|
| 237 |
+
|
| 238 |
+
def run_node(self, n: torch.fx.Node) -> Argument:
|
| 239 |
+
self.node = n
|
| 240 |
+
self.callback.node_debug_str = n.format_node()
|
| 241 |
+
return super().run_node(n)
|
| 242 |
+
|
| 243 |
+
def __init__(self) -> None:
|
| 244 |
+
self.interpreter = PropagateUnbackedSymInts(
|
| 245 |
+
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
| 246 |
+
)
|
| 247 |
+
self.tracer = self.ExportTracer(self, CodeGen())
|
| 248 |
+
self.fake_tensor_mode: Optional[FakeTensorMode] = None
|
| 249 |
+
self._initialized = True
|
| 250 |
+
self.node_debug_str: typing.Optional[str] = None
|
| 251 |
+
|
| 252 |
+
def _fx(
|
| 253 |
+
self,
|
| 254 |
+
kind: str,
|
| 255 |
+
target: torch.fx.node.Target,
|
| 256 |
+
args: Tuple[Argument, ...],
|
| 257 |
+
kwargs: Dict[str, Argument],
|
| 258 |
+
meta: NodeMetadata,
|
| 259 |
+
) -> ProxyValue:
|
| 260 |
+
args_data, kwargs_data = pytree.tree_map_only(
|
| 261 |
+
ProxyValue, lambda x: x.data, (args, kwargs)
|
| 262 |
+
)
|
| 263 |
+
res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data)
|
| 264 |
+
args_proxy, kwargs_proxy = pytree.tree_map_only(
|
| 265 |
+
ProxyValue, lambda x: x.proxy, (args, kwargs)
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
name = None
|
| 269 |
+
if isinstance(target, torch._ops.OpOverload):
|
| 270 |
+
name = self.tracer.graph._target_to_str(target.overloadpacket.__name__)
|
| 271 |
+
|
| 272 |
+
res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name)
|
| 273 |
+
res_proxy.node.meta.update(meta.data)
|
| 274 |
+
if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env):
|
| 275 |
+
if symbol_to_path := compute_unbacked_bindings(shape_env, res_data):
|
| 276 |
+
res_proxy.node.meta["unbacked_bindings"] = symbol_to_path
|
| 277 |
+
self.tracer.set_metadata(res_proxy.node, res_data)
|
| 278 |
+
return ProxyValue(res_data, res_proxy)
|
| 279 |
+
|
| 280 |
+
def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]:
|
| 281 |
+
# TODO(angelayi): Update this with what we decide to do for metadata in
|
| 282 |
+
# the exported graph module
|
| 283 |
+
if (args := graph_module.meta.get("args", None)) is not None:
|
| 284 |
+
return list(args)
|
| 285 |
+
|
| 286 |
+
def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
|
| 287 |
+
if "val" in node.meta:
|
| 288 |
+
fake = node.meta["val"]
|
| 289 |
+
if hasattr(fake, "constant") and fake.constant is not None:
|
| 290 |
+
return fake.constant
|
| 291 |
+
return fake
|
| 292 |
+
elif tensor_meta := node.meta.get("tensor_meta"):
|
| 293 |
+
assert self.fake_tensor_mode is not None
|
| 294 |
+
return FakeTensor(
|
| 295 |
+
self.fake_tensor_mode,
|
| 296 |
+
torch.empty(
|
| 297 |
+
tensor_meta.shape,
|
| 298 |
+
dtype=tensor_meta.dtype,
|
| 299 |
+
device="meta",
|
| 300 |
+
requires_grad=tensor_meta.requires_grad,
|
| 301 |
+
memory_format=tensor_meta.memory_format,
|
| 302 |
+
),
|
| 303 |
+
torch.device("cpu"),
|
| 304 |
+
)
|
| 305 |
+
elif len(node.users) == 0:
|
| 306 |
+
return None
|
| 307 |
+
raise ExportPassBaseError(
|
| 308 |
+
f"Cannot construct an input for graph module: {graph_module}.",
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
return [
|
| 312 |
+
extract_input(node)
|
| 313 |
+
for node in graph_module.graph.nodes
|
| 314 |
+
if node.op == "placeholder"
|
| 315 |
+
]
|
| 316 |
+
|
| 317 |
+
def on_attr(self, attr: ProxyValue) -> None:
|
| 318 |
+
pass
|
| 319 |
+
|
| 320 |
+
def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue:
|
| 321 |
+
arg_proxy = self.tracer.create_proxy("placeholder", name, (), {})
|
| 322 |
+
arg_proxy.node.meta = meta.data
|
| 323 |
+
self.tracer.set_metadata(arg_proxy.node, arg)
|
| 324 |
+
return ProxyValue(arg, arg_proxy)
|
| 325 |
+
|
| 326 |
+
def call_operator(
|
| 327 |
+
self,
|
| 328 |
+
op,
|
| 329 |
+
args: Tuple[Argument, ...],
|
| 330 |
+
kwargs: Dict[str, Argument],
|
| 331 |
+
meta: NodeMetadata,
|
| 332 |
+
) -> ProxyValue:
|
| 333 |
+
return self._fx("call_function", op, args, kwargs, meta)
|
| 334 |
+
|
| 335 |
+
def call_sym(
|
| 336 |
+
self,
|
| 337 |
+
target: Fn,
|
| 338 |
+
args: Tuple[Argument, ...],
|
| 339 |
+
meta: NodeMetadata,
|
| 340 |
+
) -> ProxyValue:
|
| 341 |
+
return self._fx("call_function", target, args, {}, meta)
|
| 342 |
+
|
| 343 |
+
def call_cond(
|
| 344 |
+
self,
|
| 345 |
+
pred: ProxyValue,
|
| 346 |
+
true_fn: torch.fx.GraphModule,
|
| 347 |
+
false_fn: torch.fx.GraphModule,
|
| 348 |
+
inputs: List[Argument],
|
| 349 |
+
meta: NodeMetadata,
|
| 350 |
+
) -> ProxyValue:
|
| 351 |
+
true_branch = self.call_submodule(true_fn, tuple(inputs))
|
| 352 |
+
false_branch = self.call_submodule(false_fn, tuple(inputs))
|
| 353 |
+
assert true_branch is not None
|
| 354 |
+
assert false_branch is not None
|
| 355 |
+
return self._fx(
|
| 356 |
+
"call_function",
|
| 357 |
+
torch.ops.higher_order.cond,
|
| 358 |
+
(pred, true_branch.graph_module, false_branch.graph_module, list(inputs)),
|
| 359 |
+
{},
|
| 360 |
+
meta,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
def call_map(
|
| 364 |
+
self,
|
| 365 |
+
f: torch.fx.GraphModule,
|
| 366 |
+
mapped_args: List[ProxyValue],
|
| 367 |
+
operands: List[ProxyValue],
|
| 368 |
+
meta: NodeMetadata,
|
| 369 |
+
) -> ProxyValue:
|
| 370 |
+
xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
|
| 371 |
+
f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands]))
|
| 372 |
+
assert f_branch is not None
|
| 373 |
+
return self._fx(
|
| 374 |
+
"call_function",
|
| 375 |
+
torch.ops.higher_order.map_impl,
|
| 376 |
+
(f_branch.graph_module, mapped_args, operands),
|
| 377 |
+
{},
|
| 378 |
+
meta,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
def call_getitem(
|
| 382 |
+
self, value: ProxyValue, key: int, meta: NodeMetadata
|
| 383 |
+
) -> ProxyValue:
|
| 384 |
+
return self._fx("call_function", operator.getitem, (value, key), {}, meta)
|
| 385 |
+
|
| 386 |
+
def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
|
| 387 |
+
return self._fx("output", "output", (results,), {}, meta)
|
| 388 |
+
|
| 389 |
+
def call_submodule(
|
| 390 |
+
self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...]
|
| 391 |
+
) -> PassResult:
|
| 392 |
+
prev_tracer, self.tracer = self.tracer, self.ExportTracer(
|
| 393 |
+
self, graph_module.graph._codegen
|
| 394 |
+
)
|
| 395 |
+
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
|
| 396 |
+
interpreter = self.ExportInterpreter(self, graph_module)
|
| 397 |
+
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment]
|
| 398 |
+
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
| 399 |
+
)
|
| 400 |
+
inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
|
| 401 |
+
with fx_traceback.preserve_node_meta():
|
| 402 |
+
interpreter.run(*inputs_data)
|
| 403 |
+
|
| 404 |
+
new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
|
| 405 |
+
|
| 406 |
+
self.tracer = prev_tracer
|
| 407 |
+
self.interpreter = prev_interpreter
|
| 408 |
+
return PassResult(
|
| 409 |
+
new_graph_module,
|
| 410 |
+
True,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
def call(self, graph_module: fx.GraphModule) -> PassResult:
|
| 414 |
+
if not getattr(self, "_initialized", False):
|
| 415 |
+
raise ExportPassBaseError(
|
| 416 |
+
"ExportPass is not initialized with __init__().",
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
inputs = self.inputs(graph_module)
|
| 420 |
+
|
| 421 |
+
fake_tensor_mode = None
|
| 422 |
+
for i in inputs:
|
| 423 |
+
if isinstance(i, FakeTensor):
|
| 424 |
+
assert (
|
| 425 |
+
fake_tensor_mode is None or fake_tensor_mode is i.fake_mode
|
| 426 |
+
), "Multiple fake tensor mode detected."
|
| 427 |
+
fake_tensor_mode = i.fake_mode
|
| 428 |
+
if fake_tensor_mode is None:
|
| 429 |
+
self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
|
| 430 |
+
fake_tensor_mode = nullcontext() # type: ignore[assignment]
|
| 431 |
+
dispatcher_mode = nullcontext() # type: ignore[assignment]
|
| 432 |
+
else:
|
| 433 |
+
fake_tensor_mode.allow_non_fake_inputs = True
|
| 434 |
+
self.tracer.fake_tensor_mode = fake_tensor_mode
|
| 435 |
+
dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment]
|
| 436 |
+
self.fake_tensor_mode = self.tracer.fake_tensor_mode
|
| 437 |
+
|
| 438 |
+
with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr]
|
| 439 |
+
result = self.call_submodule(graph_module, tuple(inputs))
|
| 440 |
+
|
| 441 |
+
return result
|
.venv/lib/python3.11/site-packages/torch/_export/tools.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import logging
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.export
|
| 8 |
+
import torch.export._trace
|
| 9 |
+
from torch._utils_internal import log_export_usage
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
log = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
__all__ = ["report_exportability"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _generate_inputs_for_submodules(
|
| 18 |
+
model: torch.nn.Module,
|
| 19 |
+
target_submodules: Iterable[str],
|
| 20 |
+
args: Tuple[Any, ...],
|
| 21 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
| 22 |
+
) -> Dict[str, Tuple[Any, Any]]:
|
| 23 |
+
"""
|
| 24 |
+
Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
|
| 25 |
+
function doesn't work.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
model: root model.
|
| 29 |
+
inputs: inputs to the root model.
|
| 30 |
+
target_submodules: submodules that we want to generate inputs for.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
A dict that maps from submodule name to its inputs.
|
| 34 |
+
"""
|
| 35 |
+
kwargs = kwargs or {}
|
| 36 |
+
|
| 37 |
+
handles = []
|
| 38 |
+
results = {}
|
| 39 |
+
submodule_to_names = {mod: name for name, mod in model.named_modules()}
|
| 40 |
+
|
| 41 |
+
def pre_forward(module, module_args, module_kwargs):
|
| 42 |
+
results[submodule_to_names[module]] = (module_args, module_kwargs)
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
for name, mod in model.named_modules():
|
| 46 |
+
if name in target_submodules:
|
| 47 |
+
handles.append(
|
| 48 |
+
mod.register_forward_pre_hook(pre_forward, with_kwargs=True)
|
| 49 |
+
)
|
| 50 |
+
model(*args, **kwargs)
|
| 51 |
+
except Exception as e:
|
| 52 |
+
warnings.warn(
|
| 53 |
+
f"Failed to generate submodule inputs because of the following error:\n{e}"
|
| 54 |
+
)
|
| 55 |
+
finally:
|
| 56 |
+
for h in handles:
|
| 57 |
+
h.remove()
|
| 58 |
+
return results
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def report_exportability(
|
| 62 |
+
mod: torch.nn.Module,
|
| 63 |
+
args: Tuple[Any, ...],
|
| 64 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
| 65 |
+
*,
|
| 66 |
+
strict: bool = True,
|
| 67 |
+
pre_dispatch: bool = False,
|
| 68 |
+
) -> Dict[str, Optional[Exception]]:
|
| 69 |
+
"""
|
| 70 |
+
Report exportability issues for a module in one-shot.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
mod: root module.
|
| 74 |
+
args: args to the root module.
|
| 75 |
+
kwargs: kwargs to the root module.
|
| 76 |
+
Returns:
|
| 77 |
+
A dict that maps from submodule name to the exception that was raised when trying to export it.
|
| 78 |
+
`None` means the module is exportable without issue.
|
| 79 |
+
Sample output:
|
| 80 |
+
{
|
| 81 |
+
'': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
|
| 82 |
+
'submod_1': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
|
| 83 |
+
'submod_2': None
|
| 84 |
+
}
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
log_export_usage(event="export.report_exportability")
|
| 88 |
+
|
| 89 |
+
kwargs = kwargs or {}
|
| 90 |
+
|
| 91 |
+
all_submod_names = [name for name, _ in mod.named_modules() if name != ""]
|
| 92 |
+
submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs)
|
| 93 |
+
|
| 94 |
+
tried_module_types = set()
|
| 95 |
+
report: Dict[str, Optional[Exception]] = {}
|
| 96 |
+
|
| 97 |
+
def try_export(module, module_name, args, kwargs):
|
| 98 |
+
nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types
|
| 99 |
+
|
| 100 |
+
if type(module) in tried_module_types:
|
| 101 |
+
return
|
| 102 |
+
tried_module_types.add(type(module))
|
| 103 |
+
|
| 104 |
+
if args is not None or kwargs is not None:
|
| 105 |
+
try:
|
| 106 |
+
torch.export._trace._export(
|
| 107 |
+
module,
|
| 108 |
+
args,
|
| 109 |
+
kwargs,
|
| 110 |
+
strict=strict,
|
| 111 |
+
pre_dispatch=pre_dispatch,
|
| 112 |
+
)
|
| 113 |
+
report[module_name] = None
|
| 114 |
+
log.info("Successfully exported `%s`", module_name)
|
| 115 |
+
return
|
| 116 |
+
except Exception as e:
|
| 117 |
+
short_msg = repr(e).split("\n")[0]
|
| 118 |
+
log.warning(
|
| 119 |
+
"Failed exporting `%s` with exception: %s", module_name, short_msg
|
| 120 |
+
)
|
| 121 |
+
report[module_name] = e
|
| 122 |
+
|
| 123 |
+
for name, submod in module.named_children():
|
| 124 |
+
sub_module_name = name if module_name == "" else f"{module_name}.{name}"
|
| 125 |
+
|
| 126 |
+
submod_args, submod_kwargs = submod_inputs.get(
|
| 127 |
+
sub_module_name, (None, None)
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
try_export(submod, sub_module_name, submod_args, submod_kwargs)
|
| 131 |
+
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
try_export(mod, "", args, kwargs)
|
| 135 |
+
|
| 136 |
+
unique_issues = set()
|
| 137 |
+
for exception in report.values():
|
| 138 |
+
if exception is not None:
|
| 139 |
+
key = repr(exception).split("\\n")[0]
|
| 140 |
+
unique_issues.add(key)
|
| 141 |
+
|
| 142 |
+
log.warning("Found %d export issues:", len(unique_issues))
|
| 143 |
+
for issue in unique_issues:
|
| 144 |
+
log.warning(issue)
|
| 145 |
+
|
| 146 |
+
return report
|
.venv/lib/python3.11/site-packages/torch/_export/verifier.py
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import inspect
|
| 3 |
+
import math
|
| 4 |
+
import operator
|
| 5 |
+
from collections.abc import Iterable
|
| 6 |
+
from typing import Any, Dict, final, List, Tuple, Type, TYPE_CHECKING
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch._ops import HigherOrderOperator, OpOverload
|
| 10 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 11 |
+
from torch.export.graph_signature import (
|
| 12 |
+
CustomObjArgument,
|
| 13 |
+
InputKind,
|
| 14 |
+
SymIntArgument,
|
| 15 |
+
TensorArgument,
|
| 16 |
+
TokenArgument,
|
| 17 |
+
)
|
| 18 |
+
from torch.fx import GraphModule
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from torch.export.exported_program import ExportedProgram
|
| 22 |
+
|
| 23 |
+
class SpecViolationError(Exception):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def is_functional(op: OpOverload) -> bool:
|
| 28 |
+
return not op._schema.is_mutable
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _check_has_fake_tensor(node: torch.fx.Node) -> None:
|
| 32 |
+
# TODO(angelayi): remove this in favor of _check_val
|
| 33 |
+
return _check_val(node)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _check_val(node: torch.fx.Node) -> None:
|
| 37 |
+
from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt
|
| 38 |
+
|
| 39 |
+
def _check_correct_val(val):
|
| 40 |
+
if val is None:
|
| 41 |
+
return True
|
| 42 |
+
elif isinstance(val, (int, bool, str, float)):
|
| 43 |
+
return True
|
| 44 |
+
elif isinstance(val, (torch.memory_format, torch.dtype, torch.device, torch.layout)):
|
| 45 |
+
return True
|
| 46 |
+
elif isinstance(val, (FakeTensor, torch.Tensor)): # TODO(zhxchen17) Remove Tensor.
|
| 47 |
+
return True
|
| 48 |
+
elif isinstance(val, (SymInt, SymFloat, SymBool)):
|
| 49 |
+
return True
|
| 50 |
+
elif isinstance(val, CustomObjArgument):
|
| 51 |
+
return True
|
| 52 |
+
elif isinstance(val, Iterable):
|
| 53 |
+
return all(_check_correct_val(x) for x in val)
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
def _no_returns(op):
|
| 57 |
+
if not isinstance(op, OpOverload):
|
| 58 |
+
return False
|
| 59 |
+
return len(op._schema.returns) == 0
|
| 60 |
+
|
| 61 |
+
if "val" not in node.meta:
|
| 62 |
+
if node.op == "call_function" and _no_returns(node.target):
|
| 63 |
+
return
|
| 64 |
+
raise SpecViolationError(f"Node.meta {node.name} is missing val field.")
|
| 65 |
+
|
| 66 |
+
val = node.meta["val"]
|
| 67 |
+
if not _check_correct_val(val):
|
| 68 |
+
raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _check_torch_fn(node: torch.fx.Node) -> None:
|
| 72 |
+
torch_fn = node.meta.get("torch_fn")
|
| 73 |
+
if torch_fn is None:
|
| 74 |
+
raise SpecViolationError(f"Unable to find torch_fn metadata for node {node.name}")
|
| 75 |
+
if (
|
| 76 |
+
not isinstance(torch_fn, tuple) and
|
| 77 |
+
isinstance(torch_fn[0], str) and
|
| 78 |
+
isinstance(torch_fn[1], str)
|
| 79 |
+
):
|
| 80 |
+
raise SpecViolationError(f"Node.meta {node.name} has invalid torch_fn field {torch_fn}")
|
| 81 |
+
|
| 82 |
+
class _VerifierMeta(type):
|
| 83 |
+
_registry: Dict[str, Type['Verifier']] = {}
|
| 84 |
+
|
| 85 |
+
def __new__(metacls, name, bases, attrs):
|
| 86 |
+
if bases:
|
| 87 |
+
if "check" in attrs or "_check_graph_module" in attrs:
|
| 88 |
+
raise SyntaxError("Overriding method check is not allowed.")
|
| 89 |
+
assert "dialect" in attrs and attrs["dialect"] != "ATEN"
|
| 90 |
+
else:
|
| 91 |
+
assert "check" in attrs
|
| 92 |
+
assert "_check_graph_module" in attrs
|
| 93 |
+
assert attrs["dialect"] == "ATEN"
|
| 94 |
+
|
| 95 |
+
assert isinstance(attrs["dialect"], str)
|
| 96 |
+
ret = type.__new__(metacls, name, bases, attrs)
|
| 97 |
+
metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment]
|
| 98 |
+
return ret
|
| 99 |
+
|
| 100 |
+
def getattr_recursive(obj: Any, target: str) -> Any:
|
| 101 |
+
target_atoms = target.split('.')
|
| 102 |
+
attr_itr = obj
|
| 103 |
+
for i, atom in enumerate(target_atoms):
|
| 104 |
+
if not hasattr(attr_itr, atom):
|
| 105 |
+
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
|
| 106 |
+
attr_itr = getattr(attr_itr, atom)
|
| 107 |
+
return attr_itr
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class Verifier(metaclass=_VerifierMeta):
|
| 111 |
+
dialect = "ATEN"
|
| 112 |
+
|
| 113 |
+
def allowed_builtin_ops(self) -> List:
|
| 114 |
+
return [
|
| 115 |
+
operator.getitem,
|
| 116 |
+
operator.add,
|
| 117 |
+
operator.mul,
|
| 118 |
+
operator.sub,
|
| 119 |
+
operator.truediv,
|
| 120 |
+
operator.ge,
|
| 121 |
+
operator.le,
|
| 122 |
+
operator.gt,
|
| 123 |
+
operator.lt,
|
| 124 |
+
operator.eq,
|
| 125 |
+
operator.ne,
|
| 126 |
+
operator.floordiv,
|
| 127 |
+
operator.mod,
|
| 128 |
+
operator.and_,
|
| 129 |
+
operator.or_,
|
| 130 |
+
operator.not_,
|
| 131 |
+
operator.pow,
|
| 132 |
+
operator.neg,
|
| 133 |
+
operator.abs,
|
| 134 |
+
math.ceil,
|
| 135 |
+
math.floor,
|
| 136 |
+
math.trunc,
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
def allowed_op_types(self) -> Tuple[Type[Any], ...]:
|
| 140 |
+
return (OpOverload, HigherOrderOperator)
|
| 141 |
+
|
| 142 |
+
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
|
| 143 |
+
return (torch.fx.GraphModule,)
|
| 144 |
+
|
| 145 |
+
def check_valid_op(self, op):
|
| 146 |
+
pass
|
| 147 |
+
|
| 148 |
+
def check_additional(self, gm: GraphModule) -> None:
|
| 149 |
+
"""
|
| 150 |
+
Additional checks that are specific to some dialects.
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
@final
|
| 154 |
+
def check(self, ep: "ExportedProgram") -> None:
|
| 155 |
+
self._check_graph_module(ep.graph_module)
|
| 156 |
+
_verify_exported_program_module_call_graph(ep)
|
| 157 |
+
_verify_exported_program_signature(ep)
|
| 158 |
+
|
| 159 |
+
@final
|
| 160 |
+
def _check_graph_module(self, gm: torch.fx.GraphModule) -> None:
|
| 161 |
+
def _allowed_getattr_types() -> Tuple[Type[Any], ...]:
|
| 162 |
+
ret = self.allowed_getattr_types()
|
| 163 |
+
assert not any(t is object for t in ret)
|
| 164 |
+
return ret
|
| 165 |
+
|
| 166 |
+
def _check_valid_op(op) -> None:
|
| 167 |
+
def _allowed_builtin_ops() -> List:
|
| 168 |
+
ret = self.allowed_builtin_ops()
|
| 169 |
+
assert all(inspect.isbuiltin(op) for op in ret)
|
| 170 |
+
return ret
|
| 171 |
+
|
| 172 |
+
def _allowed_op_types() -> Tuple[Type[Any], ...]:
|
| 173 |
+
ret = self.allowed_op_types()
|
| 174 |
+
assert not any(t is object for t in ret)
|
| 175 |
+
return ret
|
| 176 |
+
|
| 177 |
+
# TODO Remove this allowlist.
|
| 178 |
+
_allowed_torch_functions = (
|
| 179 |
+
torch.autograd.grad_mode.set_grad_enabled,
|
| 180 |
+
torch.sym_int,
|
| 181 |
+
torch.sym_float,
|
| 182 |
+
torch.sym_ite,
|
| 183 |
+
torch.sym_max,
|
| 184 |
+
torch.sym_min,
|
| 185 |
+
torch.sym_not,
|
| 186 |
+
torch.sym_sqrt,
|
| 187 |
+
# TODO (tmanlaibaatar)
|
| 188 |
+
# Predispatch export is able to contain autograd ops.
|
| 189 |
+
# These will be modeled as HOO later
|
| 190 |
+
torch._C._set_grad_enabled,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
if not isinstance(op, _allowed_op_types()):
|
| 194 |
+
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions:
|
| 195 |
+
raise SpecViolationError(
|
| 196 |
+
f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n"
|
| 197 |
+
f"Valid builtin ops: {_allowed_builtin_ops()}"
|
| 198 |
+
f"Valid torch functions: {_allowed_torch_functions}"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if isinstance(op, OpOverload):
|
| 202 |
+
# All ops functional
|
| 203 |
+
# TODO (tmanlaibaatar) more proper way is needed here
|
| 204 |
+
if self.dialect != "TRAINING" and not is_functional(op):
|
| 205 |
+
raise SpecViolationError(
|
| 206 |
+
f"operator '{op}' is not functional"
|
| 207 |
+
)
|
| 208 |
+
self.check_valid_op(op)
|
| 209 |
+
|
| 210 |
+
for mod in gm.modules():
|
| 211 |
+
if not isinstance(mod, torch.fx.GraphModule):
|
| 212 |
+
continue
|
| 213 |
+
|
| 214 |
+
mod.graph.lint()
|
| 215 |
+
for node in mod.graph.nodes:
|
| 216 |
+
# TODO(T140410192): should have fake tensor for all dialects
|
| 217 |
+
if node.op in {"call_module", "call_method"}:
|
| 218 |
+
raise SpecViolationError(
|
| 219 |
+
f"call_module is not valid: got a class '{node.target}' ",
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
elif node.op == "call_function":
|
| 223 |
+
_check_val(node)
|
| 224 |
+
|
| 225 |
+
_check_valid_op(node.target)
|
| 226 |
+
|
| 227 |
+
elif node.op == "get_attr":
|
| 228 |
+
if not isinstance(node.target, str):
|
| 229 |
+
raise SpecViolationError(
|
| 230 |
+
f"Expected get_attr target to be string, but got {type(node.target)}"
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
attr = getattr_recursive(mod, node.target)
|
| 234 |
+
if isinstance(attr, torch.nn.Module):
|
| 235 |
+
def _is_type(name, ty):
|
| 236 |
+
return isinstance(getattr(attr, name, None), ty)
|
| 237 |
+
if type(attr).__name__ == "LoweredBackendModule":
|
| 238 |
+
if _is_type("backend_id", str) \
|
| 239 |
+
and _is_type("processed_bytes", bytes) \
|
| 240 |
+
and _is_type("compile_specs", list) \
|
| 241 |
+
and hasattr(attr, "original_module"):
|
| 242 |
+
continue
|
| 243 |
+
else:
|
| 244 |
+
backend_id = getattr(attr, "backend_id", None)
|
| 245 |
+
processed_bytes = getattr(attr, "processed_bytes", None)
|
| 246 |
+
compile_specs = getattr(attr, "compile_specs", None)
|
| 247 |
+
raise SpecViolationError(
|
| 248 |
+
f"Invalid get_attr type {type(attr)}. \n"
|
| 249 |
+
f"LoweredBackendModule fields: "
|
| 250 |
+
f"backend_id(str) : {type(backend_id)}, "
|
| 251 |
+
f"processed_bytes(bytes) : {type(processed_bytes)}, "
|
| 252 |
+
f"compile_specs(list) : {type(compile_specs)}"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
if not isinstance(attr, _allowed_getattr_types()):
|
| 256 |
+
raise SpecViolationError(
|
| 257 |
+
f"Invalid get_attr type {type(attr)}. \n"
|
| 258 |
+
f"Valid get_attr types: {_allowed_getattr_types()}"
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
elif node.op == "placeholder":
|
| 263 |
+
_check_val(node)
|
| 264 |
+
# TODO(zhxchen17)
|
| 265 |
+
# elif node.op == "output":
|
| 266 |
+
# _check_flattened_outputs()
|
| 267 |
+
|
| 268 |
+
self.check_additional(gm)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class TrainingIRVerifier(Verifier):
|
| 272 |
+
dialect = "TRAINING"
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _verify_exported_program_module_call_graph(exported_program) -> None:
|
| 276 |
+
module_call_graph = exported_program.module_call_graph
|
| 277 |
+
nodes = {
|
| 278 |
+
node.name for node in exported_program.graph.nodes
|
| 279 |
+
}
|
| 280 |
+
for entry in module_call_graph:
|
| 281 |
+
if entry.signature is not None:
|
| 282 |
+
for arg in entry.signature.inputs:
|
| 283 |
+
if arg.name and arg.name not in nodes:
|
| 284 |
+
raise SpecViolationError(
|
| 285 |
+
f"Input {arg.name} does not exist in the graph."
|
| 286 |
+
)
|
| 287 |
+
for arg in entry.signature.outputs:
|
| 288 |
+
if arg.name and arg.name not in nodes:
|
| 289 |
+
raise SpecViolationError(
|
| 290 |
+
f"Output {arg.name} does not exist in the graph."
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def _verify_exported_program_signature(exported_program) -> None:
|
| 295 |
+
# Check ExportedProgram signature matches
|
| 296 |
+
gs = exported_program.graph_signature
|
| 297 |
+
|
| 298 |
+
# Check every node in the signature exists in the graph
|
| 299 |
+
input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"]
|
| 300 |
+
|
| 301 |
+
if len(input_node_names) != len(gs.input_specs):
|
| 302 |
+
raise SpecViolationError(
|
| 303 |
+
f"Number of graph inputs ({len(input_node_names)}) "
|
| 304 |
+
f"does not match number of inputs in the graph signature ({len(gs.input_specs)})"
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
for input_spec, node in zip(gs.input_specs, input_node_names):
|
| 308 |
+
if isinstance(input_spec.arg, (TensorArgument, SymIntArgument)):
|
| 309 |
+
if input_spec.arg.name != node:
|
| 310 |
+
raise SpecViolationError(
|
| 311 |
+
f"Input spec name {input_spec.arg.name} does not match node name {node}"
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
if input_spec.kind == InputKind.USER_INPUT:
|
| 315 |
+
continue
|
| 316 |
+
|
| 317 |
+
elif input_spec.kind == InputKind.PARAMETER:
|
| 318 |
+
if not isinstance(input_spec.arg, TensorArgument):
|
| 319 |
+
raise SpecViolationError(
|
| 320 |
+
f"Parameter {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
|
| 321 |
+
)
|
| 322 |
+
if input_spec.target is None:
|
| 323 |
+
raise SpecViolationError(
|
| 324 |
+
f"InputSpec for {input_spec.name} has no target."
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
param = input_spec.target
|
| 328 |
+
if param not in exported_program.state_dict:
|
| 329 |
+
raise SpecViolationError(
|
| 330 |
+
f"Parameter {param} is not in the state dict."
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
if not isinstance(exported_program.state_dict[param], torch.nn.Parameter):
|
| 334 |
+
raise SpecViolationError(
|
| 335 |
+
f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter."
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
elif input_spec.kind == InputKind.BUFFER:
|
| 339 |
+
if not isinstance(input_spec.arg, TensorArgument):
|
| 340 |
+
raise SpecViolationError(
|
| 341 |
+
f"Buffer {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
|
| 342 |
+
)
|
| 343 |
+
if input_spec.target is None:
|
| 344 |
+
raise SpecViolationError(
|
| 345 |
+
f"InputSpec for {input_spec.name} has no target."
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
buffer = input_spec.target
|
| 349 |
+
if input_spec.persistent is None:
|
| 350 |
+
raise SpecViolationError(
|
| 351 |
+
f"Buffer {buffer} is missing a persistence flag"
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
if input_spec.persistent is True and buffer not in exported_program.state_dict:
|
| 355 |
+
raise SpecViolationError(
|
| 356 |
+
f"Buffer {buffer} is not in the state dict."
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
if input_spec.persistent is False and buffer in exported_program.state_dict:
|
| 360 |
+
raise SpecViolationError(
|
| 361 |
+
f"Non-persistent buffer {buffer} is in the state dict, it should not be."
|
| 362 |
+
)
|
| 363 |
+
elif input_spec.kind == InputKind.CONSTANT_TENSOR:
|
| 364 |
+
if not isinstance(input_spec.arg, TensorArgument):
|
| 365 |
+
raise SpecViolationError(
|
| 366 |
+
f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
|
| 367 |
+
)
|
| 368 |
+
if input_spec.target is None:
|
| 369 |
+
raise SpecViolationError(
|
| 370 |
+
f"InputSpec for {input_spec.name} has no target."
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
tensor_const = input_spec.target
|
| 374 |
+
if tensor_const not in exported_program.constants:
|
| 375 |
+
raise SpecViolationError(
|
| 376 |
+
f"Constant tensor {tensor_const} is not in the constants dictionary."
|
| 377 |
+
)
|
| 378 |
+
elif input_spec.kind == InputKind.CUSTOM_OBJ:
|
| 379 |
+
if not isinstance(input_spec.arg, CustomObjArgument):
|
| 380 |
+
raise SpecViolationError(
|
| 381 |
+
f"Custom object {input_spec.name} is not a custom object argument. Found {input_spec.arg} instead."
|
| 382 |
+
)
|
| 383 |
+
if input_spec.target is None:
|
| 384 |
+
raise SpecViolationError(
|
| 385 |
+
f"InputSpec for {input_spec.name} has no target."
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
custom_obj = input_spec.target
|
| 389 |
+
if custom_obj not in exported_program.constants:
|
| 390 |
+
raise SpecViolationError(
|
| 391 |
+
f"Custom object {custom_obj} is not in the constants dictionary."
|
| 392 |
+
)
|
| 393 |
+
elif input_spec.kind == InputKind.TOKEN:
|
| 394 |
+
if not isinstance(input_spec.arg, TokenArgument):
|
| 395 |
+
raise SpecViolationError(
|
| 396 |
+
f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
|
| 397 |
+
)
|
| 398 |
+
else:
|
| 399 |
+
raise SpecViolationError(
|
| 400 |
+
f"Unknown InputKind {input_spec.kind}."
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
# Check outputs
|
| 404 |
+
output_node = list(exported_program.graph.nodes)[-1]
|
| 405 |
+
assert output_node.op == "output"
|
| 406 |
+
output_nodes = [
|
| 407 |
+
arg.name if isinstance(arg, torch.fx.Node) else arg
|
| 408 |
+
for arg in output_node.args[0]
|
| 409 |
+
]
|
| 410 |
+
|
| 411 |
+
if len(output_nodes) != len(gs.output_specs):
|
| 412 |
+
raise SpecViolationError(
|
| 413 |
+
f"Number of output nodes {len(output_nodes)} is different "
|
| 414 |
+
"Than the number of outputs specified by the graph signature: \n"
|
| 415 |
+
f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n"
|
| 416 |
+
f"Number of user outputs: {len(gs.user_outputs)}. \n"
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
num_tokens = len(gs.output_tokens)
|
| 420 |
+
end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens
|
| 421 |
+
mutate_nodes: List[str] = output_nodes[num_tokens:end]
|
| 422 |
+
user_output_nodes = output_nodes[end:end + len(gs.user_outputs)]
|
| 423 |
+
|
| 424 |
+
for mutation_node in mutate_nodes:
|
| 425 |
+
if mutation_node in gs.buffers_to_mutate:
|
| 426 |
+
if gs.buffers_to_mutate[mutation_node] not in gs.buffers:
|
| 427 |
+
raise SpecViolationError(
|
| 428 |
+
f"Buffer output {mutation_node} does not point to a buffer that exists. \n"
|
| 429 |
+
f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
|
| 430 |
+
f"Buffer nodes available: {gs.buffers} \n"
|
| 431 |
+
)
|
| 432 |
+
elif mutation_node in gs.user_inputs_to_mutate:
|
| 433 |
+
if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs:
|
| 434 |
+
raise SpecViolationError(
|
| 435 |
+
f"User input output {mutation_node} does not point to a user input that exists. \n"
|
| 436 |
+
f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n"
|
| 437 |
+
f"User input nodes available: {gs.user_inputs} \n")
|
| 438 |
+
else:
|
| 439 |
+
raise SpecViolationError(
|
| 440 |
+
f"Mutation node {mutation_node} is neither a buffer nor a user input. "
|
| 441 |
+
f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}"
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs):
|
| 445 |
+
if user_output_node != user_output_name:
|
| 446 |
+
raise SpecViolationError(
|
| 447 |
+
f"User output {user_output_node} is not in the correct "
|
| 448 |
+
"order or is not found in the "
|
| 449 |
+
f"exported program's user_output list: {gs.user_outputs}. "
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def load_verifier(dialect: str) -> Type[Verifier]:
|
| 454 |
+
if dialect == "ATEN" or dialect == "":
|
| 455 |
+
return _VerifierMeta._registry.get(dialect, Verifier)
|
| 456 |
+
return _VerifierMeta._registry[dialect]
|
.venv/lib/python3.11/site-packages/torch/_export/wrappers.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch._custom_ops
|
| 6 |
+
from torch._C import DispatchKey
|
| 7 |
+
from torch._higher_order_ops.strict_mode import strict_mode
|
| 8 |
+
from torch._higher_order_ops.utils import autograd_not_implemented
|
| 9 |
+
from torch._ops import HigherOrderOperator
|
| 10 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 11 |
+
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
| 12 |
+
from torch.utils import _pytree as pytree
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ExportTracepoint(HigherOrderOperator):
|
| 16 |
+
def __init__(self):
|
| 17 |
+
super().__init__("_export_tracepoint")
|
| 18 |
+
|
| 19 |
+
def __call__(self, *args, **kwargs):
|
| 20 |
+
return super().__call__(*args, **kwargs)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
_export_tracepoint = ExportTracepoint()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@_export_tracepoint.py_impl(ProxyTorchDispatchMode)
|
| 27 |
+
def export_tracepoint_dispatch_mode(mode, *args, **kwargs):
|
| 28 |
+
p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
|
| 29 |
+
proxy = mode.tracer.create_proxy(
|
| 30 |
+
"call_function", _export_tracepoint, p_args, p_kwargs
|
| 31 |
+
)
|
| 32 |
+
return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@_export_tracepoint.py_impl(FakeTensorMode)
|
| 36 |
+
def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs):
|
| 37 |
+
with mode:
|
| 38 |
+
return args
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@_export_tracepoint.py_functionalize_impl
|
| 42 |
+
def export_tracepoint_functional(ctx, *args, **kwargs):
|
| 43 |
+
unwrapped_args = ctx.unwrap_tensors(args)
|
| 44 |
+
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
| 45 |
+
|
| 46 |
+
with ctx.redispatch_to_next():
|
| 47 |
+
out = _export_tracepoint(*unwrapped_args, **unwrapped_kwargs)
|
| 48 |
+
return ctx.wrap_tensors(out)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
_export_tracepoint.py_impl(DispatchKey.Autograd)(
|
| 52 |
+
autograd_not_implemented(_export_tracepoint, deferred_error=True)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@_export_tracepoint.py_impl(DispatchKey.CPU)
|
| 57 |
+
def export_tracepoint_cpu(*args, **kwargs):
|
| 58 |
+
return args
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _wrap_submodule(mod, path, module_call_specs):
|
| 62 |
+
assert isinstance(mod, torch.nn.Module)
|
| 63 |
+
assert path != ""
|
| 64 |
+
submodule = mod
|
| 65 |
+
for name in path.split("."):
|
| 66 |
+
if not hasattr(submodule, name):
|
| 67 |
+
raise RuntimeError(f"Couldn't find submodule at path {path}")
|
| 68 |
+
submodule = getattr(submodule, name)
|
| 69 |
+
|
| 70 |
+
def update_module_call_signatures(path, in_spec, out_spec):
|
| 71 |
+
if path in module_call_specs:
|
| 72 |
+
assert module_call_specs[path]["in_spec"] == in_spec
|
| 73 |
+
assert module_call_specs[path]["out_spec"] == out_spec
|
| 74 |
+
module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
|
| 75 |
+
|
| 76 |
+
def check_flattened(flat_args):
|
| 77 |
+
for a in flat_args:
|
| 78 |
+
if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None):
|
| 79 |
+
raise AssertionError(
|
| 80 |
+
f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def pre_hook(module, args, kwargs):
|
| 84 |
+
flat_args, in_spec = pytree.tree_flatten((args, kwargs))
|
| 85 |
+
check_flattened(flat_args)
|
| 86 |
+
flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path)
|
| 87 |
+
args, kwargs = pytree.tree_unflatten(flat_args, in_spec)
|
| 88 |
+
return args, kwargs
|
| 89 |
+
|
| 90 |
+
def post_hook(module, args, kwargs, res):
|
| 91 |
+
_, in_spec = pytree.tree_flatten((args, kwargs))
|
| 92 |
+
flat_res, out_spec = pytree.tree_flatten(res)
|
| 93 |
+
check_flattened(flat_res)
|
| 94 |
+
flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path)
|
| 95 |
+
update_module_call_signatures(path, in_spec, out_spec)
|
| 96 |
+
return pytree.tree_unflatten(flat_res, out_spec)
|
| 97 |
+
|
| 98 |
+
pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True)
|
| 99 |
+
post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True)
|
| 100 |
+
return pre_handle, post_handle
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@contextmanager
|
| 104 |
+
def _wrap_submodules(f, preserve_signature, module_call_signatures):
|
| 105 |
+
handles = []
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
for path in preserve_signature:
|
| 109 |
+
handles.extend(_wrap_submodule(f, path, module_call_signatures))
|
| 110 |
+
yield
|
| 111 |
+
finally:
|
| 112 |
+
for handle in handles:
|
| 113 |
+
handle.remove()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _mark_strict_experimental(cls):
|
| 117 |
+
def call(self, *args):
|
| 118 |
+
return strict_mode(self, args)
|
| 119 |
+
|
| 120 |
+
cls.__call__ = call
|
| 121 |
+
return cls
|
.venv/lib/python3.11/site-packages/torch/_lazy/__init__.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
|
| 3 |
+
import torch._C._lazy
|
| 4 |
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
| 5 |
+
|
| 6 |
+
from .closure import add_step_closure, run_step_closures
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def mark_step(device: str = "", wait=False):
|
| 10 |
+
"""Triggers a mark step, which amounts to
|
| 11 |
+
- collecting a group of 'live' lazy tensors to index into the compilation cache
|
| 12 |
+
(lowering/compiling their IR graphs if not cached)
|
| 13 |
+
- kicking off execution of the compiled function
|
| 14 |
+
- (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator)
|
| 15 |
+
"""
|
| 16 |
+
# TODO(whc) expand this to include backend hooks and align with XLA backend needs
|
| 17 |
+
torch._C._lazy._mark_step(device, [], wait=wait)
|
| 18 |
+
|
| 19 |
+
run_step_closures()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def wait_device_ops(devices=None):
|
| 23 |
+
"""Waits for all the async operations on the given devices to complete.
|
| 24 |
+
Args:
|
| 25 |
+
devices (string..., optional): The devices whose async ops need to be waited
|
| 26 |
+
for. If empty, all the local devices will be waited for.
|
| 27 |
+
"""
|
| 28 |
+
if devices is None:
|
| 29 |
+
devices = []
|
| 30 |
+
torch._C._lazy._wait_device_ops(devices=devices)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def sync_multi(tensors, devices):
|
| 34 |
+
"""
|
| 35 |
+
Sync the list of lazy tensors so there IR get lowered for the activate backend
|
| 36 |
+
and the compiled computation graph get cached.
|
| 37 |
+
"""
|
| 38 |
+
torch._C._lazy._sync_multi(tensors, devices)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_tensor_id(tensor):
|
| 42 |
+
"""Return a unique id of the lazy tensor maintained by LTC"""
|
| 43 |
+
return torch._C._lazy._get_tensor_id(tensor)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def to_cpu(tensors, devices=None):
|
| 47 |
+
devices = devices or ["lazy"]
|
| 48 |
+
|
| 49 |
+
flattened, spec = tree_flatten(tensors)
|
| 50 |
+
sync_multi(flattened, devices)
|
| 51 |
+
return tree_unflatten([t.to("cpu") for t in flattened], spec)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def save(tensors, *args, **kwargs):
|
| 55 |
+
torch.save(to_cpu(tensors), *args, **kwargs)
|
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-311.pyc
ADDED
|
Binary file (859 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-311.pyc
ADDED
|
Binary file (522 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_lazy/computation.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch._C._lazy
|
| 3 |
+
import torch._C._lazy_ts_backend
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_tensors_ts_device_data_node(tensors):
|
| 7 |
+
"""Return tensor ids and eager tensors for DeviceData nodes in the
|
| 8 |
+
IR for the passed in lazy tensors.
|
| 9 |
+
|
| 10 |
+
TODO: This API is currently ts backend specific. We are working on
|
| 11 |
+
generalizing it to all backends including XLA.
|
| 12 |
+
"""
|
| 13 |
+
return torch._C._lazy_ts_backend._get_tensors_ts_device_data_node(tensors)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_graph_hash(tensors):
|
| 17 |
+
"""Return the graph hash for the passed in lazy tensors"""
|
| 18 |
+
return torch._C._lazy._get_graph_hash(tensors)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def run_cached_graph(hash_str, graph_inputs):
|
| 22 |
+
"""Running the cached computation graph with the given inputs
|
| 23 |
+
|
| 24 |
+
TODO: This API is currently ts backend specific. We are working on
|
| 25 |
+
generalizing it to all backends including XLA.
|
| 26 |
+
"""
|
| 27 |
+
return torch._C._lazy_ts_backend._run_cached_graph(hash_str, graph_inputs)
|
.venv/lib/python3.11/site-packages/torch/_lazy/config.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch._C._lazy
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_force_fallback():
|
| 6 |
+
"""Get the config used to force LTC fallback"""
|
| 7 |
+
return torch._C._lazy._get_force_fallback()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def set_force_fallback(configval):
|
| 11 |
+
"""Set the config used to force LTC fallback"""
|
| 12 |
+
torch._C._lazy._set_force_fallback(configval)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def set_reuse_ir(val: bool):
|
| 16 |
+
"""Set the config to reuse IR nodes for faster tracing"""
|
| 17 |
+
torch._C._lazy._set_reuse_ir(val)
|
.venv/lib/python3.11/site-packages/torch/_lazy/debug.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch._C._lazy
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def render_ir_graph(tensors):
|
| 6 |
+
"""Return a text dump of the LTC IR graph in dot format for the tensors.
|
| 7 |
+
The text can be processed by tools like dot to be rendered in pdf,png etc."""
|
| 8 |
+
return torch._C._lazy._get_tensors_dot(tensors)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def dump_ir(tensors, ir_format):
|
| 12 |
+
"""Return a dump of the tensors in the specified format.
|
| 13 |
+
Valid format are
|
| 14 |
+
- text: for LTC IR
|
| 15 |
+
- backend: for the activate backend IR
|
| 16 |
+
"""
|
| 17 |
+
if ir_format == "text":
|
| 18 |
+
return torch._C._lazy._get_tensors_text(tensors)
|
| 19 |
+
elif ir_format == "backend":
|
| 20 |
+
return torch._C._lazy._get_tensors_backend(tensors)
|
| 21 |
+
else:
|
| 22 |
+
raise RuntimeError(f"Unrecognized IR format: {ir_format}")
|
.venv/lib/python3.11/site-packages/torch/_lazy/device_context.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import threading
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
import torch._C._lazy
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DeviceContext:
|
| 9 |
+
_CONTEXTS: Dict[str, Any] = {}
|
| 10 |
+
_CONTEXTS_LOCK = threading.Lock()
|
| 11 |
+
|
| 12 |
+
def __init__(self, device):
|
| 13 |
+
self.device = device
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_device_context(device=None):
|
| 17 |
+
if device is None:
|
| 18 |
+
device = torch._C._lazy._get_default_device_type()
|
| 19 |
+
else:
|
| 20 |
+
device = str(device)
|
| 21 |
+
with DeviceContext._CONTEXTS_LOCK:
|
| 22 |
+
devctx = DeviceContext._CONTEXTS.get(device, None)
|
| 23 |
+
if devctx is None:
|
| 24 |
+
devctx = DeviceContext(device)
|
| 25 |
+
DeviceContext._CONTEXTS[device] = devctx
|
| 26 |
+
return devctx
|
.venv/lib/python3.11/site-packages/torch/_lazy/extract_compiled_graph.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import copy
|
| 3 |
+
import dataclasses
|
| 4 |
+
import itertools
|
| 5 |
+
import os
|
| 6 |
+
from typing import Any, Callable, Dict, List
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._lazy as lazy
|
| 10 |
+
import torch._lazy.metrics as metrics
|
| 11 |
+
from torch import fx
|
| 12 |
+
from torch._lazy import computation, debug as lazy_debug
|
| 13 |
+
from torch._lazy.tensor_factory_functions import tensor_factory_functions
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
debug = os.environ.get("debug_extract_compiled_graph") is not None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclasses.dataclass
|
| 20 |
+
class GraphInputMatcher:
|
| 21 |
+
"""
|
| 22 |
+
The GraphInputMatcher class setup the graph inputs for future calls after lazy tracing.
|
| 23 |
+
Specifically, those graph inputs corresponding to method parameters should be replaced with the
|
| 24 |
+
arguments for the current call.
|
| 25 |
+
|
| 26 |
+
tensor_id_to_arg_idx maps the tensor id to the parameter index.
|
| 27 |
+
graph_input_tensor_ids, graph_input_ivalues list the tensor_id and ivalue for each of the
|
| 28 |
+
TS/XLA graph inputs.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
tensor_id_to_arg_idx: Dict[int, int]
|
| 32 |
+
graph_input_tensor_ids: List[int]
|
| 33 |
+
# there are 2 categories of graph_input_tensors.
|
| 34 |
+
# Category 1: those whose id are not found in tensor_id_to_arg_idx. These are
|
| 35 |
+
# most likely const tensors and we can get its content from graph_input_tensors
|
| 36 |
+
# Category 2: those whose id are found in tensor_id_to_arg_idx. We should get
|
| 37 |
+
# the tensor from method arguments
|
| 38 |
+
graph_input_ivalues: List[Any]
|
| 39 |
+
|
| 40 |
+
# get the real graph input tensors
|
| 41 |
+
def __call__(self, args):
|
| 42 |
+
real_input = []
|
| 43 |
+
for tensor_id, traced_ivalue in zip(
|
| 44 |
+
self.graph_input_tensor_ids, self.graph_input_ivalues
|
| 45 |
+
):
|
| 46 |
+
arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None)
|
| 47 |
+
if arg_idx is None:
|
| 48 |
+
inp = traced_ivalue
|
| 49 |
+
else:
|
| 50 |
+
inp = args[arg_idx]
|
| 51 |
+
real_input.append(inp)
|
| 52 |
+
return real_input
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class ReturnValueHandler:
|
| 56 |
+
r"""
|
| 57 |
+
When ltc_sync_multi is called on multi tensors, the compiled graph
|
| 58 |
+
will contain output only for unique tensors - if a tensor appears multiple
|
| 59 |
+
times in the input to _ltc_sync_multi, only the first occurance matters.
|
| 60 |
+
|
| 61 |
+
However from python level, we still expect multi tensors returned with duplciation
|
| 62 |
+
even if the TS graph dedup the output. e.g. for method:
|
| 63 |
+
|
| 64 |
+
def forward(self, a):
|
| 65 |
+
return a, a
|
| 66 |
+
|
| 67 |
+
the TS graph captured by LTC will return a single tensor, but Python method expects 2.
|
| 68 |
+
|
| 69 |
+
This class dedup the lazy tensors first to get the index that will be used
|
| 70 |
+
to duplicate the eager tensors later.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, lazy_out_list):
|
| 74 |
+
self.index: List[List[int]] = []
|
| 75 |
+
self.total_count = len(lazy_out_list)
|
| 76 |
+
|
| 77 |
+
tensor_id_to_idx: Dict[int, int] = {}
|
| 78 |
+
for dup_idx, lazy_tensor in enumerate(lazy_out_list):
|
| 79 |
+
uniq_idx = tensor_id_to_idx.get(id(lazy_tensor), None)
|
| 80 |
+
if uniq_idx is not None:
|
| 81 |
+
self.index[uniq_idx].append(dup_idx)
|
| 82 |
+
else:
|
| 83 |
+
uniq_idx = len(self.index)
|
| 84 |
+
self.index.append([dup_idx])
|
| 85 |
+
tensor_id_to_idx[id(lazy_tensor)] = uniq_idx
|
| 86 |
+
|
| 87 |
+
def duplicate_eager_tensors(self, eager_tensor_list):
|
| 88 |
+
duplicated_list = [None] * self.total_count
|
| 89 |
+
assert len(eager_tensor_list) == len(self.index)
|
| 90 |
+
|
| 91 |
+
for uniq_idx, eager_tensor in enumerate(eager_tensor_list):
|
| 92 |
+
for dup_idx in self.index[uniq_idx]:
|
| 93 |
+
duplicated_list[dup_idx] = eager_tensor
|
| 94 |
+
return duplicated_list
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def force_lazy_device(model: fx.GraphModule):
|
| 98 |
+
"""
|
| 99 |
+
Factory methods in a Fx graph may create tensors for a specific eager devices.
|
| 100 |
+
If we take no actions, those eager tensors will be mixed with lazy tensors and
|
| 101 |
+
cause crash. This method overwrite those eager device to lazy device.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def tolazydevice(dev):
|
| 105 |
+
if isinstance(dev, torch.device):
|
| 106 |
+
return torch.device("lazy", index=dev.index)
|
| 107 |
+
return dev
|
| 108 |
+
|
| 109 |
+
def hasDeviceArg(args, kwargs):
|
| 110 |
+
return any(
|
| 111 |
+
isinstance(arg, torch.device)
|
| 112 |
+
for arg in itertools.chain(args, kwargs.values())
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
for nd in model.graph.nodes:
|
| 116 |
+
nd.args = tuple(tolazydevice(arg) for arg in nd.args)
|
| 117 |
+
nd.kwargs = {k: tolazydevice(v) for k, v in nd.kwargs.items()}
|
| 118 |
+
|
| 119 |
+
# For torchbench like yolov3, hf_Bart, dynamo generates Fx graph that return
|
| 120 |
+
# eager tensors on the default device
|
| 121 |
+
# (check https://gist.github.com/shunting314/eabdf6c769c59bc384469717b8f9bb7f for yolove,
|
| 122 |
+
# and https://gist.github.com/shunting314/8d5e2d9348a3258959d3954186c48814 for hf_Bart).
|
| 123 |
+
# To force those tensors on the lazy device, we can not simply override
|
| 124 |
+
# the device argument since there is no explicit device argument.
|
| 125 |
+
# What we are doing here is, for the list of covered tensor factory methods
|
| 126 |
+
# we add a lazy device argument explicity.
|
| 127 |
+
#
|
| 128 |
+
# TODO: This solution is no ideal since we may miss some factory methods. In future
|
| 129 |
+
# when we support lazy mode, this method can be replaced by that.
|
| 130 |
+
if nd.target in tensor_factory_functions and not hasDeviceArg(
|
| 131 |
+
nd.args, nd.kwargs
|
| 132 |
+
):
|
| 133 |
+
kwargs = dict(nd.kwargs) # nd.kwargs is immutable. make a mutable copy.
|
| 134 |
+
kwargs["device"] = torch.device("lazy")
|
| 135 |
+
nd.kwargs = kwargs
|
| 136 |
+
|
| 137 |
+
model.recompile()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_fallback_ops():
|
| 141 |
+
fallback_ops = []
|
| 142 |
+
for opname in metrics.counter_names():
|
| 143 |
+
if "aten::" not in opname:
|
| 144 |
+
continue
|
| 145 |
+
val = int(metrics.counter_value(opname))
|
| 146 |
+
if val > 0:
|
| 147 |
+
fallback_ops.append(f"{opname}={val}")
|
| 148 |
+
|
| 149 |
+
return fallback_ops
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable:
|
| 153 |
+
"""
|
| 154 |
+
Optimize an eager model with LTC and returns a wrapper to execute the
|
| 155 |
+
compiled graph directly without retracing. It depends on other mechanisms
|
| 156 |
+
like TorchDynamo guards to guarantee the returned wrapper is only called
|
| 157 |
+
when it's safe.
|
| 158 |
+
"""
|
| 159 |
+
lazy_args = [arg.to(device="lazy") for arg in example_inputs]
|
| 160 |
+
args_tensor_ids = [lazy.get_tensor_id(lazy_arg) for lazy_arg in lazy_args]
|
| 161 |
+
tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)}
|
| 162 |
+
lazy_model = copy.deepcopy(model).to(device=torch.device("lazy"))
|
| 163 |
+
force_lazy_device(lazy_model)
|
| 164 |
+
|
| 165 |
+
# This line executes lazy tracing and enable us extracting compiled graph later
|
| 166 |
+
metrics.reset()
|
| 167 |
+
lazy_out = lazy_model(*lazy_args)
|
| 168 |
+
fallback_ops = get_fallback_ops()
|
| 169 |
+
metrics.reset()
|
| 170 |
+
|
| 171 |
+
if len(fallback_ops) > 0:
|
| 172 |
+
raise RuntimeError(
|
| 173 |
+
f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
if not isinstance(lazy_out, (tuple, list)):
|
| 177 |
+
lazy_out = (lazy_out,)
|
| 178 |
+
|
| 179 |
+
args_and_out = tuple(lazy_args) + tuple(lazy_out)
|
| 180 |
+
return_value_handler = ReturnValueHandler(args_and_out)
|
| 181 |
+
if debug:
|
| 182 |
+
print("Fx code:\n", model.code)
|
| 183 |
+
print("LTC IR:", lazy_debug.dump_ir(args_and_out, "text"))
|
| 184 |
+
|
| 185 |
+
# TODO: this part is TS backend specific for now and will be generalized to
|
| 186 |
+
# support XLA
|
| 187 |
+
(
|
| 188 |
+
graph_input_tensor_ids,
|
| 189 |
+
graph_input_ivalues,
|
| 190 |
+
) = computation.get_tensors_ts_device_data_node(args_and_out)
|
| 191 |
+
assert len(graph_input_tensor_ids) == len(graph_input_ivalues)
|
| 192 |
+
graph_input_matcher = GraphInputMatcher(
|
| 193 |
+
tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
graph_hash = computation.get_graph_hash(args_and_out)
|
| 197 |
+
|
| 198 |
+
if debug:
|
| 199 |
+
print("graph_hash", graph_hash)
|
| 200 |
+
print(f"args_tensor_ids {args_tensor_ids}")
|
| 201 |
+
print("tensor ids from device data:", graph_input_tensor_ids)
|
| 202 |
+
|
| 203 |
+
# sync the list of output tensors so the computation graph for these
|
| 204 |
+
# tensors will be cached. Those computation graphs can be retrieved
|
| 205 |
+
# by graph hash later.
|
| 206 |
+
lazy.sync_multi(args_and_out, [])
|
| 207 |
+
|
| 208 |
+
def optimized_mod(*args):
|
| 209 |
+
if len(args_and_out) == 0:
|
| 210 |
+
return ()
|
| 211 |
+
graph_input = graph_input_matcher(args)
|
| 212 |
+
res = return_value_handler.duplicate_eager_tensors(
|
| 213 |
+
computation.run_cached_graph(graph_hash, graph_input)
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
assert len(res) == len(args_and_out)
|
| 217 |
+
for i, arg in enumerate(args):
|
| 218 |
+
# only copy those tensors that get inplace updated
|
| 219 |
+
if arg is not res[i]:
|
| 220 |
+
arg.copy_(res[i])
|
| 221 |
+
|
| 222 |
+
# skip the args
|
| 223 |
+
return res[len(args) :]
|
| 224 |
+
|
| 225 |
+
return optimized_mod
|
.venv/lib/python3.11/site-packages/torch/_lazy/metrics.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch._C._lazy
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def reset():
|
| 6 |
+
"""Resets all metric counters."""
|
| 7 |
+
torch._C._lazy._reset_metrics()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def counter_names():
|
| 11 |
+
"""Retrieves all the currently active counter names."""
|
| 12 |
+
return torch._C._lazy._counter_names()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def counter_value(name: str):
|
| 16 |
+
"""Return the value of the counter with the speficied name"""
|
| 17 |
+
return torch._C._lazy._counter_value(name)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def metrics_report():
|
| 21 |
+
"""Return the combined (lazy core and backend) metric report"""
|
| 22 |
+
return torch._C._lazy._metrics_report()
|
.venv/lib/python3.11/site-packages/torch/_lazy/ts_backend.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch._C._lazy_ts_backend
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def init():
|
| 6 |
+
"""Initializes the lazy Torchscript backend"""
|
| 7 |
+
torch._C._lazy_ts_backend._init()
|
.venv/lib/python3.11/site-packages/torch/multiprocessing/_atfork.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
__all__ = ["register_after_fork"]
|
| 6 |
+
|
| 7 |
+
if sys.platform == "win32":
|
| 8 |
+
import multiprocessing.util as _util
|
| 9 |
+
|
| 10 |
+
def _register(func):
|
| 11 |
+
def wrapper(arg):
|
| 12 |
+
func()
|
| 13 |
+
|
| 14 |
+
_util.register_after_fork(_register, wrapper)
|
| 15 |
+
|
| 16 |
+
else:
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
def _register(func):
|
| 20 |
+
os.register_at_fork(after_in_child=func)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def register_after_fork(func):
|
| 24 |
+
"""Register a callable to be executed in the child process after a fork.
|
| 25 |
+
|
| 26 |
+
Note:
|
| 27 |
+
In python < 3.7 this will only work with processes created using the
|
| 28 |
+
``multiprocessing`` module. In python >= 3.7 it also works with
|
| 29 |
+
``os.fork()``.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
func (function): Function taking no arguments to be called in the child after fork
|
| 33 |
+
|
| 34 |
+
"""
|
| 35 |
+
_register(func)
|
.venv/lib/python3.11/site-packages/torch/multiprocessing/pool.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import multiprocessing.pool
|
| 2 |
+
import multiprocessing.util as util
|
| 3 |
+
|
| 4 |
+
from .queue import SimpleQueue
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def clean_worker(*args, **kwargs):
|
| 8 |
+
import gc
|
| 9 |
+
|
| 10 |
+
multiprocessing.pool.worker(*args, **kwargs)
|
| 11 |
+
# Regular multiprocessing workers don't fully clean up after themselves,
|
| 12 |
+
# so we have to explicitly trigger garbage collection to make sure that all
|
| 13 |
+
# destructors are called...
|
| 14 |
+
gc.collect()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Pool(multiprocessing.pool.Pool):
|
| 18 |
+
"""Pool implementation which uses our version of SimpleQueue.
|
| 19 |
+
|
| 20 |
+
This lets us pass tensors in shared memory across processes instead of
|
| 21 |
+
serializing the underlying data.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def _setup_queues(self):
|
| 25 |
+
self._inqueue = SimpleQueue()
|
| 26 |
+
self._outqueue = SimpleQueue()
|
| 27 |
+
self._quick_put = self._inqueue._writer.send
|
| 28 |
+
self._quick_get = self._outqueue._reader.recv
|
| 29 |
+
|
| 30 |
+
def _repopulate_pool(self):
|
| 31 |
+
"""Increase the number of pool processes to the specified number.
|
| 32 |
+
|
| 33 |
+
Bring the number of pool processes up to the specified number, for use after
|
| 34 |
+
reaping workers which have exited.
|
| 35 |
+
"""
|
| 36 |
+
for i in range(self._processes - len(self._pool)):
|
| 37 |
+
# changed worker -> clean_worker
|
| 38 |
+
args = (
|
| 39 |
+
self._inqueue,
|
| 40 |
+
self._outqueue,
|
| 41 |
+
self._initializer,
|
| 42 |
+
self._initargs,
|
| 43 |
+
self._maxtasksperchild,
|
| 44 |
+
)
|
| 45 |
+
if hasattr(self, "_wrap_exception"):
|
| 46 |
+
args += (self._wrap_exception,)
|
| 47 |
+
w = self.Process(target=clean_worker, args=args)
|
| 48 |
+
self._pool.append(w)
|
| 49 |
+
w.name = w.name.replace("Process", "PoolWorker")
|
| 50 |
+
w.daemon = True
|
| 51 |
+
w.start()
|
| 52 |
+
util.debug("added worker")
|
.venv/lib/python3.11/site-packages/torch/multiprocessing/queue.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import io
|
| 3 |
+
import multiprocessing.queues
|
| 4 |
+
import pickle
|
| 5 |
+
from multiprocessing.reduction import ForkingPickler
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ConnectionWrapper:
|
| 9 |
+
"""Proxy class for _multiprocessing.Connection which uses ForkingPickler for object serialization."""
|
| 10 |
+
|
| 11 |
+
def __init__(self, conn):
|
| 12 |
+
self.conn = conn
|
| 13 |
+
|
| 14 |
+
def send(self, obj):
|
| 15 |
+
buf = io.BytesIO()
|
| 16 |
+
ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj)
|
| 17 |
+
self.send_bytes(buf.getvalue())
|
| 18 |
+
|
| 19 |
+
def recv(self):
|
| 20 |
+
buf = self.recv_bytes()
|
| 21 |
+
return pickle.loads(buf)
|
| 22 |
+
|
| 23 |
+
def __getattr__(self, name):
|
| 24 |
+
if "conn" in self.__dict__:
|
| 25 |
+
return getattr(self.conn, name)
|
| 26 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'conn'")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Queue(multiprocessing.queues.Queue):
|
| 30 |
+
def __init__(self, *args, **kwargs):
|
| 31 |
+
super().__init__(*args, **kwargs)
|
| 32 |
+
self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
|
| 33 |
+
self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
|
| 34 |
+
self._send = self._writer.send
|
| 35 |
+
self._recv = self._reader.recv
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class SimpleQueue(multiprocessing.queues.SimpleQueue):
|
| 39 |
+
def _make_methods(self):
|
| 40 |
+
if not isinstance(self._reader, ConnectionWrapper):
|
| 41 |
+
self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
|
| 42 |
+
self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
|
| 43 |
+
super()._make_methods() # type: ignore[misc]
|
.venv/lib/python3.11/site-packages/torch/multiprocessing/reductions.py
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import multiprocessing
|
| 3 |
+
import os
|
| 4 |
+
import threading
|
| 5 |
+
from multiprocessing.reduction import ForkingPickler
|
| 6 |
+
from multiprocessing.util import register_after_fork
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch._namedtensor_internals import check_serializing_named_tensor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
# Early load resource_sharer to prevent a partially initialized instance
|
| 15 |
+
# from being inherited in a forked child process. The reduce_storage method
|
| 16 |
+
# requires this module indirectly through DupFd(). The built-in mp.Queue
|
| 17 |
+
# class pickles arguments in a background thread which may overlap with the
|
| 18 |
+
# fork.
|
| 19 |
+
import multiprocessing.resource_sharer
|
| 20 |
+
except ImportError:
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class StorageWeakRef:
|
| 25 |
+
r"""A weak reference to a Storage.
|
| 26 |
+
|
| 27 |
+
The cdata member is a Python number containing the integer representation of
|
| 28 |
+
the Storage pointer.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
__slots__ = ["cdata", "_free_weak_ref"]
|
| 32 |
+
|
| 33 |
+
def __init__(self, storage):
|
| 34 |
+
self.cdata = storage._weak_ref()
|
| 35 |
+
# Save a direct reference to _free_weak_ref because the `torch` module
|
| 36 |
+
# might be cleared during Python shutdown before this module is cleared.
|
| 37 |
+
self._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined]
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def from_weakref(cls, cdata):
|
| 41 |
+
instance = cls.__new__(cls)
|
| 42 |
+
instance.cdata = cdata
|
| 43 |
+
instance._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined]
|
| 44 |
+
return instance
|
| 45 |
+
|
| 46 |
+
def expired(self):
|
| 47 |
+
return torch.Storage._expired(self.cdata) # type: ignore[attr-defined]
|
| 48 |
+
|
| 49 |
+
def __del__(self):
|
| 50 |
+
self._free_weak_ref(self.cdata)
|
| 51 |
+
|
| 52 |
+
def __hash__(self):
|
| 53 |
+
return self.cdata
|
| 54 |
+
|
| 55 |
+
def __eq__(self, other):
|
| 56 |
+
if id(self) == id(other):
|
| 57 |
+
return True
|
| 58 |
+
return self.cdata == other.cdata
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class SharedCache(dict):
|
| 62 |
+
"""Dictionary from multiprocessing handles to StorageWeakRef."""
|
| 63 |
+
|
| 64 |
+
def __init__(self) -> None:
|
| 65 |
+
# free_dead_references() is called if the len exceeds the current
|
| 66 |
+
# limit. The limit scales with the number of remaining live objects.
|
| 67 |
+
self.limit = 128
|
| 68 |
+
# `fork` inherits lock state, so in case we fork when the lock is held,
|
| 69 |
+
# we register a function to reset the lock to a new object to avoid
|
| 70 |
+
# possible deadlocks, following python multiprocessing library design.
|
| 71 |
+
self._after_fork()
|
| 72 |
+
register_after_fork(self, SharedCache._after_fork)
|
| 73 |
+
|
| 74 |
+
def _after_fork(self):
|
| 75 |
+
self.lock = threading.Lock()
|
| 76 |
+
|
| 77 |
+
def get(self, key):
|
| 78 |
+
with self.lock:
|
| 79 |
+
return dict.get(self, key)
|
| 80 |
+
|
| 81 |
+
def __setitem__(self, key, storage_ref):
|
| 82 |
+
with self.lock:
|
| 83 |
+
dict.__setitem__(self, key, storage_ref)
|
| 84 |
+
if len(self) > self.limit:
|
| 85 |
+
self.free_dead_references()
|
| 86 |
+
|
| 87 |
+
def free_dead_references(self):
|
| 88 |
+
live = 0
|
| 89 |
+
for key, storage_ref in list(self.items()):
|
| 90 |
+
if storage_ref.expired():
|
| 91 |
+
del self[key]
|
| 92 |
+
else:
|
| 93 |
+
live += 1
|
| 94 |
+
self.limit = max(128, live * 2)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# mapping from handles to StorageWeakRef objects
|
| 98 |
+
shared_cache = SharedCache()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def rebuild_event(device, handle):
|
| 102 |
+
return torch.cuda.Event.from_ipc_handle(device, handle)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def reduce_event(event):
|
| 106 |
+
handle = event.ipc_handle()
|
| 107 |
+
return (rebuild_event, (event.device, handle))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def rebuild_tensor(cls, storage, metadata):
|
| 111 |
+
storage_offset, size, stride, requires_grad = metadata
|
| 112 |
+
t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
|
| 113 |
+
if cls == torch.nn.parameter.Parameter:
|
| 114 |
+
# we have to pass requires_grad into constructor, rather than set it as an
|
| 115 |
+
# attribute later, because it's an important check for Integer Tensors to
|
| 116 |
+
# have requires_grad=False (or else they raise an error)
|
| 117 |
+
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
|
| 118 |
+
else:
|
| 119 |
+
t.requires_grad = requires_grad
|
| 120 |
+
return t
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def rebuild_meta_tensor(
|
| 124 |
+
tensor_cls,
|
| 125 |
+
tensor_size,
|
| 126 |
+
tensor_stride,
|
| 127 |
+
tensor_offset,
|
| 128 |
+
dtype,
|
| 129 |
+
storage_size_bytes,
|
| 130 |
+
requires_grad,
|
| 131 |
+
):
|
| 132 |
+
untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta")
|
| 133 |
+
|
| 134 |
+
typed_storage = torch.TypedStorage(
|
| 135 |
+
wrap_storage=untyped_storage, dtype=dtype, _internal=True
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
t = torch._utils._rebuild_tensor(
|
| 139 |
+
typed_storage,
|
| 140 |
+
tensor_offset,
|
| 141 |
+
tensor_size,
|
| 142 |
+
tensor_stride,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if tensor_cls == torch.nn.parameter.Parameter:
|
| 146 |
+
# It is crucial for integer tensors to receive
|
| 147 |
+
# the requires_grad=False as an argument in the constructor
|
| 148 |
+
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
|
| 149 |
+
else:
|
| 150 |
+
t.requires_grad = requires_grad
|
| 151 |
+
|
| 152 |
+
return t
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def rebuild_cuda_tensor(
|
| 156 |
+
tensor_cls,
|
| 157 |
+
tensor_size,
|
| 158 |
+
tensor_stride,
|
| 159 |
+
tensor_offset,
|
| 160 |
+
storage_cls,
|
| 161 |
+
dtype,
|
| 162 |
+
storage_device,
|
| 163 |
+
storage_handle,
|
| 164 |
+
storage_size_bytes,
|
| 165 |
+
storage_offset_bytes,
|
| 166 |
+
requires_grad,
|
| 167 |
+
ref_counter_handle,
|
| 168 |
+
ref_counter_offset,
|
| 169 |
+
event_handle,
|
| 170 |
+
event_sync_required,
|
| 171 |
+
):
|
| 172 |
+
# If storage_handle is None, storage points to nullptr.
|
| 173 |
+
if storage_handle is None or storage_size_bytes == 0:
|
| 174 |
+
storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True)
|
| 175 |
+
else:
|
| 176 |
+
storage = storage_from_cache(
|
| 177 |
+
storage_cls, (storage_handle, storage_offset_bytes)
|
| 178 |
+
)
|
| 179 |
+
if storage is None:
|
| 180 |
+
torch.cuda._lazy_init()
|
| 181 |
+
storage = storage_cls._new_shared_cuda(
|
| 182 |
+
storage_device,
|
| 183 |
+
storage_handle,
|
| 184 |
+
storage_size_bytes,
|
| 185 |
+
storage_offset_bytes,
|
| 186 |
+
ref_counter_handle,
|
| 187 |
+
ref_counter_offset,
|
| 188 |
+
event_handle,
|
| 189 |
+
event_sync_required,
|
| 190 |
+
)
|
| 191 |
+
shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(
|
| 192 |
+
storage
|
| 193 |
+
)
|
| 194 |
+
else:
|
| 195 |
+
# We already ref counting this Storage, but producer needs new ref-counters to be released.
|
| 196 |
+
storage_cls._release_ipc_counter(
|
| 197 |
+
ref_counter_handle, ref_counter_offset, device=storage_device
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
_storage = (
|
| 201 |
+
storage
|
| 202 |
+
if isinstance(storage, torch.UntypedStorage)
|
| 203 |
+
else storage._untyped_storage
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
t = torch._utils._rebuild_tensor(
|
| 207 |
+
torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True),
|
| 208 |
+
tensor_offset,
|
| 209 |
+
tensor_size,
|
| 210 |
+
tensor_stride,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
if tensor_cls == torch.nn.parameter.Parameter:
|
| 214 |
+
# It is crucial for integer tensors to receive
|
| 215 |
+
# the requires_grad=False as an argument in the constructor
|
| 216 |
+
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
|
| 217 |
+
else:
|
| 218 |
+
t.requires_grad = requires_grad
|
| 219 |
+
|
| 220 |
+
return t
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def reduce_tensor(tensor):
|
| 224 |
+
if tensor.requires_grad and not tensor.is_leaf:
|
| 225 |
+
raise RuntimeError(
|
| 226 |
+
"Cowardly refusing to serialize non-leaf tensor which requires_grad, "
|
| 227 |
+
"since autograd does not support crossing process boundaries. "
|
| 228 |
+
"If you just want to transfer the data, call detach() on the tensor "
|
| 229 |
+
"before serializing (e.g., putting it on the queue)."
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
check_serializing_named_tensor(tensor)
|
| 233 |
+
torch.utils.hooks.warn_if_has_hooks(tensor)
|
| 234 |
+
|
| 235 |
+
# Note [CUDA IPC and the caching allocator]
|
| 236 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 237 |
+
# When you send a CUDA tensor over IPC, you might expect that you will
|
| 238 |
+
# get out the same storage from the other end. However, the CUDA caching
|
| 239 |
+
# allocator makes it difficult to preserve this invariant. Consider
|
| 240 |
+
# the following situation: a tensor of size 0x100 points to offset 0x20 of
|
| 241 |
+
# a storage at 0xA100 of size 0x100. (For simplicity, all of these
|
| 242 |
+
# sizes are given in bytes). HOWEVER, with the caching allocator, this storage
|
| 243 |
+
# might be part of a larger cudaMalloc allocation 0xA000 of size 0x4000.
|
| 244 |
+
#
|
| 245 |
+
# When we want to send this CUDA tensor over IPC, we must send the
|
| 246 |
+
# *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just
|
| 247 |
+
# the storage 0xA100 (because that is what CUDA supports). So, on the
|
| 248 |
+
# other end, there simply isn't any way to say, "Wait, you gave me
|
| 249 |
+
# a bigger region (0xA000) than the one I wanted (0xA100)".
|
| 250 |
+
#
|
| 251 |
+
# OK, so if you sent the cudaMalloc allocation, can you just wrap that up as
|
| 252 |
+
# one storage itself? No, because this cudaMalloc allocation might contain
|
| 253 |
+
# storages of mixed types: float, bytes, double... If you make the entire
|
| 254 |
+
# allocation a single storage of a type A, we'll hit an error when constructing
|
| 255 |
+
# a tensor of type B on the storage.
|
| 256 |
+
#
|
| 257 |
+
# cudaIpcMemHandle is an identifier to access the sender cudaMalloc allocation on the
|
| 258 |
+
# receiver side. However, cudaIpcMemHandles from each device in a given process may
|
| 259 |
+
# only be opened by one context per device per other process.
|
| 260 |
+
# If we open and close a memory handle multiples times in a process, CUDA is allowed
|
| 261 |
+
# to give it a different address; similarly, once we close the memory, we're not
|
| 262 |
+
# allowed to access it(and the storage/tensor built on top of it), even if it is
|
| 263 |
+
# still live in the original process. As we cannot make a cudaMalloc allocation
|
| 264 |
+
# to a single storage in one go, this requires us to cache the device pointer for
|
| 265 |
+
# each cudaIpcMemHandle on C++ side to reconstruct types of storages, while keep
|
| 266 |
+
# the old ones alives.
|
| 267 |
+
# See [https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html]
|
| 268 |
+
#
|
| 269 |
+
# This is fine, because all we need to do is to save our position in the allocation,
|
| 270 |
+
# and reconstruct storage and tensor from it.
|
| 271 |
+
# 0xA000 -> -------CUDA Allocation------
|
| 272 |
+
# | |
|
| 273 |
+
# | |
|
| 274 |
+
# | |
|
| 275 |
+
# | |
|
| 276 |
+
# 0xA100 -> --------storage1 begin------
|
| 277 |
+
# | |
|
| 278 |
+
# 0xA120 -> --------tensor1 begin ------
|
| 279 |
+
# | |
|
| 280 |
+
# | |
|
| 281 |
+
# | |
|
| 282 |
+
# | |
|
| 283 |
+
# | |
|
| 284 |
+
# 0xA160 -> --------tensor1 end---------
|
| 285 |
+
# | |
|
| 286 |
+
# | |
|
| 287 |
+
# | |
|
| 288 |
+
# 0xA200 -> --------storage1 end--------
|
| 289 |
+
# | |
|
| 290 |
+
# 0xE000 -> --------CUDA allocation-----
|
| 291 |
+
#
|
| 292 |
+
# To send tensor1, the following info are required from sender to receiver for
|
| 293 |
+
# storage recontruction.
|
| 294 |
+
# 1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process).
|
| 295 |
+
# basePtr may not be exactly 0xA000 since it's a different process.
|
| 296 |
+
# 2. offset(0xA100) of storage1 in the CUDA allocation.
|
| 297 |
+
# 3. size of storage1(0x100).
|
| 298 |
+
#
|
| 299 |
+
# On receiver side:
|
| 300 |
+
# 1. Get the devPtr of the MemHandle to access the memory, reconstruct a storage
|
| 301 |
+
# of the same type using (basePtr, offset, size).
|
| 302 |
+
# 2. we can reconstruct the tensor on top of the reconstructed storage
|
| 303 |
+
# Tensor(size=0x040, offset=0x020, storage=Storage(data=basePtr+0xA100, size=0x0100))
|
| 304 |
+
#
|
| 305 |
+
# This strategy has a few implications:
|
| 306 |
+
#
|
| 307 |
+
# 1. When we serialize a CUDA tensor for IPC, we cannot do it all in one
|
| 308 |
+
# go (non-compositionally), and this requires to have a global map
|
| 309 |
+
# memHandle -> devPtr for each process.
|
| 310 |
+
#
|
| 311 |
+
# 2. We MUST NOT let the new IPC tensor be resizable. Originally, a resize
|
| 312 |
+
# of the storage beyond 0x100 would merely have caused us to do a
|
| 313 |
+
# reallocation. You don't really want to do this, but if you did,
|
| 314 |
+
# all that would happen is that you would lose IPC sharing. But if
|
| 315 |
+
# you do this in the new world, we will happily let you write out of
|
| 316 |
+
# bounds of your "allocation", clobbering unrelated data in the cached
|
| 317 |
+
# allocator block. BAD!
|
| 318 |
+
#
|
| 319 |
+
# By the way, in old versions of PyTorch, we supported this situation
|
| 320 |
+
# natively using a "storage view", which permitted multiple storages to be
|
| 321 |
+
# views on each other. But this was the *only* use of storage views, so we
|
| 322 |
+
# eliminated it so that we could just use tensor views to implement the same
|
| 323 |
+
# thing.
|
| 324 |
+
#
|
| 325 |
+
|
| 326 |
+
# TODO: Handle distinguishing between subclass and non-subclass versions of NT better
|
| 327 |
+
# https://github.com/pytorch/pytorch/issues/110543
|
| 328 |
+
from torch.nested._internal.nested_tensor import NestedTensor
|
| 329 |
+
|
| 330 |
+
if tensor.is_nested and not isinstance(tensor, NestedTensor):
|
| 331 |
+
return reduce_nested_tensor(tensor)
|
| 332 |
+
|
| 333 |
+
if tensor.layout in {
|
| 334 |
+
torch.sparse_coo,
|
| 335 |
+
torch.sparse_csr,
|
| 336 |
+
torch.sparse_bsr,
|
| 337 |
+
torch.sparse_csc,
|
| 338 |
+
torch.sparse_bsc,
|
| 339 |
+
}:
|
| 340 |
+
return reduce_sparse_tensor(tensor)
|
| 341 |
+
|
| 342 |
+
storage = tensor._typed_storage()
|
| 343 |
+
|
| 344 |
+
if storage._untyped_storage.device.type == "cuda":
|
| 345 |
+
(
|
| 346 |
+
device,
|
| 347 |
+
handle,
|
| 348 |
+
storage_size_bytes,
|
| 349 |
+
storage_offset_bytes,
|
| 350 |
+
ref_counter_handle,
|
| 351 |
+
ref_counter_offset,
|
| 352 |
+
event_handle,
|
| 353 |
+
event_sync_required,
|
| 354 |
+
) = storage._share_cuda_()
|
| 355 |
+
tensor_offset = tensor.storage_offset()
|
| 356 |
+
shared_cache[handle] = StorageWeakRef(storage)
|
| 357 |
+
# _backward_hooks purposely omitted here, see
|
| 358 |
+
# Note [Don't serialize hooks]
|
| 359 |
+
return (
|
| 360 |
+
rebuild_cuda_tensor,
|
| 361 |
+
(
|
| 362 |
+
type(tensor),
|
| 363 |
+
tensor.size(),
|
| 364 |
+
tensor.stride(),
|
| 365 |
+
tensor_offset, # tensor offset in its storage
|
| 366 |
+
type(storage),
|
| 367 |
+
tensor.dtype,
|
| 368 |
+
device,
|
| 369 |
+
handle, # identifier which CUDA allocation is the storage in.
|
| 370 |
+
storage_size_bytes, # size(in bytes) of the storage
|
| 371 |
+
storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation
|
| 372 |
+
tensor.requires_grad,
|
| 373 |
+
ref_counter_handle,
|
| 374 |
+
ref_counter_offset,
|
| 375 |
+
event_handle,
|
| 376 |
+
event_sync_required,
|
| 377 |
+
),
|
| 378 |
+
)
|
| 379 |
+
elif storage._untyped_storage.device.type == "meta":
|
| 380 |
+
return (
|
| 381 |
+
rebuild_meta_tensor,
|
| 382 |
+
(
|
| 383 |
+
type(tensor),
|
| 384 |
+
tensor.size(),
|
| 385 |
+
tensor.stride(),
|
| 386 |
+
tensor.storage_offset(),
|
| 387 |
+
tensor.dtype,
|
| 388 |
+
tensor.untyped_storage().size(),
|
| 389 |
+
tensor.requires_grad,
|
| 390 |
+
),
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# _backward_hooks purposely omitted here, see Note [Don't serialize hooks]
|
| 394 |
+
metadata = (
|
| 395 |
+
tensor.storage_offset(),
|
| 396 |
+
tensor.size(),
|
| 397 |
+
tensor.stride(),
|
| 398 |
+
tensor.requires_grad,
|
| 399 |
+
)
|
| 400 |
+
return (rebuild_tensor, (type(tensor), storage, metadata))
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def rebuild_nested_tensor(
|
| 404 |
+
rebuild_buffer_func,
|
| 405 |
+
rebuild_buffer_args,
|
| 406 |
+
rebuild_sizes_func,
|
| 407 |
+
rebuild_sizes_args,
|
| 408 |
+
rebuild_strides_func,
|
| 409 |
+
rebuild_strides_args,
|
| 410 |
+
rebuild_offsets_func,
|
| 411 |
+
rebuild_offsets_args,
|
| 412 |
+
):
|
| 413 |
+
buffer = rebuild_buffer_func(*rebuild_buffer_args)
|
| 414 |
+
sizes = rebuild_sizes_func(*rebuild_sizes_args)
|
| 415 |
+
strides = rebuild_strides_func(*rebuild_strides_args)
|
| 416 |
+
offsets = rebuild_offsets_func(*rebuild_offsets_args)
|
| 417 |
+
return torch._nested_view_from_buffer_copy(buffer, sizes, strides, offsets)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def reduce_nested_tensor(nt):
|
| 421 |
+
rebuild_buffer_func, rebuild_buffer_args = reduce_tensor(nt.values())
|
| 422 |
+
rebuild_sizes_func, rebuild_sizes_args = reduce_tensor(nt._nested_tensor_size())
|
| 423 |
+
rebuild_strides_func, rebuild_strides_args = reduce_tensor(
|
| 424 |
+
nt._nested_tensor_strides()
|
| 425 |
+
)
|
| 426 |
+
rebuild_offsets_func, rebuild_offsets_args = reduce_tensor(
|
| 427 |
+
nt._nested_tensor_storage_offsets()
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
return (
|
| 431 |
+
rebuild_nested_tensor,
|
| 432 |
+
(
|
| 433 |
+
rebuild_buffer_func,
|
| 434 |
+
rebuild_buffer_args,
|
| 435 |
+
rebuild_sizes_func,
|
| 436 |
+
rebuild_sizes_args,
|
| 437 |
+
rebuild_strides_func,
|
| 438 |
+
rebuild_strides_args,
|
| 439 |
+
rebuild_offsets_func,
|
| 440 |
+
rebuild_offsets_args,
|
| 441 |
+
),
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def rebuild_sparse_coo_tensor(
|
| 446 |
+
rebuild_indices_func,
|
| 447 |
+
rebuild_indices_args,
|
| 448 |
+
rebuild_values_func,
|
| 449 |
+
rebuild_values_args,
|
| 450 |
+
shape,
|
| 451 |
+
is_coalesced,
|
| 452 |
+
):
|
| 453 |
+
indices = rebuild_indices_func(*rebuild_indices_args)
|
| 454 |
+
values = rebuild_values_func(*rebuild_values_args)
|
| 455 |
+
return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def rebuild_sparse_compressed_tensor(
|
| 459 |
+
rebuild_compressed_indices_func,
|
| 460 |
+
rebuild_compressed_indices_args,
|
| 461 |
+
rebuild_plain_indices_func,
|
| 462 |
+
rebuild_plain_indices_args,
|
| 463 |
+
rebuild_values_func,
|
| 464 |
+
rebuild_values_args,
|
| 465 |
+
shape,
|
| 466 |
+
layout,
|
| 467 |
+
):
|
| 468 |
+
compressed_indices = rebuild_compressed_indices_func(
|
| 469 |
+
*rebuild_compressed_indices_args
|
| 470 |
+
)
|
| 471 |
+
plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args)
|
| 472 |
+
values = rebuild_values_func(*rebuild_values_args)
|
| 473 |
+
return torch.sparse_compressed_tensor(
|
| 474 |
+
compressed_indices, plain_indices, values, shape, layout=layout
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def reduce_sparse_tensor(sparse):
|
| 479 |
+
if sparse.layout is torch.sparse_coo:
|
| 480 |
+
rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices())
|
| 481 |
+
rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values())
|
| 482 |
+
return (
|
| 483 |
+
rebuild_sparse_coo_tensor,
|
| 484 |
+
(
|
| 485 |
+
rebuild_indices_func,
|
| 486 |
+
rebuild_indices_args,
|
| 487 |
+
rebuild_values_func,
|
| 488 |
+
rebuild_values_args,
|
| 489 |
+
sparse.shape,
|
| 490 |
+
sparse.is_coalesced(),
|
| 491 |
+
),
|
| 492 |
+
)
|
| 493 |
+
else:
|
| 494 |
+
if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
| 495 |
+
compressed_indices = sparse.crow_indices()
|
| 496 |
+
plain_indices = sparse.col_indices()
|
| 497 |
+
elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}:
|
| 498 |
+
compressed_indices = sparse.ccol_indices()
|
| 499 |
+
plain_indices = sparse.row_indices()
|
| 500 |
+
else:
|
| 501 |
+
raise NotImplementedError(sparse.layout)
|
| 502 |
+
(
|
| 503 |
+
rebuild_compressed_indices_func,
|
| 504 |
+
rebuild_compressed_indices_args,
|
| 505 |
+
) = reduce_tensor(compressed_indices)
|
| 506 |
+
rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor(
|
| 507 |
+
plain_indices
|
| 508 |
+
)
|
| 509 |
+
rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values())
|
| 510 |
+
return (
|
| 511 |
+
rebuild_sparse_compressed_tensor,
|
| 512 |
+
(
|
| 513 |
+
rebuild_compressed_indices_func,
|
| 514 |
+
rebuild_compressed_indices_args,
|
| 515 |
+
rebuild_plain_indices_func,
|
| 516 |
+
rebuild_plain_indices_args,
|
| 517 |
+
rebuild_values_func,
|
| 518 |
+
rebuild_values_args,
|
| 519 |
+
sparse.shape,
|
| 520 |
+
sparse.layout,
|
| 521 |
+
),
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def fd_id(fd):
|
| 526 |
+
# Returns a tuple which uniquely identifies a file descriptor. In Mac OS,
|
| 527 |
+
# this doesn't work with shared memory handles, which is why we don't
|
| 528 |
+
# support the "file_descriptor" sharing method on that platform.
|
| 529 |
+
stat = os.fstat(fd)
|
| 530 |
+
return (stat.st_ino, stat.st_dev)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def storage_from_cache(cls, key):
|
| 534 |
+
storage_ref = shared_cache.get(key)
|
| 535 |
+
if storage_ref is None:
|
| 536 |
+
return None
|
| 537 |
+
return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def rebuild_storage_fd(cls, df, size):
|
| 541 |
+
fd = df.detach()
|
| 542 |
+
try:
|
| 543 |
+
storage = storage_from_cache(cls, fd_id(fd))
|
| 544 |
+
if storage is not None:
|
| 545 |
+
return storage
|
| 546 |
+
storage = cls._new_shared_fd_cpu(fd, size)
|
| 547 |
+
shared_cache[fd_id(fd)] = StorageWeakRef(storage)
|
| 548 |
+
return storage
|
| 549 |
+
finally:
|
| 550 |
+
os.close(fd)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
|
| 554 |
+
storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache(
|
| 555 |
+
cls, handle
|
| 556 |
+
)
|
| 557 |
+
if storage is not None:
|
| 558 |
+
return storage._shared_decref()
|
| 559 |
+
if dtype is None:
|
| 560 |
+
storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size)
|
| 561 |
+
else:
|
| 562 |
+
byte_size = size * torch._utils._element_size(dtype)
|
| 563 |
+
untyped_storage: torch.UntypedStorage = (
|
| 564 |
+
torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
|
| 565 |
+
)
|
| 566 |
+
storage = torch.TypedStorage(
|
| 567 |
+
wrap_storage=untyped_storage, dtype=dtype, _internal=True
|
| 568 |
+
)
|
| 569 |
+
shared_cache[handle] = StorageWeakRef(storage)
|
| 570 |
+
return storage._shared_decref()
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
def rebuild_storage_empty(cls):
|
| 574 |
+
return cls()
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def rebuild_typed_storage(storage, dtype):
|
| 578 |
+
return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
# Use for torch.storage.TypedStorage
|
| 582 |
+
def reduce_typed_storage(storage):
|
| 583 |
+
return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype))
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def rebuild_typed_storage_child(storage, storage_type):
|
| 587 |
+
return storage_type(wrap_storage=storage, _internal=True)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage
|
| 591 |
+
def reduce_typed_storage_child(storage):
|
| 592 |
+
return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage)))
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def reduce_storage(storage):
|
| 596 |
+
from . import get_sharing_strategy
|
| 597 |
+
|
| 598 |
+
if storage.is_cuda:
|
| 599 |
+
raise RuntimeError(
|
| 600 |
+
"Cannot pickle CUDA storage; try pickling a CUDA tensor instead"
|
| 601 |
+
)
|
| 602 |
+
elif storage.device.type == "meta":
|
| 603 |
+
raise RuntimeError(
|
| 604 |
+
"Cannot pickle meta storage; try pickling a meta tensor instead"
|
| 605 |
+
)
|
| 606 |
+
elif get_sharing_strategy() == "file_system":
|
| 607 |
+
metadata = storage._share_filename_cpu_()
|
| 608 |
+
cache_key = metadata[1]
|
| 609 |
+
rebuild = rebuild_storage_filename
|
| 610 |
+
if isinstance(storage, torch.TypedStorage):
|
| 611 |
+
metadata += (storage.dtype,)
|
| 612 |
+
storage._shared_incref()
|
| 613 |
+
elif storage.size() == 0:
|
| 614 |
+
# This is special cased because Empty tensors
|
| 615 |
+
# (with size 0) cannot be mmapped.
|
| 616 |
+
return (rebuild_storage_empty, (type(storage),))
|
| 617 |
+
else:
|
| 618 |
+
fd, size = storage._share_fd_cpu_()
|
| 619 |
+
df = multiprocessing.reduction.DupFd(fd)
|
| 620 |
+
cache_key = fd_id(fd)
|
| 621 |
+
metadata = (df, size)
|
| 622 |
+
rebuild = rebuild_storage_fd # type: ignore[assignment]
|
| 623 |
+
|
| 624 |
+
shared_cache[cache_key] = StorageWeakRef(storage)
|
| 625 |
+
return (rebuild, (type(storage),) + metadata)
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def init_reductions():
|
| 629 |
+
ForkingPickler.register(torch.cuda.Event, reduce_event)
|
| 630 |
+
|
| 631 |
+
for t in torch._storage_classes:
|
| 632 |
+
if t.__name__ == "UntypedStorage":
|
| 633 |
+
ForkingPickler.register(t, reduce_storage)
|
| 634 |
+
else:
|
| 635 |
+
ForkingPickler.register(t, reduce_typed_storage_child)
|
| 636 |
+
|
| 637 |
+
ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage)
|
| 638 |
+
|
| 639 |
+
for t in torch._tensor_classes:
|
| 640 |
+
ForkingPickler.register(t, reduce_tensor)
|
| 641 |
+
|
| 642 |
+
# TODO: Maybe this should be in tensor_classes? :)
|
| 643 |
+
ForkingPickler.register(torch.Tensor, reduce_tensor)
|
| 644 |
+
|
| 645 |
+
from torch.nn.parameter import Parameter
|
| 646 |
+
|
| 647 |
+
ForkingPickler.register(Parameter, reduce_tensor)
|
.venv/lib/python3.11/site-packages/torch/multiprocessing/spawn.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import logging
|
| 3 |
+
import multiprocessing
|
| 4 |
+
import multiprocessing.connection
|
| 5 |
+
import os
|
| 6 |
+
import pickle
|
| 7 |
+
import signal
|
| 8 |
+
import sys
|
| 9 |
+
import tempfile
|
| 10 |
+
import time
|
| 11 |
+
import warnings
|
| 12 |
+
from concurrent.futures import as_completed, ThreadPoolExecutor
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START"
|
| 19 |
+
|
| 20 |
+
log = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"ProcessContext",
|
| 24 |
+
"ProcessException",
|
| 25 |
+
"ProcessExitedException",
|
| 26 |
+
"ProcessRaisedException",
|
| 27 |
+
"spawn",
|
| 28 |
+
"SpawnContext",
|
| 29 |
+
"start_processes",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ProcessException(Exception):
|
| 34 |
+
__slots__ = ["error_index", "error_pid"]
|
| 35 |
+
|
| 36 |
+
def __init__(self, msg: str, error_index: int, pid: int):
|
| 37 |
+
super().__init__(msg)
|
| 38 |
+
self.msg = msg
|
| 39 |
+
self.error_index = error_index
|
| 40 |
+
self.pid = pid
|
| 41 |
+
|
| 42 |
+
def __reduce__(self):
|
| 43 |
+
return type(self), (self.msg, self.error_index, self.pid)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ProcessRaisedException(ProcessException):
|
| 47 |
+
"""Exception raised when a process failed due to an exception raised by the code."""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
msg: str,
|
| 52 |
+
error_index: int,
|
| 53 |
+
error_pid: int,
|
| 54 |
+
):
|
| 55 |
+
super().__init__(msg, error_index, error_pid)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ProcessExitedException(ProcessException):
|
| 59 |
+
"""Exception raised when a process failed due to signal or exited with a specific code."""
|
| 60 |
+
|
| 61 |
+
__slots__ = ["exit_code"]
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
msg: str,
|
| 66 |
+
error_index: int,
|
| 67 |
+
error_pid: int,
|
| 68 |
+
exit_code: int,
|
| 69 |
+
signal_name: Optional[str] = None,
|
| 70 |
+
):
|
| 71 |
+
super().__init__(msg, error_index, error_pid)
|
| 72 |
+
self.exit_code = exit_code
|
| 73 |
+
self.signal_name = signal_name
|
| 74 |
+
|
| 75 |
+
def __reduce__(self):
|
| 76 |
+
return (
|
| 77 |
+
type(self),
|
| 78 |
+
(self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _wrap(fn, i, args, error_file):
|
| 83 |
+
# prctl(2) is a Linux specific system call.
|
| 84 |
+
# On other systems the following function call has no effect.
|
| 85 |
+
# This is set to ensure that non-daemonic child processes can
|
| 86 |
+
# terminate if their parent terminates before they do.
|
| 87 |
+
_prctl_pr_set_pdeathsig(signal.SIGINT)
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
fn(i, *args)
|
| 91 |
+
except KeyboardInterrupt:
|
| 92 |
+
pass # SIGINT; Killed by parent, do nothing
|
| 93 |
+
except Exception:
|
| 94 |
+
# Propagate exception to parent process, keeping original traceback
|
| 95 |
+
import traceback
|
| 96 |
+
|
| 97 |
+
with open(error_file, "wb") as fh:
|
| 98 |
+
pickle.dump(traceback.format_exc(), fh)
|
| 99 |
+
sys.exit(1)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ProcessContext:
|
| 103 |
+
def __init__(self, processes, error_files):
|
| 104 |
+
self.error_files = error_files
|
| 105 |
+
self.processes = processes
|
| 106 |
+
self.sentinels = {
|
| 107 |
+
process.sentinel: index for index, process in enumerate(processes)
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
def pids(self):
|
| 111 |
+
return [int(process.pid) for process in self.processes]
|
| 112 |
+
|
| 113 |
+
def join(self, timeout=None):
|
| 114 |
+
r"""Join one or more processes within spawn context.
|
| 115 |
+
|
| 116 |
+
Attempt to join one or more processes in this spawn context.
|
| 117 |
+
If one of them exited with a non-zero exit status, this function
|
| 118 |
+
kills the remaining processes and raises an exception with the cause
|
| 119 |
+
of the first process exiting.
|
| 120 |
+
|
| 121 |
+
Returns ``True`` if all processes have been joined successfully,
|
| 122 |
+
``False`` if there are more processes that need to be joined.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
timeout (float): Wait this long before giving up on waiting.
|
| 126 |
+
"""
|
| 127 |
+
# Ensure this function can be called even when we're done.
|
| 128 |
+
if len(self.sentinels) == 0:
|
| 129 |
+
return True
|
| 130 |
+
|
| 131 |
+
# Wait for any process to fail or all of them to succeed.
|
| 132 |
+
ready = multiprocessing.connection.wait(
|
| 133 |
+
self.sentinels.keys(),
|
| 134 |
+
timeout=timeout,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
error_index = None
|
| 138 |
+
for sentinel in ready:
|
| 139 |
+
index = self.sentinels.pop(sentinel)
|
| 140 |
+
process = self.processes[index]
|
| 141 |
+
process.join()
|
| 142 |
+
if process.exitcode != 0:
|
| 143 |
+
error_index = index
|
| 144 |
+
break
|
| 145 |
+
|
| 146 |
+
# Return if there was no error.
|
| 147 |
+
if error_index is None:
|
| 148 |
+
# Return whether or not all processes have been joined.
|
| 149 |
+
return len(self.sentinels) == 0
|
| 150 |
+
|
| 151 |
+
# Assume failure. Terminate processes that are still alive.
|
| 152 |
+
# Try SIGTERM then SIGKILL if the process isn't going down.
|
| 153 |
+
# The reason is related to python signal handling is limited
|
| 154 |
+
# to main thread and if that is in c/c++ land and stuck it won't
|
| 155 |
+
# to handle it. We have seen processes getting stuck not handling
|
| 156 |
+
# SIGTERM for the above reason.
|
| 157 |
+
timeout: int = 30
|
| 158 |
+
for process in self.processes:
|
| 159 |
+
if process.is_alive():
|
| 160 |
+
log.warning("Terminating process %s via signal SIGTERM", process.pid)
|
| 161 |
+
process.terminate()
|
| 162 |
+
end = time.monotonic() + timeout
|
| 163 |
+
for process in self.processes:
|
| 164 |
+
time_to_wait = max(0, end - time.monotonic())
|
| 165 |
+
process.join(time_to_wait)
|
| 166 |
+
for process in self.processes:
|
| 167 |
+
if process.is_alive():
|
| 168 |
+
log.warning(
|
| 169 |
+
"Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL",
|
| 170 |
+
process.pid,
|
| 171 |
+
)
|
| 172 |
+
process.kill()
|
| 173 |
+
process.join()
|
| 174 |
+
|
| 175 |
+
# The file will only be created if the process crashed.
|
| 176 |
+
failed_process = self.processes[error_index]
|
| 177 |
+
if not os.access(self.error_files[error_index], os.R_OK):
|
| 178 |
+
exitcode = self.processes[error_index].exitcode
|
| 179 |
+
if exitcode < 0:
|
| 180 |
+
try:
|
| 181 |
+
name = signal.Signals(-exitcode).name
|
| 182 |
+
except ValueError:
|
| 183 |
+
name = f"<Unknown signal {-exitcode}>"
|
| 184 |
+
raise ProcessExitedException(
|
| 185 |
+
"process %d terminated with signal %s" % (error_index, name),
|
| 186 |
+
error_index=error_index,
|
| 187 |
+
error_pid=failed_process.pid,
|
| 188 |
+
exit_code=exitcode,
|
| 189 |
+
signal_name=name,
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
raise ProcessExitedException(
|
| 193 |
+
"process %d terminated with exit code %d" % (error_index, exitcode),
|
| 194 |
+
error_index=error_index,
|
| 195 |
+
error_pid=failed_process.pid,
|
| 196 |
+
exit_code=exitcode,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
with open(self.error_files[error_index], "rb") as fh:
|
| 200 |
+
original_trace = pickle.load(fh)
|
| 201 |
+
msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
|
| 202 |
+
msg += original_trace
|
| 203 |
+
raise ProcessRaisedException(msg, error_index, failed_process.pid)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class SpawnContext(ProcessContext):
|
| 207 |
+
def __init__(self, processes, error_files):
|
| 208 |
+
warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.")
|
| 209 |
+
super().__init__(processes, error_files)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# Note: [start_processes]
|
| 213 |
+
# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a
|
| 214 |
+
# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the
|
| 215 |
+
# CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork'
|
| 216 |
+
# works better than 'spawn'. Every helper function we created for mp.spawn is indeed
|
| 217 |
+
# general enough, and backends like XLA can reuse them in Colab notebooks as well.
|
| 218 |
+
# Currently we only add this API first, we can consider adding it to documentation as
|
| 219 |
+
# needed in the future.
|
| 220 |
+
def start_processes(
|
| 221 |
+
fn,
|
| 222 |
+
args=(),
|
| 223 |
+
nprocs=1,
|
| 224 |
+
join=True,
|
| 225 |
+
daemon=False,
|
| 226 |
+
start_method="spawn",
|
| 227 |
+
):
|
| 228 |
+
# To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
|
| 229 |
+
# this func will start processes in parallel if start_method is 'forkserver'.
|
| 230 |
+
# Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1.
|
| 231 |
+
# todo: investigate why spawn does not work with threadpool and raises SIGINT
|
| 232 |
+
if (
|
| 233 |
+
start_method == "forkserver"
|
| 234 |
+
and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1"
|
| 235 |
+
):
|
| 236 |
+
log.info("Starting processes in parallel.")
|
| 237 |
+
start_parallel = True
|
| 238 |
+
else:
|
| 239 |
+
# Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
|
| 240 |
+
start_parallel = False
|
| 241 |
+
|
| 242 |
+
mp = multiprocessing.get_context(start_method)
|
| 243 |
+
error_files = [None] * nprocs
|
| 244 |
+
processes = [None] * nprocs
|
| 245 |
+
|
| 246 |
+
def start_process(i):
|
| 247 |
+
# Each process is assigned a file to write tracebacks to. We
|
| 248 |
+
# use the file being non-empty to indicate an exception
|
| 249 |
+
# occurred (vs an expected shutdown). Note: this previously
|
| 250 |
+
# used a multiprocessing.Queue but that can be prone to
|
| 251 |
+
# deadlocks, so we went with a simpler solution for a one-shot
|
| 252 |
+
# message between processes.
|
| 253 |
+
tf = tempfile.NamedTemporaryFile(
|
| 254 |
+
prefix="pytorch-errorfile-", suffix=".pickle", delete=False
|
| 255 |
+
)
|
| 256 |
+
tf.close()
|
| 257 |
+
os.unlink(tf.name)
|
| 258 |
+
process = mp.Process(
|
| 259 |
+
target=_wrap,
|
| 260 |
+
args=(fn, i, args, tf.name),
|
| 261 |
+
daemon=daemon,
|
| 262 |
+
)
|
| 263 |
+
process.start()
|
| 264 |
+
return i, process, tf.name
|
| 265 |
+
|
| 266 |
+
if not start_parallel:
|
| 267 |
+
for i in range(nprocs):
|
| 268 |
+
idx, process, tf_name = start_process(i)
|
| 269 |
+
error_files[idx] = tf_name
|
| 270 |
+
processes[idx] = process
|
| 271 |
+
else:
|
| 272 |
+
with ThreadPoolExecutor(max_workers=nprocs) as executor:
|
| 273 |
+
futures = [executor.submit(start_process, i) for i in range(nprocs)]
|
| 274 |
+
for fut in as_completed(futures):
|
| 275 |
+
idx, process, tf_name = fut.result()
|
| 276 |
+
# idx and process rank needs to be the same.
|
| 277 |
+
error_files[idx] = tf_name
|
| 278 |
+
processes[idx] = process
|
| 279 |
+
context = ProcessContext(processes, error_files)
|
| 280 |
+
if not join:
|
| 281 |
+
return context
|
| 282 |
+
|
| 283 |
+
# Loop on join until it returns True or raises an exception.
|
| 284 |
+
while not context.join():
|
| 285 |
+
pass
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"):
|
| 289 |
+
r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
|
| 290 |
+
|
| 291 |
+
If one of the processes exits with a non-zero exit status, the
|
| 292 |
+
remaining processes are killed and an exception is raised with the
|
| 293 |
+
cause of termination. In the case an exception was caught in the
|
| 294 |
+
child process, it is forwarded and its traceback is included in
|
| 295 |
+
the exception raised in the parent process.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
fn (function): Function is called as the entrypoint of the
|
| 299 |
+
spawned process. This function must be defined at the top
|
| 300 |
+
level of a module so it can be pickled and spawned. This
|
| 301 |
+
is a requirement imposed by multiprocessing.
|
| 302 |
+
|
| 303 |
+
The function is called as ``fn(i, *args)``, where ``i`` is
|
| 304 |
+
the process index and ``args`` is the passed through tuple
|
| 305 |
+
of arguments.
|
| 306 |
+
|
| 307 |
+
args (tuple): Arguments passed to ``fn``.
|
| 308 |
+
nprocs (int): Number of processes to spawn.
|
| 309 |
+
join (bool): Perform a blocking join on all processes.
|
| 310 |
+
daemon (bool): The spawned processes' daemon flag. If set to True,
|
| 311 |
+
daemonic processes will be created.
|
| 312 |
+
start_method (str): (deprecated) this method will always use ``spawn``
|
| 313 |
+
as the start method. To use a different start method
|
| 314 |
+
use ``start_processes()``.
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
None if ``join`` is ``True``,
|
| 318 |
+
:class:`~ProcessContext` if ``join`` is ``False``
|
| 319 |
+
|
| 320 |
+
"""
|
| 321 |
+
if start_method != "spawn":
|
| 322 |
+
msg = (
|
| 323 |
+
f"This method only supports start_method=spawn (got: {start_method}).\n"
|
| 324 |
+
"To use a different start_method use:\n\t\t"
|
| 325 |
+
" torch.multiprocessing.start_processes(...)"
|
| 326 |
+
)
|
| 327 |
+
warnings.warn(msg, FutureWarning, stacklevel=2)
|
| 328 |
+
return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
|
.venv/lib/python3.11/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (246 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.ao.nn.quantizable.modules.activation import MultiheadAttention
|
| 2 |
+
from torch.ao.nn.quantizable.modules.rnn import LSTM, LSTMCell
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"LSTM",
|
| 7 |
+
"LSTMCell",
|
| 8 |
+
"MultiheadAttention",
|
| 9 |
+
]
|
.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (454 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-311.pyc
ADDED
|
Binary file (669 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-311.pyc
ADDED
|
Binary file (665 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (253 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.37 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc
ADDED
|
Binary file (878 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-311.pyc
ADDED
|
Binary file (679 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import parametrizations, rnn, stateless
|
| 2 |
+
from .clip_grad import clip_grad_norm, clip_grad_norm_, clip_grad_value_
|
| 3 |
+
from .convert_parameters import parameters_to_vector, vector_to_parameters
|
| 4 |
+
from .fusion import (
|
| 5 |
+
fuse_conv_bn_eval,
|
| 6 |
+
fuse_conv_bn_weights,
|
| 7 |
+
fuse_linear_bn_eval,
|
| 8 |
+
fuse_linear_bn_weights,
|
| 9 |
+
)
|
| 10 |
+
from .init import skip_init
|
| 11 |
+
from .memory_format import (
|
| 12 |
+
convert_conv2d_weight_memory_format,
|
| 13 |
+
convert_conv3d_weight_memory_format,
|
| 14 |
+
)
|
| 15 |
+
from .spectral_norm import remove_spectral_norm, spectral_norm
|
| 16 |
+
from .weight_norm import remove_weight_norm, weight_norm
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"clip_grad_norm",
|
| 21 |
+
"clip_grad_norm_",
|
| 22 |
+
"clip_grad_value_",
|
| 23 |
+
"convert_conv2d_weight_memory_format",
|
| 24 |
+
"convert_conv3d_weight_memory_format",
|
| 25 |
+
"fuse_conv_bn_eval",
|
| 26 |
+
"fuse_conv_bn_weights",
|
| 27 |
+
"fuse_linear_bn_eval",
|
| 28 |
+
"fuse_linear_bn_weights",
|
| 29 |
+
"parameters_to_vector",
|
| 30 |
+
"parametrizations",
|
| 31 |
+
"remove_spectral_norm",
|
| 32 |
+
"remove_weight_norm",
|
| 33 |
+
"rnn",
|
| 34 |
+
"skip_init",
|
| 35 |
+
"spectral_norm",
|
| 36 |
+
"stateless",
|
| 37 |
+
"vector_to_parameters",
|
| 38 |
+
"weight_norm",
|
| 39 |
+
]
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.25 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-311.pyc
ADDED
|
Binary file (2.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-311.pyc
ADDED
|
Binary file (19.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-311.pyc
ADDED
|
Binary file (6.85 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-311.pyc
ADDED
|
Binary file (9.91 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-311.pyc
ADDED
|
Binary file (3.61 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/fusion.cpython-311.pyc
ADDED
|
Binary file (7.23 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/init.cpython-311.pyc
ADDED
|
Binary file (2.75 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-311.pyc
ADDED
|
Binary file (8.39 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-311.pyc
ADDED
|
Binary file (26.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-311.pyc
ADDED
|
Binary file (35.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/prune.cpython-311.pyc
ADDED
|
Binary file (59.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/rnn.cpython-311.pyc
ADDED
|
Binary file (28.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-311.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/stateless.cpython-311.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-311.pyc
ADDED
|
Binary file (8.14 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/nn/utils/_deprecation_utils.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import importlib
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import Callable, List
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
_MESSAGE_TEMPLATE = (
|
| 8 |
+
r"Usage of '{old_location}' is deprecated; please use '{new_location}' instead."
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def lazy_deprecated_import(
|
| 13 |
+
all: List[str],
|
| 14 |
+
old_module: str,
|
| 15 |
+
new_module: str,
|
| 16 |
+
) -> Callable:
|
| 17 |
+
r"""Import utility to lazily import deprecated packages / modules / functional.
|
| 18 |
+
|
| 19 |
+
The old_module and new_module are also used in the deprecation warning defined
|
| 20 |
+
by the `_MESSAGE_TEMPLATE`.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
all: The list of the functions that are imported. Generally, the module's
|
| 24 |
+
__all__ list of the module.
|
| 25 |
+
old_module: Old module location
|
| 26 |
+
new_module: New module location / Migrated location
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Callable to assign to the `__getattr__`
|
| 30 |
+
|
| 31 |
+
Usage:
|
| 32 |
+
|
| 33 |
+
# In the `torch/nn/quantized/functional.py`
|
| 34 |
+
from torch.nn.utils._deprecation_utils import lazy_deprecated_import
|
| 35 |
+
_MIGRATED_TO = "torch.ao.nn.quantized.functional"
|
| 36 |
+
__getattr__ = lazy_deprecated_import(
|
| 37 |
+
all=__all__,
|
| 38 |
+
old_module=__name__,
|
| 39 |
+
new_module=_MIGRATED_TO)
|
| 40 |
+
"""
|
| 41 |
+
warning_message = _MESSAGE_TEMPLATE.format(
|
| 42 |
+
old_location=old_module, new_location=new_module
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def getattr_dunder(name):
|
| 46 |
+
if name in all:
|
| 47 |
+
# We are using the "RuntimeWarning" to make sure it is not
|
| 48 |
+
# ignored by default.
|
| 49 |
+
warnings.warn(warning_message, RuntimeWarning)
|
| 50 |
+
package = importlib.import_module(new_module)
|
| 51 |
+
return getattr(package, name)
|
| 52 |
+
raise AttributeError(f"Module {new_module!r} has no attribute {name!r}.")
|
| 53 |
+
|
| 54 |
+
return getattr_dunder
|