Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64-arm.exe +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codecache.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py +1451 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_utils.py +105 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/dependencies.py +506 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/exc.py +98 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_utils.py +220 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/index_propagation.py +277 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/metrics.py +419 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/triton_helpers.py +344 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocator.h +401 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h +61 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSEvent.h +100 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSProfiler.h +393 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h +321 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/EmbeddingBag.h +139 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Fill.h +21 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LossMulti.h +72 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Normalization.h +11 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Pow.h +69 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceOps.h +56 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h +55 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorCompare.h +49 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorIterator.h +2 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TriangularOpsUtils.h +57 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IsContiguous.h +62 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh +296 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GridSampler.h +32 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh +384 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh +379 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Resize.h +61 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Sort.h +17 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adam_amsgrad_impl.cuh +40 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh +38 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh +40 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adamw_impl.cuh +38 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh +680 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/thread_constants.h +22 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/OperationUtils.h +394 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/TensorFactory.h +12 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h +103 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h +130 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/ConvUtils.h +62 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/IndexKernel.h +14 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/PackedParams.h +147 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h +29 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h +527 -0
.gitattributes
CHANGED
|
@@ -77,3 +77,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/_
|
|
| 77 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 filter=lfs diff=lfs merge=lfs -text
|
| 78 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
|
| 79 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 77 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 filter=lfs diff=lfs merge=lfs -text
|
| 78 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
|
| 79 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8 filter=lfs diff=lfs merge=lfs -text
|
| 80 |
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64-arm.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ebc4c06b7d95e74e315419ee7e88e1d0f71e9e9477538c00a93a9ff8c66a6cfc
|
| 3 |
+
size 182784
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-311.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codecache.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py
ADDED
|
@@ -0,0 +1,1451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import functools
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
import warnings
|
| 8 |
+
from itertools import count
|
| 9 |
+
|
| 10 |
+
from typing import (
|
| 11 |
+
Any,
|
| 12 |
+
Callable,
|
| 13 |
+
Dict,
|
| 14 |
+
FrozenSet,
|
| 15 |
+
List,
|
| 16 |
+
Optional,
|
| 17 |
+
Sequence,
|
| 18 |
+
Tuple,
|
| 19 |
+
Union,
|
| 20 |
+
)
|
| 21 |
+
from unittest import mock
|
| 22 |
+
|
| 23 |
+
from functorch.compile import min_cut_rematerialization_partition
|
| 24 |
+
|
| 25 |
+
import torch.fx
|
| 26 |
+
import torch.utils._pytree as pytree
|
| 27 |
+
from torch._dynamo import (
|
| 28 |
+
compiled_autograd,
|
| 29 |
+
config as dynamo_config,
|
| 30 |
+
logging as dynamo_logging,
|
| 31 |
+
utils as dynamo_utils,
|
| 32 |
+
)
|
| 33 |
+
from torch._dynamo.utils import (
|
| 34 |
+
counters,
|
| 35 |
+
detect_fake_mode,
|
| 36 |
+
lazy_format_graph_code,
|
| 37 |
+
optimus_scuba_log,
|
| 38 |
+
)
|
| 39 |
+
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
|
| 40 |
+
from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache
|
| 41 |
+
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
|
| 42 |
+
|
| 43 |
+
from torch._inductor.debug import save_args_for_compile_fx_inner
|
| 44 |
+
from torch._inductor.utils import BoxedBool, count_tangents
|
| 45 |
+
from torch._logging import trace_structured
|
| 46 |
+
from torch._ops import OpOverload
|
| 47 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 48 |
+
from torch._utils_internal import signpost_event
|
| 49 |
+
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
| 50 |
+
|
| 51 |
+
from .._dynamo.backends.common import aot_autograd
|
| 52 |
+
from ..fx._lazy_graph_module import _use_lazy_graph_module # type: ignore[attr-defined]
|
| 53 |
+
from ..fx.graph import _PyTreeCodeGen
|
| 54 |
+
from . import config, metrics
|
| 55 |
+
from .debug import DebugContext
|
| 56 |
+
from .decomposition import select_decomp_table
|
| 57 |
+
from .fx_passes.joint_graph import joint_graph_passes
|
| 58 |
+
from .fx_passes.post_grad import post_grad_passes, view_to_reshape
|
| 59 |
+
from .fx_passes.pre_grad import pre_grad_passes
|
| 60 |
+
from .graph import GraphLowering
|
| 61 |
+
from .ir import ExternKernelNode
|
| 62 |
+
from .utils import get_dtype_size, has_incompatible_cudagraph_ops, output_node
|
| 63 |
+
from .virtualized import V
|
| 64 |
+
|
| 65 |
+
if config.is_fbcode():
|
| 66 |
+
from torch._inductor.fb.utils import time_and_log
|
| 67 |
+
else:
|
| 68 |
+
# no-op decorator
|
| 69 |
+
def time_and_log(attr: str, extra_loggings: Optional[Dict[str, str]] = None):
|
| 70 |
+
return dynamo_utils.identity
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
log = logging.getLogger(__name__)
|
| 74 |
+
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
| 75 |
+
post_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "post_grad_graphs")
|
| 76 |
+
ALIGNMENT = 16
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# copy_ fails when trying to write to tensors with memory overlap,
|
| 80 |
+
# for expanded dimensions (a dimension which used to have size 1 -> ?)
|
| 81 |
+
# we can select one element from that dimension and write to it
|
| 82 |
+
# to achieve writing to all values of that dimension of the input tensor
|
| 83 |
+
def get_expanded_dims(t):
|
| 84 |
+
if not isinstance(t, torch.Tensor):
|
| 85 |
+
return None
|
| 86 |
+
return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor:
|
| 90 |
+
for expanded_dim in expanded_dims:
|
| 91 |
+
t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
|
| 92 |
+
return t
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def complex_memory_overlap(t: torch.Tensor) -> bool:
|
| 96 |
+
# if torch._debug_has_internal_overlap thinks this tensor potentially has
|
| 97 |
+
# memory overlap internally, let's dig deeper to find out whether it's true.
|
| 98 |
+
t = index_expanded_dims(t, get_expanded_dims(t))
|
| 99 |
+
if torch._debug_has_internal_overlap(t) != 0:
|
| 100 |
+
strides = t.stride()
|
| 101 |
+
sizes = t.shape
|
| 102 |
+
indices = list(range(len(strides)))
|
| 103 |
+
indices = [x for _, x in sorted(zip(strides, indices))]
|
| 104 |
+
for i in range(len(strides)):
|
| 105 |
+
prev_stride = 1 if i == 0 else strides[indices[i - 1]]
|
| 106 |
+
prev_size = 1 if i == 0 else sizes[indices[i - 1]]
|
| 107 |
+
if strides[indices[i]] < prev_stride * prev_size:
|
| 108 |
+
return True
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@functools.lru_cache(None)
|
| 113 |
+
def _step_logger():
|
| 114 |
+
return dynamo_logging.get_step_logger(log)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@functools.lru_cache(None)
|
| 118 |
+
def _warn_tf32_disabled():
|
| 119 |
+
if (
|
| 120 |
+
torch.cuda.is_available()
|
| 121 |
+
and not torch.backends.cuda.matmul.allow_tf32
|
| 122 |
+
and torch.cuda.get_device_capability() >= (8, 0)
|
| 123 |
+
):
|
| 124 |
+
warnings.warn(
|
| 125 |
+
"TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. "
|
| 126 |
+
"Consider setting `torch.set_float32_matmul_precision('high')` for better performance."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _unlift_graph(mod, gm, graph_signature):
|
| 131 |
+
from torch.export.unflatten import _assign_attr, _AttrKind
|
| 132 |
+
|
| 133 |
+
state_dict = {}
|
| 134 |
+
for name, param in mod.named_parameters(remove_duplicate=False):
|
| 135 |
+
state_dict[name] = param
|
| 136 |
+
_assign_attr(
|
| 137 |
+
param,
|
| 138 |
+
gm,
|
| 139 |
+
name,
|
| 140 |
+
attr_kind=_AttrKind.PARAMETER,
|
| 141 |
+
)
|
| 142 |
+
for name, buffer in mod.named_buffers(remove_duplicate=False):
|
| 143 |
+
state_dict[name] = buffer
|
| 144 |
+
_assign_attr(
|
| 145 |
+
buffer,
|
| 146 |
+
gm,
|
| 147 |
+
name,
|
| 148 |
+
attr_kind=_AttrKind.BUFFER,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
|
| 152 |
+
lifted_inputs = []
|
| 153 |
+
for node in placeholder_nodes:
|
| 154 |
+
node_name = node.name
|
| 155 |
+
if node_name in graph_signature.inputs_to_parameters:
|
| 156 |
+
lifted_inputs.append(graph_signature.inputs_to_parameters[node_name])
|
| 157 |
+
elif node_name in graph_signature.inputs_to_buffers:
|
| 158 |
+
lifted_inputs.append(graph_signature.inputs_to_buffers[node_name])
|
| 159 |
+
else:
|
| 160 |
+
assert node_name in graph_signature.user_inputs
|
| 161 |
+
lifted_inputs.append(None)
|
| 162 |
+
|
| 163 |
+
from torch.export._unlift import _unlift
|
| 164 |
+
|
| 165 |
+
outputs = list(gm.graph.nodes)[-1].args[0]
|
| 166 |
+
mutated_outputs = []
|
| 167 |
+
for out in outputs:
|
| 168 |
+
if out in graph_signature.buffers_to_mutate:
|
| 169 |
+
mutated_outputs.append(graph_signature.buffers_to_mutate[out.name])
|
| 170 |
+
else:
|
| 171 |
+
mutated_outputs.append(None)
|
| 172 |
+
|
| 173 |
+
unlifted_gm = _unlift(
|
| 174 |
+
gm,
|
| 175 |
+
lifted_inputs,
|
| 176 |
+
mutated_outputs,
|
| 177 |
+
pytree.LeafSpec(),
|
| 178 |
+
None,
|
| 179 |
+
state_dict,
|
| 180 |
+
{},
|
| 181 |
+
)
|
| 182 |
+
return unlifted_gm
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _get_subgraph_names(gm):
|
| 186 |
+
for node in gm.graph.nodes:
|
| 187 |
+
if node.target == torch.ops.higher_order.cond:
|
| 188 |
+
true_subgraph_name = node.args[1].name
|
| 189 |
+
false_subgraph_name = node.args[2].name
|
| 190 |
+
yield true_subgraph_name
|
| 191 |
+
yield false_subgraph_name
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _recursive_pre_grad_passes(gm, example_inputs):
|
| 195 |
+
for subgraph_name in _get_subgraph_names(gm):
|
| 196 |
+
subgraph = getattr(gm, subgraph_name)
|
| 197 |
+
# as we don't have recursive example inputs, passing None here
|
| 198 |
+
new_subgraph = _recursive_pre_grad_passes(subgraph, example_inputs=None)
|
| 199 |
+
setattr(gm, subgraph_name, new_subgraph)
|
| 200 |
+
return pre_grad_passes(gm, example_inputs)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _recursive_joint_graph_passes(gm):
|
| 204 |
+
for subgraph_name in _get_subgraph_names(gm):
|
| 205 |
+
subgraph = getattr(gm, subgraph_name)
|
| 206 |
+
_recursive_joint_graph_passes(subgraph)
|
| 207 |
+
joint_graph_passes(gm)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def _recursive_post_grad_passes(gm, is_inference: bool = False):
|
| 211 |
+
for subgraph_name in _get_subgraph_names(gm):
|
| 212 |
+
subgraph = getattr(gm, subgraph_name)
|
| 213 |
+
_recursive_post_grad_passes(subgraph, is_inference)
|
| 214 |
+
post_grad_passes(gm, is_inference)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def split_const_gm(
|
| 218 |
+
gm: torch.fx.GraphModule,
|
| 219 |
+
) -> Tuple[torch.fx.GraphModule, Dict[str, int]]:
|
| 220 |
+
"""
|
| 221 |
+
This function takes an GraphModule input "gm".
|
| 222 |
+
The gm will be split into 2 components,
|
| 223 |
+
1) const_gm, which consists the subgraph of gm that can be constant folded.
|
| 224 |
+
2) gm (being inplace modified,) which returns the graph after constant folding.
|
| 225 |
+
|
| 226 |
+
const_output_index is a mapping of corresponding node name from gm to the
|
| 227 |
+
output index of const_gm.
|
| 228 |
+
Returns (const_gm, const_output_index)
|
| 229 |
+
"""
|
| 230 |
+
from torch._inductor.constant_folding import (
|
| 231 |
+
CONST_MODULE_TAG,
|
| 232 |
+
META_TAG,
|
| 233 |
+
MODULE_TAG,
|
| 234 |
+
replace_node_with_constant,
|
| 235 |
+
run_and_get_constant_graph,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
const_gm = run_and_get_constant_graph(gm)
|
| 239 |
+
const_result = const_gm()
|
| 240 |
+
|
| 241 |
+
const_outputs = {
|
| 242 |
+
x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0])
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
to_erase_node = []
|
| 246 |
+
to_replace_node = []
|
| 247 |
+
const_output_index = {}
|
| 248 |
+
for node in gm.graph.nodes:
|
| 249 |
+
if node.name in const_outputs:
|
| 250 |
+
to_replace_node.append(node)
|
| 251 |
+
elif node.meta[META_TAG] == CONST_MODULE_TAG:
|
| 252 |
+
to_erase_node.append(node)
|
| 253 |
+
|
| 254 |
+
for node in to_replace_node:
|
| 255 |
+
new_const_name = "_FOLDED_CONST_" + node.name
|
| 256 |
+
replace_node_with_constant(
|
| 257 |
+
gm,
|
| 258 |
+
node,
|
| 259 |
+
const_result[const_outputs[node.name]],
|
| 260 |
+
new_const_name,
|
| 261 |
+
)
|
| 262 |
+
const_output_index[new_const_name] = const_outputs[node.name]
|
| 263 |
+
for node in to_erase_node[::-1]:
|
| 264 |
+
if node.users:
|
| 265 |
+
for n in node.users:
|
| 266 |
+
assert n.meta[META_TAG] == MODULE_TAG, f"node: {node} user not empty."
|
| 267 |
+
else:
|
| 268 |
+
gm.graph.erase_node(node)
|
| 269 |
+
gm.recompile()
|
| 270 |
+
|
| 271 |
+
return const_gm, const_output_index
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def is_tf32_warning_applicable(gm: torch.fx.GraphModule):
|
| 275 |
+
aten = torch.ops.aten
|
| 276 |
+
tf32_ops = {
|
| 277 |
+
aten.mm.default,
|
| 278 |
+
aten.addmm.default,
|
| 279 |
+
aten.bmm.default,
|
| 280 |
+
aten.baddbmm.default,
|
| 281 |
+
}
|
| 282 |
+
for node in gm.graph.nodes:
|
| 283 |
+
if (
|
| 284 |
+
node.op == "call_function"
|
| 285 |
+
and node.target in tf32_ops
|
| 286 |
+
and isinstance(node.meta.get("val", None), torch.Tensor)
|
| 287 |
+
and node.meta["val"].dtype == torch.float32
|
| 288 |
+
and node.meta["val"].device.type == "cuda"
|
| 289 |
+
):
|
| 290 |
+
return True
|
| 291 |
+
return False
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@DebugContext.wrap
|
| 295 |
+
def count_bytes_inner(
|
| 296 |
+
gm: torch.fx.GraphModule,
|
| 297 |
+
example_inputs: List[torch.Tensor],
|
| 298 |
+
num_fixed: int = 0,
|
| 299 |
+
**kwargs,
|
| 300 |
+
):
|
| 301 |
+
shape_env = _shape_env_from_inputs(example_inputs)
|
| 302 |
+
fake_mode = fake_tensor_prop(gm, example_inputs)
|
| 303 |
+
|
| 304 |
+
with V.set_fake_mode(fake_mode):
|
| 305 |
+
_recursive_post_grad_passes(gm, False)
|
| 306 |
+
|
| 307 |
+
graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
|
| 308 |
+
with V.set_graph_handler(graph), V.set_real_inputs(example_inputs):
|
| 309 |
+
graph.run(*example_inputs)
|
| 310 |
+
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
|
| 311 |
+
metrics.num_bytes_accessed += num_bytes
|
| 312 |
+
metrics.nodes_num_elem += nodes_num_elem
|
| 313 |
+
metrics.node_runtimes += node_runtimes
|
| 314 |
+
return make_boxed_func(gm.forward)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def fake_tensor_prop(
|
| 318 |
+
gm: torch.fx.GraphModule,
|
| 319 |
+
example_inputs: List[torch.Tensor],
|
| 320 |
+
force_allow_non_fake_inputs: bool = False,
|
| 321 |
+
):
|
| 322 |
+
"""
|
| 323 |
+
If we can not detect fake mode from the context of inputs, create one.
|
| 324 |
+
|
| 325 |
+
The created fake mode will be returned.
|
| 326 |
+
"""
|
| 327 |
+
fake_mode = detect_fake_mode(example_inputs)
|
| 328 |
+
if not fake_mode:
|
| 329 |
+
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
|
| 330 |
+
FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
|
| 331 |
+
else:
|
| 332 |
+
ctx = (
|
| 333 |
+
contextlib.nullcontext()
|
| 334 |
+
if not force_allow_non_fake_inputs
|
| 335 |
+
else mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
|
| 336 |
+
)
|
| 337 |
+
with ctx: # type: ignore[attr-defined]
|
| 338 |
+
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
|
| 339 |
+
*example_inputs
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
return fake_mode
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# pass config dict back to user
|
| 346 |
+
def get_patched_config_dict(config_patches=None) -> Dict[str, Any]:
|
| 347 |
+
with config.patch(config_patches):
|
| 348 |
+
return config.get_config_copy()
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
@DebugContext.wrap
|
| 352 |
+
@torch.utils._python_dispatch._disable_current_modes()
|
| 353 |
+
@time_and_log(
|
| 354 |
+
attr="compilation time (in seconds)",
|
| 355 |
+
extra_loggings={"config_dict": str(get_patched_config_dict())},
|
| 356 |
+
)
|
| 357 |
+
# Need this decorator for compile_fx_inner even if we already have one for
|
| 358 |
+
# compile_fx. The reason is the compilation for backward graph may happen after
|
| 359 |
+
# compile_fx return and we may want to use the _LazyGraphModule for compiling
|
| 360 |
+
# the backward graph as well.
|
| 361 |
+
@_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)
|
| 362 |
+
@dynamo_utils.dynamo_timed(phase_name="inductor_compile")
|
| 363 |
+
def compile_fx_inner(
|
| 364 |
+
gm: torch.fx.GraphModule,
|
| 365 |
+
example_inputs: List[torch.Tensor],
|
| 366 |
+
cudagraphs: Optional[BoxedBool] = None,
|
| 367 |
+
num_fixed: int = 0,
|
| 368 |
+
is_backward: bool = False,
|
| 369 |
+
graph_id: Optional[int] = None,
|
| 370 |
+
cpp_wrapper: bool = False,
|
| 371 |
+
aot_mode: bool = False,
|
| 372 |
+
is_inference: bool = False,
|
| 373 |
+
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
|
| 374 |
+
user_visible_outputs: FrozenSet[str] = frozenset(),
|
| 375 |
+
layout_opt: Optional[bool] = None,
|
| 376 |
+
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
|
| 377 |
+
) -> Union[CompiledFxGraph, str]:
|
| 378 |
+
"""
|
| 379 |
+
Inductor API that compiles a single graph.
|
| 380 |
+
|
| 381 |
+
If you change the argument list for this function, make sure you
|
| 382 |
+
also update the call to save_args_for_compile_fx_inner below accordingly.
|
| 383 |
+
"""
|
| 384 |
+
if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode:
|
| 385 |
+
# trigger the real recompilation for _LazyGraphModule before returning
|
| 386 |
+
# the forward method.
|
| 387 |
+
from torch.fx._lazy_graph_module import _LazyGraphModule
|
| 388 |
+
|
| 389 |
+
_LazyGraphModule.force_recompile(gm)
|
| 390 |
+
return make_boxed_func(gm.forward)
|
| 391 |
+
|
| 392 |
+
assert isinstance(
|
| 393 |
+
next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)
|
| 394 |
+
), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
|
| 395 |
+
|
| 396 |
+
if config.save_args:
|
| 397 |
+
save_args_for_compile_fx_inner(
|
| 398 |
+
gm,
|
| 399 |
+
example_inputs,
|
| 400 |
+
cudagraphs=cudagraphs,
|
| 401 |
+
num_fixed=num_fixed,
|
| 402 |
+
is_backward=is_backward,
|
| 403 |
+
graph_id=graph_id,
|
| 404 |
+
cpp_wrapper=cpp_wrapper,
|
| 405 |
+
aot_mode=aot_mode,
|
| 406 |
+
is_inference=is_inference,
|
| 407 |
+
boxed_forward_device_index=boxed_forward_device_index,
|
| 408 |
+
user_visible_outputs=user_visible_outputs,
|
| 409 |
+
layout_opt=layout_opt,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
if cudagraphs is None:
|
| 413 |
+
cudagraphs = BoxedBool(config.triton.cudagraphs)
|
| 414 |
+
|
| 415 |
+
# Inputs to fx_codegen_and_compile
|
| 416 |
+
# Anything that affects codegen should go here, so if the signature
|
| 417 |
+
# of fx_codegen_and_compile changes, the dict should be updated accordingly
|
| 418 |
+
graph_kwargs = {
|
| 419 |
+
"cudagraphs": cudagraphs,
|
| 420 |
+
"num_fixed": num_fixed,
|
| 421 |
+
"is_backward": is_backward,
|
| 422 |
+
"graph_id": graph_id,
|
| 423 |
+
"cpp_wrapper": cpp_wrapper,
|
| 424 |
+
"aot_mode": aot_mode,
|
| 425 |
+
"is_inference": is_inference,
|
| 426 |
+
"user_visible_outputs": user_visible_outputs,
|
| 427 |
+
"layout_opt": layout_opt,
|
| 428 |
+
"extern_node_serializer": extern_node_serializer,
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
start = time.time()
|
| 432 |
+
|
| 433 |
+
if config.fx_graph_cache and not aot_mode:
|
| 434 |
+
compiled_graph = FxGraphCache.load(
|
| 435 |
+
fx_codegen_and_compile, gm, example_inputs, graph_kwargs
|
| 436 |
+
)
|
| 437 |
+
else:
|
| 438 |
+
compiled_graph = fx_codegen_and_compile(
|
| 439 |
+
gm, example_inputs, **graph_kwargs # type: ignore[arg-type]
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
log.debug("FX codegen and compilation took %.3fs", time.time() - start)
|
| 443 |
+
|
| 444 |
+
# check cudagraph disabling reasons from inductor lowering
|
| 445 |
+
if cudagraphs and compiled_graph.disabled_cudagraphs_reason:
|
| 446 |
+
perf_hint_log.warning(
|
| 447 |
+
"skipping cudagraphs due to %s", compiled_graph.disabled_cudagraphs_reason
|
| 448 |
+
)
|
| 449 |
+
BoxedBool.disable(cudagraphs)
|
| 450 |
+
|
| 451 |
+
# Return the output strides to the caller via TracingContext
|
| 452 |
+
context = torch._guards.TracingContext.try_get()
|
| 453 |
+
if context is not None and context.output_strides is not None:
|
| 454 |
+
assert len(context.output_strides) == 0
|
| 455 |
+
context.output_strides.extend(compiled_graph.output_strides)
|
| 456 |
+
|
| 457 |
+
if aot_mode:
|
| 458 |
+
return compiled_graph
|
| 459 |
+
|
| 460 |
+
if cudagraphs:
|
| 461 |
+
# output args are tuple of first argument
|
| 462 |
+
output = output_node(gm)
|
| 463 |
+
assert len(output.args) == 1
|
| 464 |
+
stack_traces = [
|
| 465 |
+
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
|
| 466 |
+
for arg in output.args[0]
|
| 467 |
+
]
|
| 468 |
+
|
| 469 |
+
complex_memory_overlap_inputs = any(
|
| 470 |
+
complex_memory_overlap(t)
|
| 471 |
+
for t in example_inputs
|
| 472 |
+
if isinstance(t, torch.Tensor)
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
from torch._inductor.cudagraph_utils import check_for_mutation
|
| 476 |
+
|
| 477 |
+
has_mutation_str = check_for_mutation(gm, compiled_graph, num_fixed)
|
| 478 |
+
has_mutation = has_mutation_str is not None
|
| 479 |
+
|
| 480 |
+
if has_mutation:
|
| 481 |
+
compiled_graph.disabled_cudagraphs_reason = has_mutation_str
|
| 482 |
+
|
| 483 |
+
cudagraph_tests = [
|
| 484 |
+
(not has_mutation, "mutated inputs"),
|
| 485 |
+
(not has_incompatible_cudagraph_ops(gm), "incompatible ops"),
|
| 486 |
+
(not complex_memory_overlap_inputs, "complex memory overlap"),
|
| 487 |
+
(
|
| 488 |
+
all(
|
| 489 |
+
isinstance(t, (torch.Tensor, torch.SymInt)) for t in example_inputs
|
| 490 |
+
),
|
| 491 |
+
"non-Tensor inputs",
|
| 492 |
+
),
|
| 493 |
+
]
|
| 494 |
+
cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
|
| 495 |
+
|
| 496 |
+
if not cudagraph_fail_reasons:
|
| 497 |
+
if not config.triton.cudagraph_trees:
|
| 498 |
+
# Force specialize all inputs so that CUDA graphs will work
|
| 499 |
+
for t in example_inputs:
|
| 500 |
+
if isinstance(t, torch.SymInt):
|
| 501 |
+
int(t) # guard
|
| 502 |
+
|
| 503 |
+
if (
|
| 504 |
+
boxed_forward_device_index is not None
|
| 505 |
+
and not is_inference
|
| 506 |
+
and not is_backward
|
| 507 |
+
):
|
| 508 |
+
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
|
| 509 |
+
|
| 510 |
+
compiled_graph.current_callable = cudagraphify(
|
| 511 |
+
compiled_graph.get_current_callable(),
|
| 512 |
+
example_inputs,
|
| 513 |
+
static_input_idxs=range(num_fixed),
|
| 514 |
+
device_index=next(iter(compiled_graph.device_idxs)),
|
| 515 |
+
stack_traces=stack_traces,
|
| 516 |
+
is_backward=is_backward,
|
| 517 |
+
is_inference=is_inference,
|
| 518 |
+
constants=tuple(compiled_graph.constants.values()),
|
| 519 |
+
)
|
| 520 |
+
else:
|
| 521 |
+
BoxedBool.disable(cudagraphs)
|
| 522 |
+
|
| 523 |
+
# See [Backward Generation Handling]
|
| 524 |
+
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
|
| 525 |
+
# know we are we running the backward even if we will not run it in cudagraphs
|
| 526 |
+
if is_backward and config.triton.cudagraph_trees:
|
| 527 |
+
assert boxed_forward_device_index is not None
|
| 528 |
+
assert boxed_forward_device_index.value is not None
|
| 529 |
+
compiled_graph_callable = compiled_graph.get_current_callable()
|
| 530 |
+
|
| 531 |
+
manager = torch._inductor.cudagraph_trees.get_manager(
|
| 532 |
+
boxed_forward_device_index.value, create_if_none_exists=False
|
| 533 |
+
)
|
| 534 |
+
# should already exist from forward
|
| 535 |
+
assert manager is not None
|
| 536 |
+
|
| 537 |
+
def compiled_artifact(new_inputs):
|
| 538 |
+
manager.set_to_running_backward()
|
| 539 |
+
return compiled_graph_callable(new_inputs)
|
| 540 |
+
|
| 541 |
+
compiled_graph.current_callable = compiled_artifact
|
| 542 |
+
|
| 543 |
+
if "cuda" in compiled_graph.device_types:
|
| 544 |
+
# prefer better disable_cudagraphs_reason bc stack trace
|
| 545 |
+
# TODO: migrate all disable reasons to stack trace, refactor
|
| 546 |
+
if compiled_graph.disabled_cudagraphs_reason:
|
| 547 |
+
perf_hint_log.warning(compiled_graph.disabled_cudagraphs_reason)
|
| 548 |
+
else:
|
| 549 |
+
perf_hint_log.warning(
|
| 550 |
+
"skipping cudagraphs due to %s", cudagraph_fail_reasons
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
# cudagraphs does its own aligning of inputs
|
| 554 |
+
if not cudagraphs:
|
| 555 |
+
new_callable = align_inputs(
|
| 556 |
+
compiled_graph.get_current_callable(), example_inputs, range(num_fixed)
|
| 557 |
+
)
|
| 558 |
+
if new_callable is not compiled_graph.get_current_callable():
|
| 559 |
+
compiled_graph.current_callable = new_callable
|
| 560 |
+
|
| 561 |
+
_step_logger()(
|
| 562 |
+
logging.INFO,
|
| 563 |
+
"torchinductor done compiling "
|
| 564 |
+
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
|
| 565 |
+
f"graph {graph_id}",
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
# aot autograd needs to know to pass in inputs as a list
|
| 569 |
+
compiled_graph._boxed_call = True
|
| 570 |
+
return compiled_graph
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
def fx_codegen_and_compile(
|
| 574 |
+
gm: torch.fx.GraphModule,
|
| 575 |
+
example_inputs: List[torch.Tensor],
|
| 576 |
+
cudagraphs: Optional[BoxedBool] = None,
|
| 577 |
+
num_fixed: int = 0,
|
| 578 |
+
is_backward: bool = False,
|
| 579 |
+
graph_id: Optional[int] = None,
|
| 580 |
+
cpp_wrapper: bool = False,
|
| 581 |
+
aot_mode: bool = False,
|
| 582 |
+
is_inference: bool = False,
|
| 583 |
+
user_visible_outputs: FrozenSet[str] = frozenset(),
|
| 584 |
+
layout_opt: Optional[bool] = None,
|
| 585 |
+
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
|
| 586 |
+
) -> Union[CompiledFxGraph, str]:
|
| 587 |
+
if is_tf32_warning_applicable(gm):
|
| 588 |
+
_warn_tf32_disabled()
|
| 589 |
+
|
| 590 |
+
# lift the maximum depth of the Python interpreter stack
|
| 591 |
+
# to adapt large/deep models
|
| 592 |
+
sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000))
|
| 593 |
+
|
| 594 |
+
_step_logger()(
|
| 595 |
+
logging.INFO,
|
| 596 |
+
"torchinductor compiling "
|
| 597 |
+
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
|
| 598 |
+
f"graph {graph_id}",
|
| 599 |
+
)
|
| 600 |
+
V.debug.fx_graph(gm, example_inputs)
|
| 601 |
+
# TODO: Should we actually dump this? It should be redundant with the aot
|
| 602 |
+
# structured logs...
|
| 603 |
+
# trace_structured("inductor_input_graph", payload_fn=lambda: gm.print_readable(print_output=False))
|
| 604 |
+
|
| 605 |
+
shape_env = _shape_env_from_inputs(example_inputs)
|
| 606 |
+
|
| 607 |
+
# Convert view to reshape in the graph. This is necessary primarily for
|
| 608 |
+
# layout optimization. Do it unconditionally for uniformity.
|
| 609 |
+
#
|
| 610 |
+
# It's needed because when we do layout optimization, an contiguous tensor
|
| 611 |
+
# in eager mode may becomes a channels last tensor. A view op previously
|
| 612 |
+
# can be applied to the contiguous tensor may not be able to be applied
|
| 613 |
+
# on the channels tensor any more. An error like
|
| 614 |
+
# RuntimeError: view size is not compatible with input tensor's size and stride
|
| 615 |
+
# (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
|
| 616 |
+
# will be printed.
|
| 617 |
+
#
|
| 618 |
+
# Replace view op to reshape op in this case.
|
| 619 |
+
# As an example, timm_resnest/botnet26t_256/convnext_base etc. will fail if we don't do this.
|
| 620 |
+
#
|
| 621 |
+
# Also this has to be done before FakeTensorProp below to avoid the failed
|
| 622 |
+
# .view() call.
|
| 623 |
+
view_to_reshape(gm)
|
| 624 |
+
|
| 625 |
+
# It is safe to run FakeTensorProp under no_grad because by the time
|
| 626 |
+
# we're in inductor, we assume that AOTAutograd has already "taken care"
|
| 627 |
+
# of autograd, so there should be no more autograd-related API's in the
|
| 628 |
+
# graph.
|
| 629 |
+
with torch.no_grad():
|
| 630 |
+
fake_mode = fake_tensor_prop(gm, example_inputs)
|
| 631 |
+
|
| 632 |
+
# pattern matcher passes might not preserve striding information
|
| 633 |
+
# on node.meta["val"]. if in the future we rely on these being
|
| 634 |
+
# correct we will need to fix.
|
| 635 |
+
|
| 636 |
+
with V.set_fake_mode(fake_mode):
|
| 637 |
+
# has some issues with memory in training
|
| 638 |
+
_recursive_post_grad_passes(gm, is_inference=is_inference)
|
| 639 |
+
V.debug.fx_graph_transformed(gm, example_inputs)
|
| 640 |
+
post_grad_graphs_log.debug("%s", lazy_format_graph_code("AFTER POST GRAD", gm))
|
| 641 |
+
trace_structured(
|
| 642 |
+
"inductor_post_grad_graph",
|
| 643 |
+
payload_fn=lambda: gm.print_readable(print_output=False),
|
| 644 |
+
)
|
| 645 |
+
optimus_scuba_log["inductor_post_grad"] = counters["inductor"]
|
| 646 |
+
signpost_event(
|
| 647 |
+
"optimus",
|
| 648 |
+
"compile_fx.post_grad_passes",
|
| 649 |
+
optimus_scuba_log,
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
with V.set_fake_mode(fake_mode):
|
| 653 |
+
const_output_index = None
|
| 654 |
+
const_graph = None
|
| 655 |
+
const_code = None
|
| 656 |
+
|
| 657 |
+
if aot_mode and config.aot_inductor.use_runtime_constant_folding:
|
| 658 |
+
const_gm, const_output_index = split_const_gm(gm)
|
| 659 |
+
|
| 660 |
+
const_graph = GraphLowering(
|
| 661 |
+
const_gm,
|
| 662 |
+
example_inputs=[],
|
| 663 |
+
shape_env=shape_env,
|
| 664 |
+
num_static_inputs=num_fixed,
|
| 665 |
+
graph_id=graph_id,
|
| 666 |
+
cpp_wrapper=cpp_wrapper,
|
| 667 |
+
aot_mode=aot_mode,
|
| 668 |
+
user_visible_outputs=user_visible_outputs,
|
| 669 |
+
extern_node_serializer=extern_node_serializer,
|
| 670 |
+
is_inference=is_inference,
|
| 671 |
+
is_const_graph=True,
|
| 672 |
+
)
|
| 673 |
+
with V.set_graph_handler(const_graph):
|
| 674 |
+
assert cpp_wrapper, "AOT mode only supports C++ wrapper"
|
| 675 |
+
const_graph.run()
|
| 676 |
+
|
| 677 |
+
const_code, _ = const_graph.codegen_with_cpp_wrapper()
|
| 678 |
+
|
| 679 |
+
graph = GraphLowering(
|
| 680 |
+
gm,
|
| 681 |
+
# example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning.
|
| 682 |
+
# For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass,
|
| 683 |
+
# we currently use fake tensors and defake them later.
|
| 684 |
+
example_inputs=example_inputs,
|
| 685 |
+
shape_env=shape_env,
|
| 686 |
+
num_static_inputs=num_fixed,
|
| 687 |
+
graph_id=graph_id,
|
| 688 |
+
cpp_wrapper=cpp_wrapper,
|
| 689 |
+
aot_mode=aot_mode,
|
| 690 |
+
user_visible_outputs=user_visible_outputs,
|
| 691 |
+
extern_node_serializer=extern_node_serializer,
|
| 692 |
+
is_inference=is_inference,
|
| 693 |
+
const_output_index=const_output_index,
|
| 694 |
+
const_code=const_code,
|
| 695 |
+
const_module=const_graph,
|
| 696 |
+
)
|
| 697 |
+
with V.set_graph_handler(graph):
|
| 698 |
+
graph.run(*example_inputs)
|
| 699 |
+
output_strides: List[Optional[Tuple[int, ...]]] = []
|
| 700 |
+
if graph.graph_outputs is not None:
|
| 701 |
+
# We'll put the output strides in the compiled graph so we
|
| 702 |
+
# can later return them to the caller via TracingContext
|
| 703 |
+
for out in graph.graph_outputs:
|
| 704 |
+
if hasattr(out, "layout"):
|
| 705 |
+
output_strides.append(
|
| 706 |
+
tuple(
|
| 707 |
+
V.graph.sizevars.size_hint(s) for s in out.layout.stride
|
| 708 |
+
)
|
| 709 |
+
)
|
| 710 |
+
else:
|
| 711 |
+
output_strides.append(None)
|
| 712 |
+
|
| 713 |
+
metrics_helper = metrics.CachedMetricsHelper()
|
| 714 |
+
compiled_fn = graph.compile_to_fn()
|
| 715 |
+
|
| 716 |
+
if V.aot_compilation is True:
|
| 717 |
+
return compiled_fn
|
| 718 |
+
|
| 719 |
+
if cudagraphs and not V.graph.disable_cudagraphs_reason:
|
| 720 |
+
from torch._inductor.cudagraph_utils import (
|
| 721 |
+
check_lowering_disable_cudagraph,
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
V.graph.disable_cudagraphs_reason = check_lowering_disable_cudagraph(
|
| 725 |
+
V.graph.device_node_mapping
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
compiled_graph = CompiledFxGraph(
|
| 729 |
+
compiled_fn,
|
| 730 |
+
graph,
|
| 731 |
+
output_strides,
|
| 732 |
+
V.graph.disable_cudagraphs_reason,
|
| 733 |
+
metrics_helper.get_deltas(),
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
return compiled_graph
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def clone_preserve_strides(x: torch.Tensor):
|
| 740 |
+
needed_size = (
|
| 741 |
+
sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
|
| 742 |
+
)
|
| 743 |
+
buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
|
| 744 |
+
return torch.as_strided(buffer, x.size(), x.stride())
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
def copy_misaligned_inputs(
|
| 748 |
+
new_inputs: List[torch.Tensor], check_inputs_idxs: Sequence[int]
|
| 749 |
+
) -> None:
|
| 750 |
+
for i in check_inputs_idxs:
|
| 751 |
+
if new_inputs[i].data_ptr() % ALIGNMENT:
|
| 752 |
+
new_inputs[i] = clone_preserve_strides(new_inputs[i])
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
def get_input_idxs_to_check(
|
| 756 |
+
inputs: Union[List[torch.Tensor], Sequence[int]],
|
| 757 |
+
static_input_idxs: Sequence[int],
|
| 758 |
+
) -> Sequence[int]:
|
| 759 |
+
def is_aligned(storage_offset, dtype):
|
| 760 |
+
return (storage_offset * get_dtype_size(dtype)) % ALIGNMENT == 0
|
| 761 |
+
|
| 762 |
+
ids_to_check = []
|
| 763 |
+
for i, input in enumerate(inputs):
|
| 764 |
+
if (
|
| 765 |
+
isinstance(input, torch.Tensor)
|
| 766 |
+
and (
|
| 767 |
+
i not in static_input_idxs
|
| 768 |
+
or not is_aligned(input.storage_offset(), input.dtype)
|
| 769 |
+
)
|
| 770 |
+
and input.device.type == "cuda"
|
| 771 |
+
):
|
| 772 |
+
ids_to_check.append(i)
|
| 773 |
+
return ids_to_check
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def align_inputs_from_check_idxs(
|
| 777 |
+
model: Callable[[List[torch.Tensor]], Any], inputs_to_check: Sequence[int]
|
| 778 |
+
):
|
| 779 |
+
if len(inputs_to_check) == 0:
|
| 780 |
+
return model
|
| 781 |
+
|
| 782 |
+
def run(new_inputs):
|
| 783 |
+
copy_misaligned_inputs(new_inputs, inputs_to_check)
|
| 784 |
+
return model(new_inputs)
|
| 785 |
+
|
| 786 |
+
return run
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
def align_inputs(
|
| 790 |
+
model: Callable[[List[torch.Tensor]], Any],
|
| 791 |
+
inputs: List[torch.Tensor],
|
| 792 |
+
static_input_idxs: Sequence[int] = (),
|
| 793 |
+
):
|
| 794 |
+
inputs_to_check = get_input_idxs_to_check(inputs, static_input_idxs)
|
| 795 |
+
return align_inputs_from_check_idxs(model, inputs_to_check)
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
@dynamo_utils.dynamo_timed
|
| 799 |
+
def cudagraphify(
|
| 800 |
+
model: torch.fx.GraphModule,
|
| 801 |
+
inputs: List[torch.Tensor],
|
| 802 |
+
static_input_idxs: Sequence[int] = (),
|
| 803 |
+
*,
|
| 804 |
+
device_index: int,
|
| 805 |
+
stack_traces: List[Optional[str]],
|
| 806 |
+
is_backward: bool,
|
| 807 |
+
is_inference: bool,
|
| 808 |
+
constants: Tuple[torch.Tensor, ...] = (),
|
| 809 |
+
):
|
| 810 |
+
from torch._inductor.cudagraph_trees import (
|
| 811 |
+
cudagraphify_impl as new_cudagraphify_impl,
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
cudagraphify_fn: Callable[..., Any]
|
| 815 |
+
if config.triton.cudagraph_trees:
|
| 816 |
+
cudagraphify_fn = functools.partial(
|
| 817 |
+
new_cudagraphify_impl,
|
| 818 |
+
device_index=device_index,
|
| 819 |
+
stack_traces=stack_traces,
|
| 820 |
+
is_backward=is_backward,
|
| 821 |
+
is_inference=is_inference,
|
| 822 |
+
constants=constants,
|
| 823 |
+
)
|
| 824 |
+
else:
|
| 825 |
+
cudagraphify_fn = cudagraphify_impl
|
| 826 |
+
|
| 827 |
+
# if using fake tensors, defer cudagraphs until we get real inputs at runtime
|
| 828 |
+
if not any(isinstance(inp, FakeTensor) for inp in inputs):
|
| 829 |
+
return cudagraphify_fn(model, inputs, static_input_idxs)
|
| 830 |
+
|
| 831 |
+
compiled_fn = None
|
| 832 |
+
|
| 833 |
+
def run(new_inputs):
|
| 834 |
+
nonlocal compiled_fn
|
| 835 |
+
if compiled_fn is None:
|
| 836 |
+
with dynamo_utils.preserve_rng_state():
|
| 837 |
+
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
|
| 838 |
+
return compiled_fn(new_inputs)
|
| 839 |
+
|
| 840 |
+
return run
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
def remove_unaligned_input_idxs(
|
| 844 |
+
inputs: Union[List[torch.Tensor], Sequence[int]],
|
| 845 |
+
static_input_idxs: Sequence[int],
|
| 846 |
+
):
|
| 847 |
+
"""
|
| 848 |
+
We require all inputs to be aligned, so introduce a copy for any
|
| 849 |
+
that aren't.
|
| 850 |
+
"""
|
| 851 |
+
aligned_static_input_idxs = []
|
| 852 |
+
for idx, input in zip(static_input_idxs, inputs):
|
| 853 |
+
if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0:
|
| 854 |
+
aligned_static_input_idxs.append(idx)
|
| 855 |
+
if len(aligned_static_input_idxs) != len(static_input_idxs):
|
| 856 |
+
return aligned_static_input_idxs
|
| 857 |
+
return static_input_idxs
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
def static_input(x: torch.Tensor):
|
| 861 |
+
"""
|
| 862 |
+
Copy and input while preserving strides
|
| 863 |
+
"""
|
| 864 |
+
# TODO(jansel): figure out why this version doesn't work:
|
| 865 |
+
# return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
|
| 866 |
+
needed_size = (
|
| 867 |
+
sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
|
| 868 |
+
)
|
| 869 |
+
buffer = torch.empty(needed_size, dtype=x.dtype, device=x.device)
|
| 870 |
+
return torch.as_strided(buffer, x.size(), x.stride())
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
def index_expanded_dims_and_copy_(
|
| 874 |
+
dst: torch.Tensor,
|
| 875 |
+
src: torch.Tensor,
|
| 876 |
+
expanded_dims: List[int],
|
| 877 |
+
):
|
| 878 |
+
"Index into expanded dimensions of both dst and src then copy_"
|
| 879 |
+
dst = index_expanded_dims(dst, expanded_dims)
|
| 880 |
+
src = index_expanded_dims(src, expanded_dims)
|
| 881 |
+
dst.copy_(src)
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
def cudagraphify_impl(
|
| 885 |
+
model: torch.fx.GraphModule,
|
| 886 |
+
inputs: List[torch.Tensor],
|
| 887 |
+
static_input_idxs: Sequence[int] = (),
|
| 888 |
+
):
|
| 889 |
+
"""
|
| 890 |
+
Assumes inputs[static_input_idxs[i]] are always the same memory address
|
| 891 |
+
"""
|
| 892 |
+
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs)
|
| 893 |
+
static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
|
| 894 |
+
copy_misaligned_inputs(inputs, check_input_idxs)
|
| 895 |
+
|
| 896 |
+
assert isinstance(inputs, list)
|
| 897 |
+
|
| 898 |
+
inps_expanded_dims = [
|
| 899 |
+
get_expanded_dims(x) if idx not in static_input_idxs else []
|
| 900 |
+
for idx, x in enumerate(inputs)
|
| 901 |
+
]
|
| 902 |
+
|
| 903 |
+
# allocate static tensor inputs
|
| 904 |
+
static_inputs = [
|
| 905 |
+
x
|
| 906 |
+
if not isinstance(x, torch.Tensor)
|
| 907 |
+
else static_input(x)
|
| 908 |
+
if idx not in static_input_idxs
|
| 909 |
+
else x.detach()
|
| 910 |
+
for idx, x in enumerate(inputs)
|
| 911 |
+
]
|
| 912 |
+
|
| 913 |
+
# copy over input values for fresh allocations
|
| 914 |
+
for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)):
|
| 915 |
+
if isinstance(x, torch.Tensor) and idx not in static_input_idxs:
|
| 916 |
+
index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims)
|
| 917 |
+
|
| 918 |
+
# warmup
|
| 919 |
+
torch.cuda.synchronize()
|
| 920 |
+
stream = torch.cuda.Stream()
|
| 921 |
+
stream.wait_stream(torch.cuda.current_stream())
|
| 922 |
+
# copy static_inputs because it will be cleared in model
|
| 923 |
+
with torch.cuda.stream(stream):
|
| 924 |
+
model(list(static_inputs))
|
| 925 |
+
stream.synchronize()
|
| 926 |
+
torch.cuda.current_stream().wait_stream(stream)
|
| 927 |
+
torch.cuda.synchronize()
|
| 928 |
+
|
| 929 |
+
# record
|
| 930 |
+
graph = torch.cuda.CUDAGraph()
|
| 931 |
+
with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"):
|
| 932 |
+
static_outputs = model(list(static_inputs))
|
| 933 |
+
if not isinstance(static_outputs, (list, tuple)):
|
| 934 |
+
static_outputs = (static_outputs,)
|
| 935 |
+
|
| 936 |
+
if config.size_asserts:
|
| 937 |
+
|
| 938 |
+
def run(new_inputs):
|
| 939 |
+
assert len(static_inputs) == len(new_inputs)
|
| 940 |
+
for idx, (dst, src, expanded_dims) in enumerate(
|
| 941 |
+
zip(static_inputs, new_inputs, inps_expanded_dims)
|
| 942 |
+
):
|
| 943 |
+
if not isinstance(dst, torch.Tensor):
|
| 944 |
+
pass
|
| 945 |
+
elif idx in static_input_idxs:
|
| 946 |
+
assert dst.data_ptr() == src.data_ptr()
|
| 947 |
+
else:
|
| 948 |
+
# TODO - could make one single op of multiple slices
|
| 949 |
+
# and avoid dispatch.
|
| 950 |
+
# Could also pre-index the `dst` tensors
|
| 951 |
+
index_expanded_dims_and_copy_(dst, src, expanded_dims)
|
| 952 |
+
new_inputs.clear()
|
| 953 |
+
graph.replay()
|
| 954 |
+
return static_outputs
|
| 955 |
+
|
| 956 |
+
else:
|
| 957 |
+
copy_indices = [
|
| 958 |
+
idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
|
| 959 |
+
]
|
| 960 |
+
|
| 961 |
+
def run(new_inputs):
|
| 962 |
+
for idx in copy_indices:
|
| 963 |
+
expanded_dims = inps_expanded_dims[idx]
|
| 964 |
+
index_expanded_dims_and_copy_(
|
| 965 |
+
static_inputs[idx], new_inputs[idx], expanded_dims
|
| 966 |
+
)
|
| 967 |
+
new_inputs.clear()
|
| 968 |
+
graph.replay()
|
| 969 |
+
return static_outputs
|
| 970 |
+
|
| 971 |
+
return align_inputs_from_check_idxs(run, check_input_idxs)
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
def compile_fx_aot(
|
| 975 |
+
model_: torch.fx.GraphModule,
|
| 976 |
+
example_inputs_: List[torch.Tensor],
|
| 977 |
+
inner_compile: Callable[..., Any] = compile_fx_inner,
|
| 978 |
+
config_patches: Optional[Dict[str, Any]] = None,
|
| 979 |
+
):
|
| 980 |
+
config_patches: Dict[str, Any] = (
|
| 981 |
+
{"cpp_wrapper": True}
|
| 982 |
+
if config_patches is None
|
| 983 |
+
else {**config_patches, "cpp_wrapper": True}
|
| 984 |
+
)
|
| 985 |
+
if (
|
| 986 |
+
"aot_inductor.output_path" not in config_patches
|
| 987 |
+
and not config.aot_inductor.output_path
|
| 988 |
+
):
|
| 989 |
+
config_patches = {
|
| 990 |
+
**config_patches,
|
| 991 |
+
"aot_inductor.output_path": code_hash(model_.code),
|
| 992 |
+
}
|
| 993 |
+
|
| 994 |
+
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
|
| 995 |
+
with V.set_aot_compilation(True):
|
| 996 |
+
compiled_lib_path = compile_fx(
|
| 997 |
+
model_,
|
| 998 |
+
example_inputs_,
|
| 999 |
+
inner_compile=functools.partial(
|
| 1000 |
+
inner_compile,
|
| 1001 |
+
aot_mode=True,
|
| 1002 |
+
extern_node_serializer=extern_node_serializer,
|
| 1003 |
+
),
|
| 1004 |
+
config_patches=config_patches,
|
| 1005 |
+
)
|
| 1006 |
+
assert os.path.exists(
|
| 1007 |
+
compiled_lib_path
|
| 1008 |
+
), f"AOTInductor compiled library does not exist at {compiled_lib_path}"
|
| 1009 |
+
return compiled_lib_path
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
_graph_counter = count(0)
|
| 1013 |
+
|
| 1014 |
+
|
| 1015 |
+
def fw_compiler_freezing(
|
| 1016 |
+
aot_autograd_model: torch.fx.GraphModule,
|
| 1017 |
+
aot_example_inputs: List[torch.Tensor],
|
| 1018 |
+
dynamo_model: torch.fx.GraphModule,
|
| 1019 |
+
num_example_inputs: int,
|
| 1020 |
+
inner_compile: Callable[..., Any],
|
| 1021 |
+
cudagraphs: BoxedBool,
|
| 1022 |
+
graph_id: int,
|
| 1023 |
+
forward_device: BoxedDeviceIndex,
|
| 1024 |
+
):
|
| 1025 |
+
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze
|
| 1026 |
+
|
| 1027 |
+
# partition_fn won't be called
|
| 1028 |
+
_recursive_joint_graph_passes(aot_autograd_model)
|
| 1029 |
+
|
| 1030 |
+
layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model, is_inference=True)
|
| 1031 |
+
if layout_opt:
|
| 1032 |
+
# make sure meta['val'] is properly setup
|
| 1033 |
+
fake_tensor_prop(aot_autograd_model, aot_example_inputs, True)
|
| 1034 |
+
convert_conv_weights_to_channels_last(aot_autograd_model)
|
| 1035 |
+
|
| 1036 |
+
opt_model, preserved_arg_indices = freeze(
|
| 1037 |
+
dynamo_model,
|
| 1038 |
+
aot_autograd_model,
|
| 1039 |
+
aot_example_inputs, # type: ignore[arg-type]
|
| 1040 |
+
)
|
| 1041 |
+
|
| 1042 |
+
aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
|
| 1043 |
+
num_fixed = len(preserved_arg_indices) - num_example_inputs
|
| 1044 |
+
|
| 1045 |
+
fake_mode = detect_fake_mode(aot_example_inputs)
|
| 1046 |
+
|
| 1047 |
+
# for freezing, all graph outputs should be user visible
|
| 1048 |
+
*_, model_outputs_node = opt_model.graph.nodes
|
| 1049 |
+
model_outputs = model_outputs_node.args[0]
|
| 1050 |
+
user_visible_outputs = [
|
| 1051 |
+
n.name for n in model_outputs if isinstance(n, torch.fx.Node)
|
| 1052 |
+
]
|
| 1053 |
+
|
| 1054 |
+
# constant params will be real tensors, not fake
|
| 1055 |
+
tracing_context = torch._guards.TracingContext.try_get()
|
| 1056 |
+
if tracing_context is not None:
|
| 1057 |
+
params_flat = tracing_context.params_flat
|
| 1058 |
+
assert params_flat is not None
|
| 1059 |
+
for i in range(len(params_flat)):
|
| 1060 |
+
if i not in preserved_arg_indices:
|
| 1061 |
+
params_flat[i] = None
|
| 1062 |
+
|
| 1063 |
+
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
|
| 1064 |
+
optimized_function = inner_compile(
|
| 1065 |
+
opt_model,
|
| 1066 |
+
aot_example_inputs,
|
| 1067 |
+
num_fixed=num_fixed,
|
| 1068 |
+
cudagraphs=cudagraphs,
|
| 1069 |
+
graph_id=graph_id,
|
| 1070 |
+
is_inference=True,
|
| 1071 |
+
boxed_forward_device_index=forward_device,
|
| 1072 |
+
layout_opt=layout_opt,
|
| 1073 |
+
user_visible_outputs=user_visible_outputs,
|
| 1074 |
+
)
|
| 1075 |
+
|
| 1076 |
+
# aot_inductor codegens a call that takes in just the inputs, so we don't return a wrapper
|
| 1077 |
+
# that drops constant-ified params
|
| 1078 |
+
if V.aot_compilation is True:
|
| 1079 |
+
return optimized_function
|
| 1080 |
+
|
| 1081 |
+
def wrapper(args):
|
| 1082 |
+
args_new = [args[i] for i in preserved_arg_indices]
|
| 1083 |
+
args.clear()
|
| 1084 |
+
return optimized_function(args_new)
|
| 1085 |
+
|
| 1086 |
+
wrapper._boxed_call = True # type: ignore[attr-defined]
|
| 1087 |
+
|
| 1088 |
+
return wrapper
|
| 1089 |
+
|
| 1090 |
+
|
| 1091 |
+
@_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)
|
| 1092 |
+
def compile_fx(
|
| 1093 |
+
model_: torch.fx.GraphModule,
|
| 1094 |
+
example_inputs_: List[torch.Tensor],
|
| 1095 |
+
inner_compile: Callable[..., Any] = compile_fx_inner,
|
| 1096 |
+
config_patches: Optional[Dict[str, Any]] = None,
|
| 1097 |
+
decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
|
| 1098 |
+
):
|
| 1099 |
+
"""Main entrypoint to a compile given FX graph"""
|
| 1100 |
+
if config_patches:
|
| 1101 |
+
with config.patch(config_patches):
|
| 1102 |
+
return compile_fx(
|
| 1103 |
+
model_,
|
| 1104 |
+
example_inputs_,
|
| 1105 |
+
# need extra layer of patching as backwards is compiled out of scope
|
| 1106 |
+
inner_compile=config.patch(config_patches)(inner_compile),
|
| 1107 |
+
decompositions=decompositions,
|
| 1108 |
+
)
|
| 1109 |
+
|
| 1110 |
+
if config.cpp_wrapper:
|
| 1111 |
+
with config.patch(
|
| 1112 |
+
{
|
| 1113 |
+
"cpp_wrapper": False,
|
| 1114 |
+
"triton.autotune_cublasLt": False,
|
| 1115 |
+
"triton.cudagraphs": False,
|
| 1116 |
+
"triton.store_cubin": True,
|
| 1117 |
+
}
|
| 1118 |
+
), V.set_real_inputs(example_inputs_):
|
| 1119 |
+
inputs_ = example_inputs_
|
| 1120 |
+
if isinstance(model_, torch.fx.GraphModule):
|
| 1121 |
+
fake_inputs = [
|
| 1122 |
+
node.meta.get("val")
|
| 1123 |
+
for node in model_.graph.nodes
|
| 1124 |
+
if node.op == "placeholder"
|
| 1125 |
+
]
|
| 1126 |
+
if all(v is not None for v in fake_inputs):
|
| 1127 |
+
# Validate devices before switching to fake tensors.
|
| 1128 |
+
for idx, fi, i in zip(count(), fake_inputs, inputs_):
|
| 1129 |
+
if fi.device != i.device:
|
| 1130 |
+
raise ValueError(
|
| 1131 |
+
f"Device mismatch between fake input and example input at position #{idx}: "
|
| 1132 |
+
f"{fi.device} vs {i.device}. If the model was exported via torch.export(), "
|
| 1133 |
+
"make sure torch.export() and torch.aot_compile() run on the same device."
|
| 1134 |
+
)
|
| 1135 |
+
inputs_ = fake_inputs
|
| 1136 |
+
return compile_fx(
|
| 1137 |
+
model_,
|
| 1138 |
+
inputs_,
|
| 1139 |
+
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
|
| 1140 |
+
decompositions=decompositions,
|
| 1141 |
+
)
|
| 1142 |
+
|
| 1143 |
+
recursive_compile_fx = functools.partial(
|
| 1144 |
+
compile_fx,
|
| 1145 |
+
inner_compile=inner_compile,
|
| 1146 |
+
decompositions=decompositions,
|
| 1147 |
+
)
|
| 1148 |
+
|
| 1149 |
+
if not graph_returns_tuple(model_):
|
| 1150 |
+
return make_graph_return_tuple(
|
| 1151 |
+
model_,
|
| 1152 |
+
example_inputs_,
|
| 1153 |
+
recursive_compile_fx,
|
| 1154 |
+
)
|
| 1155 |
+
|
| 1156 |
+
if isinstance(model_, torch.fx.GraphModule):
|
| 1157 |
+
if isinstance(model_.graph._codegen, _PyTreeCodeGen):
|
| 1158 |
+
# this graph is the result of dynamo.export()
|
| 1159 |
+
return handle_dynamo_export_graph(
|
| 1160 |
+
model_,
|
| 1161 |
+
example_inputs_,
|
| 1162 |
+
recursive_compile_fx,
|
| 1163 |
+
)
|
| 1164 |
+
|
| 1165 |
+
model_ = _recursive_pre_grad_passes(model_, example_inputs_)
|
| 1166 |
+
optimus_scuba_log["inductor_pre_grad"] = counters["inductor"]
|
| 1167 |
+
signpost_event(
|
| 1168 |
+
"optimus",
|
| 1169 |
+
"compile_fx.pre_grad_passes",
|
| 1170 |
+
optimus_scuba_log,
|
| 1171 |
+
)
|
| 1172 |
+
|
| 1173 |
+
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
|
| 1174 |
+
return flatten_graph_inputs(
|
| 1175 |
+
model_,
|
| 1176 |
+
example_inputs_,
|
| 1177 |
+
recursive_compile_fx,
|
| 1178 |
+
)
|
| 1179 |
+
|
| 1180 |
+
assert not config._raise_error_for_testing
|
| 1181 |
+
num_example_inputs = len(example_inputs_)
|
| 1182 |
+
cudagraphs = BoxedBool(config.triton.cudagraphs)
|
| 1183 |
+
forward_device = BoxedDeviceIndex(None)
|
| 1184 |
+
|
| 1185 |
+
graph_id = next(_graph_counter)
|
| 1186 |
+
|
| 1187 |
+
decompositions = (
|
| 1188 |
+
decompositions if decompositions is not None else select_decomp_table()
|
| 1189 |
+
)
|
| 1190 |
+
|
| 1191 |
+
@dynamo_utils.dynamo_timed
|
| 1192 |
+
def fw_compiler_base(
|
| 1193 |
+
model: torch.fx.GraphModule,
|
| 1194 |
+
example_inputs: List[torch.Tensor],
|
| 1195 |
+
is_inference: bool,
|
| 1196 |
+
):
|
| 1197 |
+
if is_inference:
|
| 1198 |
+
# partition_fn won't be called
|
| 1199 |
+
_recursive_joint_graph_passes(model)
|
| 1200 |
+
|
| 1201 |
+
fixed = torch._inductor.utils.num_fw_fixed_arguments(
|
| 1202 |
+
num_example_inputs, len(example_inputs)
|
| 1203 |
+
)
|
| 1204 |
+
user_visible_outputs = set()
|
| 1205 |
+
|
| 1206 |
+
if config.keep_output_stride:
|
| 1207 |
+
*_, model_outputs_node = model.graph.nodes
|
| 1208 |
+
assert model_outputs_node.op == "output"
|
| 1209 |
+
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
| 1210 |
+
num_model_outputs = len(model_outputs)
|
| 1211 |
+
|
| 1212 |
+
context = torch._guards.TracingContext.try_get()
|
| 1213 |
+
# See Note [User Outputs in the inductor graph]
|
| 1214 |
+
if context is not None and context.fw_metadata and not is_inference:
|
| 1215 |
+
original_output_start_index = (
|
| 1216 |
+
context.fw_metadata.num_mutated_inp_runtime_indices
|
| 1217 |
+
)
|
| 1218 |
+
else:
|
| 1219 |
+
original_output_start_index = 0
|
| 1220 |
+
|
| 1221 |
+
if isinstance(model_, torch.fx.GraphModule):
|
| 1222 |
+
*_, orig_model_outputs_node = model_.graph.nodes
|
| 1223 |
+
assert orig_model_outputs_node.op == "output"
|
| 1224 |
+
orig_model_outputs, _ = pytree.tree_flatten(
|
| 1225 |
+
orig_model_outputs_node.args
|
| 1226 |
+
)
|
| 1227 |
+
num_orig_model_outputs = len(orig_model_outputs)
|
| 1228 |
+
else:
|
| 1229 |
+
num_orig_model_outputs = num_model_outputs
|
| 1230 |
+
|
| 1231 |
+
assert num_orig_model_outputs <= num_model_outputs
|
| 1232 |
+
|
| 1233 |
+
# Note [User Outputs in the inductor graph]
|
| 1234 |
+
# We makes the following assumption
|
| 1235 |
+
# For inference
|
| 1236 |
+
# len(orig_model_outputs) == len(model_outputs)
|
| 1237 |
+
# For training
|
| 1238 |
+
# len(orig_model_outputs) <= len(model_outputs)
|
| 1239 |
+
# During training, most of the time the model_outputs starts with
|
| 1240 |
+
# original module's outputs followed by saved activations.
|
| 1241 |
+
# But this can be not true if the model have inplace updated tensors.
|
| 1242 |
+
# AOTAutograd will make those tensors being returned before the original
|
| 1243 |
+
# module's output.
|
| 1244 |
+
# To make things safe, we'll use original_output_start_index field
|
| 1245 |
+
# set by AOTAutograd to decide where the original module outputs start.
|
| 1246 |
+
orig_output_end_idx = original_output_start_index + num_orig_model_outputs
|
| 1247 |
+
# Sanity chec: we are about to splice out the "user" outputs from the full set
|
| 1248 |
+
# of "graph" outputs. Make sure we're within bounds.
|
| 1249 |
+
assert orig_output_end_idx <= num_model_outputs
|
| 1250 |
+
|
| 1251 |
+
user_visible_outputs = {
|
| 1252 |
+
n.name
|
| 1253 |
+
for n in model_outputs[original_output_start_index:orig_output_end_idx]
|
| 1254 |
+
if isinstance(n, torch.fx.Node)
|
| 1255 |
+
}
|
| 1256 |
+
|
| 1257 |
+
return inner_compile(
|
| 1258 |
+
model,
|
| 1259 |
+
example_inputs,
|
| 1260 |
+
num_fixed=fixed,
|
| 1261 |
+
cudagraphs=cudagraphs,
|
| 1262 |
+
graph_id=graph_id,
|
| 1263 |
+
is_inference=is_inference,
|
| 1264 |
+
boxed_forward_device_index=forward_device,
|
| 1265 |
+
user_visible_outputs=user_visible_outputs,
|
| 1266 |
+
)
|
| 1267 |
+
|
| 1268 |
+
fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
|
| 1269 |
+
|
| 1270 |
+
if config.freezing and not torch.is_grad_enabled():
|
| 1271 |
+
inference_compiler = functools.partial(
|
| 1272 |
+
fw_compiler_freezing,
|
| 1273 |
+
dynamo_model=model_,
|
| 1274 |
+
num_example_inputs=num_example_inputs,
|
| 1275 |
+
inner_compile=inner_compile,
|
| 1276 |
+
cudagraphs=cudagraphs,
|
| 1277 |
+
graph_id=graph_id,
|
| 1278 |
+
forward_device=forward_device,
|
| 1279 |
+
)
|
| 1280 |
+
else:
|
| 1281 |
+
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
|
| 1282 |
+
|
| 1283 |
+
def partition_fn(graph, joint_inputs, **kwargs):
|
| 1284 |
+
_recursive_joint_graph_passes(graph)
|
| 1285 |
+
return min_cut_rematerialization_partition(
|
| 1286 |
+
graph, joint_inputs, **kwargs, compiler="inductor"
|
| 1287 |
+
)
|
| 1288 |
+
|
| 1289 |
+
@dynamo_utils.dynamo_timed
|
| 1290 |
+
@dynamo_utils.maybe_cprofile
|
| 1291 |
+
def bw_compiler(model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
| 1292 |
+
fixed = count_tangents(model)
|
| 1293 |
+
return inner_compile(
|
| 1294 |
+
model,
|
| 1295 |
+
example_inputs,
|
| 1296 |
+
num_fixed=fixed,
|
| 1297 |
+
cudagraphs=cudagraphs,
|
| 1298 |
+
is_backward=True,
|
| 1299 |
+
graph_id=graph_id,
|
| 1300 |
+
boxed_forward_device_index=forward_device,
|
| 1301 |
+
)
|
| 1302 |
+
|
| 1303 |
+
# TODO: can add logging before/after the call to create_aot_dispatcher_function
|
| 1304 |
+
# in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
|
| 1305 |
+
# once torchdynamo is merged into pytorch
|
| 1306 |
+
|
| 1307 |
+
fake_mode = detect_fake_mode(example_inputs_) or torch._subclasses.FakeTensorMode(
|
| 1308 |
+
allow_non_fake_inputs=True
|
| 1309 |
+
)
|
| 1310 |
+
tracing_context = (
|
| 1311 |
+
torch._guards.TracingContext.try_get()
|
| 1312 |
+
or torch._guards.TracingContext(fake_mode)
|
| 1313 |
+
)
|
| 1314 |
+
|
| 1315 |
+
if V.aot_compilation is True:
|
| 1316 |
+
gm, graph_signature = aot_export_module(
|
| 1317 |
+
model_, example_inputs_, trace_joint=False, decompositions=decompositions
|
| 1318 |
+
)
|
| 1319 |
+
unlifted_gm = _unlift_graph(model_, gm, graph_signature)
|
| 1320 |
+
if "dynamo_flat_name_to_original_fqn" in model_.meta:
|
| 1321 |
+
unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[
|
| 1322 |
+
"dynamo_flat_name_to_original_fqn"
|
| 1323 |
+
]
|
| 1324 |
+
with V.set_fake_mode(fake_mode), compiled_autograd.disable():
|
| 1325 |
+
return inference_compiler(unlifted_gm, example_inputs_)
|
| 1326 |
+
|
| 1327 |
+
with V.set_fake_mode(fake_mode), torch._guards.tracing(
|
| 1328 |
+
tracing_context
|
| 1329 |
+
), compiled_autograd.disable():
|
| 1330 |
+
return aot_autograd(
|
| 1331 |
+
fw_compiler=fw_compiler,
|
| 1332 |
+
bw_compiler=bw_compiler,
|
| 1333 |
+
inference_compiler=inference_compiler,
|
| 1334 |
+
decompositions=decompositions,
|
| 1335 |
+
partition_fn=partition_fn,
|
| 1336 |
+
keep_inference_input_mutations=True,
|
| 1337 |
+
)(model_, example_inputs_)
|
| 1338 |
+
|
| 1339 |
+
|
| 1340 |
+
def _shape_env_from_inputs(inputs: List[torch.Tensor]):
|
| 1341 |
+
shape_env = None
|
| 1342 |
+
fake_mode = detect_fake_mode(inputs)
|
| 1343 |
+
|
| 1344 |
+
# TODO(voz): It would be nice to enable this assert, but there are lots of tests that
|
| 1345 |
+
# pass in real inputs for now.
|
| 1346 |
+
# if len(inputs) > 0:
|
| 1347 |
+
# assert fake_mode is not None, breakpoint()
|
| 1348 |
+
|
| 1349 |
+
if fake_mode is not None:
|
| 1350 |
+
return fake_mode.shape_env
|
| 1351 |
+
|
| 1352 |
+
# When there are no tensor inputs, get shape_env from the first SymInt.
|
| 1353 |
+
for input in inputs:
|
| 1354 |
+
if isinstance(input, torch.SymInt):
|
| 1355 |
+
return input.node.shape_env
|
| 1356 |
+
|
| 1357 |
+
# TODO(voz): Should we always have one anyway?
|
| 1358 |
+
return None
|
| 1359 |
+
|
| 1360 |
+
|
| 1361 |
+
def graph_returns_tuple(gm: torch.fx.GraphModule):
|
| 1362 |
+
"""True if a FX graph returns a tuple"""
|
| 1363 |
+
if not isinstance(gm, torch.fx.GraphModule):
|
| 1364 |
+
return True # can't check this, assume true
|
| 1365 |
+
(rv,) = output_node(gm).args
|
| 1366 |
+
if isinstance(rv, (list, tuple)):
|
| 1367 |
+
return True
|
| 1368 |
+
if (
|
| 1369 |
+
isinstance(rv, torch.fx.node.Node)
|
| 1370 |
+
and hasattr(rv.target, "_schema")
|
| 1371 |
+
and len(rv.target._schema.returns) > 1
|
| 1372 |
+
and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns)
|
| 1373 |
+
):
|
| 1374 |
+
# for graphs whose result is one node with multiple outputs
|
| 1375 |
+
return True
|
| 1376 |
+
return False
|
| 1377 |
+
|
| 1378 |
+
|
| 1379 |
+
def make_graph_return_tuple(
|
| 1380 |
+
gm: torch.fx.GraphModule,
|
| 1381 |
+
inputs: List[torch.Tensor],
|
| 1382 |
+
compile_gm: Callable[..., Any],
|
| 1383 |
+
):
|
| 1384 |
+
"""
|
| 1385 |
+
Mutate gm so it returns a tuple. This is only needed for graphs
|
| 1386 |
+
not created by torchdynamo that return non-tuples.
|
| 1387 |
+
"""
|
| 1388 |
+
node = output_node(gm)
|
| 1389 |
+
(rv,) = node.args
|
| 1390 |
+
rv, spec = pytree.tree_flatten(rv)
|
| 1391 |
+
with gm.graph.inserting_before(node):
|
| 1392 |
+
gm.graph.output(rv)
|
| 1393 |
+
gm.graph.erase_node(node)
|
| 1394 |
+
assert graph_returns_tuple(gm)
|
| 1395 |
+
|
| 1396 |
+
compiled_fn = compile_gm(gm, inputs)
|
| 1397 |
+
|
| 1398 |
+
@functools.wraps(compiled_fn)
|
| 1399 |
+
def wrapper(*args, **kwargs):
|
| 1400 |
+
return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec)
|
| 1401 |
+
|
| 1402 |
+
return wrapper
|
| 1403 |
+
|
| 1404 |
+
|
| 1405 |
+
def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm):
|
| 1406 |
+
"""
|
| 1407 |
+
Mutate inputs so that they are flat and wrap gm such that it
|
| 1408 |
+
accepts those inputs. This is only needed for graphs not created
|
| 1409 |
+
by torchdynamo that take bumpy inputs.
|
| 1410 |
+
"""
|
| 1411 |
+
inputs, spec = pytree.tree_flatten(inputs)
|
| 1412 |
+
|
| 1413 |
+
class GmWrapper(torch.nn.Module):
|
| 1414 |
+
def __init__(self):
|
| 1415 |
+
super().__init__()
|
| 1416 |
+
self.gm = gm
|
| 1417 |
+
|
| 1418 |
+
def forward(self, *args):
|
| 1419 |
+
args: List[Any] = list(args)
|
| 1420 |
+
return self.gm(*pytree.tree_unflatten(args, spec))
|
| 1421 |
+
|
| 1422 |
+
compiled_fn = compile_gm(GmWrapper(), inputs)
|
| 1423 |
+
|
| 1424 |
+
@functools.wraps(compiled_fn)
|
| 1425 |
+
def wrapper(*args):
|
| 1426 |
+
# note this doesn't check the spec, assuming it is the same
|
| 1427 |
+
return compiled_fn(*pytree.arg_tree_leaves(*args))
|
| 1428 |
+
|
| 1429 |
+
return wrapper
|
| 1430 |
+
|
| 1431 |
+
|
| 1432 |
+
def handle_dynamo_export_graph(
|
| 1433 |
+
gm: torch.fx.GraphModule,
|
| 1434 |
+
inputs: List[torch.Tensor],
|
| 1435 |
+
compile_gm: Callable[..., Any],
|
| 1436 |
+
):
|
| 1437 |
+
"""
|
| 1438 |
+
`torch._dynamo.export` embeds pytrees in the FX graph codegen object,
|
| 1439 |
+
convert that to a normal FX graph so inductor can compile it.
|
| 1440 |
+
"""
|
| 1441 |
+
codegen = gm.graph._codegen
|
| 1442 |
+
gm.graph._codegen = torch.fx.graph.CodeGen()
|
| 1443 |
+
gm.recompile()
|
| 1444 |
+
|
| 1445 |
+
compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs))
|
| 1446 |
+
|
| 1447 |
+
@functools.wraps(compiled_fn)
|
| 1448 |
+
def wrapper(*args):
|
| 1449 |
+
return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args)))
|
| 1450 |
+
|
| 1451 |
+
return wrapper
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_utils.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
from typing import Dict, Iterable, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._inductor.codecache import CompiledFxGraph
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_mutating_use_stack_trace(placeholder_node: torch.fx.Node) -> Optional[str]:
|
| 9 |
+
# reinplaced uses might have a single, non-copy_ use
|
| 10 |
+
if len(placeholder_node.users) == 1:
|
| 11 |
+
return next(iter(placeholder_node.users)).meta.get("stack_trace", None)
|
| 12 |
+
|
| 13 |
+
for use in placeholder_node.users:
|
| 14 |
+
if use.target == torch.ops.aten.copy_.default:
|
| 15 |
+
if stack_trace := use.meta.get("stack_trace", None):
|
| 16 |
+
return stack_trace
|
| 17 |
+
|
| 18 |
+
return None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def format_default_skip_message(reason: str) -> str:
|
| 22 |
+
return f"skipping cudagraphs due to {reason}"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_mutation_stack_trace(
|
| 26 |
+
gm: torch.fx.GraphModule, mutation_indices: Iterable[int]
|
| 27 |
+
) -> str:
|
| 28 |
+
stack_trace: Optional[str] = ""
|
| 29 |
+
placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
|
| 30 |
+
|
| 31 |
+
for idx in mutation_indices:
|
| 32 |
+
placeholder = placeholders[idx]
|
| 33 |
+
if stack_trace := get_mutating_use_stack_trace(placeholder):
|
| 34 |
+
break
|
| 35 |
+
|
| 36 |
+
if stack_trace:
|
| 37 |
+
msg = f"skipping cudagraphs due to mutation on input. Found from : \n {stack_trace}"
|
| 38 |
+
return msg
|
| 39 |
+
|
| 40 |
+
return format_default_skip_message("mutated inputs")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def check_for_mutation(
|
| 44 |
+
gm: torch.fx.GraphModule, compiled_graph: CompiledFxGraph, num_fixed: int
|
| 45 |
+
) -> Optional[str]:
|
| 46 |
+
default_msg = format_default_skip_message("mutated inputs")
|
| 47 |
+
|
| 48 |
+
# doesnt work for non-trees because the warmup run would apply mutation twice
|
| 49 |
+
if torch._inductor.config.triton.cudagraph_trees:
|
| 50 |
+
# checking if mutation is only on parameters/static inputs
|
| 51 |
+
mutation_indices = [
|
| 52 |
+
idx for idx in compiled_graph.mutated_input_idxs if idx >= num_fixed
|
| 53 |
+
]
|
| 54 |
+
has_mutation = len(mutation_indices) != 0
|
| 55 |
+
if not has_mutation:
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
return get_mutation_stack_trace(gm, mutation_indices)
|
| 59 |
+
|
| 60 |
+
else:
|
| 61 |
+
has_mutation = len(compiled_graph.mutated_inputs) != 0
|
| 62 |
+
return None if not has_mutation else default_msg
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_use_stack_trace(node) -> Optional[str]:
|
| 66 |
+
for use in node.users:
|
| 67 |
+
if stack_trace := use.meta.get("stack_trace", None):
|
| 68 |
+
return stack_trace
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def check_multiple_devices_or_any_cpu_nodes(
|
| 73 |
+
device_node_mapping: Dict[torch.device, torch.fx.Node]
|
| 74 |
+
) -> Optional[str]:
|
| 75 |
+
if cpu_node := device_node_mapping.get(torch.device("cpu")):
|
| 76 |
+
if stack_trace := get_use_stack_trace(cpu_node):
|
| 77 |
+
return format_default_skip_message(
|
| 78 |
+
f"cpu device. Found from : \n {stack_trace}"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
return format_default_skip_message("cpu device")
|
| 82 |
+
|
| 83 |
+
if (
|
| 84 |
+
len(device_node_mapping) == 1
|
| 85 |
+
and next(iter(device_node_mapping.keys())).type == "cuda"
|
| 86 |
+
):
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
keys_repr = (repr(key) for key in device_node_mapping.keys())
|
| 90 |
+
return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def check_lowering_disable_cudagraph(
|
| 94 |
+
device_node_mapping: Dict[torch.device, torch.fx.Node]
|
| 95 |
+
):
|
| 96 |
+
return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@dataclasses.dataclass
|
| 100 |
+
class BoxedDeviceIndex:
|
| 101 |
+
value: Optional[int]
|
| 102 |
+
|
| 103 |
+
def set(self, device_idx: Optional[int]):
|
| 104 |
+
assert device_idx is None or isinstance(device_idx, int)
|
| 105 |
+
self.value = device_idx
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/dependencies.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import dataclasses
|
| 3 |
+
import itertools
|
| 4 |
+
import logging
|
| 5 |
+
import re
|
| 6 |
+
import typing
|
| 7 |
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
| 8 |
+
from unittest.mock import patch
|
| 9 |
+
|
| 10 |
+
import sympy
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
| 14 |
+
|
| 15 |
+
from .codegen.common import index_prevent_reordering
|
| 16 |
+
from .utils import (
|
| 17 |
+
get_dtype_size,
|
| 18 |
+
reduction_num_outputs,
|
| 19 |
+
sympy_index_symbol,
|
| 20 |
+
sympy_str,
|
| 21 |
+
sympy_subs,
|
| 22 |
+
VarRanges,
|
| 23 |
+
)
|
| 24 |
+
from .virtualized import OpsHandler, ReductionType, V
|
| 25 |
+
|
| 26 |
+
log = logging.getLogger(__name__)
|
| 27 |
+
is_indirect = re.compile(r"indirect|tmp").search
|
| 28 |
+
Dep = Union["MemoryDep", "StarDep", "WeakDep"]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MemoryDep(typing.NamedTuple):
|
| 32 |
+
name: str
|
| 33 |
+
index: sympy.Expr # type: ignore[assignment]
|
| 34 |
+
var_names: Tuple[sympy.Symbol, ...]
|
| 35 |
+
size: Tuple[sympy.Expr, ...]
|
| 36 |
+
|
| 37 |
+
def __repr__(self):
|
| 38 |
+
return f"MemoryDep({self.name!r}, {self.index}, {self.ranges})"
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]:
|
| 42 |
+
"""{c0: 128, c1: 512, ...}"""
|
| 43 |
+
return dict(zip(self.var_names, self.size))
|
| 44 |
+
|
| 45 |
+
def get_numel(self) -> sympy.Expr:
|
| 46 |
+
if self.is_indirect():
|
| 47 |
+
numel = V.graph.get_numel(self.name)
|
| 48 |
+
else:
|
| 49 |
+
vars = set(self.index.free_symbols)
|
| 50 |
+
numel = sympy.Integer(1)
|
| 51 |
+
for var, size in zip(self.var_names, self.size):
|
| 52 |
+
if var in vars:
|
| 53 |
+
numel = numel * size
|
| 54 |
+
return numel
|
| 55 |
+
|
| 56 |
+
def rename(self, renames: Dict[str, str]) -> "MemoryDep":
|
| 57 |
+
if self.name in renames:
|
| 58 |
+
return MemoryDep(
|
| 59 |
+
renames[self.name], self.index, var_names=self.var_names, size=self.size
|
| 60 |
+
)
|
| 61 |
+
return self
|
| 62 |
+
|
| 63 |
+
def numbytes_hint(self):
|
| 64 |
+
return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
|
| 65 |
+
V.graph.get_dtype(self.name)
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def has_unbacked_symbols(self):
|
| 69 |
+
return len(free_unbacked_symbols(self.get_numel())) > 0
|
| 70 |
+
|
| 71 |
+
def is_contiguous(self) -> bool:
|
| 72 |
+
return isinstance(self.index, sympy.Symbol) and self.index in self.var_names
|
| 73 |
+
|
| 74 |
+
def is_scalar(self) -> bool:
|
| 75 |
+
if isinstance(self.index, sympy.Symbol):
|
| 76 |
+
return self.index not in self.var_names and not self.is_indirect()
|
| 77 |
+
return isinstance(self.index, (int, sympy.Integer))
|
| 78 |
+
|
| 79 |
+
def is_indirect(self) -> bool:
|
| 80 |
+
return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class StarDep(typing.NamedTuple):
|
| 84 |
+
# depends on the entire buffer
|
| 85 |
+
name: str
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def index(self):
|
| 89 |
+
raise NotImplementedError("StarDep does not have an index")
|
| 90 |
+
|
| 91 |
+
def get_numel(self) -> sympy.Expr:
|
| 92 |
+
return V.graph.get_numel(self.name)
|
| 93 |
+
|
| 94 |
+
def rename(self, renames: Dict[str, str]) -> "StarDep":
|
| 95 |
+
if self.name in renames:
|
| 96 |
+
return StarDep(renames[self.name])
|
| 97 |
+
return self
|
| 98 |
+
|
| 99 |
+
def numbytes_hint(self):
|
| 100 |
+
return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
|
| 101 |
+
V.graph.get_dtype(self.name)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def has_unbacked_symbols(self):
|
| 105 |
+
return len(free_unbacked_symbols(self.get_numel())) > 0
|
| 106 |
+
|
| 107 |
+
def is_contiguous(self) -> bool:
|
| 108 |
+
return False
|
| 109 |
+
|
| 110 |
+
def is_scalar(self) -> bool:
|
| 111 |
+
return False
|
| 112 |
+
|
| 113 |
+
def is_indirect(self) -> bool:
|
| 114 |
+
return False
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Used for tracking mutation ordering
|
| 118 |
+
# if A reads a buffer and B mutates it
|
| 119 |
+
# B must be ordered after A
|
| 120 |
+
#
|
| 121 |
+
# It is weak because if it turns out A's read is never used, we can still
|
| 122 |
+
# eliminate it
|
| 123 |
+
class WeakDep(typing.NamedTuple):
|
| 124 |
+
name: str
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def index(self):
|
| 128 |
+
raise NotImplementedError("WeakDep does not have an index")
|
| 129 |
+
|
| 130 |
+
def get_numel(self) -> sympy.Expr:
|
| 131 |
+
return sympy.Integer(1)
|
| 132 |
+
|
| 133 |
+
def rename(self, renames: Dict[str, str]) -> "WeakDep":
|
| 134 |
+
if self.name in renames:
|
| 135 |
+
return WeakDep(renames[self.name])
|
| 136 |
+
return self
|
| 137 |
+
|
| 138 |
+
def numbytes_hint(self):
|
| 139 |
+
return 1 # Purely inserted for ordering, not an actual dep
|
| 140 |
+
|
| 141 |
+
def has_unbacked_symbols(self):
|
| 142 |
+
return False
|
| 143 |
+
|
| 144 |
+
def is_contiguous(self) -> bool:
|
| 145 |
+
return False
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class IndexExprDep(typing.NamedTuple):
|
| 149 |
+
index: sympy.Expr # type: ignore[assignment]
|
| 150 |
+
var_names: Tuple[sympy.Symbol, ...]
|
| 151 |
+
size: Tuple[sympy.Expr, ...]
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@dataclasses.dataclass
|
| 155 |
+
class ReadWrites:
|
| 156 |
+
reads: Set[Dep]
|
| 157 |
+
writes: Set[Dep]
|
| 158 |
+
index_exprs: Set[IndexExprDep]
|
| 159 |
+
range_vars: Optional[List[sympy.Expr]] = None
|
| 160 |
+
var_ranges: Optional[VarRanges] = None
|
| 161 |
+
op_counts: typing.Counter[str] = dataclasses.field(
|
| 162 |
+
default_factory=collections.Counter
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites":
|
| 166 |
+
return ReadWrites(
|
| 167 |
+
{dep.rename(renames) for dep in self.reads},
|
| 168 |
+
{dep.rename(renames) for dep in self.writes},
|
| 169 |
+
self.index_exprs,
|
| 170 |
+
self.range_vars,
|
| 171 |
+
self.var_ranges,
|
| 172 |
+
op_counts=self.op_counts,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def with_read(self, dep: Dep) -> "ReadWrites":
|
| 176 |
+
assert isinstance(dep, (WeakDep, StarDep))
|
| 177 |
+
return ReadWrites(
|
| 178 |
+
set.union(self.reads, {dep}),
|
| 179 |
+
self.writes,
|
| 180 |
+
self.index_exprs,
|
| 181 |
+
self.range_vars,
|
| 182 |
+
self.var_ranges,
|
| 183 |
+
op_counts=self.op_counts,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def merge(self, other: "ReadWrites"):
|
| 187 |
+
reads = set.union(self.reads, other.reads)
|
| 188 |
+
writes = set.union(self.writes, other.writes)
|
| 189 |
+
index_exprs = set.union(self.index_exprs, other.index_exprs)
|
| 190 |
+
op_counts = collections.Counter(self.op_counts)
|
| 191 |
+
op_counts.update(other.op_counts)
|
| 192 |
+
return ReadWrites(reads - writes, writes, index_exprs, op_counts=op_counts)
|
| 193 |
+
|
| 194 |
+
@staticmethod
|
| 195 |
+
def merge_list(read_writes: List["ReadWrites"]):
|
| 196 |
+
all_writes = set.union(*[rw.writes for rw in read_writes])
|
| 197 |
+
all_reads = set.union(*[rw.reads for rw in read_writes]) - all_writes
|
| 198 |
+
all_index_exprs = set.union(*[rw.index_exprs for rw in read_writes])
|
| 199 |
+
|
| 200 |
+
op_counts: typing.Counter[Any] = collections.Counter()
|
| 201 |
+
for rw in read_writes:
|
| 202 |
+
op_counts.update(rw.op_counts)
|
| 203 |
+
|
| 204 |
+
return ReadWrites(all_reads, all_writes, all_index_exprs, op_counts=op_counts)
|
| 205 |
+
|
| 206 |
+
def remove_reads(self, rem_reads):
|
| 207 |
+
return ReadWrites(
|
| 208 |
+
self.reads - rem_reads,
|
| 209 |
+
self.writes,
|
| 210 |
+
self.index_exprs,
|
| 211 |
+
self.range_vars,
|
| 212 |
+
self.var_ranges,
|
| 213 |
+
op_counts=self.op_counts,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def reads_and_writes(self):
|
| 217 |
+
return itertools.chain(self.reads, self.writes)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
|
| 221 |
+
def __init__(self, var_ranges: VarRanges, normalize: bool):
|
| 222 |
+
super().__init__()
|
| 223 |
+
self._reads: Set[Dep] = set()
|
| 224 |
+
self._writes: Set[MemoryDep] = set()
|
| 225 |
+
self._index_exprs: Set[IndexExprDep] = set()
|
| 226 |
+
self._var_ranges: VarRanges = var_ranges
|
| 227 |
+
self._normalize: bool = normalize
|
| 228 |
+
|
| 229 |
+
def canonicalize(
|
| 230 |
+
self, index: sympy.Expr
|
| 231 |
+
) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]:
|
| 232 |
+
if not self._normalize:
|
| 233 |
+
sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()]
|
| 234 |
+
var_names = tuple(
|
| 235 |
+
k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1
|
| 236 |
+
)
|
| 237 |
+
sizes = tuple(v for v in sizes if v != 1)
|
| 238 |
+
return index, var_names, sizes # type: ignore[return-value]
|
| 239 |
+
|
| 240 |
+
# Try to further simplify the indexes even if simplify_loops didn't
|
| 241 |
+
# convert it to the simplest form because of the interference from
|
| 242 |
+
# different indexing formulas.
|
| 243 |
+
free_symbols = index.free_symbols
|
| 244 |
+
var_ranges = {
|
| 245 |
+
k: V.graph.sizevars.simplify(v)
|
| 246 |
+
for k, v in self._var_ranges.items()
|
| 247 |
+
# TODO(jansel): explore this further normalization
|
| 248 |
+
# if k in free_symbols
|
| 249 |
+
}
|
| 250 |
+
index_vars = [*var_ranges.keys()]
|
| 251 |
+
sizes = tuple(var_ranges.values())
|
| 252 |
+
new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
|
| 253 |
+
index_vars,
|
| 254 |
+
sizes,
|
| 255 |
+
index_prevent_reordering([index], index_vars, sizes),
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# assign new variables each dimension to deal with numbering mismatches
|
| 259 |
+
# d0, d1, d2 could become d0, d2 -- which won't match d0, d1
|
| 260 |
+
new_vars, add_var = var_builder(canonicalization_prefix())
|
| 261 |
+
replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
|
| 262 |
+
index = sympy_subs(sympy.expand(index), replacement)
|
| 263 |
+
|
| 264 |
+
new_vars = [*new_vars.keys()]
|
| 265 |
+
new_sizes = [*new_sizes]
|
| 266 |
+
free_symbols = index.free_symbols
|
| 267 |
+
while new_vars and new_vars[-1] not in free_symbols:
|
| 268 |
+
# Reduction has last (reduced) dim in its sizes, but
|
| 269 |
+
# downstream users won't. Normalize this away.
|
| 270 |
+
new_vars.pop()
|
| 271 |
+
new_sizes.pop()
|
| 272 |
+
return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type]
|
| 273 |
+
|
| 274 |
+
def load(self, name: str, index: sympy.Expr) -> str:
|
| 275 |
+
self._reads.add(MemoryDep(name, *self.canonicalize(index)))
|
| 276 |
+
return f"load({name}, {sympy_str(index)})"
|
| 277 |
+
|
| 278 |
+
def load_seed(self, name: str, index: int):
|
| 279 |
+
assert isinstance(index, int)
|
| 280 |
+
return self.load(name, sympy.Integer(index))
|
| 281 |
+
|
| 282 |
+
def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str:
|
| 283 |
+
self._writes.add(MemoryDep(name, *self.canonicalize(index)))
|
| 284 |
+
return f"store({name}, {sympy_str(index)}, {value}, {mode})"
|
| 285 |
+
|
| 286 |
+
def store_reduction(self, name: str, index, value) -> str:
|
| 287 |
+
return self.store(name, index, f"store_reduction({value})")
|
| 288 |
+
|
| 289 |
+
def index_expr(self, index: sympy.Expr, dtype) -> str:
|
| 290 |
+
self._index_exprs.add(IndexExprDep(*self.canonicalize(index)))
|
| 291 |
+
return f"index_expr({sympy_str(index)}, {dtype})"
|
| 292 |
+
|
| 293 |
+
def bucketize(
|
| 294 |
+
self,
|
| 295 |
+
values,
|
| 296 |
+
offsets_name: str,
|
| 297 |
+
offsets_size: sympy.Expr,
|
| 298 |
+
indexing_dtype: torch.dtype,
|
| 299 |
+
right: bool,
|
| 300 |
+
):
|
| 301 |
+
self._reads.add(StarDep(offsets_name))
|
| 302 |
+
return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})"
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class _OpCounter:
|
| 306 |
+
"""Shim to count how many times each op is used"""
|
| 307 |
+
|
| 308 |
+
def __init__(self, inner):
|
| 309 |
+
super().__init__()
|
| 310 |
+
self.parent_handler = inner
|
| 311 |
+
self._op_counts: typing.Counter[Any] = collections.Counter()
|
| 312 |
+
|
| 313 |
+
def __getattr__(self, name):
|
| 314 |
+
self._op_counts[name] += 1
|
| 315 |
+
return getattr(self.parent_handler, name)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined]
|
| 319 |
+
def __init__(self, var_ranges: VarRanges, normalize: bool):
|
| 320 |
+
parent_handler = _RecordLoadStoreInner(
|
| 321 |
+
var_ranges=var_ranges, normalize=normalize
|
| 322 |
+
)
|
| 323 |
+
parent_handler = _OpCounter(parent_handler)
|
| 324 |
+
super().__init__(parent_handler=parent_handler)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]:
|
| 328 |
+
cnt = itertools.count()
|
| 329 |
+
var_ranges: VarRanges = dict()
|
| 330 |
+
|
| 331 |
+
def add_var(length: sympy.Expr) -> sympy.Symbol:
|
| 332 |
+
v = sympy_index_symbol(f"{prefix}{next(cnt)}")
|
| 333 |
+
var_ranges[v] = length
|
| 334 |
+
return v
|
| 335 |
+
|
| 336 |
+
return var_ranges, add_var
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str):
|
| 340 |
+
var_ranges, add_var = var_builder(prefix)
|
| 341 |
+
args: List[List[sympy.Symbol]] = []
|
| 342 |
+
for size in argsizes:
|
| 343 |
+
args.append(list(map(add_var, size)))
|
| 344 |
+
return args, var_ranges
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"):
|
| 348 |
+
from .ir import SqueezeView
|
| 349 |
+
|
| 350 |
+
var_ranges, add_var = var_builder(prefix)
|
| 351 |
+
args: List[List[sympy.Expr]] = []
|
| 352 |
+
new_sizes: List[List[sympy.Expr]] = []
|
| 353 |
+
for size in argsizes:
|
| 354 |
+
new_size, reindex = SqueezeView.squeezer(size)
|
| 355 |
+
new_sizes.append(new_size)
|
| 356 |
+
args.append(reindex(list(map(add_var, new_size))))
|
| 357 |
+
return args, var_ranges
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def extract_read_writes(
|
| 361 |
+
fn: Callable[..., Any],
|
| 362 |
+
*argsizes: Tuple[sympy.Expr, ...],
|
| 363 |
+
normalize: bool = False,
|
| 364 |
+
prefix: str = "d",
|
| 365 |
+
):
|
| 366 |
+
args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
|
| 367 |
+
rw = RecordLoadStore(var_ranges, normalize=normalize)
|
| 368 |
+
with V.set_ops_handler(rw):
|
| 369 |
+
fn(*args)
|
| 370 |
+
|
| 371 |
+
if normalize:
|
| 372 |
+
range_vars = [] # Number of vars could differ due to normalization
|
| 373 |
+
else:
|
| 374 |
+
range_vars = list(itertools.chain.from_iterable(args))
|
| 375 |
+
|
| 376 |
+
inner = rw.parent_handler.parent_handler
|
| 377 |
+
return ReadWrites(
|
| 378 |
+
set(inner._reads),
|
| 379 |
+
set(inner._writes),
|
| 380 |
+
inner._index_exprs,
|
| 381 |
+
range_vars,
|
| 382 |
+
var_ranges,
|
| 383 |
+
rw.parent_handler._op_counts,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def extract_input_node_reduction_ranges(
|
| 388 |
+
input_node: "torch._inductor.ir.TensorBox",
|
| 389 |
+
) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]:
|
| 390 |
+
"""
|
| 391 |
+
Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
|
| 392 |
+
It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
|
| 393 |
+
In this case, reduction_sizes of the Reduction nodes need to be the same.
|
| 394 |
+
Otherwise returns (None, None).
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
from .ir import ComputedBuffer, Loops
|
| 398 |
+
|
| 399 |
+
if isinstance(input_node.data, ComputedBuffer):
|
| 400 |
+
# Input node has already been realized. Return its size and reduction_size.
|
| 401 |
+
size = input_node.get_size()
|
| 402 |
+
reduction_size = input_node.get_reduction_size()
|
| 403 |
+
if len(reduction_size) > 0:
|
| 404 |
+
return (size, reduction_size)
|
| 405 |
+
else:
|
| 406 |
+
return (None, None)
|
| 407 |
+
|
| 408 |
+
if not isinstance(input_node.data.data, Loops): # type: ignore[attr-defined]
|
| 409 |
+
# Other IRNodes do not have reduction_ranges.
|
| 410 |
+
return (None, None)
|
| 411 |
+
|
| 412 |
+
# There is one issue: what if there are views / permutations between the input node and its dependent realized nodes?
|
| 413 |
+
# The current method still uses reduction ranges from the dependent realized node, which is not ideal.
|
| 414 |
+
# Is there a way to check whether there are permutations inbetween?
|
| 415 |
+
reads = input_node.get_reads()
|
| 416 |
+
reduction_size = None
|
| 417 |
+
size = None
|
| 418 |
+
while reduction_size is None and len(reads) > 0:
|
| 419 |
+
seen = set()
|
| 420 |
+
new_reads = []
|
| 421 |
+
for read in reads:
|
| 422 |
+
if not isinstance(read, MemoryDep):
|
| 423 |
+
continue
|
| 424 |
+
if read.name in seen:
|
| 425 |
+
continue
|
| 426 |
+
seen.add(read.name)
|
| 427 |
+
buffer = V.graph.get_buffer(read.name)
|
| 428 |
+
if buffer is None:
|
| 429 |
+
continue
|
| 430 |
+
if (
|
| 431 |
+
isinstance(buffer, ComputedBuffer)
|
| 432 |
+
and len(buffer.get_reduction_size()) > 0
|
| 433 |
+
):
|
| 434 |
+
if reduction_size is None:
|
| 435 |
+
reduction_size = buffer.get_reduction_size()
|
| 436 |
+
size = buffer.get_size()
|
| 437 |
+
elif (
|
| 438 |
+
reduction_size != buffer.get_reduction_size()
|
| 439 |
+
or size != buffer.get_size()
|
| 440 |
+
):
|
| 441 |
+
return (None, None)
|
| 442 |
+
else:
|
| 443 |
+
new_reads.extend(buffer.get_reads())
|
| 444 |
+
if reads == new_reads:
|
| 445 |
+
return (size, reduction_size)
|
| 446 |
+
else:
|
| 447 |
+
reads = new_reads
|
| 448 |
+
return (size, reduction_size)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def canonicalization_prefix():
|
| 452 |
+
return "c"
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
# ops handler which computes all the free unbacked symbols for an IR
|
| 456 |
+
class FreeUnbackedSymbolsOpsHandler:
|
| 457 |
+
symbols: Set[sympy.Symbol]
|
| 458 |
+
|
| 459 |
+
def __init__(self):
|
| 460 |
+
self.symbols = set()
|
| 461 |
+
|
| 462 |
+
def __getattr__(self, name: str) -> Callable[..., Any]:
|
| 463 |
+
def inner(*args, **kwargs):
|
| 464 |
+
for a in itertools.chain(args, kwargs.values()):
|
| 465 |
+
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
|
| 466 |
+
self.symbols |= free_unbacked_symbols(a)
|
| 467 |
+
|
| 468 |
+
return inner
|
| 469 |
+
|
| 470 |
+
def indirect_indexing(self, index_var, size, check=True) -> sympy.Symbol:
|
| 471 |
+
assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean))
|
| 472 |
+
self.symbols |= free_unbacked_symbols(size)
|
| 473 |
+
return sympy_index_symbol(f"({str(index_var)})")
|
| 474 |
+
|
| 475 |
+
def frexp(self, x):
|
| 476 |
+
return (None,) * 2
|
| 477 |
+
|
| 478 |
+
def reduction(
|
| 479 |
+
self,
|
| 480 |
+
dtype: torch.dtype,
|
| 481 |
+
src_dtype: torch.dtype,
|
| 482 |
+
reduction_type: ReductionType,
|
| 483 |
+
value: Union[None, Tuple[None, ...]],
|
| 484 |
+
) -> Union[None, Tuple[None, ...]]:
|
| 485 |
+
num_values = reduction_num_outputs(reduction_type)
|
| 486 |
+
return (None,) * num_values if num_values > 1 else None
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def _typecheck_FreeUnbackedSymbolsOpsHandler(
|
| 490 |
+
h: FreeUnbackedSymbolsOpsHandler,
|
| 491 |
+
) -> OpsHandler[None]:
|
| 492 |
+
return h
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def extract_free_unbacked_symbols(fn: Callable[..., Any], index, rindex=None):
|
| 496 |
+
from .ir import FlexibleLayout
|
| 497 |
+
|
| 498 |
+
args = [index, rindex] if rindex is not None else [index]
|
| 499 |
+
handler = FreeUnbackedSymbolsOpsHandler()
|
| 500 |
+
# NB: I cargo culted the allow_indexing patch here, I don't understand why
|
| 501 |
+
# people do this all over
|
| 502 |
+
with V.set_ops_handler(handler), patch.object(
|
| 503 |
+
FlexibleLayout, "allow_indexing", True
|
| 504 |
+
):
|
| 505 |
+
fn(*args)
|
| 506 |
+
return handler.symbols
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/exc.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import tempfile
|
| 5 |
+
import textwrap
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
|
| 8 |
+
if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1":
|
| 9 |
+
|
| 10 |
+
@lru_cache(None)
|
| 11 |
+
def _record_missing_op(target):
|
| 12 |
+
with open(f"{tempfile.gettempdir()}/missing_ops.txt", "a") as fd:
|
| 13 |
+
fd.write(str(target) + "\n")
|
| 14 |
+
|
| 15 |
+
else:
|
| 16 |
+
|
| 17 |
+
def _record_missing_op(target): # type: ignore[misc]
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class OperatorIssue(RuntimeError):
|
| 22 |
+
@staticmethod
|
| 23 |
+
def operator_str(target, args, kwargs):
|
| 24 |
+
lines = [f"target: {target}"] + [
|
| 25 |
+
f"args[{i}]: {arg}" for i, arg in enumerate(args)
|
| 26 |
+
]
|
| 27 |
+
if kwargs:
|
| 28 |
+
lines.append(f"kwargs: {kwargs}")
|
| 29 |
+
return textwrap.indent("\n".join(lines), " ")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class MissingOperatorWithoutDecomp(OperatorIssue):
|
| 33 |
+
def __init__(self, target, args, kwargs):
|
| 34 |
+
_record_missing_op(target)
|
| 35 |
+
super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class MissingOperatorWithDecomp(OperatorIssue):
|
| 39 |
+
def __init__(self, target, args, kwargs):
|
| 40 |
+
_record_missing_op(target)
|
| 41 |
+
super().__init__(
|
| 42 |
+
f"missing decomposition\n{self.operator_str(target, args, kwargs)}"
|
| 43 |
+
+ textwrap.dedent(
|
| 44 |
+
f"""
|
| 45 |
+
|
| 46 |
+
There is a decomposition available for {target} in
|
| 47 |
+
torch._decomp.get_decompositions(). Please add this operator to the
|
| 48 |
+
`decompositions` list in torch._inductor.decompositions
|
| 49 |
+
"""
|
| 50 |
+
)
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class LoweringException(OperatorIssue):
|
| 55 |
+
def __init__(self, exc: Exception, target, args, kwargs):
|
| 56 |
+
super().__init__(
|
| 57 |
+
f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class InvalidCxxCompiler(RuntimeError):
|
| 62 |
+
def __init__(self):
|
| 63 |
+
from . import config
|
| 64 |
+
|
| 65 |
+
super().__init__(
|
| 66 |
+
f"No working C++ compiler found in {config.__name__}.cpp.cxx: {config.cpp.cxx}"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class CppWrapperCodeGenError(RuntimeError):
|
| 71 |
+
def __init__(self, msg: str):
|
| 72 |
+
super().__init__(f"C++ wrapper codegen error: {msg}")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class CppCompileError(RuntimeError):
|
| 76 |
+
def __init__(self, cmd: list[str], output: str):
|
| 77 |
+
if isinstance(output, bytes):
|
| 78 |
+
output = output.decode("utf-8")
|
| 79 |
+
|
| 80 |
+
super().__init__(
|
| 81 |
+
textwrap.dedent(
|
| 82 |
+
"""
|
| 83 |
+
C++ compile error
|
| 84 |
+
|
| 85 |
+
Command:
|
| 86 |
+
{cmd}
|
| 87 |
+
|
| 88 |
+
Output:
|
| 89 |
+
{output}
|
| 90 |
+
"""
|
| 91 |
+
)
|
| 92 |
+
.strip()
|
| 93 |
+
.format(cmd=" ".join(cmd), output=output)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class CUDACompileError(CppCompileError):
|
| 98 |
+
pass
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_utils.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import operator
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.fx
|
| 7 |
+
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
|
| 8 |
+
from torch.utils import _pytree as pytree
|
| 9 |
+
from torch.utils._pytree import tree_map
|
| 10 |
+
from .virtualized import V
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
|
| 14 |
+
# Works for length 2 patterns with 1 module and 1 function/method.
|
| 15 |
+
def matches_module_function_pattern(
|
| 16 |
+
pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]],
|
| 17 |
+
node: torch.fx.node.Node,
|
| 18 |
+
modules: Dict[str, torch.nn.modules.Module],
|
| 19 |
+
) -> bool:
|
| 20 |
+
if len(node.args) == 0:
|
| 21 |
+
return False
|
| 22 |
+
if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
|
| 23 |
+
node, torch.fx.Node
|
| 24 |
+
):
|
| 25 |
+
return False
|
| 26 |
+
# the first node is call_module
|
| 27 |
+
if node.args[0].op != "call_module":
|
| 28 |
+
return False
|
| 29 |
+
if not isinstance(node.args[0].target, str):
|
| 30 |
+
return False
|
| 31 |
+
if node.args[0].target not in modules:
|
| 32 |
+
return False
|
| 33 |
+
if type(modules[node.args[0].target]) is not pattern[0]:
|
| 34 |
+
return False
|
| 35 |
+
# the second node is call_function or call_method
|
| 36 |
+
if node.op != "call_function" and node.op != "call_method":
|
| 37 |
+
return False
|
| 38 |
+
if node.target != pattern[1]:
|
| 39 |
+
return False
|
| 40 |
+
# make sure node.args[0] output is only used by current node.
|
| 41 |
+
if len(node.args[0].users) > 1:
|
| 42 |
+
return False
|
| 43 |
+
return True
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class FakeTensorUpdater:
|
| 47 |
+
"""
|
| 48 |
+
The main idea here is that it's difficult to maintain accurate fake
|
| 49 |
+
tensors (our primary form of metadata) for each node in our graph as we
|
| 50 |
+
transform it.
|
| 51 |
+
|
| 52 |
+
The most reliable way to obtain this information is by rerunning
|
| 53 |
+
faketensor propagation. However, in general, faketensor propagation is
|
| 54 |
+
fairly expensive. So, instead we'd like to only rerun faketensor
|
| 55 |
+
propagation on nodes that have changed.
|
| 56 |
+
|
| 57 |
+
In order to detect which nodes have changed, we first hash its node,
|
| 58 |
+
target, and argument lists (which are immutable in FX).
|
| 59 |
+
|
| 60 |
+
Then, whenever we call incremental_update, we check which FX nodes have a
|
| 61 |
+
new hash, and recompute the faketensor metadata for that node. Then, we
|
| 62 |
+
continue to recursively compute the faketensors for all users until the
|
| 63 |
+
fake tensors stop changing.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, graph: torch.fx.Graph):
|
| 67 |
+
self.processed_hashes = set()
|
| 68 |
+
self.graph = graph
|
| 69 |
+
|
| 70 |
+
for node in self.graph.nodes:
|
| 71 |
+
self.processed_hashes.add(self.hash_node(node))
|
| 72 |
+
|
| 73 |
+
def hash_node(self, node: torch.fx.Node):
|
| 74 |
+
# todo(chilli): Not a great hash function
|
| 75 |
+
return (node, node.target, id(node.args), id(node.kwargs))
|
| 76 |
+
|
| 77 |
+
def incremental_update(self):
|
| 78 |
+
processed = set()
|
| 79 |
+
existing_storages: DefaultDict[Optional[int], int] = defaultdict(int)
|
| 80 |
+
for node in self.graph.nodes:
|
| 81 |
+
existing_storages[get_node_storage(node)] += 1
|
| 82 |
+
|
| 83 |
+
def is_intlist_same(new, old):
|
| 84 |
+
return statically_known_true(sym_eq(new, old))
|
| 85 |
+
|
| 86 |
+
def is_fake_tensor_same(new, old):
|
| 87 |
+
if type(new) != type(old):
|
| 88 |
+
return False
|
| 89 |
+
if isinstance(new, (list, tuple)):
|
| 90 |
+
if len(new) != len(old):
|
| 91 |
+
return False
|
| 92 |
+
return all(
|
| 93 |
+
is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old)
|
| 94 |
+
)
|
| 95 |
+
assert isinstance(new, torch.Tensor)
|
| 96 |
+
if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
|
| 97 |
+
return False
|
| 98 |
+
if new.layout == torch.strided and (
|
| 99 |
+
not is_intlist_same(new.stride(), old.stride())
|
| 100 |
+
or not statically_known_true(
|
| 101 |
+
new.storage_offset() == old.storage_offset()
|
| 102 |
+
)
|
| 103 |
+
):
|
| 104 |
+
return False
|
| 105 |
+
|
| 106 |
+
if get_storage(new) == get_storage(old):
|
| 107 |
+
return True
|
| 108 |
+
|
| 109 |
+
# This is the case where it returns a completely fresh storage that's used nowhere else.
|
| 110 |
+
if (
|
| 111 |
+
existing_storages[get_storage(old)] == 1
|
| 112 |
+
and get_storage(new) not in existing_storages
|
| 113 |
+
):
|
| 114 |
+
return True
|
| 115 |
+
return False
|
| 116 |
+
|
| 117 |
+
for node in self.graph.nodes:
|
| 118 |
+
if self.hash_node(node) in self.processed_hashes:
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
def is_aten_node(node):
|
| 122 |
+
return node.op == "call_function" and isinstance(
|
| 123 |
+
node.target, torch._ops.OpOverload
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
if not is_aten_node(node):
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
processing = [node]
|
| 130 |
+
while len(processing) > 0:
|
| 131 |
+
updating_node = processing.pop()
|
| 132 |
+
if updating_node in processed:
|
| 133 |
+
continue
|
| 134 |
+
if is_aten_node(updating_node):
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
is_valid, args, kwargs = get_fake_args_kwargs(updating_node)
|
| 138 |
+
if not is_valid:
|
| 139 |
+
continue
|
| 140 |
+
with V.fake_mode:
|
| 141 |
+
new_fake_tensor = updating_node.target(*args, **kwargs)
|
| 142 |
+
if "val" in updating_node.meta and is_fake_tensor_same(
|
| 143 |
+
new_fake_tensor, updating_node.meta["val"]
|
| 144 |
+
):
|
| 145 |
+
continue
|
| 146 |
+
updating_node.meta["val"] = new_fake_tensor
|
| 147 |
+
|
| 148 |
+
# todo(chilli): This code path is not exercised by our existing
|
| 149 |
+
# tests - add a test
|
| 150 |
+
existing_storages[get_node_storage(new_fake_tensor)] += 1
|
| 151 |
+
processed.add(updating_node)
|
| 152 |
+
processing.extend(updating_node.users)
|
| 153 |
+
|
| 154 |
+
self.processed_hashes.add(self.hash_node(updating_node))
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_storage(t: torch.Tensor) -> int:
|
| 158 |
+
return t.untyped_storage()._cdata
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_node_storage(node: torch.fx.Node) -> Optional[int]:
|
| 162 |
+
if "val" not in node.meta:
|
| 163 |
+
return None
|
| 164 |
+
if not isinstance(node.meta["val"], torch.Tensor):
|
| 165 |
+
return None
|
| 166 |
+
if not torch._C._has_storage(node.meta["val"]):
|
| 167 |
+
return None
|
| 168 |
+
return get_storage(node.meta["val"])
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def get_fake(x):
|
| 172 |
+
if isinstance(x, torch.fx.Node):
|
| 173 |
+
if "val" not in x.meta:
|
| 174 |
+
return x
|
| 175 |
+
return x.meta["val"]
|
| 176 |
+
return x
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]:
|
| 180 |
+
"""
|
| 181 |
+
First value returns a boolean if any of the input nodes don't have a faketensor.
|
| 182 |
+
"""
|
| 183 |
+
args, kwargs = tree_map(get_fake, (x.args, x.kwargs))
|
| 184 |
+
if any(
|
| 185 |
+
isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs)
|
| 186 |
+
):
|
| 187 |
+
return False, args, kwargs
|
| 188 |
+
return True, args, kwargs
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def is_node_realized(node: torch.fx.Node) -> bool:
|
| 192 |
+
"""Returns true if a node is always realized when lowered to inductor IR.
|
| 193 |
+
|
| 194 |
+
NOTE: This may return some false negatives. e.g. it doesn't
|
| 195 |
+
handle buffers realized heuristically during lowering, or
|
| 196 |
+
buffers realized indirectly through view ops.
|
| 197 |
+
"""
|
| 198 |
+
from torch._inductor.lowering import fallbacks, needs_realized_inputs
|
| 199 |
+
|
| 200 |
+
def is_buffer(node: torch.fx.Node) -> bool:
|
| 201 |
+
if node.op == "call_function" and node.target is operator.getitem:
|
| 202 |
+
# For nodes with multiple outputs, we get the fx graph:
|
| 203 |
+
# foo = torch.ops.aten.foo(...)
|
| 204 |
+
# getitem = foo[0]
|
| 205 |
+
# getitem_1 = foo[1]
|
| 206 |
+
# where we need to check if foo is a fallback kernel
|
| 207 |
+
return is_buffer(node.args[0]) # type: ignore[arg-type]
|
| 208 |
+
return node.op in ("placeholder", "output") or node.target in fallbacks
|
| 209 |
+
|
| 210 |
+
if is_buffer(node):
|
| 211 |
+
return True
|
| 212 |
+
|
| 213 |
+
def realizes_inputs(node: torch.fx.Node) -> bool:
|
| 214 |
+
return node.op == "output" or node.target in needs_realized_inputs
|
| 215 |
+
|
| 216 |
+
if any(realizes_inputs(user) for user in node.users):
|
| 217 |
+
return True
|
| 218 |
+
|
| 219 |
+
# Otherwise, assume node isn't realized
|
| 220 |
+
return False
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/index_propagation.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This file implements the IndexPropagation ops handler, which wraps an
|
| 2 |
+
underlying handler to add a limited form of constant propagation, as well as
|
| 3 |
+
propagation of sympy expressions downstream of ops.index_expr calls.
|
| 4 |
+
|
| 5 |
+
For example, say we have the IR:
|
| 6 |
+
|
| 7 |
+
tmp0 = ops.index_expr(x, torch.int32)
|
| 8 |
+
tmp1 = ops.constant(2, torch.int32)
|
| 9 |
+
tmp2 = ops.mul(tmp0, tmp1)
|
| 10 |
+
tmp3 = ops.indirect_indexing(tmp2, x_size)
|
| 11 |
+
tmp4 = ops.load("buf0", tmp3)
|
| 12 |
+
|
| 13 |
+
The underlying handler would just see:
|
| 14 |
+
|
| 15 |
+
ops.load("buf0", x * 2)
|
| 16 |
+
|
| 17 |
+
This is limited by the set of operators handled in the sympy expression
|
| 18 |
+
printers. So simple operations like minimum and maximum cannot be translated to
|
| 19 |
+
SymPy expressions yet, despite sympy.Min and sympy.Max existing.
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
import itertools
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from typing import Any, Callable, Dict, Literal, Optional, overload, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import sympy
|
| 27 |
+
|
| 28 |
+
from typing_extensions import TypeAlias
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
from torch._prims_common import is_boolean_dtype, is_integer_dtype
|
| 32 |
+
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class TypedExpr:
|
| 37 |
+
"""A SymPy expression with associated type"""
|
| 38 |
+
|
| 39 |
+
expr: sympy.Expr
|
| 40 |
+
dtype: torch.dtype
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class SymPyOps:
|
| 44 |
+
"""An ops handler where all IR values are SymPy expressions
|
| 45 |
+
|
| 46 |
+
When a value cannot be represented as a SymPy expression, the method is
|
| 47 |
+
either not defined, or returns NotImplemented
|
| 48 |
+
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def identity(value: Any) -> Any:
|
| 53 |
+
return value
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr:
|
| 57 |
+
if is_boolean_dtype(dtype):
|
| 58 |
+
expr = sympy.Integer(bool(value))
|
| 59 |
+
elif is_integer_dtype(dtype):
|
| 60 |
+
expr = sympy.Integer(int(value))
|
| 61 |
+
else:
|
| 62 |
+
expr = sympy.Float(float(value))
|
| 63 |
+
return TypedExpr(expr, dtype)
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
def index_expr(value: sympy.Expr, dtype: torch.dtype) -> Union[int, TypedExpr]:
|
| 67 |
+
if isinstance(value, int):
|
| 68 |
+
value = sympy.Integer(value)
|
| 69 |
+
return TypedExpr(value, dtype)
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def to_dtype(
|
| 73 |
+
value: Any, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None
|
| 74 |
+
) -> Union[int, TypedExpr]:
|
| 75 |
+
if isinstance(value.expr, (sympy.Integer, sympy.Float)):
|
| 76 |
+
return SymPyOps.constant(value.expr, dtype)
|
| 77 |
+
elif is_integer_dtype(dtype) and is_integer_dtype(value.dtype):
|
| 78 |
+
return SymPyOps.index_expr(value.expr, dtype)
|
| 79 |
+
else:
|
| 80 |
+
# TODO: Inductor doesn't handle floating point in sympy expressions well at the moment
|
| 81 |
+
return NotImplemented
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def square(x: TypedExpr) -> TypedExpr:
|
| 85 |
+
return TypedExpr(x.expr * x.expr, x.dtype)
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def add(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
| 89 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 90 |
+
return TypedExpr(x.expr + y.expr, result_type)
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
def sub(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
| 94 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 95 |
+
return TypedExpr(x.expr - y.expr, result_type)
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def mul(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
| 99 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 100 |
+
return TypedExpr(x.expr * y.expr, result_type)
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def neg(x: TypedExpr) -> TypedExpr:
|
| 104 |
+
return TypedExpr(-x.expr, x.dtype)
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def floordiv(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
| 108 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 109 |
+
if not is_integer_dtype(result_type):
|
| 110 |
+
return NotImplemented
|
| 111 |
+
|
| 112 |
+
return TypedExpr(FloorDiv(x.expr, y.expr), result_type)
|
| 113 |
+
|
| 114 |
+
@staticmethod
|
| 115 |
+
def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
|
| 116 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 117 |
+
if not is_integer_dtype(result_type):
|
| 118 |
+
return NotImplemented
|
| 119 |
+
|
| 120 |
+
result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
|
| 121 |
+
return TypedExpr(result_expr, result_type)
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
|
| 125 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 126 |
+
if not is_integer_dtype(result_type):
|
| 127 |
+
return NotImplemented
|
| 128 |
+
# In these cases, remainder in Python == remainder in C++, so this transformation
|
| 129 |
+
# is sound
|
| 130 |
+
if (
|
| 131 |
+
x.expr.is_nonnegative is not None
|
| 132 |
+
and x.expr.is_nonnegative == y.expr.is_positive
|
| 133 |
+
):
|
| 134 |
+
result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
|
| 135 |
+
return TypedExpr(result_expr, result_type)
|
| 136 |
+
return NotImplemented
|
| 137 |
+
|
| 138 |
+
@staticmethod
|
| 139 |
+
def minimum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
| 140 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 141 |
+
return TypedExpr(sympy.Min(x.expr, y.expr), result_type)
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
def maximum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
| 145 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 146 |
+
return TypedExpr(sympy.Max(x.expr, y.expr), result_type)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@dataclass
|
| 150 |
+
class IndexPropVar:
|
| 151 |
+
value: Any # Either an IR value, or TypedExpr if is_symbolic is true
|
| 152 |
+
is_symbolic: bool = False
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
def new_symbolic(expr: TypedExpr) -> "IndexPropVar":
|
| 156 |
+
return IndexPropVar(expr, is_symbolic=True)
|
| 157 |
+
|
| 158 |
+
def __post_init__(self):
|
| 159 |
+
assert not self.is_symbolic or isinstance(
|
| 160 |
+
self.value, TypedExpr
|
| 161 |
+
), "Symbolic IndexPropVar must contain a TypedExpr"
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
IndexPropResult: TypeAlias = Union[IndexPropVar, Tuple["IndexPropResult", ...]]
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class IndexPropagation:
|
| 168 |
+
"""Ops wrapper that tries to propagate constant and index_expr values through the computation.
|
| 169 |
+
|
| 170 |
+
This aims to maximize the compile time simplification possible, and convert
|
| 171 |
+
indirect indexing from arange into normal static indexing.
|
| 172 |
+
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(self, inner: Any):
|
| 176 |
+
self._inner = inner
|
| 177 |
+
|
| 178 |
+
def materialize_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> Any:
|
| 179 |
+
# Construct a new constant/index_expr from the SymPy expression
|
| 180 |
+
if isinstance(expr, sympy.Integer):
|
| 181 |
+
return self._inner.constant(int(expr), dtype)
|
| 182 |
+
elif expr.is_number:
|
| 183 |
+
return self._inner.constant(float(expr), dtype)
|
| 184 |
+
return self._inner.index_expr(expr, dtype)
|
| 185 |
+
|
| 186 |
+
def unwrap(self, a: Union[Any, IndexPropVar]) -> Any:
|
| 187 |
+
if isinstance(a, (list, tuple)):
|
| 188 |
+
return tuple(self.unwrap(v) for v in a)
|
| 189 |
+
|
| 190 |
+
if not isinstance(a, IndexPropVar):
|
| 191 |
+
return a
|
| 192 |
+
|
| 193 |
+
# Prefer the sympy representation if possible
|
| 194 |
+
if a.is_symbolic:
|
| 195 |
+
return self.materialize_expr(a.value.expr, a.value.dtype)
|
| 196 |
+
|
| 197 |
+
return a.value
|
| 198 |
+
|
| 199 |
+
def wrap(self, a) -> IndexPropResult:
|
| 200 |
+
if isinstance(a, (list, tuple)):
|
| 201 |
+
return tuple(self.wrap(v) for v in a)
|
| 202 |
+
return IndexPropVar(a)
|
| 203 |
+
|
| 204 |
+
@overload
|
| 205 |
+
def fallback(
|
| 206 |
+
self,
|
| 207 |
+
name: Literal["indirect_indexing"],
|
| 208 |
+
args: Tuple[Any, ...],
|
| 209 |
+
kwargs: Dict[str, Any],
|
| 210 |
+
) -> IndexPropVar:
|
| 211 |
+
...
|
| 212 |
+
|
| 213 |
+
@overload
|
| 214 |
+
def fallback(
|
| 215 |
+
self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
| 216 |
+
) -> IndexPropResult:
|
| 217 |
+
...
|
| 218 |
+
|
| 219 |
+
def fallback(
|
| 220 |
+
self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
| 221 |
+
) -> IndexPropResult:
|
| 222 |
+
# Fallback to the wrapped handler
|
| 223 |
+
new_args = [self.unwrap(a) for a in args]
|
| 224 |
+
new_kwargs = {k: self.unwrap(v) for k, v in kwargs.items()}
|
| 225 |
+
return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
|
| 226 |
+
|
| 227 |
+
def propagate_sympy(
|
| 228 |
+
self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
| 229 |
+
) -> IndexPropResult:
|
| 230 |
+
# Build a new SymPy expression from this ops call
|
| 231 |
+
def unwrap(a: Union[Any, IndexPropVar]) -> Any:
|
| 232 |
+
if not isinstance(a, IndexPropVar):
|
| 233 |
+
return a
|
| 234 |
+
return a.value
|
| 235 |
+
|
| 236 |
+
new_args = [unwrap(a) for a in args]
|
| 237 |
+
new_kwargs = {k: unwrap(v) for k, v in kwargs.items()}
|
| 238 |
+
new_expr = getattr(SymPyOps, name)(*new_args, **new_kwargs)
|
| 239 |
+
is_valid_expr = new_expr is not NotImplemented and (
|
| 240 |
+
# Inductor doesn't expect floating point in sympy expressions, but
|
| 241 |
+
# allow floating point constants to be propagated
|
| 242 |
+
isinstance(new_expr.expr, sympy.Number)
|
| 243 |
+
or new_expr.expr.is_integer
|
| 244 |
+
)
|
| 245 |
+
if not is_valid_expr:
|
| 246 |
+
return self.fallback(name, args, kwargs)
|
| 247 |
+
return IndexPropVar.new_symbolic(new_expr)
|
| 248 |
+
|
| 249 |
+
def __getattr__(self, name: str) -> Callable[..., IndexPropResult]:
|
| 250 |
+
def inner(*args: Any, **kwargs: Any) -> IndexPropResult:
|
| 251 |
+
if not hasattr(SymPyOps, name):
|
| 252 |
+
return self.fallback(name, args, kwargs)
|
| 253 |
+
|
| 254 |
+
var_arguments = [
|
| 255 |
+
a
|
| 256 |
+
for a in itertools.chain(args, kwargs.values())
|
| 257 |
+
if isinstance(a, IndexPropVar)
|
| 258 |
+
]
|
| 259 |
+
if not all(v.is_symbolic for v in var_arguments):
|
| 260 |
+
return self.fallback(name, args, kwargs)
|
| 261 |
+
|
| 262 |
+
return self.propagate_sympy(name, args, kwargs)
|
| 263 |
+
|
| 264 |
+
return inner
|
| 265 |
+
|
| 266 |
+
def indirect_indexing(
|
| 267 |
+
self, index: Union[Any, IndexPropVar], size: Any, check: bool = True
|
| 268 |
+
) -> Any:
|
| 269 |
+
# nb. We do index + Where(...) rather than Where(idx >= 0, idx, idx + sz) because we don't have CSE
|
| 270 |
+
# for SymPy expressions, so we don't want to repeat idx too much
|
| 271 |
+
|
| 272 |
+
# indirect_indexing returns a sympy value, so no need to wrap in IndexPropVar here
|
| 273 |
+
if isinstance(index, IndexPropVar) and index.is_symbolic:
|
| 274 |
+
# If we are turning a indirect indexing into direct, we need to wrap it.
|
| 275 |
+
index = index.value.expr
|
| 276 |
+
return index + Where(index >= 0, 0, size)
|
| 277 |
+
return self.fallback("indirect_indexing", (index, size, check), {}).value
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/metrics.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import csv
|
| 4 |
+
import inspect
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from functools import lru_cache
|
| 9 |
+
|
| 10 |
+
from typing import Dict, List, Set, Tuple, TYPE_CHECKING, Union
|
| 11 |
+
|
| 12 |
+
from torch._inductor import config
|
| 13 |
+
from torch._inductor.utils import get_benchmark_name
|
| 14 |
+
|
| 15 |
+
# Prevent circular import
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from torch._inductor.scheduler import (
|
| 18 |
+
BaseSchedulerNode,
|
| 19 |
+
ExternKernelSchedulerNode,
|
| 20 |
+
NopKernelSchedulerNode,
|
| 21 |
+
SchedulerNode,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# counter for tracking how many kernels have been generated
|
| 25 |
+
generated_kernel_count = 0
|
| 26 |
+
generated_cpp_vec_kernel_count = 0
|
| 27 |
+
num_bytes_accessed = 0
|
| 28 |
+
nodes_num_elem: List[
|
| 29 |
+
Tuple[
|
| 30 |
+
Union[NopKernelSchedulerNode, SchedulerNode, ExternKernelSchedulerNode],
|
| 31 |
+
int,
|
| 32 |
+
]
|
| 33 |
+
] = []
|
| 34 |
+
node_runtimes: List[Tuple[BaseSchedulerNode, float]] = []
|
| 35 |
+
|
| 36 |
+
# counters for tracking fusions
|
| 37 |
+
ir_nodes_pre_fusion = 0
|
| 38 |
+
|
| 39 |
+
# counters for tracking to_dtype inserted
|
| 40 |
+
cpp_to_dtype_count = 0
|
| 41 |
+
|
| 42 |
+
# counters for tracking cpp_wrapper disabled
|
| 43 |
+
disable_cpp_wrapper = 0
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# reset all counters
|
| 47 |
+
def reset():
|
| 48 |
+
global generated_kernel_count
|
| 49 |
+
global generated_cpp_vec_kernel_count
|
| 50 |
+
global num_bytes_accessed, nodes_num_elem
|
| 51 |
+
global ir_nodes_pre_fusion
|
| 52 |
+
global cpp_to_dtype_count
|
| 53 |
+
global disable_cpp_wrapper
|
| 54 |
+
|
| 55 |
+
generated_kernel_count = 0
|
| 56 |
+
generated_cpp_vec_kernel_count = 0
|
| 57 |
+
num_bytes_accessed = 0
|
| 58 |
+
nodes_num_elem.clear()
|
| 59 |
+
node_runtimes.clear()
|
| 60 |
+
ir_nodes_pre_fusion = 0
|
| 61 |
+
cpp_to_dtype_count = 0
|
| 62 |
+
disable_cpp_wrapper = 0
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclass
|
| 66 |
+
class CachedMetricsDeltas:
|
| 67 |
+
"""
|
| 68 |
+
The subset of metrics we want update across cache hits, e.g., the
|
| 69 |
+
FxGraphCache.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
generated_kernel_count: int
|
| 73 |
+
generated_cpp_vec_kernel_count: int
|
| 74 |
+
ir_nodes_pre_fusion: int
|
| 75 |
+
cpp_to_dtype_count: int
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class CachedMetricsHelper:
|
| 79 |
+
"""
|
| 80 |
+
A helper class to help calculate and apply counter deltas for those
|
| 81 |
+
metrics we want to save with cache entries (e.g., FxGraphCache) and
|
| 82 |
+
apply on a cache hit.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(self):
|
| 86 |
+
global generated_kernel_count
|
| 87 |
+
global generated_cpp_vec_kernel_count
|
| 88 |
+
global ir_nodes_pre_fusion
|
| 89 |
+
global cpp_to_dtype_count
|
| 90 |
+
|
| 91 |
+
self.generated_kernel_count = generated_kernel_count
|
| 92 |
+
self.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count
|
| 93 |
+
self.ir_nodes_pre_fusion = ir_nodes_pre_fusion
|
| 94 |
+
self.cpp_to_dtype_count = cpp_to_dtype_count
|
| 95 |
+
|
| 96 |
+
def get_deltas(self) -> CachedMetricsDeltas:
|
| 97 |
+
global generated_kernel_count
|
| 98 |
+
global generated_cpp_vec_kernel_count
|
| 99 |
+
global ir_nodes_pre_fusion
|
| 100 |
+
global cpp_to_dtype_count
|
| 101 |
+
|
| 102 |
+
return CachedMetricsDeltas(
|
| 103 |
+
generated_kernel_count - self.generated_kernel_count,
|
| 104 |
+
generated_cpp_vec_kernel_count - self.generated_cpp_vec_kernel_count,
|
| 105 |
+
ir_nodes_pre_fusion - self.ir_nodes_pre_fusion,
|
| 106 |
+
cpp_to_dtype_count - self.cpp_to_dtype_count,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def apply_deltas(delta: CachedMetricsDeltas):
|
| 111 |
+
global generated_kernel_count
|
| 112 |
+
global generated_cpp_vec_kernel_count
|
| 113 |
+
global ir_nodes_pre_fusion
|
| 114 |
+
global cpp_to_dtype_count
|
| 115 |
+
|
| 116 |
+
generated_kernel_count += delta.generated_kernel_count
|
| 117 |
+
generated_cpp_vec_kernel_count += delta.generated_cpp_vec_kernel_count
|
| 118 |
+
ir_nodes_pre_fusion += delta.ir_nodes_pre_fusion
|
| 119 |
+
cpp_to_dtype_count += delta.cpp_to_dtype_count
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
REGISTERED_METRIC_TABLES: Dict[str, MetricTable] = {}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@dataclass
|
| 126 |
+
class MetricTable:
|
| 127 |
+
table_name: str
|
| 128 |
+
column_names: List[str]
|
| 129 |
+
|
| 130 |
+
num_rows_added: int = 0
|
| 131 |
+
|
| 132 |
+
def add_row(self, row_fn):
|
| 133 |
+
if self.table_name not in enabled_metric_tables():
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
row_dict = row_fn()
|
| 137 |
+
assert len(self.column_names) == len(
|
| 138 |
+
row_dict
|
| 139 |
+
), f"{len(self.column_names)} v.s. {len(row_dict)}"
|
| 140 |
+
assert set(self.column_names) == set(
|
| 141 |
+
row_dict.keys()
|
| 142 |
+
), f"{set(self.column_names)} v.s. {set(row_dict.keys())}"
|
| 143 |
+
|
| 144 |
+
row = [
|
| 145 |
+
get_benchmark_name(),
|
| 146 |
+
]
|
| 147 |
+
row += [row_dict[column_name] for column_name in self.column_names]
|
| 148 |
+
self._write_row(row)
|
| 149 |
+
|
| 150 |
+
def output_filename(self):
|
| 151 |
+
return f"metric_table_{self.table_name}.csv"
|
| 152 |
+
|
| 153 |
+
def write_header(self):
|
| 154 |
+
filename = self.output_filename()
|
| 155 |
+
with open(filename, "w") as fd:
|
| 156 |
+
writer = csv.writer(fd, lineterminator="\n")
|
| 157 |
+
writer.writerow(["model_name"] + self.column_names)
|
| 158 |
+
|
| 159 |
+
def _write_row(self, row):
|
| 160 |
+
filename = self.output_filename()
|
| 161 |
+
if self.num_rows_added == 0 and not os.path.exists(filename):
|
| 162 |
+
self.write_header()
|
| 163 |
+
|
| 164 |
+
self.num_rows_added += 1
|
| 165 |
+
|
| 166 |
+
for idx, orig_val in enumerate(row):
|
| 167 |
+
if isinstance(orig_val, float):
|
| 168 |
+
new_val = f"{orig_val:.6f}"
|
| 169 |
+
elif orig_val is None:
|
| 170 |
+
new_val = ""
|
| 171 |
+
else:
|
| 172 |
+
new_val = orig_val
|
| 173 |
+
row[idx] = new_val
|
| 174 |
+
|
| 175 |
+
with open(filename, "a") as fd:
|
| 176 |
+
writer = csv.writer(fd, lineterminator="\n")
|
| 177 |
+
writer.writerow(row)
|
| 178 |
+
|
| 179 |
+
@staticmethod
|
| 180 |
+
def register_table(name, column_names):
|
| 181 |
+
table = MetricTable(name, column_names)
|
| 182 |
+
REGISTERED_METRIC_TABLES[name] = table
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
MetricTable.register_table(
|
| 186 |
+
"slow_fusion",
|
| 187 |
+
[
|
| 188 |
+
"kernel1_path",
|
| 189 |
+
"kernel1_latency",
|
| 190 |
+
"kernel2_path",
|
| 191 |
+
"kernel2_latency",
|
| 192 |
+
"fused_kernel_path",
|
| 193 |
+
"fused_kernel_latency",
|
| 194 |
+
"slow_down_ratio",
|
| 195 |
+
],
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# track the fusion statistics for each graph
|
| 199 |
+
MetricTable.register_table(
|
| 200 |
+
"graph_stats",
|
| 201 |
+
[
|
| 202 |
+
"graph_id",
|
| 203 |
+
"num_nodes_before_fusion",
|
| 204 |
+
"num_nodes_after_fusion",
|
| 205 |
+
],
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# track the perf difference between persistent reduction and non-persistent
|
| 209 |
+
# reductions
|
| 210 |
+
MetricTable.register_table(
|
| 211 |
+
"persistent_red_perf",
|
| 212 |
+
[
|
| 213 |
+
"kernel1_name",
|
| 214 |
+
"kernel2_name",
|
| 215 |
+
"kernel1_latency",
|
| 216 |
+
"kernel2_latency",
|
| 217 |
+
"size_hints",
|
| 218 |
+
"reduction_hint",
|
| 219 |
+
"speedup",
|
| 220 |
+
],
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Log metadata for pointwise/reduction kernels. E.g., model name, kernel path, numel, rnumel, reduction hint
|
| 224 |
+
MetricTable.register_table(
|
| 225 |
+
"kernel_metadata",
|
| 226 |
+
[
|
| 227 |
+
"kernel_name",
|
| 228 |
+
"kernel_path",
|
| 229 |
+
"kernel_category", # pointwise/reduction/foreach etc.
|
| 230 |
+
"size_hints",
|
| 231 |
+
"reduction_hint",
|
| 232 |
+
"line_of_code",
|
| 233 |
+
"num_load",
|
| 234 |
+
"num_store",
|
| 235 |
+
"num_for_loop",
|
| 236 |
+
"num_atomic_add",
|
| 237 |
+
"num_args",
|
| 238 |
+
# xyz numel can be different to size_hints since size_hints are rounded
|
| 239 |
+
# up to the nearest power of 2.
|
| 240 |
+
# Inductor kernel will burn in the xyz numel in kernel code for static
|
| 241 |
+
# shape kernels.
|
| 242 |
+
# Logging them will be helpful to find unaligned shape for reduction
|
| 243 |
+
"xnumel",
|
| 244 |
+
"ynumel",
|
| 245 |
+
"rnumel",
|
| 246 |
+
"kernel_args_num_gb",
|
| 247 |
+
],
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _parse_kernel_fn_code(kernel_module_code):
|
| 252 |
+
"""
|
| 253 |
+
The kernel_module_code is the python module that contains kernel function code.
|
| 254 |
+
kernel function is the proper triton kernel function annotated with
|
| 255 |
+
@triton.jit
|
| 256 |
+
"""
|
| 257 |
+
from .codecache import PyCodeCache
|
| 258 |
+
from .wrapper_benchmark import get_triton_kernel
|
| 259 |
+
|
| 260 |
+
mod = PyCodeCache.load(kernel_module_code)
|
| 261 |
+
kernel = get_triton_kernel(mod)
|
| 262 |
+
# kernel is a CachingAutotune; kernel.fn is the JITFunction;
|
| 263 |
+
# kernel.fn.fn is the function being decorate by triton.jit
|
| 264 |
+
return inspect.getsource(kernel.fn.fn)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def _parse_kernel_line_of_code(proper_kernel_fn_code):
|
| 268 |
+
"""
|
| 269 |
+
Return the line of code for the kernel excluding the decorators.
|
| 270 |
+
"""
|
| 271 |
+
return len(proper_kernel_fn_code.splitlines())
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def _parse_size_hints(kernel_module_code, kernel_category):
|
| 275 |
+
if kernel_category == "foreach":
|
| 276 |
+
# foreach kernel does not have size_hints
|
| 277 |
+
return None
|
| 278 |
+
m = re.search(r"size_hints=(\[[0-9, ]*\]),", kernel_module_code)
|
| 279 |
+
assert m, "size_hints missing!"
|
| 280 |
+
return m.group(1)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def _parse_reduction_hint(kernel_category, kernel_module_code):
|
| 284 |
+
if kernel_category not in ("reduction", "persistent_reduction"):
|
| 285 |
+
return None
|
| 286 |
+
m = re.search(r"reduction_hint=ReductionHint\.(\w*),", kernel_module_code)
|
| 287 |
+
assert m, "reduction_hint not found in kernel source code!"
|
| 288 |
+
return m.group(1)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def _count_pattern(proper_kernel_fn_code, pattern):
|
| 292 |
+
return proper_kernel_fn_code.count(pattern)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _count_args(proper_kernel_fn_code):
|
| 296 |
+
def_line = proper_kernel_fn_code.splitlines()[0]
|
| 297 |
+
assert def_line.startswith("def ")
|
| 298 |
+
start_idx = def_line.index("(")
|
| 299 |
+
end_idx = def_line.index("):")
|
| 300 |
+
decl_csv = def_line[start_idx + 1 : end_idx]
|
| 301 |
+
comps = decl_csv.split(",")
|
| 302 |
+
return len(comps)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _parse_proper_kernel_fn_code(kernel_fn_code):
|
| 306 |
+
"""
|
| 307 |
+
Skip decorators.
|
| 308 |
+
"""
|
| 309 |
+
start_pos = kernel_fn_code.index("def ")
|
| 310 |
+
return kernel_fn_code[start_pos:]
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def _parse_numel(proper_kernel_fn_code, numel_arg_name):
|
| 314 |
+
m = re.search(f"{numel_arg_name} = ([\\d]+)", proper_kernel_fn_code)
|
| 315 |
+
if m:
|
| 316 |
+
return int(m.group(1))
|
| 317 |
+
else:
|
| 318 |
+
return None
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def _parse_kernel_args_num_gb(kernel_fn_code, kernel_category):
|
| 322 |
+
"""
|
| 323 |
+
inductor meta looks like:
|
| 324 |
+
inductor_meta={... 'mutated_arg_names': [], 'no_x_dim': False, 'kernel_num_gb': 2.0},
|
| 325 |
+
"""
|
| 326 |
+
m = re.search(r".kernel_num_gb.:\s*([0-9.]+)", kernel_fn_code)
|
| 327 |
+
if m:
|
| 328 |
+
return float(m.group(1))
|
| 329 |
+
else:
|
| 330 |
+
"""
|
| 331 |
+
There are a few cases that kernel_num_gdb field can be missing:
|
| 332 |
+
1. the field will be missing if config.benchmark_kernel and
|
| 333 |
+
config.profile_bandwidth are false
|
| 334 |
+
2. even if config.benchmark_kernel or config.profile_bandwidth is true.
|
| 335 |
+
foreach kernel does not have kernel_num_gb field in the metadata
|
| 336 |
+
"""
|
| 337 |
+
return None
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def log_kernel_metadata(kernel_name, kernel_path, kernel_module_code):
|
| 341 |
+
"""
|
| 342 |
+
An utility to log kernel metadata. We may parse metadata from kernel source code here.
|
| 343 |
+
|
| 344 |
+
It's fine to parse the generated kernel code here since the logging is
|
| 345 |
+
disabled by default. It would hurt compilation time.
|
| 346 |
+
"""
|
| 347 |
+
from .wrapper_benchmark import get_kernel_category_by_source_code
|
| 348 |
+
|
| 349 |
+
kernel_category = get_kernel_category_by_source_code(kernel_module_code)
|
| 350 |
+
reduction_hint = _parse_reduction_hint(kernel_category, kernel_module_code)
|
| 351 |
+
size_hints = _parse_size_hints(kernel_module_code, kernel_category)
|
| 352 |
+
kernel_fn_code = _parse_kernel_fn_code(kernel_module_code)
|
| 353 |
+
|
| 354 |
+
proper_kernel_fn_code = _parse_proper_kernel_fn_code(kernel_fn_code)
|
| 355 |
+
|
| 356 |
+
# the line of code excluding the decortors
|
| 357 |
+
kernel_line_of_code = _parse_kernel_line_of_code(proper_kernel_fn_code)
|
| 358 |
+
|
| 359 |
+
get_metric_table("kernel_metadata").add_row(
|
| 360 |
+
lambda: {
|
| 361 |
+
"kernel_name": kernel_name,
|
| 362 |
+
"kernel_path": kernel_path,
|
| 363 |
+
"kernel_category": kernel_category,
|
| 364 |
+
"size_hints": size_hints,
|
| 365 |
+
"reduction_hint": reduction_hint,
|
| 366 |
+
"line_of_code": kernel_line_of_code,
|
| 367 |
+
"num_load": _count_pattern(proper_kernel_fn_code, "tl.load"),
|
| 368 |
+
"num_store": _count_pattern(proper_kernel_fn_code, "tl.store"),
|
| 369 |
+
"num_for_loop": _count_pattern(proper_kernel_fn_code, "for "),
|
| 370 |
+
"num_atomic_add": _count_pattern(proper_kernel_fn_code, "tl.atomic_add"),
|
| 371 |
+
"num_args": _count_args(proper_kernel_fn_code),
|
| 372 |
+
"xnumel": _parse_numel(proper_kernel_fn_code, "xnumel"),
|
| 373 |
+
"ynumel": _parse_numel(proper_kernel_fn_code, "ynumel"),
|
| 374 |
+
"rnumel": _parse_numel(proper_kernel_fn_code, "rnumel"),
|
| 375 |
+
"kernel_args_num_gb": _parse_kernel_args_num_gb(
|
| 376 |
+
kernel_fn_code, kernel_category
|
| 377 |
+
),
|
| 378 |
+
}
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def purge_old_log_files():
|
| 383 |
+
"""
|
| 384 |
+
Purge the old log file at the beginning when the benchmark script runs.
|
| 385 |
+
Should do it in the parent process rather than the child processes running
|
| 386 |
+
each individual model.
|
| 387 |
+
"""
|
| 388 |
+
for name, table in REGISTERED_METRIC_TABLES.items():
|
| 389 |
+
if name in enabled_metric_tables():
|
| 390 |
+
filename = table.output_filename()
|
| 391 |
+
if os.path.exists(filename):
|
| 392 |
+
os.unlink(filename)
|
| 393 |
+
|
| 394 |
+
table.write_header()
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
@lru_cache
|
| 398 |
+
def enabled_metric_tables() -> Set[str]:
|
| 399 |
+
config_str = config.enabled_metric_tables
|
| 400 |
+
|
| 401 |
+
enabled = set()
|
| 402 |
+
for name in config_str.split(","):
|
| 403 |
+
name = name.strip()
|
| 404 |
+
if not name:
|
| 405 |
+
continue
|
| 406 |
+
assert (
|
| 407 |
+
name in REGISTERED_METRIC_TABLES
|
| 408 |
+
), f"Metric table name {name} is not registered"
|
| 409 |
+
enabled.add(name)
|
| 410 |
+
return enabled
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def is_metric_table_enabled(name):
|
| 414 |
+
return name in enabled_metric_tables()
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def get_metric_table(name):
|
| 418 |
+
assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined"
|
| 419 |
+
return REGISTERED_METRIC_TABLES[name]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/triton_helpers.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import triton
|
| 2 |
+
import triton.language as tl
|
| 3 |
+
|
| 4 |
+
# In the latest triton, math functions were shuffled around into different modules:
|
| 5 |
+
# https://github.com/openai/triton/pull/3172
|
| 6 |
+
if hasattr(tl.extra.cuda, "libdevice"):
|
| 7 |
+
libdevice = tl.extra.cuda.libdevice
|
| 8 |
+
math = tl.math
|
| 9 |
+
else:
|
| 10 |
+
libdevice = tl.math
|
| 11 |
+
math = tl
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@triton.jit
|
| 15 |
+
def promote_to_tensor(x):
|
| 16 |
+
# Addition promotes to tensor for us
|
| 17 |
+
return x + tl.zeros((1,), tl.int1)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@triton.jit
|
| 21 |
+
def is_floating(x):
|
| 22 |
+
return promote_to_tensor(x).dtype.is_floating()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@triton.jit
|
| 26 |
+
def _prod_accumulate(a, b):
|
| 27 |
+
return a * b
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@triton.jit
|
| 31 |
+
def prod(input, axis):
|
| 32 |
+
return tl.reduce(input, axis, _prod_accumulate)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@triton.jit
|
| 36 |
+
def minimum(a, b):
|
| 37 |
+
mask = a < b
|
| 38 |
+
if is_floating(a):
|
| 39 |
+
mask |= a != a
|
| 40 |
+
return tl.where(mask, a, b)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@triton.jit
|
| 44 |
+
def maximum(a, b):
|
| 45 |
+
mask = a > b
|
| 46 |
+
if is_floating(a):
|
| 47 |
+
mask |= a != a
|
| 48 |
+
return tl.where(mask, a, b)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@triton.jit
|
| 52 |
+
def min2(a, dim):
|
| 53 |
+
return tl.reduce(a, dim, minimum)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@triton.jit
|
| 57 |
+
def max2(a, dim):
|
| 58 |
+
return tl.reduce(a, dim, maximum)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@triton.jit
|
| 62 |
+
def minimum_with_index(a_value, a_index, b_value, b_index):
|
| 63 |
+
mask = a_value < b_value
|
| 64 |
+
equal = a_value == b_value
|
| 65 |
+
if is_floating(a_value):
|
| 66 |
+
a_isnan = a_value != a_value
|
| 67 |
+
b_isnan = b_value != b_value
|
| 68 |
+
mask |= a_isnan and not b_isnan
|
| 69 |
+
# Consider NaNs as equal
|
| 70 |
+
equal |= a_isnan and b_isnan
|
| 71 |
+
|
| 72 |
+
# Prefer lowest index if values are equal
|
| 73 |
+
mask |= equal & (a_index < b_index)
|
| 74 |
+
return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@triton.jit
|
| 78 |
+
def maximum_with_index(a_value, a_index, b_value, b_index):
|
| 79 |
+
mask = a_value > b_value
|
| 80 |
+
equal = a_value == b_value
|
| 81 |
+
if is_floating(a_value):
|
| 82 |
+
a_isnan = a_value != a_value
|
| 83 |
+
b_isnan = b_value != b_value
|
| 84 |
+
mask |= a_isnan and not b_isnan
|
| 85 |
+
# Consider NaNs as equal
|
| 86 |
+
equal |= a_isnan and b_isnan
|
| 87 |
+
|
| 88 |
+
# Prefer lowest index if values are equal
|
| 89 |
+
mask |= equal & (a_index < b_index)
|
| 90 |
+
return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@triton.jit
|
| 94 |
+
def min_with_index(value, index, dim):
|
| 95 |
+
return tl.reduce((value, index), dim, minimum_with_index)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@triton.jit
|
| 99 |
+
def max_with_index(value, index, dim):
|
| 100 |
+
return tl.reduce((value, index), dim, maximum_with_index)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@triton.jit
|
| 104 |
+
def welford_reduce(value, mean, m2, weight, first_iteration):
|
| 105 |
+
if first_iteration:
|
| 106 |
+
new_weight = tl.full(weight.shape, 1, weight.dtype)
|
| 107 |
+
new_mean = value
|
| 108 |
+
new_m2 = tl.zeros_like(m2)
|
| 109 |
+
else:
|
| 110 |
+
delta = value - mean
|
| 111 |
+
new_weight = weight + 1
|
| 112 |
+
new_mean = mean + delta / new_weight
|
| 113 |
+
new_m2 = m2 + delta * (value - new_mean)
|
| 114 |
+
return new_mean, new_m2, new_weight
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@triton.jit
|
| 118 |
+
def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
|
| 119 |
+
delta = mean_2 - mean_1
|
| 120 |
+
new_weight = weight_1 + weight_2
|
| 121 |
+
w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)
|
| 122 |
+
return (
|
| 123 |
+
mean_1 + delta * w2_over_w,
|
| 124 |
+
m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
|
| 125 |
+
new_weight,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@triton.jit
|
| 130 |
+
def welford(mean, m2, weight, dim):
|
| 131 |
+
return tl.reduce((mean, m2, weight), dim, welford_combine)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@triton.jit
|
| 135 |
+
def device_assert_then(cond, msg, r):
|
| 136 |
+
tl.device_assert(cond, msg)
|
| 137 |
+
return r
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@triton.jit
|
| 141 |
+
def randint64(seed, offset, low, high):
|
| 142 |
+
r0, r1, r2, r3 = tl.randint4x(seed, offset)
|
| 143 |
+
r0 = r0.to(tl.uint64)
|
| 144 |
+
r1 = r1.to(tl.uint64)
|
| 145 |
+
result = r0 | (r1 << 32)
|
| 146 |
+
size = high - low
|
| 147 |
+
result = result % size.to(tl.uint64)
|
| 148 |
+
result = result.to(tl.int64) + low
|
| 149 |
+
return result
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@triton.jit
|
| 153 |
+
def _any_combine(a, b):
|
| 154 |
+
return a | b
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@triton.jit
|
| 158 |
+
def any(a, dim):
|
| 159 |
+
return tl.reduce(a, dim, _any_combine)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@triton.jit
|
| 163 |
+
def bucketize_binary_search(
|
| 164 |
+
values, # 1D tensor
|
| 165 |
+
offsets_ptr,
|
| 166 |
+
indexing_dtype,
|
| 167 |
+
right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]
|
| 168 |
+
OFFSETS_SIZE: int,
|
| 169 |
+
BLOCK_SHAPE, # tuple/list of block shape
|
| 170 |
+
):
|
| 171 |
+
"""
|
| 172 |
+
See [Note: Inductor bucketize op]
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)
|
| 176 |
+
high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)
|
| 177 |
+
|
| 178 |
+
full_range = OFFSETS_SIZE + 1
|
| 179 |
+
while full_range > 1:
|
| 180 |
+
mid = (high + low) // 2
|
| 181 |
+
mask = mid < OFFSETS_SIZE
|
| 182 |
+
bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask)
|
| 183 |
+
if right:
|
| 184 |
+
is_above = values >= bucket_upper_bound
|
| 185 |
+
else:
|
| 186 |
+
is_above = values > bucket_upper_bound
|
| 187 |
+
|
| 188 |
+
low = tl.where(is_above & mask, mid + 1, low)
|
| 189 |
+
high = tl.where(is_above, high, mid)
|
| 190 |
+
|
| 191 |
+
full_range = (full_range + 1) // 2
|
| 192 |
+
|
| 193 |
+
return low
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@triton.jit
|
| 197 |
+
def pack_value_flag(
|
| 198 |
+
value,
|
| 199 |
+
flag,
|
| 200 |
+
DTYPE_VALUE_AS_UINT: tl.constexpr,
|
| 201 |
+
DTYPE_PACK: tl.constexpr,
|
| 202 |
+
):
|
| 203 |
+
# Workaround for triton bug, tensor.to doesn't unwrap constexpr values
|
| 204 |
+
DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)
|
| 205 |
+
bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
|
| 206 |
+
uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)
|
| 207 |
+
return flag.to(DTYPE_PACK) | (uv << bitwidth)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@triton.jit
|
| 211 |
+
def unpack_value(
|
| 212 |
+
pack,
|
| 213 |
+
DTYPE_VALUE,
|
| 214 |
+
DTYPE_VALUE_AS_UINT,
|
| 215 |
+
):
|
| 216 |
+
# Workaround for triton bug, tensor.to doesn't unwrap constexpr values
|
| 217 |
+
DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE)
|
| 218 |
+
DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)
|
| 219 |
+
bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
|
| 220 |
+
value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)
|
| 221 |
+
return value_uint.to(DTYPE_VALUE, bitcast=True)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@triton.jit
|
| 225 |
+
def unpack_flag(pack, DTYPE_FLAG):
|
| 226 |
+
return pack.to(DTYPE_FLAG)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
@triton.jit
|
| 230 |
+
def exclusive_scan_decoupled_lookback(
|
| 231 |
+
scratch_base,
|
| 232 |
+
block_value,
|
| 233 |
+
index,
|
| 234 |
+
combine_fn,
|
| 235 |
+
init,
|
| 236 |
+
DTYPE_VALUE_AS_UINT: tl.constexpr,
|
| 237 |
+
DTYPE_PACK: tl.constexpr,
|
| 238 |
+
):
|
| 239 |
+
"""Compute exclusive scan of a scalar value between blocks
|
| 240 |
+
|
| 241 |
+
Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
|
| 242 |
+
|
| 243 |
+
scratch_base: Pointer to scratch space in global memory
|
| 244 |
+
block_value: Scalar value for this block
|
| 245 |
+
index: Scalar index of this block relative to the current scan
|
| 246 |
+
combine_fn: Function ``(value, value) -> value`` which is scanned over
|
| 247 |
+
init: Scalar value equal to the identiy of combine_fn
|
| 248 |
+
DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value``
|
| 249 |
+
DTYPE_PACK: Unsigned type twice the width of block_value
|
| 250 |
+
|
| 251 |
+
NOTE: This function is limited to values which are 32-bits or less.
|
| 252 |
+
"""
|
| 253 |
+
DTYPE_VALUE = block_value.dtype
|
| 254 |
+
pack = pack_value_flag(
|
| 255 |
+
block_value,
|
| 256 |
+
tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),
|
| 257 |
+
DTYPE_VALUE_AS_UINT,
|
| 258 |
+
DTYPE_PACK,
|
| 259 |
+
)
|
| 260 |
+
tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
|
| 261 |
+
|
| 262 |
+
exclusive_prefix = init
|
| 263 |
+
test_target = index - 1
|
| 264 |
+
while test_target >= 0:
|
| 265 |
+
# tl.atomic_load
|
| 266 |
+
flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)
|
| 267 |
+
while flag == 0:
|
| 268 |
+
pack = tl.atomic_add(scratch_base + test_target, 0, sem="relaxed")
|
| 269 |
+
flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)
|
| 270 |
+
|
| 271 |
+
value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)
|
| 272 |
+
exclusive_prefix = combine_fn(value, exclusive_prefix)
|
| 273 |
+
|
| 274 |
+
if flag == 2:
|
| 275 |
+
test_target = -1
|
| 276 |
+
else:
|
| 277 |
+
test_target = test_target - 1
|
| 278 |
+
|
| 279 |
+
# Make inclusive block sum visible to other blocks
|
| 280 |
+
inclusive_prefix = combine_fn(exclusive_prefix, block_value)
|
| 281 |
+
pack = pack_value_flag(
|
| 282 |
+
inclusive_prefix,
|
| 283 |
+
tl.full([], 2, DTYPE_VALUE_AS_UINT),
|
| 284 |
+
DTYPE_VALUE_AS_UINT,
|
| 285 |
+
DTYPE_PACK,
|
| 286 |
+
)
|
| 287 |
+
tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
|
| 288 |
+
return exclusive_prefix
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
@triton.jit
|
| 292 |
+
def exclusive_scan_decoupled_lookback_64(
|
| 293 |
+
scratch_base, block_value, index, combine_fn, init
|
| 294 |
+
):
|
| 295 |
+
"""Compute exclusive scan of a scalar value between blocks
|
| 296 |
+
|
| 297 |
+
Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
|
| 298 |
+
|
| 299 |
+
scratch_base: Pointer to scratch space in global memory
|
| 300 |
+
block_value: Scalar value for this block, must be 64-bits wide
|
| 301 |
+
index: Scalar index of this block relative to the current scan
|
| 302 |
+
combine_fn: Function ``(value, value) -> value`` which is scanned over
|
| 303 |
+
init: Scalar value equal to the identiy of combine_fn
|
| 304 |
+
"""
|
| 305 |
+
block_value_u64 = block_value.to(tl.uint64, bitcast=True)
|
| 306 |
+
tl.store(scratch_base + 3 * index + 1, block_value_u64)
|
| 307 |
+
tl.debug_barrier()
|
| 308 |
+
flag_one = tl.full([], 1, tl.uint64)
|
| 309 |
+
tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem="release")
|
| 310 |
+
|
| 311 |
+
exclusive_prefix = init
|
| 312 |
+
test_target = index - 1
|
| 313 |
+
while test_target >= 0:
|
| 314 |
+
flag = tl.full([], 0, tl.uint64)
|
| 315 |
+
while flag == 0:
|
| 316 |
+
flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem="acquire")
|
| 317 |
+
|
| 318 |
+
value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))
|
| 319 |
+
value = value_u64.to(block_value.dtype, bitcast=True)
|
| 320 |
+
exclusive_prefix = combine_fn(value, exclusive_prefix)
|
| 321 |
+
|
| 322 |
+
if flag == 2:
|
| 323 |
+
test_target = -1
|
| 324 |
+
else:
|
| 325 |
+
test_target = test_target - 1
|
| 326 |
+
|
| 327 |
+
# Make inclusive block sum visible to other blocks
|
| 328 |
+
inclusive_prefix = combine_fn(exclusive_prefix, block_value)
|
| 329 |
+
inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)
|
| 330 |
+
tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)
|
| 331 |
+
tl.debug_barrier()
|
| 332 |
+
flag_two = tl.full([], 2, tl.uint64)
|
| 333 |
+
tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release")
|
| 334 |
+
|
| 335 |
+
return exclusive_prefix
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
@triton.jit
|
| 339 |
+
def frexp(x):
|
| 340 |
+
# TODO(isuruf): use inline_asm_elementwise here
|
| 341 |
+
y = libdevice.ilogb(x) + 1
|
| 342 |
+
exponent = tl.where(x == 0, 0, y)
|
| 343 |
+
mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y))
|
| 344 |
+
return mantissa, exponent
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocator.h
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <ATen/mps/MPSAllocatorInterface.h>
|
| 6 |
+
#include <ATen/mps/MPSEvent.h>
|
| 7 |
+
#include <ATen/mps/MPSStream.h>
|
| 8 |
+
|
| 9 |
+
#include <cstdio>
|
| 10 |
+
#include <mutex>
|
| 11 |
+
#include <set>
|
| 12 |
+
#include <unordered_set>
|
| 13 |
+
#include <mach/vm_page_size.h>
|
| 14 |
+
#include <c10/util/flat_hash_map.h>
|
| 15 |
+
|
| 16 |
+
// this implementation is based on CUDACachingAllocator.
|
| 17 |
+
// It utilizes Metal Heaps to improve the performance with buffer allocation.
|
| 18 |
+
// Do not include this header. Use MPSAllocatorInterface.h instead.
|
| 19 |
+
// TODO: Unify the logic with CUDACachingAllocator and remove redundant code.
|
| 20 |
+
namespace at::mps::HeapAllocator {
|
| 21 |
+
|
| 22 |
+
static const size_t kMaxSmallAlloc = MB(1); // largest "small" allocation is 1 MiB
|
| 23 |
+
static const size_t kMinLargeAlloc = MB(10); // allocations between 1 and 10 MiB may use kLargeHeap
|
| 24 |
+
static const size_t kRoundLarge = MB(2); // round up large allocations to 2 MiB
|
| 25 |
+
static const size_t kSmallHeap = MB(8); // "small" allocations are packed in 8 MiB heaps
|
| 26 |
+
static const size_t kLargeHeap = MB(32); // "large" allocations may be packed in 32 MiB heaps
|
| 27 |
+
static const size_t kXLargeHeapD = MB(128); // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
|
| 28 |
+
static const size_t kXLargeHeapU = MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
|
| 29 |
+
static const size_t kMaxScalarAlloc = (sizeof(int64_t)); // largest "scalar" allocation
|
| 30 |
+
|
| 31 |
+
// buffer pools could be customized with a combination of usage flags
|
| 32 |
+
enum UsageFlags : uint32_t {
|
| 33 |
+
PRIVATE = 0,
|
| 34 |
+
SMALL = (1 << 0), // small heaps have sizes of kSmallHeap, and large ones kLargeHeap
|
| 35 |
+
SHARED = (1 << 1), // shared pools allocated on devices with unified memory; otherwise, private between host/device
|
| 36 |
+
MANAGED = (1 << 2), // managed storage mode
|
| 37 |
+
HAZARD = (1 << 3), // enables Automatic Hazard Tracking for the resources allocated on the pool
|
| 38 |
+
SCALAR = (1 << 4), // used to import CPU scalar values to GPU and use them in MPS Stream
|
| 39 |
+
};
|
| 40 |
+
// debug verbosity flags
|
| 41 |
+
enum DebugVerbosity : uint32_t {
|
| 42 |
+
SILENT = 0,
|
| 43 |
+
PROFILING = (1 << 0), // print generic profiling data for total system memory usage
|
| 44 |
+
ALLOCATIONS = (1 << 1), // print buffer allocations
|
| 45 |
+
RECYCLES = (1 << 2), // print buffer recycling
|
| 46 |
+
RELEASES = (1 << 3), // print buffer releases
|
| 47 |
+
LARGE_ONLY = (1 << 4), // only log large buffer pool transactions
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
struct HeapBlock;
|
| 51 |
+
|
| 52 |
+
struct BufferBlock {
|
| 53 |
+
id<MTLBuffer> buffer;
|
| 54 |
+
void* cpu_ptr = nullptr; // stores the pointer to CPU mapping of a Shared MTLBuffer
|
| 55 |
+
size_t size; // size after alignment
|
| 56 |
+
size_t requested_size; // requested size (before alignment)
|
| 57 |
+
// buffer shape is used for retrieving base of views in cached graphs
|
| 58 |
+
std::vector<int64_t> shape;
|
| 59 |
+
bool in_use = false;
|
| 60 |
+
HeapBlock* heap;
|
| 61 |
+
id_t buf_id;
|
| 62 |
+
// counter to candidate least recently used buffers for garbage collection
|
| 63 |
+
uint32_t gc_count = 0;
|
| 64 |
+
uint32_t use_count = 0;
|
| 65 |
+
// counter to assign unique ids to buffer blocks
|
| 66 |
+
static uint64_t buffer_counter;
|
| 67 |
+
// Metal events used to sync GPU/CPU operations on the shared-storage buffers
|
| 68 |
+
MPSEventPtr event;
|
| 69 |
+
|
| 70 |
+
BufferBlock(size_t Size, size_t RequestedSize = 0, const id<MTLBuffer> Buffer = nullptr,
|
| 71 |
+
HeapBlock* Heap = nullptr) :
|
| 72 |
+
buffer(Buffer), size(Size), requested_size(RequestedSize),
|
| 73 |
+
heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) { }
|
| 74 |
+
|
| 75 |
+
static bool Comparator(const BufferBlock* a, const BufferBlock* b) {
|
| 76 |
+
return (a->size != b->size) ? a->size < b->size : (uintptr_t)a->buffer < (uintptr_t)b->buffer;
|
| 77 |
+
}
|
| 78 |
+
static size_t alignUp(size_t Size, size_t Alignment) {
|
| 79 |
+
assert(((Alignment - 1) & Alignment) == 0);
|
| 80 |
+
return ((Size + Alignment - 1) & ~(Alignment - 1));
|
| 81 |
+
}
|
| 82 |
+
uint32_t retainCount() const { return [buffer retainCount]; }
|
| 83 |
+
};
|
| 84 |
+
typedef bool (*BufferComparison)(const BufferBlock*, const BufferBlock*);
|
| 85 |
+
|
| 86 |
+
struct BufferPool;
|
| 87 |
+
struct AllocParams {
|
| 88 |
+
AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool) :
|
| 89 |
+
search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) { }
|
| 90 |
+
size_t size() const { return search_key.size; }
|
| 91 |
+
|
| 92 |
+
BufferBlock search_key;
|
| 93 |
+
BufferPool* pool;
|
| 94 |
+
BufferBlock* buffer_block = nullptr;
|
| 95 |
+
size_t requested_size;
|
| 96 |
+
// true if we exceed the low watermark limit. In this case
|
| 97 |
+
// we apply strategies to relieve the pressure before allocation.
|
| 98 |
+
bool has_memory_pressure = false;
|
| 99 |
+
// true if we're allocating on a unified memory device
|
| 100 |
+
bool has_unified_memory = true;
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
struct HeapBlock {
|
| 104 |
+
id<MTLHeap> heap;
|
| 105 |
+
struct { size_t total, available; } size;
|
| 106 |
+
BufferPool* pool;
|
| 107 |
+
unsigned int n_buffers = 0;
|
| 108 |
+
id_t heap_id;
|
| 109 |
+
// indicates if we split this heap to sub-allocate 'several' buffers (otherwise single buffer)
|
| 110 |
+
bool is_split;
|
| 111 |
+
// counter to assign unique ids to heap blocks
|
| 112 |
+
static uint64_t heap_counter;
|
| 113 |
+
|
| 114 |
+
HeapBlock(size_t Size, const id<MTLHeap> Heap = nullptr, BufferPool *Pool = nullptr) :
|
| 115 |
+
heap(Heap), size({.total = Size, .available = Size}), pool(Pool),
|
| 116 |
+
heap_id(Heap ? ++heap_counter : 0), is_split(true) { }
|
| 117 |
+
|
| 118 |
+
static MTLResourceOptions getOptions(uint32_t usage) {
|
| 119 |
+
// TODO: check the caching performance of write-combined mode
|
| 120 |
+
MTLResourceOptions options = MTLResourceCPUCacheModeDefaultCache;
|
| 121 |
+
|
| 122 |
+
if (usage & UsageFlags::MANAGED)
|
| 123 |
+
options |= MTLResourceStorageModeManaged;
|
| 124 |
+
else if (usage & UsageFlags::SHARED)
|
| 125 |
+
options |= MTLResourceStorageModeShared;
|
| 126 |
+
else
|
| 127 |
+
options |= MTLResourceStorageModePrivate;
|
| 128 |
+
|
| 129 |
+
options |= (usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
|
| 130 |
+
|
| 131 |
+
return options;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
static HeapBlock* createHeapBlock(AllocParams& params, id<MTLDevice> device, uint32_t usage) {
|
| 135 |
+
HeapBlock *heapBlock = nullptr;
|
| 136 |
+
bool is_split = true;
|
| 137 |
+
const size_t size = params.size();
|
| 138 |
+
MTLHeapDescriptor *d = [MTLHeapDescriptor new];
|
| 139 |
+
if (d) {
|
| 140 |
+
const size_t kXLargeHeap = params.has_unified_memory ? kXLargeHeapU : kXLargeHeapD;
|
| 141 |
+
if (size <= kMaxSmallAlloc) {
|
| 142 |
+
d.size = kSmallHeap;
|
| 143 |
+
} else if (size < kMinLargeAlloc) {
|
| 144 |
+
d.size = kLargeHeap;
|
| 145 |
+
} else if (size < kXLargeHeap / 2 && !params.has_memory_pressure) {
|
| 146 |
+
d.size = kXLargeHeap;
|
| 147 |
+
} else {
|
| 148 |
+
d.size = kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
|
| 149 |
+
is_split = false;
|
| 150 |
+
}
|
| 151 |
+
d.storageMode = (usage & UsageFlags::SHARED) ? MTLStorageModeShared : MTLStorageModePrivate;
|
| 152 |
+
d.cpuCacheMode = MTLCPUCacheModeDefaultCache;
|
| 153 |
+
// this automatically handles Metal buffer access synchronizations at the
|
| 154 |
+
// cost of slightly lower performance.
|
| 155 |
+
d.hazardTrackingMode = (usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
|
| 156 |
+
d.resourceOptions = getOptions(usage);
|
| 157 |
+
d.type = MTLHeapTypeAutomatic;
|
| 158 |
+
id<MTLHeap> heap = [device newHeapWithDescriptor: d];
|
| 159 |
+
if (heap) {
|
| 160 |
+
[heap setPurgeableState:MTLPurgeableStateNonVolatile];
|
| 161 |
+
const size_t heap_size = heapAvailableSize(heap);
|
| 162 |
+
heapBlock = new HeapBlock(heap_size, heap, params.pool);
|
| 163 |
+
if (heapBlock) {
|
| 164 |
+
heapBlock->is_split = is_split;
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
[d release];
|
| 168 |
+
}
|
| 169 |
+
return heapBlock;
|
| 170 |
+
}
|
| 171 |
+
static bool Comparator(const HeapBlock* a, const HeapBlock* b) {
|
| 172 |
+
return (a->size.available != b->size.available) ? a->size.available < b->size.available :
|
| 173 |
+
(uintptr_t)a->heap < (uintptr_t)b->heap;
|
| 174 |
+
}
|
| 175 |
+
static NSUInteger heapAvailableSize(id<MTLHeap> heap, size_t Alignment = vm_page_size) {
|
| 176 |
+
return [heap maxAvailableSizeWithAlignment:Alignment];
|
| 177 |
+
}
|
| 178 |
+
NSUInteger Size() {
|
| 179 |
+
return [heap size];
|
| 180 |
+
}
|
| 181 |
+
id<MTLBuffer> newMTLBuffer(size_t length, uint32_t usage) {
|
| 182 |
+
id<MTLBuffer> buf = [heap newBufferWithLength:length options:getOptions(usage)];
|
| 183 |
+
if (buf) {
|
| 184 |
+
updateAvailableSize();
|
| 185 |
+
n_buffers++;
|
| 186 |
+
}
|
| 187 |
+
return buf;
|
| 188 |
+
}
|
| 189 |
+
// returns the retainCount before releasing the buffer
|
| 190 |
+
uint32_t releaseMTLBuffer(id<MTLBuffer>& buffer) {
|
| 191 |
+
const uint32_t retainCount = [buffer retainCount];
|
| 192 |
+
[buffer release];
|
| 193 |
+
buffer = nil;
|
| 194 |
+
updateAvailableSize();
|
| 195 |
+
n_buffers--;
|
| 196 |
+
return retainCount;
|
| 197 |
+
}
|
| 198 |
+
// returns the retainCount before releasing the heap
|
| 199 |
+
uint32_t releaseMTLHeap() {
|
| 200 |
+
const uint32_t retainCount = [heap retainCount];
|
| 201 |
+
TORCH_INTERNAL_ASSERT(!n_buffers); // assert if heap isn't empty
|
| 202 |
+
[heap setPurgeableState:MTLPurgeableStateEmpty];
|
| 203 |
+
[heap release];
|
| 204 |
+
heap = nil;
|
| 205 |
+
size.available = 0;
|
| 206 |
+
return retainCount;
|
| 207 |
+
}
|
| 208 |
+
uint32_t retainCount() const { return [heap retainCount]; }
|
| 209 |
+
void updateAvailableSize() { size.available = heapAvailableSize(heap); }
|
| 210 |
+
};
|
| 211 |
+
typedef bool (*HeapComparison)(const HeapBlock*, const HeapBlock*);
|
| 212 |
+
|
| 213 |
+
struct BufferPool {
|
| 214 |
+
enum class Kind {
|
| 215 |
+
PRIVATE_SMALL,
|
| 216 |
+
PRIVATE_LARGE,
|
| 217 |
+
SHARED_SMALL,
|
| 218 |
+
SHARED_LARGE,
|
| 219 |
+
SCALAR,
|
| 220 |
+
};
|
| 221 |
+
|
| 222 |
+
BufferPool(const id<MTLDevice> Device, uint32_t Usage) :
|
| 223 |
+
device(Device), usage(Usage),
|
| 224 |
+
heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) { }
|
| 225 |
+
|
| 226 |
+
const id<MTLDevice> device;
|
| 227 |
+
// usage flags to customize the pool for various purposes (see UsageFlags enum)
|
| 228 |
+
const uint32_t usage;
|
| 229 |
+
// total number of buffers in the pool
|
| 230 |
+
uint32_t n_buffers = 0;
|
| 231 |
+
// total allocations size on this pool
|
| 232 |
+
size_t allocated_size = 0;
|
| 233 |
+
// total memory available in the pool
|
| 234 |
+
size_t available_size = 0;
|
| 235 |
+
// list of heaps ordered by their "available" (not total) memory size
|
| 236 |
+
std::set<HeapBlock*, HeapComparison> heaps;
|
| 237 |
+
// list of only "available" buffers in the pool (i.e., buffers not in-use)
|
| 238 |
+
std::set<BufferBlock*, BufferComparison> available_buffers;
|
| 239 |
+
// list of buffers that are in a state of "limbo" where they've already been freed
|
| 240 |
+
// from PyTorch-side, but were not returned to pool due to still being
|
| 241 |
+
// in-use by command buffers with retainCount > 1. In this state, the buffer is
|
| 242 |
+
// neither ready to be recycled, nor could be returned to pool as available.
|
| 243 |
+
// These buffers will be returned to pool once the command buffer's
|
| 244 |
+
// completionHandler callbacks are called.
|
| 245 |
+
std::unordered_set<BufferBlock*> buffers_pending_free;
|
| 246 |
+
// list of heaps pending size update
|
| 247 |
+
std::unordered_set<HeapBlock*> heaps_pending_update;
|
| 248 |
+
};
|
| 249 |
+
|
| 250 |
+
class MPSHeapAllocatorImpl {
|
| 251 |
+
public:
|
| 252 |
+
explicit MPSHeapAllocatorImpl() :
|
| 253 |
+
m_device(at::mps::MPSDevice::getInstance()->device()),
|
| 254 |
+
m_max_buffer_size([m_device maxBufferLength]),
|
| 255 |
+
m_stream(getDefaultMPSStream()),
|
| 256 |
+
m_event_pool(getMPSEventPool()) {
|
| 257 |
+
init_allocator();
|
| 258 |
+
}
|
| 259 |
+
~MPSHeapAllocatorImpl() {
|
| 260 |
+
emptyCache();
|
| 261 |
+
}
|
| 262 |
+
// interface exposed to at::Allocator
|
| 263 |
+
id<MTLBuffer> malloc(size_t size, uint32_t usage);
|
| 264 |
+
// frees a buffer and returns it into buffer pool
|
| 265 |
+
void free(void* ptr);
|
| 266 |
+
// releases all the cached buffers and their associated heaps
|
| 267 |
+
void emptyCache();
|
| 268 |
+
// free inactive buffers that are pending to be freed
|
| 269 |
+
void freeInactiveBuffers();
|
| 270 |
+
// returns true if buffer was allocated from the shared pool
|
| 271 |
+
bool isSharedBuffer(const void* ptr);
|
| 272 |
+
// get the requested unaligned size of an MTLBuffer
|
| 273 |
+
ssize_t getUnalignedBufferSize(const void* ptr);
|
| 274 |
+
// set the shape of a base tensor from a view tensor
|
| 275 |
+
void setBufferShape(const void* ptr, const IntArrayRef& shape);
|
| 276 |
+
// retrieve the shape of a base tensor from a view tensor
|
| 277 |
+
IntArrayRef getBufferShape(const void* ptr);
|
| 278 |
+
// get the unique ID of the buffer
|
| 279 |
+
id_t getBufferId(const void* ptr);
|
| 280 |
+
// allocate a buffer from a specialized pool to import CPU scalars into GPU
|
| 281 |
+
id<MTLBuffer> allocScalarBufferWithValue(void* value, size_t size);
|
| 282 |
+
// returns a CPU-mapping of the input buffer and its retainCount,
|
| 283 |
+
// if only it has Shared storage-mode and allocated on MPSAllocator
|
| 284 |
+
std::pair<const void*, uint32_t> getSharedBufferPtr(const void* buffer);
|
| 285 |
+
// records events for a list of MTLBuffers (list is used to lock the mutex once)
|
| 286 |
+
// returns true if records any event (given if passed buffers exist and are shared-storage)
|
| 287 |
+
bool recordEvents(c10::ArrayRef<const void*> buffers);
|
| 288 |
+
// waits for the event to signal the completion of GPU execution
|
| 289 |
+
// on the passed shared buffers (list is used to lock the mutex once)
|
| 290 |
+
// returns true if actually waited on any event
|
| 291 |
+
bool waitForEvents(c10::ArrayRef<const void*> buffers);
|
| 292 |
+
// this indicates how far (in Megabytes) the current total allocations are from the
|
| 293 |
+
// low watermark limit which is used to detect if we're under memory pressure
|
| 294 |
+
// This returns zero if we've reached the low watermark limit
|
| 295 |
+
ssize_t getLowWatermarkValue();
|
| 296 |
+
// (see m_low_watermark_ratio for description)
|
| 297 |
+
void setLowWatermarkRatio(double ratio);
|
| 298 |
+
// (see m_high_watermark_ratio for description)
|
| 299 |
+
void setHighWatermarkRatio(double ratio);
|
| 300 |
+
// (see m_low_watermark_limit for description)
|
| 301 |
+
size_t getLowWatermarkLimit() const { return m_low_watermark_limit; }
|
| 302 |
+
// (see m_max_total_allowed_size for description)
|
| 303 |
+
size_t getHighWatermarkLimit() const { return m_max_total_allowed_size; }
|
| 304 |
+
// (see m_total_allocated_memory for description)
|
| 305 |
+
size_t getTotalAllocatedMemory() const { return m_total_allocated_memory; }
|
| 306 |
+
// (see m_current_allocated_memory for description)
|
| 307 |
+
size_t getCurrentAllocatedMemory() const { return m_current_allocated_memory; }
|
| 308 |
+
// total GPU memory allocated in the process by Metal driver; including
|
| 309 |
+
// implicit allocations from MPS/MPSGraph frameworks and MPSHeapAllocatorImpl.
|
| 310 |
+
size_t getDriverAllocatedMemory() const { return current_allocated_size(); }
|
| 311 |
+
// (see enum DebugVerbosity for description)
|
| 312 |
+
uint32_t getDebugVerbosity() const { return m_debug_verbosity; }
|
| 313 |
+
// returns the device that we allocate from
|
| 314 |
+
inline id<MTLDevice> Device() const { return m_device; }
|
| 315 |
+
|
| 316 |
+
// TODO: make a common function to do size unit conversions in PyTorch.
|
| 317 |
+
inline std::string format_size(uint64_t size) const;
|
| 318 |
+
|
| 319 |
+
private:
|
| 320 |
+
// (see m_high_watermark_ratio for description)
|
| 321 |
+
constexpr static double default_high_watermark_ratio = 1.7;
|
| 322 |
+
// we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize.
|
| 323 |
+
constexpr static double default_high_watermark_upper_bound = 2.0;
|
| 324 |
+
// (see m_low_watermark_ratio for description)
|
| 325 |
+
// on unified memory, we could allocate beyond the recommendedMaxWorkingSetSize
|
| 326 |
+
constexpr static double default_low_watermark_ratio_unified = 1.4;
|
| 327 |
+
constexpr static double default_low_watermark_ratio_discrete = 1.0;
|
| 328 |
+
|
| 329 |
+
const id<MTLDevice> m_device;
|
| 330 |
+
std::recursive_mutex m_mutex;
|
| 331 |
+
// allocated buffers by device pointer
|
| 332 |
+
ska::flat_hash_map<const void*, BufferBlock*> m_allocated_buffers;
|
| 333 |
+
// using a container for pools to simplify iterating them
|
| 334 |
+
ska::flat_hash_map<BufferPool::Kind, std::unique_ptr<BufferPool>> m_pools;
|
| 335 |
+
// total memory allocated by HeapAllocator (including blocks in pools)
|
| 336 |
+
size_t m_total_allocated_memory = 0;
|
| 337 |
+
// currently active memory allocations in use (i.e., blocks not in pools)
|
| 338 |
+
size_t m_current_allocated_memory = 0;
|
| 339 |
+
// max buffer size allowed by Metal
|
| 340 |
+
size_t m_max_buffer_size = 0;
|
| 341 |
+
// maximum total size allowed to be allocated
|
| 342 |
+
size_t m_max_total_allowed_size = 0;
|
| 343 |
+
// high watermark ratio is a hard limit for the total allowed allocations
|
| 344 |
+
// 0. : disables high watermark limit (may cause system failure if system-wide OOM occurs)
|
| 345 |
+
// 1. : recommended maximum allocation size (i.e., device.recommendedMaxWorkingSetSize)
|
| 346 |
+
// >1.: allows limits beyond the device.recommendedMaxWorkingSetSize
|
| 347 |
+
// e.g., value 0.95 means we allocate up to 95% of recommended maximum
|
| 348 |
+
// allocation size; beyond that, the allocations would fail with OOM error.
|
| 349 |
+
double m_high_watermark_ratio;
|
| 350 |
+
// low watermark ratio is a soft limit to attempt limiting memory allocations up to the lower watermark
|
| 351 |
+
// level by garbage collection or committing command buffers more frequently (a.k.a, adaptive commit).
|
| 352 |
+
// Value between 0 to m_high_watermark_ratio (setting 0.0 disables adaptive commit and garbage collection)
|
| 353 |
+
// e.g., value 0.9 means we 'attempt' to limit allocations up to 90% of recommended maximum
|
| 354 |
+
// allocation size.
|
| 355 |
+
double m_low_watermark_ratio;
|
| 356 |
+
// low watermark size limit (in Bytes) at the time we initialize the allocator
|
| 357 |
+
size_t m_low_watermark_limit;
|
| 358 |
+
// use "PYTORCH_DEBUG_MPS_ALLOCATOR" env-var to set debug verbosity
|
| 359 |
+
uint32_t m_debug_verbosity;
|
| 360 |
+
// default MPS stream
|
| 361 |
+
MPSStream* m_stream;
|
| 362 |
+
// we hold a reference to MPSEventPool so it could get destroyed after MPSAllocator
|
| 363 |
+
std::shared_ptr<MPSEventPool> m_event_pool;
|
| 364 |
+
|
| 365 |
+
void init_allocator();
|
| 366 |
+
void init_buffer_pools();
|
| 367 |
+
HeapBlock* get_free_heap(AllocParams& params);
|
| 368 |
+
bool get_free_buffer(AllocParams& params);
|
| 369 |
+
BufferBlock* get_allocated_buffer_block(const void* ptr);
|
| 370 |
+
BufferBlock* alloc_buffer_block(size_t size, uint32_t usage);
|
| 371 |
+
bool alloc_buffer(AllocParams& params);
|
| 372 |
+
void free_buffer(BufferBlock* buffer_block);
|
| 373 |
+
// returns true if the container heap is also released
|
| 374 |
+
bool release_buffer(BufferBlock* buffer_block, bool remove_empty_heap = true);
|
| 375 |
+
void release_buffers(BufferPool& pool);
|
| 376 |
+
bool release_available_cached_buffers(AllocParams& params);
|
| 377 |
+
bool release_cached_buffers();
|
| 378 |
+
// free unused cached blocks to reclaim GPU memory if memory pressure is high
|
| 379 |
+
void garbage_collect_cached_buffers(AllocParams& params);
|
| 380 |
+
// returns the suitable buffer pool type for the usage or
|
| 381 |
+
// requested/allocated sizes
|
| 382 |
+
BufferPool& get_pool(size_t requested_size, size_t aligned_size, uint32_t usage);
|
| 383 |
+
// returns the aligned allocation size that is optimized
|
| 384 |
+
// for the buffers to get reused frequently
|
| 385 |
+
size_t get_allocation_size(size_t size, uint32_t usage) const;
|
| 386 |
+
// maximum size of device memory available for allocation in current process
|
| 387 |
+
// Note: the recommendedMaxWorkingSetSize is typically 75% of the total system memory.
|
| 388 |
+
size_t max_device_size() const { return [m_device recommendedMaxWorkingSetSize]; }
|
| 389 |
+
// there are implicit allocations from MPS backend, so we need to query the 'device' for
|
| 390 |
+
// total allocated size instead of manually tracking in MPSAllocator
|
| 391 |
+
size_t current_allocated_size() const { return [m_device currentAllocatedSize]; }
|
| 392 |
+
|
| 393 |
+
bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) const {
|
| 394 |
+
for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
|
| 395 |
+
MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block ? buffer_block->buffer : nullptr, event);
|
| 396 |
+
}
|
| 397 |
+
return true;
|
| 398 |
+
}
|
| 399 |
+
};
|
| 400 |
+
|
| 401 |
+
} // namespace at::mps::HeapAllocator
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Allocator.h>
|
| 6 |
+
#include <c10/util/Registry.h>
|
| 7 |
+
#include <ATen/core/ATen_fwd.h>
|
| 8 |
+
|
| 9 |
+
#define MB(x) (x * 1048576UL)
|
| 10 |
+
|
| 11 |
+
namespace at::mps {
|
| 12 |
+
|
| 13 |
+
// this is a public interface to access MPSAllocator.
|
| 14 |
+
// Do not declare methods that would depend on MPS or Metal frameworks.
|
| 15 |
+
class IMPSAllocator : public c10::Allocator {
|
| 16 |
+
public:
|
| 17 |
+
// see the comments in MPSAllocator.h for the description of these methods.
|
| 18 |
+
virtual void emptyCache() const = 0;
|
| 19 |
+
virtual void freeInactiveBuffers() const = 0;
|
| 20 |
+
virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0;
|
| 21 |
+
virtual IntArrayRef getBufferShape(const void* ptr) const = 0;
|
| 22 |
+
virtual id_t getBufferId(const void* ptr) const = 0;
|
| 23 |
+
virtual void setBufferShape(const void* ptr, const IntArrayRef& shape) const = 0;
|
| 24 |
+
virtual bool isSharedBuffer(const void* ptr) const = 0;
|
| 25 |
+
virtual bool isSharedStorageSupported() const = 0;
|
| 26 |
+
virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) const = 0;
|
| 27 |
+
virtual std::string formatSize(size_t size) const = 0;
|
| 28 |
+
virtual void setLowWatermarkRatio(double ratio) const = 0;
|
| 29 |
+
virtual void setHighWatermarkRatio(double ratio) const = 0;
|
| 30 |
+
virtual ssize_t getLowWatermarkValue() const = 0;
|
| 31 |
+
virtual size_t getLowWatermarkLimit() const = 0;
|
| 32 |
+
virtual size_t getHighWatermarkLimit() const = 0;
|
| 33 |
+
virtual size_t getTotalAllocatedMemory() const = 0;
|
| 34 |
+
virtual size_t getCurrentAllocatedMemory() const = 0;
|
| 35 |
+
virtual size_t getDriverAllocatedMemory() const = 0;
|
| 36 |
+
virtual std::pair<const void*, uint32_t> getSharedBufferPtr(const void* ptr) const = 0;
|
| 37 |
+
virtual bool recordEvents(c10::ArrayRef<const void*> buffers) const = 0;
|
| 38 |
+
virtual bool waitForEvents(c10::ArrayRef<const void*> buffers) const = 0;
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
class IMpsAllocatorCallback {
|
| 42 |
+
public:
|
| 43 |
+
enum class EventType {
|
| 44 |
+
ALLOCATED, // buffer got allocated to be used immediately
|
| 45 |
+
RECYCLED, // buffer pulled from free list to be reused
|
| 46 |
+
FREED, // buffer put to free list for future recycling
|
| 47 |
+
RELEASED, // buffer memory released
|
| 48 |
+
ALLOCATION_FAILED // buffer allocation failed
|
| 49 |
+
};
|
| 50 |
+
virtual ~IMpsAllocatorCallback() = default;
|
| 51 |
+
virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
// MPS allocator will execute every registered callback when a block of memory is freed.
|
| 55 |
+
C10_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
|
| 56 |
+
#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
|
| 57 |
+
C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__);
|
| 58 |
+
|
| 59 |
+
IMPSAllocator* getIMPSAllocator(bool sharedAllocator = false);
|
| 60 |
+
|
| 61 |
+
} // namespace at::mps
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSEvent.h
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <ATen/mps/MPSStream.h>
|
| 6 |
+
#include <ctime>
|
| 7 |
+
#include <stack>
|
| 8 |
+
|
| 9 |
+
namespace at::mps {
|
| 10 |
+
|
| 11 |
+
// NOTE: don't create instances of this class directly.
|
| 12 |
+
// Use MPSEventPool to acquire instances of MPSEvent.
|
| 13 |
+
class MPSEvent {
|
| 14 |
+
public:
|
| 15 |
+
explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
|
| 16 |
+
~MPSEvent();
|
| 17 |
+
|
| 18 |
+
// records an event on the stream
|
| 19 |
+
void record(bool needsLock, bool syncEvent = false);
|
| 20 |
+
// makes all future work submitted to the stream wait for this event.
|
| 21 |
+
bool wait(bool needsLock, bool syncEvent = false);
|
| 22 |
+
// schedules a notifyListener callback for the event.
|
| 23 |
+
bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
|
| 24 |
+
// checks if events are already signaled.
|
| 25 |
+
bool query() const;
|
| 26 |
+
// blocks the CPU thread until all the GPU work that were scheduled
|
| 27 |
+
// prior to recording this event are completed.
|
| 28 |
+
bool synchronize();
|
| 29 |
+
// resets this event with new parameters in case it gets reused from the event pool
|
| 30 |
+
void reset(MPSStream* stream, bool enable_timing);
|
| 31 |
+
// returns the unique ID of the event instance
|
| 32 |
+
id_t getID() const { return m_id; }
|
| 33 |
+
// returns the completion timestamp of the event
|
| 34 |
+
uint64_t getCompletionTime() const { return m_completion_time; }
|
| 35 |
+
// if already recorded, waits for cpu_sync_cv to be signaled
|
| 36 |
+
void waitForCpuSync();
|
| 37 |
+
|
| 38 |
+
private:
|
| 39 |
+
id_t m_id;
|
| 40 |
+
// enables measuring the completion time of the notifyListener of this event
|
| 41 |
+
bool m_enable_timing;
|
| 42 |
+
uint64_t m_signalCounter = 0;
|
| 43 |
+
MPSStream* m_stream = nullptr;
|
| 44 |
+
MTLSharedEvent_t m_event = nullptr;
|
| 45 |
+
MTLSharedEventListener* m_listener = nullptr;
|
| 46 |
+
// used to sync the events created on this Stream with CPU
|
| 47 |
+
std::mutex m_cpu_sync_mutex{};
|
| 48 |
+
std::condition_variable m_cpu_sync_cv{};
|
| 49 |
+
// CondVar predicate to sync the events created on this Stream with CPU
|
| 50 |
+
bool m_cpu_sync_completed = false;
|
| 51 |
+
// used to compute elapsed time
|
| 52 |
+
uint64_t m_completion_time = 0;
|
| 53 |
+
|
| 54 |
+
void recordLocked(bool syncEvent);
|
| 55 |
+
bool waitLocked(bool syncEvent);
|
| 56 |
+
bool notifyLocked(MTLSharedEventNotificationBlock block);
|
| 57 |
+
void notifyCpuSync();
|
| 58 |
+
static uint64_t getTime() {
|
| 59 |
+
return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
|
| 60 |
+
}
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;
|
| 64 |
+
|
| 65 |
+
class MPSEventPool {
|
| 66 |
+
public:
|
| 67 |
+
explicit MPSEventPool(MPSStream* default_stream);
|
| 68 |
+
~MPSEventPool();
|
| 69 |
+
|
| 70 |
+
MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
|
| 71 |
+
void emptyCache();
|
| 72 |
+
|
| 73 |
+
// these are mainly used for MPSHooks and torch.mps.Event() bindings
|
| 74 |
+
id_t acquireEvent(bool enable_timing);
|
| 75 |
+
void releaseEvent(id_t event_id);
|
| 76 |
+
void recordEvent(id_t event_id, bool syncEvent);
|
| 77 |
+
void waitForEvent(id_t event_id, bool syncEvent);
|
| 78 |
+
void synchronizeEvent(id_t event_id);
|
| 79 |
+
bool queryEvent(id_t event_id);
|
| 80 |
+
// returns elapsed time between two recorded events in milliseconds
|
| 81 |
+
double elapsedTime(id_t start_event_id, id_t end_event_id);
|
| 82 |
+
|
| 83 |
+
private:
|
| 84 |
+
MPSStream* m_default_stream = nullptr;
|
| 85 |
+
std::recursive_mutex m_mutex;
|
| 86 |
+
std::stack<std::unique_ptr<MPSEvent>> m_pool{};
|
| 87 |
+
// dictionary to associate event IDs with event objects
|
| 88 |
+
// used to retain in-use events out of the pool
|
| 89 |
+
// for torch.mps.Event() bindings.
|
| 90 |
+
std::unordered_map<id_t, MPSEventPtr> m_in_use_events{};
|
| 91 |
+
uint64_t m_event_counter = 0;
|
| 92 |
+
std::function<void(MPSEvent*)> m_default_deleter;
|
| 93 |
+
|
| 94 |
+
MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
// shared_ptr is used to get MPSEventPool destroyed after dependent instances
|
| 98 |
+
std::shared_ptr<MPSEventPool> getMPSEventPool();
|
| 99 |
+
|
| 100 |
+
} // namespace at::mps
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSProfiler.h
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <ATen/Tensor.h>
|
| 6 |
+
#include <ATen/mps/MPSStream.h>
|
| 7 |
+
#include <ATen/mps/MPSAllocatorInterface.h>
|
| 8 |
+
|
| 9 |
+
#include <os/signpost.h>
|
| 10 |
+
#include <os/log.h>
|
| 11 |
+
|
| 12 |
+
#include <sstream>
|
| 13 |
+
#include <string>
|
| 14 |
+
#include <atomic>
|
| 15 |
+
#include <unordered_map>
|
| 16 |
+
#include <utility>
|
| 17 |
+
#include <ctime>
|
| 18 |
+
|
| 19 |
+
namespace at::mps {
|
| 20 |
+
|
| 21 |
+
namespace Profiler {
|
| 22 |
+
|
| 23 |
+
struct BaseInfo {
|
| 24 |
+
// profiling info types
|
| 25 |
+
enum class Type {
|
| 26 |
+
GRAPH,
|
| 27 |
+
KERNEL,
|
| 28 |
+
COPY,
|
| 29 |
+
CPU_FALLBACK,
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle) :
|
| 33 |
+
type(infoType), profileId(Id), handle(Handle) { }
|
| 34 |
+
virtual ~BaseInfo() = default;
|
| 35 |
+
|
| 36 |
+
// type of profiling info
|
| 37 |
+
Type type;
|
| 38 |
+
// unique profile ID for execution instances of operations or copies
|
| 39 |
+
uint64_t profileId;
|
| 40 |
+
// ID generated by os_signpost
|
| 41 |
+
// since it's possible to use event and interval-based signposts at the
|
| 42 |
+
// same time, we need separate IDs for each.
|
| 43 |
+
os_signpost_id_t eventSignpostId = 0, intervalSignpostId = 0;
|
| 44 |
+
// accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime - GPUStartTime")
|
| 45 |
+
std::atomic<double> totalGpuTime{0.0};
|
| 46 |
+
// accumulated Scheduling time in ms (obtained from CompletionHandler's "KernelEndTime - KernelStartTime")
|
| 47 |
+
std::atomic<double> totalSchedulingTime{0.0};
|
| 48 |
+
// indicates if the operation or copy execution has completed
|
| 49 |
+
std::atomic_bool completed{false};
|
| 50 |
+
// handle used to identify the profile info's instance (usually the pointer)
|
| 51 |
+
const uintptr_t handle;
|
| 52 |
+
|
| 53 |
+
virtual const std::string toString(double gpuTime = 0, double schedulingTime = 0) const;
|
| 54 |
+
// builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
|
| 55 |
+
static std::string buildTensorString(const Tensor& tensor, bool includeBufferId = false) {
|
| 56 |
+
if (tensor.defined()) {
|
| 57 |
+
std::stringstream tensorStr;
|
| 58 |
+
auto deviceType = tensor.device().type();
|
| 59 |
+
tensorStr << c10::DeviceTypeName(deviceType);
|
| 60 |
+
// see comments for INCLUDE_BUFFER_ID
|
| 61 |
+
if (includeBufferId && deviceType == at::kMPS) {
|
| 62 |
+
id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
| 63 |
+
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer))
|
| 64 |
+
<< ":" << buffer.retainCount << ")";
|
| 65 |
+
}
|
| 66 |
+
tensorStr << ":"
|
| 67 |
+
<< tensor.scalar_type() << tensor.sizes();
|
| 68 |
+
return tensorStr.str();
|
| 69 |
+
} else {
|
| 70 |
+
return "undefined";
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
static uint64_t getTime() {
|
| 74 |
+
return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
|
| 75 |
+
}
|
| 76 |
+
};
|
| 77 |
+
|
| 78 |
+
struct OperationInfo : BaseInfo {
|
| 79 |
+
OperationInfo(const void* Handle, bool IsGraph, uint64_t Id, const std::string& StrKey) :
|
| 80 |
+
BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)), strKey(StrKey) { }
|
| 81 |
+
|
| 82 |
+
uint64_t runCount = 0;
|
| 83 |
+
std::string strKey;
|
| 84 |
+
|
| 85 |
+
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
|
| 86 |
+
|
| 87 |
+
// builds a string for a kernel
|
| 88 |
+
static std::string buildKernelString(const std::string& kernelName,
|
| 89 |
+
const TensorList& tensors,
|
| 90 |
+
bool includeBufferId = false) {
|
| 91 |
+
std::stringstream kernelStr;
|
| 92 |
+
kernelStr << kernelName;
|
| 93 |
+
for (const Tensor& tensor: tensors) {
|
| 94 |
+
kernelStr << ":" << BaseInfo::buildTensorString(tensor, includeBufferId);
|
| 95 |
+
}
|
| 96 |
+
return kernelStr.str();
|
| 97 |
+
}
|
| 98 |
+
};
|
| 99 |
+
|
| 100 |
+
struct CpuFbInfo : BaseInfo {
|
| 101 |
+
CpuFbInfo(uint64_t Id, const std::string& OpName) :
|
| 102 |
+
BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) { }
|
| 103 |
+
|
| 104 |
+
uint64_t runCount = 0;
|
| 105 |
+
// the current and total overhead of copies in bytes required to convert the Op's
|
| 106 |
+
// input tensors from MPS to CPU and then output from CPU back to MPS
|
| 107 |
+
size_t currentCopyOverhead = 0;
|
| 108 |
+
size_t totalCopyOverhead = 0;
|
| 109 |
+
std::string opName;
|
| 110 |
+
std::string strKey;
|
| 111 |
+
uint64_t startTime = 0;
|
| 112 |
+
|
| 113 |
+
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
|
| 114 |
+
|
| 115 |
+
void updateCopyOverhead(const TensorList& tensors) {
|
| 116 |
+
currentCopyOverhead = 0;
|
| 117 |
+
for (const Tensor& tensor: tensors) {
|
| 118 |
+
if (tensor.defined()) {
|
| 119 |
+
currentCopyOverhead += tensor.nbytes();
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
totalCopyOverhead += currentCopyOverhead;
|
| 123 |
+
}
|
| 124 |
+
};
|
| 125 |
+
|
| 126 |
+
struct CopyInfo : BaseInfo {
|
| 127 |
+
enum class Kind {
|
| 128 |
+
MPS_TO_MPS,
|
| 129 |
+
MPS_TO_CPU,
|
| 130 |
+
CPU_TO_MPS,
|
| 131 |
+
};
|
| 132 |
+
|
| 133 |
+
CopyInfo(const void* Handle, size_t Length, uint64_t Id, bool IsNonBlocking, bool UsesBlitter) :
|
| 134 |
+
BaseInfo(Type::COPY, Id, uintptr_t(Handle)), kind(Kind::MPS_TO_MPS),
|
| 135 |
+
length(Length), isNonBlocking(IsNonBlocking), usesBlitter(UsesBlitter) { }
|
| 136 |
+
|
| 137 |
+
Kind kind;
|
| 138 |
+
size_t length;
|
| 139 |
+
bool isNonBlocking;
|
| 140 |
+
bool usesBlitter;
|
| 141 |
+
std::string srcStrKey;
|
| 142 |
+
std::string dstStrKey;
|
| 143 |
+
// for copies that don't use blitters, we measure CPU time
|
| 144 |
+
uint64_t startTime = 0;
|
| 145 |
+
|
| 146 |
+
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
|
| 147 |
+
|
| 148 |
+
static std::string buildTensorString(const void* buffer, const OptionalTensorRef tensor, bool includeBufferId = false);
|
| 149 |
+
|
| 150 |
+
static bool isStorageOnMPS(const void* buffer, const OptionalTensorRef tensor) {
|
| 151 |
+
if (tensor.has_value()) {
|
| 152 |
+
return tensor->device().type() == at::kMPS;
|
| 153 |
+
}
|
| 154 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(buffer);
|
| 155 |
+
// getUnalignedBufferSize() returns -1 if input buffer is not on MPS device
|
| 156 |
+
return getIMPSAllocator()->getUnalignedBufferSize(buffer) >= 0;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
static Kind getCopyKind(const void* srcBuffer, const void* dstBuffer,
|
| 160 |
+
const OptionalTensorRef srcTensor, const OptionalTensorRef dstTensor) {
|
| 161 |
+
const bool isSrcOnMPS = isStorageOnMPS(srcBuffer, srcTensor);
|
| 162 |
+
const bool isDstOnMPS = isStorageOnMPS(dstBuffer, dstTensor);
|
| 163 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isSrcOnMPS || isDstOnMPS);
|
| 164 |
+
if (isSrcOnMPS && !isDstOnMPS) {
|
| 165 |
+
return Kind::MPS_TO_CPU;
|
| 166 |
+
} else if (!isSrcOnMPS && isDstOnMPS) {
|
| 167 |
+
return Kind::CPU_TO_MPS;
|
| 168 |
+
}
|
| 169 |
+
return Kind::MPS_TO_MPS;
|
| 170 |
+
}
|
| 171 |
+
};
|
| 172 |
+
|
| 173 |
+
struct CopyStat : CopyInfo {
|
| 174 |
+
explicit CopyStat(std::string CopyKindStr) :
|
| 175 |
+
CopyInfo(nullptr, 0, 0, false, false), kindStr(std::move(CopyKindStr)) {}
|
| 176 |
+
// total number of copies
|
| 177 |
+
size_t totalCount = 0;
|
| 178 |
+
// number of Scalar copies (i.e., less than sizeof(int64))
|
| 179 |
+
size_t scalarsCount = 0;
|
| 180 |
+
// number of blocking copies (i.e., require syncing to GPU)
|
| 181 |
+
size_t blockingCount = 0;
|
| 182 |
+
// number of copies that used memcpy(), instead of Metal Blit Encoder
|
| 183 |
+
size_t memcpyCount = 0;
|
| 184 |
+
// accumulated GPU time in ms for the scalar copies
|
| 185 |
+
std::atomic<double> scalarsGpuTime{0.0};
|
| 186 |
+
// copy kind in string type
|
| 187 |
+
std::string kindStr;
|
| 188 |
+
};
|
| 189 |
+
|
| 190 |
+
class MPSProfiler {
|
| 191 |
+
public:
|
| 192 |
+
// lower 16 bits used for profiler options
|
| 193 |
+
enum ProfileOptions : uint32_t {
|
| 194 |
+
OPTIONS_NONE = 0,
|
| 195 |
+
// ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK, etc.)
|
| 196 |
+
// (used for convenience to not compute bit flags by OR-ing manually)
|
| 197 |
+
// trace all signpost types using events
|
| 198 |
+
ALL_SIGNPOST_EVENTS = (1 << 0),
|
| 199 |
+
// trace all signpost types using intervals
|
| 200 |
+
ALL_SIGNPOST_INTERVALS = (1 << 1),
|
| 201 |
+
// always wait for command buffer to finish executing after each commit
|
| 202 |
+
WAIT_UNTIL_COMPLETED = (1 << 2),
|
| 203 |
+
// for interval-based signposts, include the scheduling portion of
|
| 204 |
+
// Graph/Kernel/Copy executions as well.
|
| 205 |
+
// if flag is disable, only "GPU run time" is included in interval,
|
| 206 |
+
// and not schedule time.
|
| 207 |
+
INCLUDE_SCHEDULE_INTERVAL = (1 << 3),
|
| 208 |
+
|
| 209 |
+
// use these if you need to trace signposts types individually (rarely required)
|
| 210 |
+
// trace signpost using intervals
|
| 211 |
+
USE_INTERVALS = (1 << 4),
|
| 212 |
+
// trace signpost by emitting events
|
| 213 |
+
USE_EVENTS = (1 << 5),
|
| 214 |
+
// used for sanity check (Change this when new option added)
|
| 215 |
+
OPTIONS_COUNT = (USE_EVENTS << 1) - 1,
|
| 216 |
+
};
|
| 217 |
+
|
| 218 |
+
// when adding new types, #define the type string in MPSProfiler.mm as well.
|
| 219 |
+
// upper 16 bits used for event types
|
| 220 |
+
enum SignpostTypes : uint32_t {
|
| 221 |
+
SIGNPOST_NONE = 0,
|
| 222 |
+
// trace signposts for PyTorch operation executions
|
| 223 |
+
RUN_OPERATION = (1 << 16),
|
| 224 |
+
// trace signposts for blitter copies
|
| 225 |
+
BLIT_COPY = (1 << 17),
|
| 226 |
+
// trace signposts for ops that fall back on CPU
|
| 227 |
+
CPU_FALLBACK = (1 << 18),
|
| 228 |
+
// used for sanity check (Change this when new type added)
|
| 229 |
+
SIGNPOST_COUNT = (CPU_FALLBACK << 1) - 1,
|
| 230 |
+
};
|
| 231 |
+
|
| 232 |
+
enum LogOptions : uint32_t {
|
| 233 |
+
LOG_NONE = 0,
|
| 234 |
+
|
| 235 |
+
// Info logging options during execution
|
| 236 |
+
// -------------------------------------
|
| 237 |
+
// prints operation info (id/key/run_count) during execution
|
| 238 |
+
OPERATION_INFO = (1 << 0),
|
| 239 |
+
// prints copy info (src/dst tensors/buffers, size, etc.) during execution
|
| 240 |
+
COPY_INFO = (1 << 1),
|
| 241 |
+
// prints CPU Fallback info (id/runCount/opName/copyOverhead) during execution
|
| 242 |
+
CPU_FALLBACK_INFO = (1 << 2),
|
| 243 |
+
|
| 244 |
+
// Profiling Statistics logging options when process terminates
|
| 245 |
+
// ------------------------------------------------------------
|
| 246 |
+
// prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before process terminates
|
| 247 |
+
// this is convenient to not combine following stats bit flags manually
|
| 248 |
+
ALL_STATS = (1 << 3),
|
| 249 |
+
// prints operation stats (GPU times, run count, etc.) before process terminates
|
| 250 |
+
OPERATION_STATS = (1 << 4),
|
| 251 |
+
// prints copies stats (GPU times, copy kinds, sizes, etc.) before process terminates
|
| 252 |
+
COPY_STATS = (1 << 5),
|
| 253 |
+
// prints CPU Fallback stats (CPU times, run times, size of MPS<->CPU copies
|
| 254 |
+
// for tensors, etc.) before process terminates
|
| 255 |
+
CPU_FALLBACK_STATS = (1 << 6),
|
| 256 |
+
|
| 257 |
+
// Metadata format options when logging the info
|
| 258 |
+
// ---------------------------------------------
|
| 259 |
+
// if enabled, includes GPU run time in metadata (i.e., GPUEndTime-GPUStartTime
|
| 260 |
+
// from Metal Command Buffers) (e.g., [GPU=0.324 ms])
|
| 261 |
+
INCLUDE_GPU_TIME = (1 << 7),
|
| 262 |
+
// if enabled, includes GPU scheduling time in metadata separately
|
| 263 |
+
// (i.e., KernelEndTime-KernelStartTime from Metal Command Buffers)
|
| 264 |
+
// e.g., [GPU=0.324 ms, KRNL=0.036 ms]
|
| 265 |
+
INCLUDE_KERNEL_TIME = (1 << 8),
|
| 266 |
+
// if enabled, includes the unique buffer ID in metadata for the storage
|
| 267 |
+
// of a tensor that was allocated on MPSAllocator. This is useful (along with
|
| 268 |
+
// the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are involved
|
| 269 |
+
// with various operations.
|
| 270 |
+
INCLUDE_BUFFER_ID = (1 << 9),
|
| 271 |
+
|
| 272 |
+
// used for sanity check (Change this when new option added)
|
| 273 |
+
LOG_COUNT = (INCLUDE_BUFFER_ID << 1) - 1,
|
| 274 |
+
};
|
| 275 |
+
|
| 276 |
+
explicit MPSProfiler();
|
| 277 |
+
~MPSProfiler();
|
| 278 |
+
|
| 279 |
+
// the handle is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
|
| 280 |
+
// the beginProfile*() functions return a profileId which is unique per graph/kernel/copy
|
| 281 |
+
uint64_t beginProfileKernel(const void* handle, const std::string& strKey, bool isGraph);
|
| 282 |
+
uint64_t beginProfileKernel(const void* handle, const std::string& kernelName, const TensorList& tensors);
|
| 283 |
+
uint64_t beginProfileCopy(const void* srcBuffer, const void* dstBuffer,
|
| 284 |
+
const OptionalTensorRef srcTensor,
|
| 285 |
+
const OptionalTensorRef dstTensor,
|
| 286 |
+
size_t length, bool isNonBlocking, bool usesBlitter = true);
|
| 287 |
+
uint64_t beginProfileCPUFallback(const std::string& opName, const TensorList& tensors);
|
| 288 |
+
void beginProfileGPUInterval(const void* handle);
|
| 289 |
+
|
| 290 |
+
void endProfileCopy(uint64_t profileId, SyncType syncType);
|
| 291 |
+
void endProfileKernel(const void* handle, SyncType syncType = SyncType::NONE);
|
| 292 |
+
void endProfileCPUFallback(const std::string& opName);
|
| 293 |
+
|
| 294 |
+
// these are used to hook into Python bindings for torch.mps.profiler module.
|
| 295 |
+
// this enables generating OS Signpost traces from MPSProfiler on-demand
|
| 296 |
+
// during runtime (instead of environment variables).
|
| 297 |
+
// The "mode" could be either "interval", "event", or both "interval,event"
|
| 298 |
+
// for interval-based and/or event-based signpost tracing.
|
| 299 |
+
void StartTrace(const string& mode, bool waitUntilCompleted);
|
| 300 |
+
void StopTrace();
|
| 301 |
+
|
| 302 |
+
// convenience functions to indicate whether signpost tracing or
|
| 303 |
+
// logging are enabled for the SignpostTypes
|
| 304 |
+
bool isOperationProfilingEnabled() const {
|
| 305 |
+
return (m_signpost_types & SignpostTypes::RUN_OPERATION) ||
|
| 306 |
+
(m_log_options & (LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
|
| 307 |
+
}
|
| 308 |
+
bool isCopyProfilingEnabled() const {
|
| 309 |
+
return (m_signpost_types & SignpostTypes::BLIT_COPY) ||
|
| 310 |
+
(m_log_options & (LogOptions::COPY_INFO | LogOptions::COPY_STATS));
|
| 311 |
+
}
|
| 312 |
+
bool isCPUFallbackProfilingEnabled() const {
|
| 313 |
+
return (m_signpost_types & SignpostTypes::CPU_FALLBACK) ||
|
| 314 |
+
(m_log_options & (LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
|
| 315 |
+
}
|
| 316 |
+
bool isSignpostTracingEnabled() const {
|
| 317 |
+
return (m_signpost_types != SignpostTypes::SIGNPOST_NONE);
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
private:
|
| 321 |
+
// indicates what type of signpost types are enabled and traced by MPS profiler.
|
| 322 |
+
uint32_t m_signpost_types = 0;
|
| 323 |
+
uint32_t m_profile_options = 0;
|
| 324 |
+
uint32_t m_log_options = 0;
|
| 325 |
+
uint64_t m_kernel_counter = 0;
|
| 326 |
+
uint64_t m_graph_counter = 0;
|
| 327 |
+
uint64_t m_cpu_fb_counter = 0;
|
| 328 |
+
uint64_t m_copy_counter = 0;
|
| 329 |
+
// technically, it's possible to trace both events and intervals at the same time
|
| 330 |
+
// so we use separate os_log categories for them
|
| 331 |
+
os_log_t m_os_log_events;
|
| 332 |
+
os_log_t m_os_log_intervals;
|
| 333 |
+
// stats logging could run either from destructor or signal handler
|
| 334 |
+
// so this is used to check if logging has already started.
|
| 335 |
+
std::atomic_bool hasLoggedStats{false};
|
| 336 |
+
// indicates there are pending completionHandler callbacks that haven't been called yet.
|
| 337 |
+
std::atomic_bool hasPendingCompletionHandlers{false};
|
| 338 |
+
// used to capture sigint signal to log profiling stats
|
| 339 |
+
static struct sigaction currentSigint, previousSigint;
|
| 340 |
+
|
| 341 |
+
// We use the following lists for two reasons:
|
| 342 |
+
// 1- for interval-based signposts the "begin" point won't be in same function
|
| 343 |
+
// as the "end" point where we need to be able to retrieve signpost's info
|
| 344 |
+
// 2- if Operations info need to be logged when process ends using LogOptions::OPERATION_INFO.
|
| 345 |
+
|
| 346 |
+
// the pointer key for this map is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
|
| 347 |
+
// this list is retained and could be logged along with aggregate profiling numbers when the process ends.
|
| 348 |
+
std::unordered_map<uintptr_t, std::unique_ptr<OperationInfo>> m_op_info_list{};
|
| 349 |
+
// the string key for this map is the op name that we fall back to execute on CPU
|
| 350 |
+
// this list is retained and could be logged along with aggregate profiling numbers when the process ends.
|
| 351 |
+
std::unordered_map<std::string, std::unique_ptr<CpuFbInfo>> m_cpu_fb_info_list{};
|
| 352 |
+
// this list contains the info for copies, and its key is the unique profileId
|
| 353 |
+
// which is generated from m_copy_counter
|
| 354 |
+
// The copyInfo list is not retained.
|
| 355 |
+
std::unordered_map<uint64_t, std::unique_ptr<CopyInfo>> m_copy_info_list{};
|
| 356 |
+
// a short list that contains copy stats
|
| 357 |
+
std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>> m_copy_stat_list{};
|
| 358 |
+
|
| 359 |
+
void initialize();
|
| 360 |
+
void beginProfileExecution(BaseInfo& info, bool cpuExecution = false);
|
| 361 |
+
void endProfileExecution(BaseInfo& info, os_signpost_id_t event_signpost_id,
|
| 362 |
+
os_signpost_id_t interval_signpost_id,
|
| 363 |
+
double gpuTime, double schedulingTime);
|
| 364 |
+
void addProfilerScheduledHandler(BaseInfo& info);
|
| 365 |
+
void addProfilerCompletedHandler(BaseInfo& info, SyncType syncType);
|
| 366 |
+
void emitSignpostEvent(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
|
| 367 |
+
const std::string& msg) const;
|
| 368 |
+
void beginSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
|
| 369 |
+
const std::string& msg) const;
|
| 370 |
+
void endSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id) const;
|
| 371 |
+
|
| 372 |
+
void updateCopyStats(const CopyInfo& copyInfo, double gpuTime, double schedulingTime);
|
| 373 |
+
// returns true if logging the profiling info "during the execution" is enabled
|
| 374 |
+
bool isProfileInfoLoggingEnabled(BaseInfo::Type infoType, bool isExecutionEnded);
|
| 375 |
+
// logs all the profiling stats that are enabled
|
| 376 |
+
void logProfilingStats();
|
| 377 |
+
// logs kernel profiling stats when the process ends.
|
| 378 |
+
void logOperationsProfilingStats(std::FILE* f) const;
|
| 379 |
+
// logs CPU Fallback profiling stats when the process ends.
|
| 380 |
+
void logCPUFallbackProfilingStats(std::FILE* f) const;
|
| 381 |
+
// logs copy profiling stats when the process ends.
|
| 382 |
+
void logCopyProfilingStats(std::FILE* f) const;
|
| 383 |
+
|
| 384 |
+
os_signpost_id_t generateSignpostId(os_signpost_type_t signpostType, const void* ptr = nullptr);
|
| 385 |
+
static SignpostTypes getSignpostType(BaseInfo::Type infoType);
|
| 386 |
+
static void handleIntSignal(int signal);
|
| 387 |
+
};
|
| 388 |
+
|
| 389 |
+
} // namespace Profiler
|
| 390 |
+
|
| 391 |
+
Profiler::MPSProfiler& getMPSProfiler();
|
| 392 |
+
|
| 393 |
+
} // namespace at::mps
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/Optional.h>
|
| 4 |
+
#include <c10/util/string_view.h>
|
| 5 |
+
#include <ATen/Config.h>
|
| 6 |
+
#include <ATen/native/DispatchStub.h>
|
| 7 |
+
|
| 8 |
+
// Forward declare TI
|
| 9 |
+
namespace at {
|
| 10 |
+
class Tensor;
|
| 11 |
+
struct TensorIterator;
|
| 12 |
+
|
| 13 |
+
namespace native {
|
| 14 |
+
enum class TransposeType;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
namespace at::native {
|
| 20 |
+
|
| 21 |
+
enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss};
|
| 22 |
+
|
| 23 |
+
#if AT_BUILD_WITH_LAPACK()
|
| 24 |
+
// Define per-batch functions to be used in the implementation of batched
|
| 25 |
+
// linear algebra operations
|
| 26 |
+
|
| 27 |
+
template <class scalar_t>
|
| 28 |
+
void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info);
|
| 29 |
+
|
| 30 |
+
template <class scalar_t>
|
| 31 |
+
void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info);
|
| 32 |
+
|
| 33 |
+
template <class scalar_t, class value_t=scalar_t>
|
| 34 |
+
void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);
|
| 35 |
+
|
| 36 |
+
template <class scalar_t>
|
| 37 |
+
void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
|
| 38 |
+
|
| 39 |
+
template <class scalar_t>
|
| 40 |
+
void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
|
| 41 |
+
|
| 42 |
+
template <class scalar_t>
|
| 43 |
+
void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info);
|
| 44 |
+
|
| 45 |
+
template <class scalar_t, class value_t = scalar_t>
|
| 46 |
+
void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info);
|
| 47 |
+
|
| 48 |
+
template <class scalar_t>
|
| 49 |
+
void lapackGels(char trans, int m, int n, int nrhs,
|
| 50 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 51 |
+
scalar_t *work, int lwork, int *info);
|
| 52 |
+
|
| 53 |
+
template <class scalar_t, class value_t = scalar_t>
|
| 54 |
+
void lapackGelsd(int m, int n, int nrhs,
|
| 55 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 56 |
+
value_t *s, value_t rcond, int *rank,
|
| 57 |
+
scalar_t* work, int lwork,
|
| 58 |
+
value_t *rwork, int* iwork, int *info);
|
| 59 |
+
|
| 60 |
+
template <class scalar_t, class value_t = scalar_t>
|
| 61 |
+
void lapackGelsy(int m, int n, int nrhs,
|
| 62 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 63 |
+
int *jpvt, value_t rcond, int *rank,
|
| 64 |
+
scalar_t *work, int lwork, value_t* rwork, int *info);
|
| 65 |
+
|
| 66 |
+
template <class scalar_t, class value_t = scalar_t>
|
| 67 |
+
void lapackGelss(int m, int n, int nrhs,
|
| 68 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 69 |
+
value_t *s, value_t rcond, int *rank,
|
| 70 |
+
scalar_t *work, int lwork,
|
| 71 |
+
value_t *rwork, int *info);
|
| 72 |
+
|
| 73 |
+
template <LapackLstsqDriverType, class scalar_t, class value_t = scalar_t>
|
| 74 |
+
struct lapackLstsq_impl;
|
| 75 |
+
|
| 76 |
+
template <class scalar_t, class value_t>
|
| 77 |
+
struct lapackLstsq_impl<LapackLstsqDriverType::Gels, scalar_t, value_t> {
|
| 78 |
+
static void call(
|
| 79 |
+
char trans, int m, int n, int nrhs,
|
| 80 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 81 |
+
scalar_t *work, int lwork, int *info, // Gels flavor
|
| 82 |
+
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
|
| 83 |
+
value_t *s, // Gelss flavor
|
| 84 |
+
int *iwork // Gelsd flavor
|
| 85 |
+
) {
|
| 86 |
+
lapackGels<scalar_t>(
|
| 87 |
+
trans, m, n, nrhs,
|
| 88 |
+
a, lda, b, ldb,
|
| 89 |
+
work, lwork, info);
|
| 90 |
+
}
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
template <class scalar_t, class value_t>
|
| 94 |
+
struct lapackLstsq_impl<LapackLstsqDriverType::Gelsy, scalar_t, value_t> {
|
| 95 |
+
static void call(
|
| 96 |
+
char trans, int m, int n, int nrhs,
|
| 97 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 98 |
+
scalar_t *work, int lwork, int *info, // Gels flavor
|
| 99 |
+
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
|
| 100 |
+
value_t *s, // Gelss flavor
|
| 101 |
+
int *iwork // Gelsd flavor
|
| 102 |
+
) {
|
| 103 |
+
lapackGelsy<scalar_t, value_t>(
|
| 104 |
+
m, n, nrhs,
|
| 105 |
+
a, lda, b, ldb,
|
| 106 |
+
jpvt, rcond, rank,
|
| 107 |
+
work, lwork, rwork, info);
|
| 108 |
+
}
|
| 109 |
+
};
|
| 110 |
+
|
| 111 |
+
template <class scalar_t, class value_t>
|
| 112 |
+
struct lapackLstsq_impl<LapackLstsqDriverType::Gelsd, scalar_t, value_t> {
|
| 113 |
+
static void call(
|
| 114 |
+
char trans, int m, int n, int nrhs,
|
| 115 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 116 |
+
scalar_t *work, int lwork, int *info, // Gels flavor
|
| 117 |
+
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
|
| 118 |
+
value_t *s, // Gelss flavor
|
| 119 |
+
int *iwork // Gelsd flavor
|
| 120 |
+
) {
|
| 121 |
+
lapackGelsd<scalar_t, value_t>(
|
| 122 |
+
m, n, nrhs,
|
| 123 |
+
a, lda, b, ldb,
|
| 124 |
+
s, rcond, rank,
|
| 125 |
+
work, lwork,
|
| 126 |
+
rwork, iwork, info);
|
| 127 |
+
}
|
| 128 |
+
};
|
| 129 |
+
|
| 130 |
+
template <class scalar_t, class value_t>
|
| 131 |
+
struct lapackLstsq_impl<LapackLstsqDriverType::Gelss, scalar_t, value_t> {
|
| 132 |
+
static void call(
|
| 133 |
+
char trans, int m, int n, int nrhs,
|
| 134 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 135 |
+
scalar_t *work, int lwork, int *info, // Gels flavor
|
| 136 |
+
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
|
| 137 |
+
value_t *s, // Gelss flavor
|
| 138 |
+
int *iwork // Gelsd flavor
|
| 139 |
+
) {
|
| 140 |
+
lapackGelss<scalar_t, value_t>(
|
| 141 |
+
m, n, nrhs,
|
| 142 |
+
a, lda, b, ldb,
|
| 143 |
+
s, rcond, rank,
|
| 144 |
+
work, lwork,
|
| 145 |
+
rwork, info);
|
| 146 |
+
}
|
| 147 |
+
};
|
| 148 |
+
|
| 149 |
+
template <LapackLstsqDriverType driver_type, class scalar_t, class value_t = scalar_t>
|
| 150 |
+
void lapackLstsq(
|
| 151 |
+
char trans, int m, int n, int nrhs,
|
| 152 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 153 |
+
scalar_t *work, int lwork, int *info, // Gels flavor
|
| 154 |
+
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
|
| 155 |
+
value_t *s, // Gelss flavor
|
| 156 |
+
int *iwork // Gelsd flavor
|
| 157 |
+
) {
|
| 158 |
+
lapackLstsq_impl<driver_type, scalar_t, value_t>::call(
|
| 159 |
+
trans, m, n, nrhs,
|
| 160 |
+
a, lda, b, ldb,
|
| 161 |
+
work, lwork, info,
|
| 162 |
+
jpvt, rcond, rank, rwork,
|
| 163 |
+
s,
|
| 164 |
+
iwork);
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
template <class scalar_t>
|
| 168 |
+
void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info);
|
| 169 |
+
|
| 170 |
+
template <class scalar_t>
|
| 171 |
+
void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
|
| 172 |
+
|
| 173 |
+
template <class scalar_t>
|
| 174 |
+
void lapackLdlHermitian(
|
| 175 |
+
char uplo,
|
| 176 |
+
int n,
|
| 177 |
+
scalar_t* a,
|
| 178 |
+
int lda,
|
| 179 |
+
int* ipiv,
|
| 180 |
+
scalar_t* work,
|
| 181 |
+
int lwork,
|
| 182 |
+
int* info);
|
| 183 |
+
|
| 184 |
+
template <class scalar_t>
|
| 185 |
+
void lapackLdlSymmetric(
|
| 186 |
+
char uplo,
|
| 187 |
+
int n,
|
| 188 |
+
scalar_t* a,
|
| 189 |
+
int lda,
|
| 190 |
+
int* ipiv,
|
| 191 |
+
scalar_t* work,
|
| 192 |
+
int lwork,
|
| 193 |
+
int* info);
|
| 194 |
+
|
| 195 |
+
template <class scalar_t>
|
| 196 |
+
void lapackLdlSolveHermitian(
|
| 197 |
+
char uplo,
|
| 198 |
+
int n,
|
| 199 |
+
int nrhs,
|
| 200 |
+
scalar_t* a,
|
| 201 |
+
int lda,
|
| 202 |
+
int* ipiv,
|
| 203 |
+
scalar_t* b,
|
| 204 |
+
int ldb,
|
| 205 |
+
int* info);
|
| 206 |
+
|
| 207 |
+
template <class scalar_t>
|
| 208 |
+
void lapackLdlSolveSymmetric(
|
| 209 |
+
char uplo,
|
| 210 |
+
int n,
|
| 211 |
+
int nrhs,
|
| 212 |
+
scalar_t* a,
|
| 213 |
+
int lda,
|
| 214 |
+
int* ipiv,
|
| 215 |
+
scalar_t* b,
|
| 216 |
+
int ldb,
|
| 217 |
+
int* info);
|
| 218 |
+
|
| 219 |
+
template<class scalar_t, class value_t=scalar_t>
|
| 220 |
+
void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info);
|
| 221 |
+
#endif
|
| 222 |
+
|
| 223 |
+
#if AT_BUILD_WITH_BLAS()
|
| 224 |
+
template <class scalar_t>
|
| 225 |
+
void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb);
|
| 226 |
+
#endif
|
| 227 |
+
|
| 228 |
+
using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/);
|
| 229 |
+
DECLARE_DISPATCH(cholesky_fn, cholesky_stub);
|
| 230 |
+
|
| 231 |
+
using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/);
|
| 232 |
+
|
| 233 |
+
DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);
|
| 234 |
+
|
| 235 |
+
using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/);
|
| 236 |
+
|
| 237 |
+
DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub);
|
| 238 |
+
|
| 239 |
+
using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/);
|
| 240 |
+
DECLARE_DISPATCH(geqrf_fn, geqrf_stub);
|
| 241 |
+
|
| 242 |
+
using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/);
|
| 243 |
+
DECLARE_DISPATCH(orgqr_fn, orgqr_stub);
|
| 244 |
+
|
| 245 |
+
using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const Tensor& /*other*/, bool /*left*/, bool /*transpose*/);
|
| 246 |
+
DECLARE_DISPATCH(ormqr_fn, ormqr_stub);
|
| 247 |
+
|
| 248 |
+
using linalg_eigh_fn = void (*)(
|
| 249 |
+
const Tensor& /*eigenvalues*/,
|
| 250 |
+
const Tensor& /*eigenvectors*/,
|
| 251 |
+
const Tensor& /*infos*/,
|
| 252 |
+
bool /*upper*/,
|
| 253 |
+
bool /*compute_eigenvectors*/);
|
| 254 |
+
DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub);
|
| 255 |
+
|
| 256 |
+
using lstsq_fn = void (*)(
|
| 257 |
+
const Tensor& /*a*/,
|
| 258 |
+
Tensor& /*b*/,
|
| 259 |
+
Tensor& /*rank*/,
|
| 260 |
+
Tensor& /*singular_values*/,
|
| 261 |
+
Tensor& /*infos*/,
|
| 262 |
+
double /*rcond*/,
|
| 263 |
+
std::string /*driver_name*/);
|
| 264 |
+
DECLARE_DISPATCH(lstsq_fn, lstsq_stub);
|
| 265 |
+
|
| 266 |
+
using triangular_solve_fn = void (*)(
|
| 267 |
+
const Tensor& /*A*/,
|
| 268 |
+
const Tensor& /*B*/,
|
| 269 |
+
bool /*left*/,
|
| 270 |
+
bool /*upper*/,
|
| 271 |
+
TransposeType /*transpose*/,
|
| 272 |
+
bool /*unitriangular*/);
|
| 273 |
+
DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);
|
| 274 |
+
|
| 275 |
+
using lu_factor_fn = void (*)(
|
| 276 |
+
const Tensor& /*input*/,
|
| 277 |
+
const Tensor& /*pivots*/,
|
| 278 |
+
const Tensor& /*infos*/,
|
| 279 |
+
bool /*compute_pivots*/);
|
| 280 |
+
DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub);
|
| 281 |
+
|
| 282 |
+
using unpack_pivots_fn = void(*)(
|
| 283 |
+
TensorIterator& iter,
|
| 284 |
+
const int64_t dim_size,
|
| 285 |
+
const int64_t max_pivot);
|
| 286 |
+
DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub);
|
| 287 |
+
|
| 288 |
+
using lu_solve_fn = void (*)(
|
| 289 |
+
const Tensor& /*LU*/,
|
| 290 |
+
const Tensor& /*pivots*/,
|
| 291 |
+
const Tensor& /*B*/,
|
| 292 |
+
TransposeType /*trans*/);
|
| 293 |
+
DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub);
|
| 294 |
+
|
| 295 |
+
using ldl_factor_fn = void (*)(
|
| 296 |
+
const Tensor& /*LD*/,
|
| 297 |
+
const Tensor& /*pivots*/,
|
| 298 |
+
const Tensor& /*info*/,
|
| 299 |
+
bool /*upper*/,
|
| 300 |
+
bool /*hermitian*/);
|
| 301 |
+
DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub);
|
| 302 |
+
|
| 303 |
+
using svd_fn = void (*)(
|
| 304 |
+
const Tensor& /*A*/,
|
| 305 |
+
const bool /*full_matrices*/,
|
| 306 |
+
const bool /*compute_uv*/,
|
| 307 |
+
const c10::optional<c10::string_view>& /*driver*/,
|
| 308 |
+
const Tensor& /*U*/,
|
| 309 |
+
const Tensor& /*S*/,
|
| 310 |
+
const Tensor& /*Vh*/,
|
| 311 |
+
const Tensor& /*info*/);
|
| 312 |
+
DECLARE_DISPATCH(svd_fn, svd_stub);
|
| 313 |
+
|
| 314 |
+
using ldl_solve_fn = void (*)(
|
| 315 |
+
const Tensor& /*LD*/,
|
| 316 |
+
const Tensor& /*pivots*/,
|
| 317 |
+
const Tensor& /*result*/,
|
| 318 |
+
bool /*upper*/,
|
| 319 |
+
bool /*hermitian*/);
|
| 320 |
+
DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub);
|
| 321 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/EmbeddingBag.h
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Tensor.h>
|
| 2 |
+
#include <ATen/Config.h>
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
#ifdef USE_FBGEMM
|
| 6 |
+
#include <fbgemm/FbgemmEmbedding.h>
|
| 7 |
+
#endif
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
void check_arguments(
|
| 12 |
+
const Tensor& weight,
|
| 13 |
+
const Tensor& indices,
|
| 14 |
+
const Tensor& offsets,
|
| 15 |
+
const int64_t mode,
|
| 16 |
+
const c10::optional<Tensor>& per_sample_weights,
|
| 17 |
+
bool include_last_offset);
|
| 18 |
+
|
| 19 |
+
void make_bag_size_out(
|
| 20 |
+
Tensor& bag_size_out,
|
| 21 |
+
const Tensor& offsets,
|
| 22 |
+
const Tensor& indices,
|
| 23 |
+
const int64_t mode,
|
| 24 |
+
const bool include_last_offset,
|
| 25 |
+
const bool requires_grad);
|
| 26 |
+
|
| 27 |
+
void make_max_indices_out(
|
| 28 |
+
Tensor& max_indices_out,
|
| 29 |
+
const Tensor& weight,
|
| 30 |
+
const Tensor& indices,
|
| 31 |
+
const Tensor& offsets,
|
| 32 |
+
const Tensor& bag_size,
|
| 33 |
+
const int64_t mode,
|
| 34 |
+
bool include_last_offset);
|
| 35 |
+
|
| 36 |
+
void make_offset2bag_out(
|
| 37 |
+
Tensor& offset2bag,
|
| 38 |
+
Tensor& output,
|
| 39 |
+
const Tensor& weight,
|
| 40 |
+
const Tensor& indices,
|
| 41 |
+
const Tensor& offsets,
|
| 42 |
+
const int64_t mode,
|
| 43 |
+
const c10::optional<Tensor>& per_sample_weights,
|
| 44 |
+
const int64_t padding_idx = -1);
|
| 45 |
+
|
| 46 |
+
#ifdef USE_FBGEMM
|
| 47 |
+
|
| 48 |
+
template<bool has_weight, typename TIndex, typename TData>
|
| 49 |
+
struct _CallbackAndBlockSize {
|
| 50 |
+
using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type;
|
| 51 |
+
|
| 52 |
+
int64_t blockSize = -1;
|
| 53 |
+
TCallback callback = nullptr;
|
| 54 |
+
|
| 55 |
+
static TCallback generateCallback(int64_t block_size) {
|
| 56 |
+
return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>(
|
| 57 |
+
block_size,
|
| 58 |
+
has_weight,
|
| 59 |
+
/* normalize_by_lengths */false,
|
| 60 |
+
/* prefetch */16,
|
| 61 |
+
/* is_weight_positional */false,
|
| 62 |
+
/* use_offsets */true);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
_CallbackAndBlockSize() = default;
|
| 66 |
+
|
| 67 |
+
explicit _CallbackAndBlockSize(c10::optional<int64_t> maybe_block_size)
|
| 68 |
+
: blockSize(maybe_block_size.value_or(-1))
|
| 69 |
+
, callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr)
|
| 70 |
+
{}
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
template<typename... StorageMixins>
|
| 74 |
+
struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
|
| 75 |
+
|
| 76 |
+
_EmbeddingBagKernelCacheImpl() = default;
|
| 77 |
+
// use each of the mixins to store corresponding kernel and block size
|
| 78 |
+
explicit _EmbeddingBagKernelCacheImpl(c10::optional<int64_t> maybe_block_size)
|
| 79 |
+
: StorageMixins(maybe_block_size)...
|
| 80 |
+
{}
|
| 81 |
+
|
| 82 |
+
// this method is thread safe (call sites may call from different threads)
|
| 83 |
+
template<bool has_weight, typename TIndex, typename TData>
|
| 84 |
+
typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback
|
| 85 |
+
getCallback(int64_t block_size) const {
|
| 86 |
+
// if the cache doesn't store the kernel for the incoming block size
|
| 87 |
+
// (so it is different from the one stored in corresponding mixin)
|
| 88 |
+
// regenerate the kernel (not writing it into the cache so we avoid locks)
|
| 89 |
+
if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) {
|
| 90 |
+
return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size);
|
| 91 |
+
}
|
| 92 |
+
// else retrieve the cached kernel from the corresponding mixin
|
| 93 |
+
return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback;
|
| 94 |
+
}
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
// instantiate the cache with the list of storage mixins
|
| 98 |
+
// for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
|
| 99 |
+
using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
|
| 100 |
+
_CallbackAndBlockSize<true, int32_t, float>,
|
| 101 |
+
_CallbackAndBlockSize<false, int32_t, float>,
|
| 102 |
+
_CallbackAndBlockSize<true, int64_t, float>,
|
| 103 |
+
_CallbackAndBlockSize<false, int64_t, float>,
|
| 104 |
+
_CallbackAndBlockSize<true, int32_t, unsigned short>,
|
| 105 |
+
_CallbackAndBlockSize<false, int32_t, unsigned short>,
|
| 106 |
+
_CallbackAndBlockSize<true, int64_t, unsigned short>,
|
| 107 |
+
_CallbackAndBlockSize<false, int64_t, unsigned short>>;
|
| 108 |
+
#else
|
| 109 |
+
struct _EmbeddingBagKernelCache {
|
| 110 |
+
explicit _EmbeddingBagKernelCache(c10::optional<int64_t> /* maybe_block_size */) {}
|
| 111 |
+
};
|
| 112 |
+
#endif
|
| 113 |
+
|
| 114 |
+
void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
|
| 115 |
+
Tensor& bag_size, Tensor* max_indices,
|
| 116 |
+
const Tensor &weight, const Tensor &indices,
|
| 117 |
+
const Tensor &offsets, const int64_t mode = 0,
|
| 118 |
+
const c10::optional<Tensor>& per_sample_weights = c10::nullopt,
|
| 119 |
+
bool include_last_offset = false,
|
| 120 |
+
int64_t padding_idx = -1,
|
| 121 |
+
_EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
|
| 122 |
+
|
| 123 |
+
void _embedding_bag_cpu_out(
|
| 124 |
+
at::Tensor& output,
|
| 125 |
+
at::Tensor& offset2bag,
|
| 126 |
+
at::Tensor& bag_size,
|
| 127 |
+
at::Tensor* p_max_indices,
|
| 128 |
+
const at::Tensor& weight,
|
| 129 |
+
const at::Tensor& indices,
|
| 130 |
+
const at::Tensor& offsets,
|
| 131 |
+
const bool scale_grad_by_freq,
|
| 132 |
+
const int64_t mode,
|
| 133 |
+
const bool sparse,
|
| 134 |
+
const c10::optional<at::Tensor>& per_sample_weights,
|
| 135 |
+
const bool include_last_offset,
|
| 136 |
+
const c10::optional<int64_t>& padding_idx,
|
| 137 |
+
_EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
|
| 138 |
+
|
| 139 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Fill.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Functions that fill Tensors with constants. Implementations are in Fill.cpp.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <ATen/native/DispatchStub.h>
|
| 6 |
+
|
| 7 |
+
namespace c10 {
|
| 8 |
+
class Scalar;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
class Tensor;
|
| 13 |
+
struct TensorIterator;
|
| 14 |
+
|
| 15 |
+
namespace native {
|
| 16 |
+
|
| 17 |
+
DECLARE_DISPATCH(void(*)(TensorIterator&, const c10::Scalar&), fill_stub);
|
| 18 |
+
|
| 19 |
+
Tensor& fill_out(Tensor& self, const Scalar& value);
|
| 20 |
+
|
| 21 |
+
}} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LossMulti.h
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <ATen/AccumulateType.h>
|
| 4 |
+
#include <ATen/Dispatch.h>
|
| 5 |
+
#include <ATen/TensorUtils.h>
|
| 6 |
+
|
| 7 |
+
namespace at::native {
|
| 8 |
+
namespace {
|
| 9 |
+
static C10_UNUSED void multilabel_margin_loss_shape_check(
|
| 10 |
+
int64_t& nframe,
|
| 11 |
+
int64_t& dim,
|
| 12 |
+
const int64_t& ndims,
|
| 13 |
+
const Tensor& input,
|
| 14 |
+
const Tensor& target) {
|
| 15 |
+
TORCH_CHECK(
|
| 16 |
+
(ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
|
| 17 |
+
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
|
| 18 |
+
input.sizes());
|
| 19 |
+
|
| 20 |
+
if (ndims <= 1) {
|
| 21 |
+
nframe = 1;
|
| 22 |
+
dim = ndims == 0 ? 1 : input.size(0);
|
| 23 |
+
TORCH_CHECK(
|
| 24 |
+
target.dim() <= 1 && target.numel() == dim,
|
| 25 |
+
"inconsistent target size: ", target.sizes(), " for input of size: ",
|
| 26 |
+
input.sizes());
|
| 27 |
+
} else {
|
| 28 |
+
nframe = input.size(0);
|
| 29 |
+
dim = input.size(1);
|
| 30 |
+
TORCH_CHECK(
|
| 31 |
+
target.dim() == 2 && target.size(0) == nframe &&
|
| 32 |
+
target.size(1) == dim,
|
| 33 |
+
"inconsistent target size: ", target.sizes(), " for input of size: ",
|
| 34 |
+
input.sizes());
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
static C10_UNUSED void multi_margin_loss_shape_check(
|
| 39 |
+
int64_t& nframe,
|
| 40 |
+
int64_t& dim,
|
| 41 |
+
const int64_t& ndims,
|
| 42 |
+
const Tensor& input,
|
| 43 |
+
const Tensor& target,
|
| 44 |
+
const c10::optional<Tensor>& weight) {
|
| 45 |
+
TORCH_CHECK(
|
| 46 |
+
(ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
|
| 47 |
+
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
|
| 48 |
+
input.sizes());
|
| 49 |
+
|
| 50 |
+
if (ndims <= 1) {
|
| 51 |
+
nframe = 1;
|
| 52 |
+
dim = ndims == 0 ? 1 : input.size(0);
|
| 53 |
+
} else {
|
| 54 |
+
nframe = input.size(0);
|
| 55 |
+
dim = input.size(1);
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
TORCH_CHECK(
|
| 59 |
+
target.dim() <= 1 && target.numel() == nframe,
|
| 60 |
+
"inconsistent target size, expected ", nframe, " but got ",
|
| 61 |
+
target.sizes());
|
| 62 |
+
if (weight && weight->defined()) {
|
| 63 |
+
TORCH_CHECK(
|
| 64 |
+
weight->dim() <= 1 && weight->numel() == dim,
|
| 65 |
+
"inconsistent weight size, expected ", dim, " but got ",
|
| 66 |
+
weight->sizes());
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
} // anonymous namespace
|
| 72 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Normalization.h
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/TensorIterator.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
|
| 9 |
+
DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
|
| 10 |
+
|
| 11 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Pow.h
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
|
| 5 |
+
namespace c10 {
|
| 6 |
+
class Scalar;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
|
| 11 |
+
struct TensorIterator;
|
| 12 |
+
struct TensorIteratorBase;
|
| 13 |
+
|
| 14 |
+
namespace native {
|
| 15 |
+
|
| 16 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 17 |
+
#define HOST_DEVICE __host__ __device__
|
| 18 |
+
#else
|
| 19 |
+
#define HOST_DEVICE
|
| 20 |
+
#endif
|
| 21 |
+
|
| 22 |
+
// integral power in pytorch allows for negative exponents, giving truncated integral results.
|
| 23 |
+
// e.g. since 2**-1==0.5, the truncated integral result is zero. 1**negative_exponent is the
|
| 24 |
+
// only non-zero result.
|
| 25 |
+
template <class T,
|
| 26 |
+
typename std::enable_if<std::is_integral<T>::value, T>::type* = nullptr>
|
| 27 |
+
static inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) {
|
| 28 |
+
T result = 1;
|
| 29 |
+
while (b) {
|
| 30 |
+
if (b & 1) {
|
| 31 |
+
result *= a;
|
| 32 |
+
}
|
| 33 |
+
b /= 2;
|
| 34 |
+
a *= a;
|
| 35 |
+
}
|
| 36 |
+
return result;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
template <class T,
|
| 40 |
+
typename std::enable_if<std::is_integral<T>::value && !std::is_signed<T>::value, T>::type* = nullptr>
|
| 41 |
+
static inline HOST_DEVICE T powi(T a, T b) {
|
| 42 |
+
return powi_impl(a, b);
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
template <class T,
|
| 46 |
+
typename std::enable_if<std::is_integral<T>::value && std::is_signed<T>::value, T>::type* = nullptr>
|
| 47 |
+
static inline HOST_DEVICE T powi(T a, T b) {
|
| 48 |
+
if ( b < 0 ) {
|
| 49 |
+
if ( a == 1 ) {
|
| 50 |
+
return 1;
|
| 51 |
+
} else if ( a == -1 ) {
|
| 52 |
+
auto negative = (-b) % static_cast<T>(2);
|
| 53 |
+
return negative ? -1 : 1;
|
| 54 |
+
} else {
|
| 55 |
+
return 0;
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
return powi_impl(a, b);
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
using pow_tensor_tensor_fn = void (*)(TensorIteratorBase&);
|
| 62 |
+
using pow_tensor_scalar_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
|
| 63 |
+
|
| 64 |
+
DECLARE_DISPATCH(pow_tensor_tensor_fn, pow_tensor_tensor_stub);
|
| 65 |
+
DECLARE_DISPATCH(pow_tensor_scalar_fn, pow_tensor_scalar_stub);
|
| 66 |
+
|
| 67 |
+
} // namespace native
|
| 68 |
+
|
| 69 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceOps.h
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <c10/util/ArrayRef.h>
|
| 5 |
+
#include <c10/util/Optional.h>
|
| 6 |
+
|
| 7 |
+
namespace c10 {
|
| 8 |
+
class Scalar;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
struct TensorIterator;
|
| 13 |
+
class Tensor;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
namespace at::native {
|
| 17 |
+
|
| 18 |
+
using reduce_fn = void(*)(TensorIterator &);
|
| 19 |
+
|
| 20 |
+
DECLARE_DISPATCH(reduce_fn, sum_stub);
|
| 21 |
+
DECLARE_DISPATCH(reduce_fn, nansum_stub);
|
| 22 |
+
DECLARE_DISPATCH(reduce_fn, prod_stub);
|
| 23 |
+
DECLARE_DISPATCH(reduce_fn, mean_stub);
|
| 24 |
+
DECLARE_DISPATCH(reduce_fn, and_stub);
|
| 25 |
+
DECLARE_DISPATCH(reduce_fn, or_stub);
|
| 26 |
+
DECLARE_DISPATCH(reduce_fn, min_values_stub);
|
| 27 |
+
DECLARE_DISPATCH(reduce_fn, max_values_stub);
|
| 28 |
+
DECLARE_DISPATCH(reduce_fn, argmax_stub);
|
| 29 |
+
DECLARE_DISPATCH(reduce_fn, argmin_stub);
|
| 30 |
+
|
| 31 |
+
using reduce_std_var_function =
|
| 32 |
+
void (*)(TensorIterator&, double correction, bool take_sqrt);
|
| 33 |
+
DECLARE_DISPATCH(reduce_std_var_function, std_var_stub);
|
| 34 |
+
|
| 35 |
+
using reduce_norm_fn =
|
| 36 |
+
void (*)(Tensor&, const Tensor&, const c10::Scalar&, c10::optional<int64_t>);
|
| 37 |
+
DECLARE_DISPATCH(reduce_norm_fn, norm_kernel);
|
| 38 |
+
|
| 39 |
+
using reduce_fn_flag = void(*)(TensorIterator &, const c10::Scalar&);
|
| 40 |
+
DECLARE_DISPATCH(reduce_fn_flag, norm_stub);
|
| 41 |
+
|
| 42 |
+
using structured_cum_fn = void (*)(const Tensor&, const Tensor&, int64_t);
|
| 43 |
+
using cum_fn = void (*)(Tensor&, const Tensor&, int64_t);
|
| 44 |
+
DECLARE_DISPATCH(structured_cum_fn, cumsum_stub);
|
| 45 |
+
DECLARE_DISPATCH(structured_cum_fn, cumprod_stub);
|
| 46 |
+
DECLARE_DISPATCH(cum_fn, logcumsumexp_stub);
|
| 47 |
+
|
| 48 |
+
DECLARE_DISPATCH(void (*)(const Tensor&, int64_t, bool, Tensor&, Tensor&), aminmax_stub);
|
| 49 |
+
DECLARE_DISPATCH(void (*)(const Tensor&, Tensor&, Tensor&), aminmax_allreduce_stub);
|
| 50 |
+
|
| 51 |
+
// Used in cuda/Normalization.cu
|
| 52 |
+
TORCH_API std::tuple<Tensor&,Tensor&> var_mean_out(
|
| 53 |
+
Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim,
|
| 54 |
+
int64_t correction, bool keepdim);
|
| 55 |
+
|
| 56 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/// This file contains some tensor-agnostic operations to be used in the
|
| 2 |
+
/// core functions of the `SobolEngine`
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
|
| 5 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 6 |
+
#include <ATen/Functions.h>
|
| 7 |
+
#else
|
| 8 |
+
#include <ATen/ops/arange.h>
|
| 9 |
+
#include <ATen/ops/mul.h>
|
| 10 |
+
#include <ATen/ops/pow.h>
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
namespace at::native::sobol_utils {
|
| 14 |
+
|
| 15 |
+
/// Function to return the minimum of number of bits to represent the integer `n`
|
| 16 |
+
inline int64_t bit_length(const int64_t n) {
|
| 17 |
+
int64_t nbits, nloc;
|
| 18 |
+
for (nloc = n, nbits = 0; nloc > 0; nloc /= 2, nbits++);
|
| 19 |
+
return nbits;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
/// Function to get the position of the rightmost zero in the bit representation of an integer
|
| 23 |
+
/// This value is the zero-indexed position
|
| 24 |
+
inline int64_t rightmost_zero(const int64_t n) {
|
| 25 |
+
int64_t z, i;
|
| 26 |
+
for (z = n, i = 0; z % 2 == 1; z /= 2, i++);
|
| 27 |
+
return i;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
/// Function to get a subsequence of bits in the representation of an integer starting from
|
| 31 |
+
/// `pos` and of length `length`
|
| 32 |
+
inline int64_t bitsubseq(const int64_t n, const int64_t pos, const int64_t length) {
|
| 33 |
+
return (n >> pos) & ((1 << length) - 1);
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
/// Function to perform the inner product between a batched square matrix and a power of 2 vector
|
| 37 |
+
inline at::Tensor cdot_pow2(const at::Tensor& bmat) {
|
| 38 |
+
at::Tensor inter = at::arange(bmat.size(-1) - 1, -1, -1, bmat.options());
|
| 39 |
+
inter = at::pow(2, inter).expand_as(bmat);
|
| 40 |
+
return at::mul(inter, bmat).sum(-1);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
/// All definitions below this point are data. These are constant, and should not be modified
|
| 44 |
+
/// without notice
|
| 45 |
+
|
| 46 |
+
constexpr int64_t MAXDIM = 21201;
|
| 47 |
+
constexpr int64_t MAXDEG = 18;
|
| 48 |
+
constexpr int64_t MAXBIT = 30;
|
| 49 |
+
constexpr int64_t LARGEST_NUMBER = 1 << MAXBIT;
|
| 50 |
+
constexpr float RECIPD = 1.0 / LARGEST_NUMBER;
|
| 51 |
+
|
| 52 |
+
extern const int64_t poly[MAXDIM];
|
| 53 |
+
extern const int64_t initsobolstate[MAXDIM][MAXDEG];
|
| 54 |
+
|
| 55 |
+
} // namespace at::native::sobol_utils
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorCompare.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
|
| 5 |
+
namespace c10 {
|
| 6 |
+
class Scalar;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
class Tensor;
|
| 11 |
+
struct TensorIterator;
|
| 12 |
+
struct TensorIteratorBase;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
namespace at::native {
|
| 16 |
+
|
| 17 |
+
using reduce_minmax_fn =
|
| 18 |
+
void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool);
|
| 19 |
+
using structured_reduce_minmax_fn =
|
| 20 |
+
void (*)(const Tensor&, const Tensor&, const Tensor&, int64_t, bool);
|
| 21 |
+
|
| 22 |
+
DECLARE_DISPATCH(structured_reduce_minmax_fn, max_stub);
|
| 23 |
+
DECLARE_DISPATCH(structured_reduce_minmax_fn, min_stub);
|
| 24 |
+
|
| 25 |
+
using where_fn = void (*)(TensorIterator &);
|
| 26 |
+
DECLARE_DISPATCH(where_fn, where_kernel);
|
| 27 |
+
|
| 28 |
+
using is_infinity_op_fn = void (*)(TensorIteratorBase &);
|
| 29 |
+
DECLARE_DISPATCH(is_infinity_op_fn, isposinf_stub);
|
| 30 |
+
DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub);
|
| 31 |
+
|
| 32 |
+
using mode_fn = void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool);
|
| 33 |
+
DECLARE_DISPATCH(mode_fn, mode_stub);
|
| 34 |
+
|
| 35 |
+
using clamp_tensor_fn = void (*)(TensorIteratorBase &);
|
| 36 |
+
DECLARE_DISPATCH(clamp_tensor_fn, clamp_stub);
|
| 37 |
+
|
| 38 |
+
namespace detail {
|
| 39 |
+
enum class ClampLimits {Min, Max, MinMax};
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
DECLARE_DISPATCH(void (*)(TensorIteratorBase &, const c10::Scalar&, const c10::Scalar&), clamp_scalar_stub);
|
| 43 |
+
DECLARE_DISPATCH(void (*)(TensorIteratorBase &, c10::Scalar), clamp_min_scalar_stub);
|
| 44 |
+
DECLARE_DISPATCH(void (*)(TensorIteratorBase &, c10::Scalar), clamp_max_scalar_stub);
|
| 45 |
+
|
| 46 |
+
using isin_default_fn = void (*)(const Tensor&, const Tensor&, bool, const Tensor&);
|
| 47 |
+
DECLARE_DISPATCH(isin_default_fn, isin_default_stub);
|
| 48 |
+
|
| 49 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorIterator.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/TensorIterator.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TriangularOpsUtils.h
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Tensor.h>
|
| 2 |
+
#include <ATen/native/LinearAlgebraUtils.h>
|
| 3 |
+
|
| 4 |
+
namespace at::native {
|
| 5 |
+
|
| 6 |
+
/*
|
| 7 |
+
* Given batches of matrices with arbitrary batch dim,
|
| 8 |
+
* computes the number of batches for Triu and Tril. This ignores stride 0 dimension
|
| 9 |
+
*/
|
| 10 |
+
static inline int64_t batchCountTrilTriu(const Tensor& batched_matrices) {
|
| 11 |
+
int64_t result = 1;
|
| 12 |
+
for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
|
| 13 |
+
if (batched_matrices.stride(i) != 0) {
|
| 14 |
+
result *= batched_matrices.size(i);
|
| 15 |
+
}
|
| 16 |
+
}
|
| 17 |
+
return result;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
/* Checks a necessary property for the triu and tril implementations, hence the name.
|
| 21 |
+
* Here batch contiguity is checked for tensors with greater than 4 dimensions.
|
| 22 |
+
* Contiguous tensors and tensors with less than 3 dimensions pass this check
|
| 23 |
+
*/
|
| 24 |
+
static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor& tensor, bool allow_zero_stride) {
|
| 25 |
+
// Complete contiguity is the most desired property, which is why
|
| 26 |
+
// we return true if the tensor is contiguous
|
| 27 |
+
if (tensor.is_contiguous()) {
|
| 28 |
+
auto default_strides_for_size = batched_matrix_contiguous_strides(tensor.sizes());
|
| 29 |
+
if (tensor.strides() == default_strides_for_size) {
|
| 30 |
+
return std::make_tuple(true, tensor);
|
| 31 |
+
} else {
|
| 32 |
+
return std::make_tuple(false, tensor.as_strided(tensor.sizes(), default_strides_for_size));
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
int64_t dims = tensor.dim();
|
| 37 |
+
|
| 38 |
+
// Tensors with dimension less than 4 are handled by default
|
| 39 |
+
if (allow_zero_stride && dims <= 3) {
|
| 40 |
+
return std::make_tuple(true, tensor);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
int64_t expected_stride = tensor.size(-1) * tensor.size(-2);
|
| 44 |
+
for (int64_t i = dims - 3; i >= 0; i--) {
|
| 45 |
+
// Skip trivial dimension;
|
| 46 |
+
if (allow_zero_stride && i == 0 && (tensor.stride(i) == 0 || tensor.size(i) == 1)) {
|
| 47 |
+
continue;
|
| 48 |
+
}
|
| 49 |
+
if (expected_stride != tensor.stride(i)) {
|
| 50 |
+
return std::make_tuple(false, tensor.contiguous());
|
| 51 |
+
}
|
| 52 |
+
expected_stride *= tensor.size(i);
|
| 53 |
+
}
|
| 54 |
+
return std::make_tuple(true, tensor);
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IsContiguous.h
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace at { namespace native { inline namespace CPU_CAPABILITY {
|
| 4 |
+
|
| 5 |
+
// n: number of function arguments (arity)
|
| 6 |
+
// traits: function_traits (see FunctionTraits.h)
|
| 7 |
+
// s: index of scalar argument or -1
|
| 8 |
+
template <int n, int stride_index, typename traits, int s=-1>
|
| 9 |
+
struct IsContiguous {
|
| 10 |
+
static bool eval(const int64_t* strides) {
|
| 11 |
+
using type = typename traits::template arg<n - 1>::type;
|
| 12 |
+
return strides[stride_index] == (s == n ? 0 : sizeof(type)) &&
|
| 13 |
+
IsContiguous<n - 1, stride_index - 1, traits, s>::eval(strides);
|
| 14 |
+
}
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
// will be called when there is an output exists
|
| 18 |
+
template <typename traits, int s>
|
| 19 |
+
struct IsContiguous<0, 0, traits, s> {
|
| 20 |
+
static bool eval(const int64_t* strides) {
|
| 21 |
+
return strides[0] == sizeof(typename traits::result_type);
|
| 22 |
+
}
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
// will be called when there is no output
|
| 26 |
+
template <typename traits, int s>
|
| 27 |
+
struct IsContiguous<0, -1, traits, s> {
|
| 28 |
+
static bool eval(const int64_t* /*strides*/) {
|
| 29 |
+
return true;
|
| 30 |
+
}
|
| 31 |
+
};
|
| 32 |
+
|
| 33 |
+
// output and all inputs are contiguous
|
| 34 |
+
template <typename traits,
|
| 35 |
+
typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr>
|
| 36 |
+
static inline bool is_contiguous(const int64_t* strides) {
|
| 37 |
+
return IsContiguous<traits::arity, traits::arity - 1, traits>::eval(strides);
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
template <typename traits,
|
| 41 |
+
typename std::enable_if<!std::is_void<typename traits::result_type>::value>::type* = nullptr>
|
| 42 |
+
static inline bool is_contiguous(const int64_t* strides) {
|
| 43 |
+
return IsContiguous<traits::arity, traits::arity, traits>::eval(strides);
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
// input at `s` is scalar (stride 0); output and other inputs are contiguous
|
| 47 |
+
// NB: output is typically at strides[0] so first input corresponds to s=1
|
| 48 |
+
template <typename traits, int s,
|
| 49 |
+
typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr>
|
| 50 |
+
static inline bool is_contiguous_scalar(const int64_t* strides) {
|
| 51 |
+
static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
|
| 52 |
+
return IsContiguous<traits::arity, traits::arity - 1, traits, s>::eval(strides);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
template <typename traits, int s,
|
| 56 |
+
typename std::enable_if<!std::is_void<typename traits::result_type>::value>::type* = nullptr>
|
| 57 |
+
static inline bool is_contiguous_scalar(const int64_t* strides) {
|
| 58 |
+
static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
|
| 59 |
+
return IsContiguous<traits::arity, traits::arity, traits, s>::eval(strides);
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
}}}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
class Tensor;
|
| 8 |
+
|
| 9 |
+
namespace native {
|
| 10 |
+
|
| 11 |
+
using forward_fn = void (*)(const Tensor&, const Tensor&);
|
| 12 |
+
using backward_fn = void(*)(const Tensor &, const Tensor &, const Tensor&);
|
| 13 |
+
|
| 14 |
+
DECLARE_DISPATCH(forward_fn, softmax_lastdim_kernel);
|
| 15 |
+
DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel);
|
| 16 |
+
DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel);
|
| 17 |
+
DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel);
|
| 18 |
+
|
| 19 |
+
using forward_fn_with_dim = void(*)(const Tensor &, const Tensor &, const int64_t);
|
| 20 |
+
using backward_fn_with_dim =
|
| 21 |
+
void (*)(const Tensor&, const Tensor&, const Tensor&, const int64_t);
|
| 22 |
+
|
| 23 |
+
DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel);
|
| 24 |
+
DECLARE_DISPATCH(forward_fn_with_dim, log_softmax_kernel);
|
| 25 |
+
DECLARE_DISPATCH(backward_fn_with_dim, softmax_backward_kernel);
|
| 26 |
+
DECLARE_DISPATCH(backward_fn_with_dim, log_softmax_backward_kernel);
|
| 27 |
+
}
|
| 28 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/jit_macros.h>
|
| 3 |
+
|
| 4 |
+
// Jiterator functions are guarded behind this macro
|
| 5 |
+
#if AT_USE_JITERATOR()
|
| 6 |
+
|
| 7 |
+
#include <ATen/OpMathType.h>
|
| 8 |
+
#include <ATen/TensorIterator.h>
|
| 9 |
+
#include <ATen/core/Array.h>
|
| 10 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 11 |
+
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
| 12 |
+
#include <ATen/native/cuda/jit_utils.h>
|
| 13 |
+
#include <ATen/native/cuda/MemoryAccess.cuh>
|
| 14 |
+
#include <ATen/native/cuda/thread_constants.h>
|
| 15 |
+
|
| 16 |
+
#include <ATen/native/cuda/Loops.cuh>
|
| 17 |
+
|
| 18 |
+
#include <c10/macros/Macros.h>
|
| 19 |
+
#include <c10/core/ScalarType.h>
|
| 20 |
+
#include <c10/util/SmallBuffer.h>
|
| 21 |
+
|
| 22 |
+
#include <initializer_list>
|
| 23 |
+
#include <type_traits>
|
| 24 |
+
#include <tuple>
|
| 25 |
+
#include <mutex>
|
| 26 |
+
|
| 27 |
+
namespace at {
|
| 28 |
+
namespace native {
|
| 29 |
+
|
| 30 |
+
template <typename Tuple, std::size_t... I>
|
| 31 |
+
constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence<I...> seq) {
|
| 32 |
+
constexpr auto size = seq.size();
|
| 33 |
+
(void)t; // warning : unused parameter when tuple is empty.
|
| 34 |
+
return std::array<void*, size>{static_cast<void*>(&std::get<I>(t))...};
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// Helper function convert tuple to std::array<void*, N>
|
| 38 |
+
// for passing the arguments to CUDA Kernel
|
| 39 |
+
// NOTE: We capture tuple by reference,
|
| 40 |
+
// so the pointers in returned array are only valid
|
| 41 |
+
// till tuple is alive.
|
| 42 |
+
template <typename ...Args>
|
| 43 |
+
constexpr auto tuple_to_array(std::tuple<Args...>& extra_args) {
|
| 44 |
+
constexpr auto tuple_size = sizeof...(Args);
|
| 45 |
+
return tuple_to_array_helper(extra_args, std::make_index_sequence<tuple_size>{});
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
struct JittedVecKernelCache {
|
| 49 |
+
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
|
| 50 |
+
at::cuda::jit::NvrtcFunction vec1;
|
| 51 |
+
at::cuda::jit::NvrtcFunction vec2;
|
| 52 |
+
at::cuda::jit::NvrtcFunction vec4;
|
| 53 |
+
};
|
| 54 |
+
|
| 55 |
+
struct JittedKernelVariantCache {
|
| 56 |
+
JittedVecKernelCache vec;
|
| 57 |
+
at::cuda::jit::NvrtcFunction noncontiguous;
|
| 58 |
+
at::cuda::jit::NvrtcFunction dynamic_contiguous;
|
| 59 |
+
at::cuda::jit::NvrtcFunction dynamic_noncontiguous;
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
inline c10::SmallBuffer<void*, 64> pack_kernel_args(
|
| 63 |
+
std::initializer_list<void*> args,
|
| 64 |
+
c10::ArrayRef<void*> extra_args) {
|
| 65 |
+
c10::SmallBuffer<void*, 64> ret(args.size() + extra_args.size());
|
| 66 |
+
std::copy(args.begin(), args.end(), ret.data());
|
| 67 |
+
std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size());
|
| 68 |
+
return ret;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
template<typename array_t,
|
| 72 |
+
typename inp_calc_t,
|
| 73 |
+
typename out_calc_t,
|
| 74 |
+
typename loader_t,
|
| 75 |
+
typename storer_t>
|
| 76 |
+
void launch_jitted_unrolled_kernel(
|
| 77 |
+
std::mutex &jiterator_mutex,
|
| 78 |
+
at::cuda::jit::NvrtcFunction &fn_cache,
|
| 79 |
+
const at::cuda::jit::KernelDescriptor &desc,
|
| 80 |
+
int64_t N,
|
| 81 |
+
array_t data,
|
| 82 |
+
inp_calc_t ic,
|
| 83 |
+
out_calc_t oc,
|
| 84 |
+
loader_t l,
|
| 85 |
+
storer_t s,
|
| 86 |
+
bool contiguous,
|
| 87 |
+
at::cuda::jit::BinaryFuncVariant scalar_pos,
|
| 88 |
+
void* scalar_val,
|
| 89 |
+
c10::ArrayRef<void*> extra_args) {
|
| 90 |
+
|
| 91 |
+
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
| 92 |
+
//casting result to int is always safe, intermediate is int64 and won't overflow
|
| 93 |
+
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
|
| 94 |
+
|
| 95 |
+
if (!fn_cache.function) {
|
| 96 |
+
const std::lock_guard<std::mutex> lock{jiterator_mutex};
|
| 97 |
+
if (!fn_cache.function) {
|
| 98 |
+
constexpr bool dynamic_casting = !std::is_same<decltype(l), memory::LoadWithoutCast>() ||
|
| 99 |
+
!std::is_same<decltype(s), memory::StoreWithoutCast>();
|
| 100 |
+
auto code = at::cuda::jit::generate_code(
|
| 101 |
+
desc, contiguous, dynamic_casting, scalar_pos);
|
| 102 |
+
fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name);
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
|
| 107 |
+
at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u},
|
| 108 |
+
{num_threads(), 1u, 1u});
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
template<int arity, typename array_t>
|
| 112 |
+
void launch_jitted_vectorized_kernel(
|
| 113 |
+
std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache,
|
| 114 |
+
const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data,
|
| 115 |
+
at::cuda::jit::BinaryFuncVariant scalar_pos,
|
| 116 |
+
void *scalar_val, c10::ArrayRef<void*> extra_args) {
|
| 117 |
+
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
| 118 |
+
// N is still int64_t for the computation, but it's always safe to cast result to int
|
| 119 |
+
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
|
| 120 |
+
const int vec_size = at::cuda::jit::can_vectorize_up_to(
|
| 121 |
+
desc, c10::ArrayRef<char*>(data.data, data.size()));
|
| 122 |
+
|
| 123 |
+
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
|
| 124 |
+
// fn_ptr is set to the appropriate function based on the vec size and GPU used
|
| 125 |
+
at::cuda::jit::NvrtcFunction* fn_ptr;
|
| 126 |
+
if (vec_size == 4) {
|
| 127 |
+
fn_ptr = &fn_cache.vec4;
|
| 128 |
+
} else if (vec_size == 2) {
|
| 129 |
+
fn_ptr = &fn_cache.vec2;
|
| 130 |
+
} else if (vec_size ==1) {
|
| 131 |
+
fn_ptr = &fn_cache.vec1;
|
| 132 |
+
} else {
|
| 133 |
+
TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
bool vectorized = vec_size > 1;
|
| 137 |
+
|
| 138 |
+
if (!fn_ptr->function) {
|
| 139 |
+
const std::lock_guard<std::mutex> lock{jiterator_mutex};
|
| 140 |
+
if (!fn_ptr->function) { // cache miss!
|
| 141 |
+
|
| 142 |
+
// Generates program
|
| 143 |
+
auto code = at::cuda::jit::generate_code(
|
| 144 |
+
desc, /*contiguous=*/true, /*dynamic_casting=*/false,
|
| 145 |
+
scalar_pos, vectorized, vec_size);
|
| 146 |
+
std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name;
|
| 147 |
+
|
| 148 |
+
// Acquires the program
|
| 149 |
+
*fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
if (vectorized) {
|
| 154 |
+
auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args);
|
| 155 |
+
at::cuda::jit::launch_jitted_pwise_function(
|
| 156 |
+
*fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
|
| 157 |
+
} else {
|
| 158 |
+
// NVCC complains about unused variables l and s.
|
| 159 |
+
// It should be false positive in most cases, so we suppress the warnings.
|
| 160 |
+
#pragma nv_diagnostic push
|
| 161 |
+
#pragma nv_diag_suppress 177
|
| 162 |
+
auto ic = TrivialOffsetCalculator<arity>();
|
| 163 |
+
auto oc = TrivialOffsetCalculator<1>();
|
| 164 |
+
auto l = memory::LoadWithoutCast();
|
| 165 |
+
auto s = memory::StoreWithoutCast();
|
| 166 |
+
|
| 167 |
+
auto args = pack_kernel_args(
|
| 168 |
+
{&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
|
| 169 |
+
at::cuda::jit::launch_jitted_pwise_function(
|
| 170 |
+
*fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
|
| 171 |
+
#pragma nv_diagnostic pop
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
template <int arity>
|
| 176 |
+
void jitted_gpu_kernel_generic(
|
| 177 |
+
std::mutex &jiterator_mutex,
|
| 178 |
+
JittedKernelVariantCache &cache,
|
| 179 |
+
const at::cuda::jit::KernelDescriptor &desc,
|
| 180 |
+
at::cuda::jit::BinaryFuncVariant scalar_pos,
|
| 181 |
+
c10::ArrayRef<void*> extra_args,
|
| 182 |
+
TensorIteratorBase& iter,
|
| 183 |
+
const bool dynamic_casting,
|
| 184 |
+
void *scalar_val) {
|
| 185 |
+
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
|
| 186 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
|
| 187 |
+
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
| 188 |
+
|
| 189 |
+
constexpr int ntensors = arity + 1;
|
| 190 |
+
at::detail::Array<char*, ntensors> data;
|
| 191 |
+
for (auto i : c10::irange(ntensors)) {
|
| 192 |
+
data[i] = (char*)iter.data_ptr(i);
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
int64_t numel = iter.numel();
|
| 196 |
+
bool contiguous = iter.is_contiguous();
|
| 197 |
+
|
| 198 |
+
// Decides which of 4 kernel types to launch
|
| 199 |
+
// Variations are:
|
| 200 |
+
// - Case 1: no dynamic casting and contiguous
|
| 201 |
+
// - Case 2: no dynamic casting and noncontiguous
|
| 202 |
+
// - Case 3: dynamic casting and contiguous
|
| 203 |
+
// - Case 4: dynamic casting and noncontiguous
|
| 204 |
+
// These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
|
| 205 |
+
|
| 206 |
+
if (!dynamic_casting) {
|
| 207 |
+
if (contiguous) {
|
| 208 |
+
// Case 1: no dynamic casting and contiguous
|
| 209 |
+
launch_jitted_vectorized_kernel<arity>(
|
| 210 |
+
jiterator_mutex, cache.vec, desc,
|
| 211 |
+
numel, data, scalar_pos, scalar_val, extra_args);
|
| 212 |
+
return;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
// Case 2: no dynamic casting and noncontiguous
|
| 216 |
+
auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
|
| 217 |
+
auto output_offset_calculator = make_output_offset_calculator(iter);
|
| 218 |
+
auto loader = memory::LoadWithoutCast();
|
| 219 |
+
auto storer = memory::StoreWithoutCast();
|
| 220 |
+
launch_jitted_unrolled_kernel(
|
| 221 |
+
jiterator_mutex, cache.noncontiguous, desc, numel, data,
|
| 222 |
+
input_offset_calculator, output_offset_calculator, loader,
|
| 223 |
+
storer, contiguous, scalar_pos, scalar_val, extra_args);
|
| 224 |
+
return;
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
// Cases 3 and 4 are handled below
|
| 228 |
+
// Both require construction of a storer (this asserts 1 output) and one or more loaders
|
| 229 |
+
|
| 230 |
+
// Creates store cast to output (the zeroth tensor in TensorIterator)
|
| 231 |
+
auto storer = memory::StoreWithCast<1>(iter);
|
| 232 |
+
|
| 233 |
+
// Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
|
| 234 |
+
auto loader = memory::LoadWithCast<arity>(iter);
|
| 235 |
+
|
| 236 |
+
if (contiguous) {
|
| 237 |
+
// Case 3: dynamic casting and contiguous
|
| 238 |
+
auto input_offset_calculator = TrivialOffsetCalculator<arity>();
|
| 239 |
+
auto output_offset_calculator = TrivialOffsetCalculator<1>();
|
| 240 |
+
launch_jitted_unrolled_kernel(
|
| 241 |
+
jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator,
|
| 242 |
+
output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
|
| 243 |
+
return;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
// Case 4: dynamic casting and noncontiguous
|
| 247 |
+
auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
|
| 248 |
+
auto output_offset_calculator = make_output_offset_calculator(iter);
|
| 249 |
+
launch_jitted_unrolled_kernel(
|
| 250 |
+
jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator,
|
| 251 |
+
output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
// NOTE: static to reduce chances of name collision.
|
| 255 |
+
template <
|
| 256 |
+
char const* name,
|
| 257 |
+
typename result_type,
|
| 258 |
+
typename f_inputs_type,
|
| 259 |
+
int arity,
|
| 260 |
+
at::cuda::jit::BinaryFuncVariant scalar_pos =
|
| 261 |
+
at::cuda::jit::BinaryFuncVariant::NoScalar,
|
| 262 |
+
typename... ExtraArgs>
|
| 263 |
+
static void jitted_gpu_kernel_impl(
|
| 264 |
+
TensorIteratorBase& iter,
|
| 265 |
+
const std::string &f,
|
| 266 |
+
const bool dynamic_casting,
|
| 267 |
+
at::opmath_type<f_inputs_type> scalar_val,
|
| 268 |
+
std::tuple<ExtraArgs...> extra_args) {
|
| 269 |
+
|
| 270 |
+
// TODO: Memory use can probably be optimized by re-using kernels across GPUs with
|
| 271 |
+
// the same compute capability
|
| 272 |
+
static std::mutex jiterator_mutex;
|
| 273 |
+
static std::vector<JittedKernelVariantCache> device_caches(c10::cuda::device_count());
|
| 274 |
+
|
| 275 |
+
constexpr int nInputs = arity;
|
| 276 |
+
constexpr int nOutputs = 1; // TODO: Support more than 1 output
|
| 277 |
+
static const auto desc = at::cuda::jit::make_kernel_descriptor<
|
| 278 |
+
result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs);
|
| 279 |
+
|
| 280 |
+
auto &cache = device_caches[iter.device().index()];
|
| 281 |
+
auto extra_args_array = tuple_to_array(extra_args);
|
| 282 |
+
return jitted_gpu_kernel_generic<arity>(
|
| 283 |
+
jiterator_mutex,
|
| 284 |
+
cache,
|
| 285 |
+
desc,
|
| 286 |
+
scalar_pos,
|
| 287 |
+
extra_args_array,
|
| 288 |
+
iter,
|
| 289 |
+
dynamic_casting,
|
| 290 |
+
&scalar_val
|
| 291 |
+
);
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
}} // at::native
|
| 295 |
+
|
| 296 |
+
#endif // AT_USE_JITERATOR()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GridSampler.h
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <array>
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class TensorBase;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
namespace native {
|
| 11 |
+
|
| 12 |
+
void launch_grid_sampler_2d_forward_kernel(
|
| 13 |
+
const TensorBase &output, const TensorBase &input, const TensorBase &grid,
|
| 14 |
+
int64_t interpolation_mode, int64_t padding_mode, bool align_corners);
|
| 15 |
+
|
| 16 |
+
void launch_grid_sampler_3d_forward_kernel(
|
| 17 |
+
const TensorBase &output, const TensorBase &input, const TensorBase &grid,
|
| 18 |
+
int64_t interpolation_mode, int64_t padding_mode, bool align_corners);
|
| 19 |
+
|
| 20 |
+
void launch_grid_sampler_2d_backward_kernel(
|
| 21 |
+
const TensorBase &grad_input, const TensorBase &grad_grid,
|
| 22 |
+
const TensorBase &grad_output, const TensorBase &input,
|
| 23 |
+
const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode,
|
| 24 |
+
bool align_corners, std::array<bool, 2> output_mask);
|
| 25 |
+
|
| 26 |
+
void launch_grid_sampler_3d_backward_kernel(
|
| 27 |
+
const TensorBase &grad_input, const TensorBase &grad_grid,
|
| 28 |
+
const TensorBase &grad_output, const TensorBase &input,
|
| 29 |
+
const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode,
|
| 30 |
+
bool align_corners, std::array<bool, 2> output_mask);
|
| 31 |
+
|
| 32 |
+
}} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
#include <type_traits>
|
| 5 |
+
#include <c10/core/DynamicCast.h>
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
#include <c10/util/TypeCast.h>
|
| 8 |
+
#include <c10/macros/Macros.h>
|
| 9 |
+
#include <ATen/core/Array.h>
|
| 10 |
+
#include <ATen/detail/FunctionTraits.h>
|
| 11 |
+
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
| 12 |
+
#include <ATen/native/cuda/thread_constants.h>
|
| 13 |
+
|
| 14 |
+
#include <thrust/tuple.h>
|
| 15 |
+
|
| 16 |
+
// References:
|
| 17 |
+
// https://devblogs.nvidia.com/cuda-pro-tip-increase-performance-with-vectorized-memory-access/
|
| 18 |
+
|
| 19 |
+
namespace at { namespace native { namespace memory {
|
| 20 |
+
|
| 21 |
+
namespace detail {
|
| 22 |
+
|
| 23 |
+
// What does the `static_unroll` do?
|
| 24 |
+
//
|
| 25 |
+
// We want to do something like:
|
| 26 |
+
//
|
| 27 |
+
// using args_t = typename traits::ArgsTuple;
|
| 28 |
+
// args_t args;
|
| 29 |
+
// #pragma unroll
|
| 30 |
+
// for (int i = 0; i < traits::arity; i++) {
|
| 31 |
+
// std::get<i>(args) = ....
|
| 32 |
+
// }
|
| 33 |
+
//
|
| 34 |
+
// but unfortunately the above code does not work because
|
| 35 |
+
// the template argument has to be a compile time constant
|
| 36 |
+
// so `static_unroll` is created to simulate `#pragma unroll`
|
| 37 |
+
// using template metaprogramming.
|
| 38 |
+
|
| 39 |
+
template<template<int i> typename func, int end, int current=0>
|
| 40 |
+
struct static_unroll {
|
| 41 |
+
template<typename... Args>
|
| 42 |
+
static inline C10_HOST_DEVICE void with_args(Args&&... args) {
|
| 43 |
+
func<current>::apply(std::forward<Args>(args)...);
|
| 44 |
+
static_unroll<func, end, current+1>::with_args(args...);
|
| 45 |
+
}
|
| 46 |
+
};
|
| 47 |
+
|
| 48 |
+
template<template<int i> typename func, int end>
|
| 49 |
+
struct static_unroll<func, end, end> {
|
| 50 |
+
template<typename... Args>
|
| 51 |
+
static inline C10_HOST_DEVICE void with_args(Args... args) {}
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
// helper structs to be used with static_unroll to load arguments
|
| 55 |
+
// one by one
|
| 56 |
+
|
| 57 |
+
template<int arg_index>
|
| 58 |
+
struct vectorized_load_helper {
|
| 59 |
+
template <typename args_t, typename policy_t>
|
| 60 |
+
static __device__ void apply(policy_t &self, args_t *args, int idx) {
|
| 61 |
+
using arg_t = std::tuple_element_t<arg_index, args_t>;
|
| 62 |
+
// `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
|
| 63 |
+
// need a +1 offset to get the input
|
| 64 |
+
auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + block_work_size() * idx;
|
| 65 |
+
auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get<arg_index>(args[thread_unroll_idx]); };
|
| 66 |
+
self.load_single_arg(args_accessor, ptr);
|
| 67 |
+
}
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
template<int arg_index>
|
| 71 |
+
struct unroll_load_helper {
|
| 72 |
+
template <typename args_t, typename policy_t, typename offset_t, typename loader_t>
|
| 73 |
+
static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) {
|
| 74 |
+
using arg_t = std::tuple_element_t<arg_index, args_t>;
|
| 75 |
+
// `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
|
| 76 |
+
// need a +1 offset to get the input
|
| 77 |
+
std::get<arg_index>(args[j]) = loader.template load<arg_t>(self.data[arg_index + num_outputs], offset[arg_index], arg_index);
|
| 78 |
+
}
|
| 79 |
+
};
|
| 80 |
+
|
| 81 |
+
template <int current>
|
| 82 |
+
struct multi_outputs_store_helper {
|
| 83 |
+
template<int ntensors, int num_outputs, typename ...Args>
|
| 84 |
+
C10_HOST_DEVICE static void apply(
|
| 85 |
+
at::detail::Array<char*, ntensors> data,
|
| 86 |
+
at::detail::Array<uint32_t, num_outputs> offsets,
|
| 87 |
+
thrust::tuple<Args...> ret) {
|
| 88 |
+
using T = typename thrust::tuple_element<current, thrust::tuple<Args...>>::type;
|
| 89 |
+
T *to = reinterpret_cast<T *>(data[current]) + offsets[current];
|
| 90 |
+
*to = thrust::get<current>(ret);
|
| 91 |
+
}
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
} // namespace detail
|
| 95 |
+
|
| 96 |
+
struct LoadWithoutCast {
|
| 97 |
+
template<typename scalar_t>
|
| 98 |
+
__device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
|
| 99 |
+
return c10::load(reinterpret_cast<scalar_t *>(base_ptr) + offset);
|
| 100 |
+
}
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
template <int N>
|
| 104 |
+
struct LoadWithCast {
|
| 105 |
+
using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
|
| 106 |
+
using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
|
| 107 |
+
|
| 108 |
+
array_t dtypes;
|
| 109 |
+
size_array_t element_sizes;
|
| 110 |
+
|
| 111 |
+
LoadWithCast(const TensorIteratorBase& iter) {
|
| 112 |
+
CUDA_KERNEL_ASSERT(iter.ninputs() == N);
|
| 113 |
+
#pragma unroll
|
| 114 |
+
for (auto i = 0; i < N; ++i) {
|
| 115 |
+
this->dtypes[i] = iter.dtype(i + iter.noutputs());
|
| 116 |
+
element_sizes[i] = c10::elementSize(iter.dtype(i + iter.noutputs()));
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
template<typename scalar_t>
|
| 121 |
+
__device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
|
| 122 |
+
void *ptr = base_ptr + element_sizes[arg] * offset;
|
| 123 |
+
return c10::fetch_and_cast<scalar_t>(dtypes[arg], ptr);
|
| 124 |
+
}
|
| 125 |
+
};
|
| 126 |
+
|
| 127 |
+
struct StoreWithoutCast {
|
| 128 |
+
template<typename scalar_t>
|
| 129 |
+
__device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
|
| 130 |
+
*(reinterpret_cast<scalar_t *>(base_ptr) + offset) = value;
|
| 131 |
+
}
|
| 132 |
+
};
|
| 133 |
+
|
| 134 |
+
template <int N = 1>
|
| 135 |
+
struct StoreWithCast {
|
| 136 |
+
using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
|
| 137 |
+
using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
|
| 138 |
+
|
| 139 |
+
array_t dtypes;
|
| 140 |
+
size_array_t element_sizes;
|
| 141 |
+
|
| 142 |
+
StoreWithCast(const TensorIteratorBase& iter) {
|
| 143 |
+
CUDA_KERNEL_ASSERT(iter.noutputs() == N);
|
| 144 |
+
#pragma unroll
|
| 145 |
+
for (auto i = 0; i < N; ++i) {
|
| 146 |
+
this->dtypes[i] = iter.dtype(i);
|
| 147 |
+
element_sizes[i] = c10::elementSize(iter.dtype(i));
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
template<typename scalar_t>
|
| 152 |
+
__device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
|
| 153 |
+
void *ptr = base_ptr + element_sizes[arg] * offset;
|
| 154 |
+
c10::cast_and_store<scalar_t>(dtypes[arg], ptr, value);
|
| 155 |
+
}
|
| 156 |
+
};
|
| 157 |
+
|
| 158 |
+
// aligned vector generates vectorized load/store on CUDA
|
| 159 |
+
template<typename scalar_t, int vec_size>
|
| 160 |
+
struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
|
| 161 |
+
scalar_t val[vec_size];
|
| 162 |
+
};
|
| 163 |
+
|
| 164 |
+
template <int vec_size, typename scalar_t>
|
| 165 |
+
__device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
|
| 166 |
+
using vec_t = aligned_vector<scalar_t, vec_size>;
|
| 167 |
+
auto *from = reinterpret_cast<const vec_t *>(base_ptr);
|
| 168 |
+
return from[offset];
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
template <int vec_size>
|
| 172 |
+
__device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint32_t offset) {
|
| 173 |
+
// See NOTE [Loading boolean values]
|
| 174 |
+
auto tmp = load_vector<vec_size>(reinterpret_cast<const uint8_t*>(base_ptr), offset);
|
| 175 |
+
aligned_vector<bool, vec_size> ret;
|
| 176 |
+
for (int i = 0; i < vec_size; ++i) {
|
| 177 |
+
ret.val[i] = bool(tmp.val[i]);
|
| 178 |
+
}
|
| 179 |
+
return ret;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
namespace policies {
|
| 183 |
+
|
| 184 |
+
// Assumption:
|
| 185 |
+
// all tensors are contiguous, that is: stride == sizeof(type) for all tensors
|
| 186 |
+
template<typename data_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, int num_outputs = 1>
|
| 187 |
+
struct unroll {
|
| 188 |
+
|
| 189 |
+
data_t data;
|
| 190 |
+
int remaining;
|
| 191 |
+
inp_calc_t input_offset_calculator;
|
| 192 |
+
out_calc_t output_offset_calculator;
|
| 193 |
+
loader_t loader;
|
| 194 |
+
storer_t storer;
|
| 195 |
+
|
| 196 |
+
__device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s):
|
| 197 |
+
data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {}
|
| 198 |
+
|
| 199 |
+
__device__ inline bool check_inbounds(int thread_work_elem) {
|
| 200 |
+
return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining);
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
template<typename args_t>
|
| 204 |
+
__device__ inline void load(args_t *args, int idx) {
|
| 205 |
+
constexpr int arity = std::tuple_size<args_t>::value;
|
| 206 |
+
int thread_idx = threadIdx.x;
|
| 207 |
+
#pragma unroll
|
| 208 |
+
for (int i = 0; i < thread_work_size(); i++) {
|
| 209 |
+
if (thread_idx >= remaining) {
|
| 210 |
+
return;
|
| 211 |
+
}
|
| 212 |
+
int linear_idx = thread_idx + block_work_size() * idx;
|
| 213 |
+
auto offset = input_offset_calculator.get(linear_idx);
|
| 214 |
+
detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
|
| 215 |
+
thread_idx += num_threads();
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
template<typename scalar_t>
|
| 220 |
+
__device__ inline void store(scalar_t *from, int idx) {
|
| 221 |
+
int thread_idx = threadIdx.x;
|
| 222 |
+
#pragma unroll
|
| 223 |
+
for (int i = 0; i < thread_work_size(); i++) {
|
| 224 |
+
if (thread_idx >= remaining) {
|
| 225 |
+
return;
|
| 226 |
+
}
|
| 227 |
+
int linear_idx = thread_idx + block_work_size() * idx;
|
| 228 |
+
int offset = output_offset_calculator.get(linear_idx)[0];
|
| 229 |
+
storer.store(from[i], data[0], offset);
|
| 230 |
+
thread_idx += num_threads();
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
};
|
| 234 |
+
|
| 235 |
+
// Assumption:
|
| 236 |
+
// all tensors are contiguous, that is: stride == sizeof(type) for all tensors
|
| 237 |
+
// Note:
|
| 238 |
+
// Functions in vectorized policy does not do boundary check. It assumes the whole block
|
| 239 |
+
// has its job to do. So the reminders should be handled by the caller manually.
|
| 240 |
+
template <int vec_size, typename data_t> // vec_size: number of scalars, can be 1, 2, or 4.
|
| 241 |
+
struct vectorized {
|
| 242 |
+
|
| 243 |
+
static_assert(thread_work_size() % vec_size == 0, "The workload per thread must be a multiple of vec_size");
|
| 244 |
+
static constexpr int loop_size = thread_work_size() / vec_size;
|
| 245 |
+
|
| 246 |
+
data_t data;
|
| 247 |
+
|
| 248 |
+
__device__ vectorized(data_t data) : data(data) {}
|
| 249 |
+
|
| 250 |
+
__device__ inline constexpr bool check_inbounds(int thread_work_elem) {
|
| 251 |
+
return true;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
template<typename accessor_t, typename scalar_t>
|
| 255 |
+
__device__ inline void load_single_arg(accessor_t to, scalar_t *from) {
|
| 256 |
+
int thread_idx = threadIdx.x;
|
| 257 |
+
#pragma unroll
|
| 258 |
+
for (int i = 0; i < loop_size; i++) {
|
| 259 |
+
int index = thread_idx + i * num_threads();
|
| 260 |
+
auto v = load_vector<vec_size>(from, index);
|
| 261 |
+
#pragma unroll
|
| 262 |
+
for (int j = 0; j < vec_size; j++) {
|
| 263 |
+
to(vec_size * i + j) = v.val[j];
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
template<typename args_t>
|
| 269 |
+
__device__ inline void load(args_t *args, int idx) {
|
| 270 |
+
constexpr int arity = std::tuple_size<args_t>::value;
|
| 271 |
+
detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx);
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
template<typename scalar_t>
|
| 275 |
+
__device__ inline void store(scalar_t *from, int idx) {
|
| 276 |
+
using vec_t = aligned_vector<scalar_t, vec_size>;
|
| 277 |
+
scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + block_work_size() * idx;
|
| 278 |
+
vec_t *to_ = reinterpret_cast<vec_t *>(to);
|
| 279 |
+
int thread_idx = threadIdx.x;
|
| 280 |
+
#pragma unroll
|
| 281 |
+
for (int i = 0; i < loop_size; i++) {
|
| 282 |
+
int index = thread_idx + i * num_threads();
|
| 283 |
+
vec_t v;
|
| 284 |
+
for (int j = 0; j < vec_size; j++) {
|
| 285 |
+
v.val[j] = from[vec_size * i + j];
|
| 286 |
+
}
|
| 287 |
+
to_[index] = v;
|
| 288 |
+
}
|
| 289 |
+
}
|
| 290 |
+
};
|
| 291 |
+
|
| 292 |
+
template <typename data_t, typename inp_calc_t, typename out_calc_t, int num_outputs>
|
| 293 |
+
struct multi_outputs_unroll {
|
| 294 |
+
//multi_outputs_unroll struct members and check_inbounds and load methods are copypasted from unroll struct
|
| 295 |
+
//we don't use inheritance because of compiler bug in cuda 10.2+
|
| 296 |
+
data_t data;
|
| 297 |
+
int remaining;
|
| 298 |
+
inp_calc_t input_offset_calculator;
|
| 299 |
+
out_calc_t output_offset_calculator;
|
| 300 |
+
LoadWithoutCast loader;
|
| 301 |
+
StoreWithoutCast storer;
|
| 302 |
+
|
| 303 |
+
__device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc):
|
| 304 |
+
data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {}
|
| 305 |
+
|
| 306 |
+
__device__ inline bool check_inbounds(int thread_work_elem) {
|
| 307 |
+
return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining);
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
template<typename args_t>
|
| 311 |
+
__device__ inline void load(args_t *args, int idx) {
|
| 312 |
+
constexpr int arity = std::tuple_size<args_t>::value;
|
| 313 |
+
int thread_idx = threadIdx.x;
|
| 314 |
+
#pragma unroll
|
| 315 |
+
for (int i = 0; i < thread_work_size(); i++) {
|
| 316 |
+
if (thread_idx >= remaining) {
|
| 317 |
+
return;
|
| 318 |
+
}
|
| 319 |
+
int linear_idx = thread_idx + block_work_size() * idx;
|
| 320 |
+
auto offset = input_offset_calculator.get(linear_idx);
|
| 321 |
+
detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
|
| 322 |
+
thread_idx += num_threads();
|
| 323 |
+
}
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
template <typename return_t>
|
| 328 |
+
__device__ inline void store(return_t *from, int idx) {
|
| 329 |
+
int thread_idx = threadIdx.x;
|
| 330 |
+
#pragma unroll
|
| 331 |
+
for (int i = 0; i < thread_work_size(); i++) {
|
| 332 |
+
if (thread_idx >= this->remaining) {
|
| 333 |
+
return;
|
| 334 |
+
}
|
| 335 |
+
int linear_idx = thread_idx + block_work_size() * idx;
|
| 336 |
+
auto offsets = this->output_offset_calculator.get(linear_idx);
|
| 337 |
+
memory::detail::static_unroll<detail::multi_outputs_store_helper, num_outputs>::with_args(this->data, offsets, from[i]);
|
| 338 |
+
thread_idx += num_threads();
|
| 339 |
+
}
|
| 340 |
+
}
|
| 341 |
+
};
|
| 342 |
+
|
| 343 |
+
} // namespace policies
|
| 344 |
+
|
| 345 |
+
// This is only used in host, but we will wrap this into some templates
|
| 346 |
+
// which is C10_HOST_DEVICE, so we have to make this C10_HOST_DEVICE
|
| 347 |
+
// in order to compile
|
| 348 |
+
template<typename scalar_t>
|
| 349 |
+
inline C10_HOST_DEVICE int can_vectorize_up_to(char *pointer) {
|
| 350 |
+
uint64_t address = reinterpret_cast<uint64_t>(pointer);
|
| 351 |
+
constexpr int vec2_alignment = std::alignment_of<aligned_vector<scalar_t, 2>>::value;
|
| 352 |
+
constexpr int vec4_alignment = std::alignment_of<aligned_vector<scalar_t, 4>>::value;
|
| 353 |
+
if (address % vec4_alignment == 0) {
|
| 354 |
+
return 4;
|
| 355 |
+
} else if (address % vec2_alignment == 0) {
|
| 356 |
+
return 2;
|
| 357 |
+
}
|
| 358 |
+
return 1;
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
template<int i>
|
| 362 |
+
struct can_vectorize_up_to_helper {
|
| 363 |
+
template <typename array_t, typename traits>
|
| 364 |
+
static C10_HOST_DEVICE void apply(int &result, array_t pointers, traits _) {
|
| 365 |
+
using arg_t = typename traits::template arg<i>::type;
|
| 366 |
+
// `pointers` hold the data_ptr for tensors [output, input0, input1, ...], so we
|
| 367 |
+
// need a +1 offset to get the input
|
| 368 |
+
result = std::min<int>(result, can_vectorize_up_to<arg_t>(pointers[i + 1]));
|
| 369 |
+
}
|
| 370 |
+
};
|
| 371 |
+
|
| 372 |
+
template<typename func_t, typename array_t>
|
| 373 |
+
inline int can_vectorize_up_to(array_t pointers) {
|
| 374 |
+
using traits = function_traits<func_t>;
|
| 375 |
+
using return_t = typename traits::result_type;
|
| 376 |
+
constexpr int arity = traits::arity;
|
| 377 |
+
int result = can_vectorize_up_to<return_t>(pointers[0]);
|
| 378 |
+
// We need to get the type for each argument of `func_t`, this can only
|
| 379 |
+
// be done at compile time.
|
| 380 |
+
detail::static_unroll<can_vectorize_up_to_helper, arity>::with_args(result, pointers, traits());
|
| 381 |
+
return result;
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
}}} // namespace at::native::memory
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 4 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 5 |
+
#include <ATen/native/cuda/Loops.cuh>
|
| 6 |
+
#include <ATen/native/cuda/MemoryAccess.cuh>
|
| 7 |
+
#include <vector>
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
namespace {
|
| 12 |
+
|
| 13 |
+
static constexpr int64_t kILP = 4;
|
| 14 |
+
static constexpr int64_t kChunkSize = 65536;
|
| 15 |
+
static constexpr int64_t kBlockSize = 512;
|
| 16 |
+
|
| 17 |
+
// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy`
|
| 18 |
+
// TensorListMetadata has to be < 4KB - the limit for kernel launch argument
|
| 19 |
+
static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
|
| 20 |
+
static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
|
| 21 |
+
static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
|
| 22 |
+
static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
|
| 23 |
+
72,
|
| 24 |
+
60};
|
| 25 |
+
|
| 26 |
+
template <typename T>
|
| 27 |
+
__device__ __forceinline__ bool is_aligned(T* p) {
|
| 28 |
+
return ((uint64_t)p) % (kILP * sizeof(T)) == 0;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
template <typename T>
|
| 32 |
+
__device__ __forceinline__ void load_store(
|
| 33 |
+
T* dst,
|
| 34 |
+
T* src,
|
| 35 |
+
int64_t dst_offset,
|
| 36 |
+
int64_t src_offset) {
|
| 37 |
+
using LT = at::native::memory::aligned_vector<T, kILP>;
|
| 38 |
+
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
template <int n>
|
| 42 |
+
struct TensorListMetadata {
|
| 43 |
+
const void* addresses[n][depth_to_max_tensors[n - 1]];
|
| 44 |
+
int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
|
| 45 |
+
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
| 46 |
+
int block_to_chunk[depth_to_max_blocks[n - 1]];
|
| 47 |
+
int start_tensor_this_launch;
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
template <typename scalar_vals_t, int n>
|
| 51 |
+
struct TensorListScalarListMetadata {
|
| 52 |
+
const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]];
|
| 53 |
+
int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]];
|
| 54 |
+
scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]];
|
| 55 |
+
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
| 56 |
+
int block_to_chunk[depth_to_max_blocks[n - 1]];
|
| 57 |
+
};
|
| 58 |
+
|
| 59 |
+
// note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of
|
| 60 |
+
// 4kb with `c10::complex<double>`
|
| 61 |
+
template <>
|
| 62 |
+
struct TensorListScalarListMetadata<c10::complex<double>, 1> {
|
| 63 |
+
const void* addresses[1]
|
| 64 |
+
[depth_to_max_tensors_scalarlist_of_complex_double[0]];
|
| 65 |
+
int64_t
|
| 66 |
+
numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]];
|
| 67 |
+
c10::complex<double>
|
| 68 |
+
scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]];
|
| 69 |
+
unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]];
|
| 70 |
+
int block_to_chunk[depth_to_max_blocks[1 - 1]];
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
template <>
|
| 74 |
+
struct TensorListScalarListMetadata<c10::complex<double>, 2> {
|
| 75 |
+
const void* addresses[2]
|
| 76 |
+
[depth_to_max_tensors_scalarlist_of_complex_double[1]];
|
| 77 |
+
int64_t
|
| 78 |
+
numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]];
|
| 79 |
+
c10::complex<double>
|
| 80 |
+
scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]];
|
| 81 |
+
unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]];
|
| 82 |
+
int block_to_chunk[depth_to_max_blocks[2 - 1]];
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
// NOTE(crcrpar): This is a conservative resolution to handle `state_steps`
|
| 86 |
+
// whose each element is `at::Tensor` of 1 element representing the number of
|
| 87 |
+
// `step`s called so far.
|
| 88 |
+
template <int n>
|
| 89 |
+
struct FusedOptimizerTensorListMetadata {
|
| 90 |
+
const void* addresses[n][depth_to_max_tensors[n - 1]];
|
| 91 |
+
int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
|
| 92 |
+
const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]];
|
| 93 |
+
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
| 94 |
+
int block_to_chunk[depth_to_max_blocks[n - 1]];
|
| 95 |
+
int start_tensor_this_launch;
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
template <typename T, typename U, typename... ArgTypes>
|
| 99 |
+
C10_LAUNCH_BOUNDS_1(kBlockSize)
|
| 100 |
+
__global__ void multi_tensor_apply_kernel(
|
| 101 |
+
T tensorListMeta,
|
| 102 |
+
U callable,
|
| 103 |
+
ArgTypes... args) {
|
| 104 |
+
// Hand the chunk information to the user-supplied functor to process however
|
| 105 |
+
// it likes.
|
| 106 |
+
callable(kChunkSize, tensorListMeta, args...);
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
} // namespace
|
| 110 |
+
|
| 111 |
+
// multi_tensor_apply enables horizontal fusion across lists of tensors.
|
| 112 |
+
// For example, whereas you once had a for-loop of a + b = c, where a, b,
|
| 113 |
+
// and c are individual tensors in lists as, bs, and cs, you can now with
|
| 114 |
+
// fewer kernel launches compute as + bs = cs.
|
| 115 |
+
//
|
| 116 |
+
// You can also imagine bs to be a scalar list vs a tensor list.
|
| 117 |
+
//
|
| 118 |
+
// The function below takes in tensor lists, scalars, and a callable and
|
| 119 |
+
// chunks up the computation to launch as few kernels as possible by iterating
|
| 120 |
+
// through every "chunk" in every tensor (thus the nested for loops). In the
|
| 121 |
+
// simplest case, everything gets bundled into just one kernel launch, but
|
| 122 |
+
// due to blocksize constraints, we may need to launch multiple kernels.
|
| 123 |
+
// Each kernel launch is defined by one tensorListMeta construct, which we
|
| 124 |
+
// use to track and reset the necessary metadata for each launch.
|
| 125 |
+
template <int depth, typename scalar_T, typename T, typename... ArgTypes>
|
| 126 |
+
void multi_tensor_apply(
|
| 127 |
+
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
| 128 |
+
at::ArrayRef<Scalar> scalars,
|
| 129 |
+
T callable,
|
| 130 |
+
ArgTypes... args) {
|
| 131 |
+
TORCH_CHECK(
|
| 132 |
+
tensor_lists.size() == depth,
|
| 133 |
+
"Number of tensor lists has to match the depth.");
|
| 134 |
+
const size_t n_tensors = tensor_lists[0].size();
|
| 135 |
+
using scalar_vals_t = typename T::opmath_t;
|
| 136 |
+
TensorListScalarListMetadata<scalar_vals_t, depth> tensorListMeta;
|
| 137 |
+
|
| 138 |
+
int loc_block_info = 0;
|
| 139 |
+
int loc_tensor_info = 0;
|
| 140 |
+
for (size_t t = 0; t < n_tensors; t++) {
|
| 141 |
+
// short-circuit to avoid adding empty tensors to tensorListMeta
|
| 142 |
+
if (tensor_lists[0][t].numel() == 0) {
|
| 143 |
+
continue;
|
| 144 |
+
}
|
| 145 |
+
tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t].to<scalar_T>();
|
| 146 |
+
tensorListMeta.numel_for_tensor[loc_tensor_info] =
|
| 147 |
+
tensor_lists[0][t].numel();
|
| 148 |
+
for (int d = 0; d < depth; d++) {
|
| 149 |
+
tensorListMeta.addresses[d][loc_tensor_info] =
|
| 150 |
+
tensor_lists[d][t].const_data_ptr();
|
| 151 |
+
}
|
| 152 |
+
loc_tensor_info++;
|
| 153 |
+
|
| 154 |
+
// now we enter [chunking territory].
|
| 155 |
+
// we will launch a kernel when EITHER the blocks get filled up OR
|
| 156 |
+
// the tensors get filled up. There will always be at least one block
|
| 157 |
+
// per tensor since the zero-sized ones will not enter the loop, so
|
| 158 |
+
// the nested forloop within represents iterating through the chunks
|
| 159 |
+
// of a single tensor.
|
| 160 |
+
const auto numel = tensor_lists[0][t].numel();
|
| 161 |
+
const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
|
| 162 |
+
for (auto chunk = 0; chunk < chunks; chunk++) {
|
| 163 |
+
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
| 164 |
+
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
|
| 165 |
+
loc_block_info++;
|
| 166 |
+
|
| 167 |
+
// a tensor is not considered full unless all its chunks have been
|
| 168 |
+
// processed
|
| 169 |
+
const bool tensors_full =
|
| 170 |
+
(loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] &&
|
| 171 |
+
chunk == chunks - 1);
|
| 172 |
+
const bool blocks_full =
|
| 173 |
+
(loc_block_info == depth_to_max_blocks[depth - 1]);
|
| 174 |
+
|
| 175 |
+
if (tensors_full || blocks_full) {
|
| 176 |
+
multi_tensor_apply_kernel<<<
|
| 177 |
+
loc_block_info,
|
| 178 |
+
kBlockSize,
|
| 179 |
+
0,
|
| 180 |
+
at::cuda::getCurrentCUDAStream()>>>(
|
| 181 |
+
tensorListMeta, callable, args...);
|
| 182 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 183 |
+
|
| 184 |
+
// Reset.
|
| 185 |
+
loc_block_info = 0;
|
| 186 |
+
// all chunks have already been handled in the kernel
|
| 187 |
+
if (chunk == chunks - 1) {
|
| 188 |
+
loc_tensor_info = 0;
|
| 189 |
+
} else { // blocks were full and tensor chunks remain
|
| 190 |
+
tensorListMeta.numel_for_tensor[0] =
|
| 191 |
+
tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
|
| 192 |
+
tensorListMeta.scalar_vals[0] =
|
| 193 |
+
tensorListMeta.scalar_vals[loc_tensor_info - 1];
|
| 194 |
+
for (int d = 0; d < depth; d++) {
|
| 195 |
+
tensorListMeta.addresses[d][0] =
|
| 196 |
+
tensorListMeta.addresses[d][loc_tensor_info - 1];
|
| 197 |
+
}
|
| 198 |
+
loc_tensor_info = 1;
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
// note: [finishing what we started]
|
| 205 |
+
// if there's remaining work to be done but the tensors/blocks aren't full
|
| 206 |
+
// yet we are at the end, submit the kernel to do the work!
|
| 207 |
+
if (loc_block_info != 0) {
|
| 208 |
+
multi_tensor_apply_kernel<<<
|
| 209 |
+
loc_block_info,
|
| 210 |
+
kBlockSize,
|
| 211 |
+
0,
|
| 212 |
+
at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
|
| 213 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 214 |
+
}
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
template <int depth, typename T, typename... ArgTypes>
|
| 218 |
+
void multi_tensor_apply(
|
| 219 |
+
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
| 220 |
+
T callable,
|
| 221 |
+
ArgTypes... args) {
|
| 222 |
+
TORCH_CHECK(
|
| 223 |
+
tensor_lists.size() == depth,
|
| 224 |
+
"Number of tensor lists has to match the depth.");
|
| 225 |
+
const size_t n_tensors = tensor_lists[0].size();
|
| 226 |
+
TensorListMetadata<depth> tensorListMeta;
|
| 227 |
+
tensorListMeta.start_tensor_this_launch = 0;
|
| 228 |
+
|
| 229 |
+
int loc_block_info = 0;
|
| 230 |
+
int loc_tensor_info = 0;
|
| 231 |
+
for (size_t t = 0; t < n_tensors; t++) {
|
| 232 |
+
// short-circuit to avoid adding empty tensors to tensorListMeta
|
| 233 |
+
if (tensor_lists[0][t].numel() == 0) {
|
| 234 |
+
continue;
|
| 235 |
+
}
|
| 236 |
+
tensorListMeta.numel_for_tensor[loc_tensor_info] =
|
| 237 |
+
tensor_lists[0][t].numel();
|
| 238 |
+
for (int d = 0; d < depth; d++) {
|
| 239 |
+
tensorListMeta.addresses[d][loc_tensor_info] =
|
| 240 |
+
tensor_lists[d][t].const_data_ptr();
|
| 241 |
+
}
|
| 242 |
+
loc_tensor_info++;
|
| 243 |
+
|
| 244 |
+
// see note: [chunking territory].
|
| 245 |
+
const auto numel = tensor_lists[0][t].numel();
|
| 246 |
+
const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
|
| 247 |
+
for (auto chunk = 0; chunk < chunks; chunk++) {
|
| 248 |
+
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
| 249 |
+
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
|
| 250 |
+
loc_block_info++;
|
| 251 |
+
|
| 252 |
+
const bool tensors_full =
|
| 253 |
+
(loc_tensor_info == depth_to_max_tensors[depth - 1] &&
|
| 254 |
+
chunk == chunks - 1);
|
| 255 |
+
const bool blocks_full =
|
| 256 |
+
(loc_block_info == depth_to_max_blocks[depth - 1]);
|
| 257 |
+
|
| 258 |
+
if (tensors_full || blocks_full) {
|
| 259 |
+
multi_tensor_apply_kernel<<<
|
| 260 |
+
loc_block_info,
|
| 261 |
+
kBlockSize,
|
| 262 |
+
0,
|
| 263 |
+
at::cuda::getCurrentCUDAStream()>>>(
|
| 264 |
+
tensorListMeta, callable, args...);
|
| 265 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 266 |
+
|
| 267 |
+
// Reset.
|
| 268 |
+
loc_block_info = 0;
|
| 269 |
+
if (chunk == chunks - 1) {
|
| 270 |
+
loc_tensor_info = 0;
|
| 271 |
+
tensorListMeta.start_tensor_this_launch = t + 1;
|
| 272 |
+
} else {
|
| 273 |
+
tensorListMeta.numel_for_tensor[0] =
|
| 274 |
+
tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
|
| 275 |
+
for (int d = 0; d < depth; d++) {
|
| 276 |
+
tensorListMeta.addresses[d][0] =
|
| 277 |
+
tensorListMeta.addresses[d][loc_tensor_info - 1];
|
| 278 |
+
}
|
| 279 |
+
loc_tensor_info = 1;
|
| 280 |
+
tensorListMeta.start_tensor_this_launch = t;
|
| 281 |
+
}
|
| 282 |
+
}
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
// see note: [finishing what we started]
|
| 287 |
+
if (loc_block_info != 0) {
|
| 288 |
+
multi_tensor_apply_kernel<<<
|
| 289 |
+
loc_block_info,
|
| 290 |
+
kBlockSize,
|
| 291 |
+
0,
|
| 292 |
+
at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
|
| 293 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 294 |
+
}
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
template <int depth, typename T, typename... ArgTypes>
|
| 298 |
+
void multi_tensor_apply_for_fused_optimizer(
|
| 299 |
+
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
| 300 |
+
at::TensorList state_steps,
|
| 301 |
+
T callable,
|
| 302 |
+
ArgTypes... args) {
|
| 303 |
+
TORCH_CHECK(
|
| 304 |
+
tensor_lists.size() == depth,
|
| 305 |
+
"Number of tensor lists has to match the depth");
|
| 306 |
+
const auto num_tensors = tensor_lists[0].size();
|
| 307 |
+
FusedOptimizerTensorListMetadata<depth> tensorListMeta;
|
| 308 |
+
|
| 309 |
+
int loc_block_info = 0;
|
| 310 |
+
int loc_tensor_info = 0;
|
| 311 |
+
for (const auto& tensor_index : c10::irange(num_tensors)) {
|
| 312 |
+
// short-circuit to avoid adding empty tensors to tensorListMeta
|
| 313 |
+
if (tensor_lists[0][tensor_index].numel() == 0) {
|
| 314 |
+
continue;
|
| 315 |
+
}
|
| 316 |
+
tensorListMeta.state_steps_addresses[loc_tensor_info] =
|
| 317 |
+
state_steps[tensor_index].const_data_ptr();
|
| 318 |
+
tensorListMeta.numel_for_tensor[loc_tensor_info] =
|
| 319 |
+
tensor_lists[0][tensor_index].numel();
|
| 320 |
+
for (const auto& d : c10::irange(depth)) {
|
| 321 |
+
tensorListMeta.addresses[d][loc_tensor_info] =
|
| 322 |
+
tensor_lists[d][tensor_index].const_data_ptr();
|
| 323 |
+
}
|
| 324 |
+
loc_tensor_info++;
|
| 325 |
+
|
| 326 |
+
// see above note: [chunking territory]
|
| 327 |
+
const auto numel = tensor_lists[0][tensor_index].numel();
|
| 328 |
+
const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
|
| 329 |
+
TORCH_CHECK(chunks > -1);
|
| 330 |
+
for (const auto& chunk : c10::irange(chunks)) {
|
| 331 |
+
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
| 332 |
+
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
|
| 333 |
+
loc_block_info++;
|
| 334 |
+
|
| 335 |
+
const auto tensor_full =
|
| 336 |
+
(loc_tensor_info == depth_to_max_tensors[depth - 1] &&
|
| 337 |
+
chunk == chunks - 1);
|
| 338 |
+
const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1];
|
| 339 |
+
|
| 340 |
+
if (tensor_full || blocks_full) {
|
| 341 |
+
multi_tensor_apply_kernel<<<
|
| 342 |
+
loc_block_info,
|
| 343 |
+
kBlockSize,
|
| 344 |
+
0,
|
| 345 |
+
at::cuda::getCurrentCUDAStream()>>>(
|
| 346 |
+
tensorListMeta, callable, args...);
|
| 347 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 348 |
+
|
| 349 |
+
// Reset.
|
| 350 |
+
loc_block_info = 0;
|
| 351 |
+
if (chunk == chunks - 1) {
|
| 352 |
+
loc_tensor_info = 0;
|
| 353 |
+
} else {
|
| 354 |
+
tensorListMeta.numel_for_tensor[0] =
|
| 355 |
+
tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
|
| 356 |
+
tensorListMeta.state_steps_addresses[0] =
|
| 357 |
+
tensorListMeta.state_steps_addresses[loc_tensor_info - 1];
|
| 358 |
+
for (const auto& d : c10::irange(depth)) {
|
| 359 |
+
tensorListMeta.addresses[d][0] =
|
| 360 |
+
tensorListMeta.addresses[d][loc_tensor_info - 1];
|
| 361 |
+
}
|
| 362 |
+
loc_tensor_info = 1;
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
// see above note: [finishing what we've started]
|
| 369 |
+
if (loc_block_info != 0) {
|
| 370 |
+
multi_tensor_apply_kernel<<<
|
| 371 |
+
loc_block_info,
|
| 372 |
+
kBlockSize,
|
| 373 |
+
0,
|
| 374 |
+
at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
|
| 375 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Resize.h
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/EmptyTensor.h>
|
| 4 |
+
#include <ATen/native/ResizeCommon.h>
|
| 5 |
+
|
| 6 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 7 |
+
|
| 8 |
+
namespace at { namespace native {
|
| 9 |
+
|
| 10 |
+
TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes);
|
| 11 |
+
|
| 12 |
+
static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_bytes) {
|
| 13 |
+
// It does not make sense to try to resize a storage
|
| 14 |
+
// to hold 0 elements, and this can break
|
| 15 |
+
// if storage_offset is positive but
|
| 16 |
+
// new_size is 0, so just bail in that case
|
| 17 |
+
// (same comment is in Resize.h)
|
| 18 |
+
if (self->numel() == 0) {
|
| 19 |
+
return;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
const Storage &storage = self->unsafe_storage();
|
| 23 |
+
TORCH_CHECK(storage, "Tensor: invalid null storage");
|
| 24 |
+
if (new_size_bytes > storage.nbytes()) {
|
| 25 |
+
resize_bytes_cuda(storage.unsafeGetStorageImpl(), new_size_bytes);
|
| 26 |
+
}
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
inline TensorImpl* resize_impl_cuda_(
|
| 30 |
+
TensorImpl* self,
|
| 31 |
+
IntArrayRef size,
|
| 32 |
+
at::OptionalIntArrayRef stride,
|
| 33 |
+
bool device_guard = true) {
|
| 34 |
+
if (self->sizes() == size && (!stride || self->strides() == stride)) {
|
| 35 |
+
return self;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
// NB: We don't need to hold the device guard when calling from TH
|
| 39 |
+
cuda::OptionalCUDAGuard guard;
|
| 40 |
+
if (device_guard) {
|
| 41 |
+
guard.set_index(self->storage().device().index());
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
const auto itemsize = self->dtype().itemsize();
|
| 45 |
+
const auto storage_offset = self->storage_offset();
|
| 46 |
+
size_t storage_size = 1;
|
| 47 |
+
if (stride) {
|
| 48 |
+
self->set_sizes_and_strides(size, *stride);
|
| 49 |
+
storage_size = at::detail::computeStorageNbytes(
|
| 50 |
+
size, *stride, itemsize, storage_offset);
|
| 51 |
+
} else {
|
| 52 |
+
self->set_sizes_contiguous(size);
|
| 53 |
+
storage_size = at::detail::computeStorageNbytesContiguous(
|
| 54 |
+
size, itemsize, storage_offset);
|
| 55 |
+
}
|
| 56 |
+
maybe_resize_storage_cuda(self, storage_size);
|
| 57 |
+
|
| 58 |
+
return self;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
}}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Sort.h
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <cstdint>
|
| 3 |
+
#include <ATen/core/TensorBase.h>
|
| 4 |
+
#include <ATen/native/cuda/SortStable.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
namespace native {
|
| 8 |
+
|
| 9 |
+
inline bool should_use_small_sort(const TensorBase &self, int64_t dim) {
|
| 10 |
+
return self.size(dim) <= 4096;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
void sortKeyValueInplace(
|
| 14 |
+
const TensorBase &key, const TensorBase &value, int dim,
|
| 15 |
+
bool descending, bool stable=false);
|
| 16 |
+
|
| 17 |
+
}} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adam_amsgrad_impl.cuh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
namespace native {
|
| 6 |
+
|
| 7 |
+
void _fused_adam_amsgrad_cuda_impl_(
|
| 8 |
+
at::TensorList params,
|
| 9 |
+
at::TensorList grads,
|
| 10 |
+
at::TensorList exp_avgs,
|
| 11 |
+
at::TensorList exp_avg_sqs,
|
| 12 |
+
at::TensorList max_exp_avg_sqs,
|
| 13 |
+
at::TensorList state_steps,
|
| 14 |
+
const double lr,
|
| 15 |
+
const double beta1,
|
| 16 |
+
const double beta2,
|
| 17 |
+
const double weight_decay,
|
| 18 |
+
const double eps,
|
| 19 |
+
const bool maximize,
|
| 20 |
+
const c10::optional<at::Tensor>& grad_scale,
|
| 21 |
+
const c10::optional<at::Tensor>& found_inf);
|
| 22 |
+
|
| 23 |
+
void _fused_adam_amsgrad_cuda_impl_(
|
| 24 |
+
at::TensorList params,
|
| 25 |
+
at::TensorList grads,
|
| 26 |
+
at::TensorList exp_avgs,
|
| 27 |
+
at::TensorList exp_avg_sqs,
|
| 28 |
+
at::TensorList max_exp_avg_sqs,
|
| 29 |
+
at::TensorList state_steps,
|
| 30 |
+
const at::Tensor& lr,
|
| 31 |
+
const double beta1,
|
| 32 |
+
const double beta2,
|
| 33 |
+
const double weight_decay,
|
| 34 |
+
const double eps,
|
| 35 |
+
const bool maximize,
|
| 36 |
+
const c10::optional<at::Tensor>& grad_scale,
|
| 37 |
+
const c10::optional<at::Tensor>& found_inf);
|
| 38 |
+
|
| 39 |
+
} // namespace native
|
| 40 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
namespace native {
|
| 6 |
+
|
| 7 |
+
void _fused_adam_cuda_impl_(
|
| 8 |
+
at::TensorList params,
|
| 9 |
+
at::TensorList grads,
|
| 10 |
+
at::TensorList exp_avgs,
|
| 11 |
+
at::TensorList exp_avg_sqs,
|
| 12 |
+
at::TensorList state_steps,
|
| 13 |
+
const double lr,
|
| 14 |
+
const double beta1,
|
| 15 |
+
const double beta2,
|
| 16 |
+
const double weight_decay,
|
| 17 |
+
const double eps,
|
| 18 |
+
const bool maximize,
|
| 19 |
+
const c10::optional<at::Tensor>& grad_scale,
|
| 20 |
+
const c10::optional<at::Tensor>& found_inf);
|
| 21 |
+
|
| 22 |
+
void _fused_adam_cuda_impl_(
|
| 23 |
+
at::TensorList params,
|
| 24 |
+
at::TensorList grads,
|
| 25 |
+
at::TensorList exp_avgs,
|
| 26 |
+
at::TensorList exp_avg_sqs,
|
| 27 |
+
at::TensorList state_steps,
|
| 28 |
+
const at::Tensor& lr,
|
| 29 |
+
const double beta1,
|
| 30 |
+
const double beta2,
|
| 31 |
+
const double weight_decay,
|
| 32 |
+
const double eps,
|
| 33 |
+
const bool maximize,
|
| 34 |
+
const c10::optional<at::Tensor>& grad_scale,
|
| 35 |
+
const c10::optional<at::Tensor>& found_inf);
|
| 36 |
+
|
| 37 |
+
} // namespace native
|
| 38 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
namespace native {
|
| 6 |
+
|
| 7 |
+
void _fused_adamw_amsgrad_cuda_impl_(
|
| 8 |
+
at::TensorList params,
|
| 9 |
+
at::TensorList grads,
|
| 10 |
+
at::TensorList exp_avgs,
|
| 11 |
+
at::TensorList exp_avg_sqs,
|
| 12 |
+
at::TensorList max_exp_avg_sqs,
|
| 13 |
+
at::TensorList state_steps,
|
| 14 |
+
const double lr,
|
| 15 |
+
const double beta1,
|
| 16 |
+
const double beta2,
|
| 17 |
+
const double weight_decay,
|
| 18 |
+
const double eps,
|
| 19 |
+
const bool maximize,
|
| 20 |
+
const c10::optional<at::Tensor>& grad_scale,
|
| 21 |
+
const c10::optional<at::Tensor>& found_inf);
|
| 22 |
+
|
| 23 |
+
void _fused_adamw_amsgrad_cuda_impl_(
|
| 24 |
+
at::TensorList params,
|
| 25 |
+
at::TensorList grads,
|
| 26 |
+
at::TensorList exp_avgs,
|
| 27 |
+
at::TensorList exp_avg_sqs,
|
| 28 |
+
at::TensorList max_exp_avg_sqs,
|
| 29 |
+
at::TensorList state_steps,
|
| 30 |
+
const at::Tensor& lr,
|
| 31 |
+
const double beta1,
|
| 32 |
+
const double beta2,
|
| 33 |
+
const double weight_decay,
|
| 34 |
+
const double eps,
|
| 35 |
+
const bool maximize,
|
| 36 |
+
const c10::optional<at::Tensor>& grad_scale,
|
| 37 |
+
const c10::optional<at::Tensor>& found_inf);
|
| 38 |
+
|
| 39 |
+
} // namespace native
|
| 40 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adamw_impl.cuh
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
namespace native {
|
| 6 |
+
|
| 7 |
+
void _fused_adamw_cuda_impl_(
|
| 8 |
+
at::TensorList params,
|
| 9 |
+
at::TensorList grads,
|
| 10 |
+
at::TensorList exp_avgs,
|
| 11 |
+
at::TensorList exp_avg_sqs,
|
| 12 |
+
at::TensorList state_steps,
|
| 13 |
+
const double lr,
|
| 14 |
+
const double beta1,
|
| 15 |
+
const double beta2,
|
| 16 |
+
const double weight_decay,
|
| 17 |
+
const double eps,
|
| 18 |
+
const bool maximize,
|
| 19 |
+
const c10::optional<at::Tensor>& grad_scale,
|
| 20 |
+
const c10::optional<at::Tensor>& found_inf);
|
| 21 |
+
|
| 22 |
+
void _fused_adamw_cuda_impl_(
|
| 23 |
+
at::TensorList params,
|
| 24 |
+
at::TensorList grads,
|
| 25 |
+
at::TensorList exp_avgs,
|
| 26 |
+
at::TensorList exp_avg_sqs,
|
| 27 |
+
at::TensorList state_steps,
|
| 28 |
+
const at::Tensor& lr,
|
| 29 |
+
const double beta1,
|
| 30 |
+
const double beta2,
|
| 31 |
+
const double weight_decay,
|
| 32 |
+
const double eps,
|
| 33 |
+
const bool maximize,
|
| 34 |
+
const c10::optional<at::Tensor>& grad_scale,
|
| 35 |
+
const c10::optional<at::Tensor>& found_inf);
|
| 36 |
+
|
| 37 |
+
} // namespace native
|
| 38 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh
ADDED
|
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
namespace at {
|
| 2 |
+
namespace cuda {
|
| 3 |
+
//windows doesn't like large string literals, so split in two
|
| 4 |
+
const std::string reduction_template_0 = R"ESCAPE(
|
| 5 |
+
#define C10_HOST_DEVICE __host__ __device__
|
| 6 |
+
#define C10_DEVICE __device__
|
| 7 |
+
#if defined(__clang__) && defined(__HIP__)
|
| 8 |
+
#ifndef __forceinline__
|
| 9 |
+
#define __forceinline__ inline __attribute__((always_inline))
|
| 10 |
+
#endif
|
| 11 |
+
// until ROCm support for kernel asserts is restored
|
| 12 |
+
#define assert(expr) (static_cast<void>(0))
|
| 13 |
+
#endif
|
| 14 |
+
|
| 15 |
+
template <typename T>
|
| 16 |
+
__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
|
| 17 |
+
{
|
| 18 |
+
#if defined(__clang__) && defined(__HIP__)
|
| 19 |
+
return __shfl_down(value, delta, width);
|
| 20 |
+
#else
|
| 21 |
+
return __shfl_down_sync(mask, value, delta, width);
|
| 22 |
+
#endif
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
#if ${complex}
|
| 27 |
+
template <typename T>
|
| 28 |
+
__device__ __forceinline__ std::complex<T> WARP_SHFL_DOWN(std::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
|
| 29 |
+
{
|
| 30 |
+
return std::complex<T>(
|
| 31 |
+
#if defined(__clang__) && defined(__HIP__)
|
| 32 |
+
__shfl_down(value.real(), delta, width),
|
| 33 |
+
__shfl_down(value.imag(), delta, width));
|
| 34 |
+
#else
|
| 35 |
+
__shfl_down_sync(mask, value.real(), delta, width),
|
| 36 |
+
__shfl_down_sync(mask, value.imag(), delta, width));
|
| 37 |
+
#endif
|
| 38 |
+
}
|
| 39 |
+
#endif
|
| 40 |
+
|
| 41 |
+
// aligned vector generates vectorized load/store on CUDA
|
| 42 |
+
template<typename scalar_t, int vec_size>
|
| 43 |
+
struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
|
| 44 |
+
scalar_t val[vec_size];
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) {
|
| 49 |
+
// get GCD of num and denom using Euclid's algorithm.
|
| 50 |
+
// Can replace this with std::gcd if we ever support c++17.
|
| 51 |
+
size_t a = denominator;
|
| 52 |
+
size_t b = numerator;
|
| 53 |
+
while (b != 0) {
|
| 54 |
+
a %= b;
|
| 55 |
+
// swap(a,b)
|
| 56 |
+
size_t tmp = a;
|
| 57 |
+
a = b;
|
| 58 |
+
b = tmp;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
// a is now the GCD
|
| 62 |
+
numerator /= a;
|
| 63 |
+
denominator /= a;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
struct ReduceConfig {
|
| 70 |
+
//has to match host-side ReduceConfig in the eager code
|
| 71 |
+
static constexpr int BLOCK_X = 0;
|
| 72 |
+
static constexpr int BLOCK_Y = 1;
|
| 73 |
+
static constexpr int CTA = 2;
|
| 74 |
+
|
| 75 |
+
static constexpr int input_vec_size = 4;
|
| 76 |
+
int element_size_bytes;
|
| 77 |
+
int num_inputs;
|
| 78 |
+
int num_outputs;
|
| 79 |
+
int step_input = 1;
|
| 80 |
+
int step_output = 1;
|
| 81 |
+
int ctas_per_output = 1;
|
| 82 |
+
int input_mult[3] = {0, 0, 0};
|
| 83 |
+
int output_mult[2] = {0, 0};
|
| 84 |
+
|
| 85 |
+
int block_width;
|
| 86 |
+
int block_height;
|
| 87 |
+
int num_threads;
|
| 88 |
+
|
| 89 |
+
bool vectorize_input = false;
|
| 90 |
+
int output_vec_size = 1;
|
| 91 |
+
|
| 92 |
+
C10_HOST_DEVICE bool should_block_x_reduce() const {
|
| 93 |
+
return input_mult[BLOCK_X] != 0;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
C10_HOST_DEVICE bool should_block_y_reduce() const {
|
| 97 |
+
return input_mult[BLOCK_Y] != 0;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
C10_HOST_DEVICE bool should_global_reduce() const {
|
| 101 |
+
return input_mult[CTA] != 0;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
C10_DEVICE bool should_store(int output_idx) const {
|
| 105 |
+
return output_idx < num_outputs &&
|
| 106 |
+
(!should_block_x_reduce() || threadIdx.x == 0) &&
|
| 107 |
+
(!should_block_y_reduce() || threadIdx.y == 0);
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
C10_DEVICE bool should_reduce_tail() const {
|
| 111 |
+
return (!should_block_y_reduce() || threadIdx.y == 0) &&
|
| 112 |
+
(!should_global_reduce() || blockIdx.y == 0);
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
C10_HOST_DEVICE int input_idx() const {
|
| 116 |
+
int lane = threadIdx.x;
|
| 117 |
+
int warp = threadIdx.y;
|
| 118 |
+
int cta2 = blockIdx.y;
|
| 119 |
+
return (lane * input_mult[BLOCK_X] +
|
| 120 |
+
warp * input_mult[BLOCK_Y] +
|
| 121 |
+
cta2 * input_mult[CTA]);
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
template <int output_vec_size>
|
| 125 |
+
C10_HOST_DEVICE int output_idx() const {
|
| 126 |
+
int lane = threadIdx.x;
|
| 127 |
+
int warp = threadIdx.y;
|
| 128 |
+
int cta1 = blockIdx.x;
|
| 129 |
+
return (lane * output_mult[BLOCK_X] +
|
| 130 |
+
warp * output_mult[BLOCK_Y] +
|
| 131 |
+
cta1 * step_output) * output_vec_size;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
C10_DEVICE int shared_memory_offset(int offset) const {
|
| 135 |
+
return threadIdx.x + (threadIdx.y + offset) * blockDim.x;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
C10_DEVICE int staging_memory_offset(int cta2) const {
|
| 139 |
+
int offset = cta2 + blockIdx.x * gridDim.y;
|
| 140 |
+
if (!should_block_x_reduce()) {
|
| 141 |
+
offset = threadIdx.x + offset * blockDim.x;
|
| 142 |
+
}
|
| 143 |
+
return offset;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
};
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
//TODO this will need to be different for more generic reduction functions
|
| 151 |
+
namespace reducer {
|
| 152 |
+
|
| 153 |
+
using scalar_t = ${scalar_type};
|
| 154 |
+
using arg_t = ${reduction_accum_type};
|
| 155 |
+
using out_scalar_t = ${result_type};
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
inline __device__ ${functor}
|
| 159 |
+
|
| 160 |
+
inline __device__ out_scalar_t project(arg_t arg) {
|
| 161 |
+
return (out_scalar_t) arg;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) {
|
| 165 |
+
return WARP_SHFL_DOWN(arg, offset);
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
inline __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) {
|
| 169 |
+
return acc;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
// wrap a normal reduction that ignores the index
|
| 173 |
+
inline __device__ arg_t reduce(arg_t acc, arg_t val, int64_t idx) {
|
| 174 |
+
return combine(acc, val);
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
struct ReduceJitOp {
|
| 180 |
+
using scalar_t = ${scalar_type};
|
| 181 |
+
using arg_t = ${reduction_accum_type};
|
| 182 |
+
using out_scalar_t = ${result_type};
|
| 183 |
+
|
| 184 |
+
using InputCalculator = OffsetCalculator<1>;
|
| 185 |
+
using OutputCalculator = OffsetCalculator<2>;
|
| 186 |
+
|
| 187 |
+
// static constexpr bool can_accumulate_in_output =
|
| 188 |
+
// std::is_convertible<arg_t, out_scalar_t>::value
|
| 189 |
+
// && std::is_convertible<out_scalar_t, arg_t>::value;
|
| 190 |
+
|
| 191 |
+
static constexpr int input_vec_size = ReduceConfig::input_vec_size;
|
| 192 |
+
|
| 193 |
+
arg_t ident;
|
| 194 |
+
ReduceConfig config;
|
| 195 |
+
InputCalculator input_calc;
|
| 196 |
+
OutputCalculator output_calc;
|
| 197 |
+
const void* src;
|
| 198 |
+
const char* dst[2]; //it accepts at most two destinations
|
| 199 |
+
// acc_buf used for accumulation among sub Tensor Iterator when accumulation on
|
| 200 |
+
// output is not permissible
|
| 201 |
+
void* acc_buf;
|
| 202 |
+
// cta_buf used for accumulation between blocks during global reduction
|
| 203 |
+
void* cta_buf;
|
| 204 |
+
int* semaphores;
|
| 205 |
+
int64_t base_idx;
|
| 206 |
+
bool accumulate;
|
| 207 |
+
bool final_output;
|
| 208 |
+
int noutputs;
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
C10_DEVICE void run() const {
|
| 212 |
+
extern __shared__ char shared_memory[];
|
| 213 |
+
uint32_t output_idx = config.output_idx<${output_vec_size}>();
|
| 214 |
+
uint32_t input_idx = config.input_idx();
|
| 215 |
+
auto base_offsets1 = output_calc.get(output_idx)[1];
|
| 216 |
+
|
| 217 |
+
using arg_vec_t = Array<arg_t, ${output_vec_size}>;
|
| 218 |
+
arg_vec_t value;
|
| 219 |
+
|
| 220 |
+
if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
|
| 221 |
+
const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1);
|
| 222 |
+
|
| 223 |
+
value = thread_reduce<${output_vec_size}>(input_slice);
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
if (config.should_block_y_reduce()) {
|
| 227 |
+
value = block_y_reduce<${output_vec_size}>(value, shared_memory);
|
| 228 |
+
}
|
| 229 |
+
if (config.should_block_x_reduce()) {
|
| 230 |
+
value = block_x_reduce<${output_vec_size}>(value, shared_memory);
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
using out_ptr_vec_t = Array<out_scalar_t*, ${output_vec_size}>;
|
| 234 |
+
using offset_vec_t = Array<uint32_t, ${output_vec_size}>;
|
| 235 |
+
offset_vec_t base_offsets;
|
| 236 |
+
out_ptr_vec_t out;
|
| 237 |
+
|
| 238 |
+
#pragma unroll
|
| 239 |
+
for (int i = 0; i < ${output_vec_size}; i++) {
|
| 240 |
+
base_offsets[i] = output_calc.get(output_idx + i)[0];
|
| 241 |
+
out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
arg_vec_t* acc = nullptr;
|
| 245 |
+
if (acc_buf != nullptr) {
|
| 246 |
+
size_t numerator = sizeof(arg_t);
|
| 247 |
+
size_t denominator = sizeof(out_scalar_t);
|
| 248 |
+
reduce_fraction(numerator, denominator);
|
| 249 |
+
acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
if (config.should_global_reduce()) {
|
| 253 |
+
value = global_reduce<${output_vec_size}>(value, acc, shared_memory);
|
| 254 |
+
} else if (config.should_store(output_idx)) {
|
| 255 |
+
if (accumulate) {
|
| 256 |
+
#pragma unroll
|
| 257 |
+
for (int i = 0; i < ${output_vec_size}; i++) {
|
| 258 |
+
value[i] = reducer::translate_idx(value[i], base_idx);
|
| 259 |
+
}
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
if (acc == nullptr) {
|
| 263 |
+
if (accumulate) {
|
| 264 |
+
value = accumulate_in_output<${output_vec_size}>(out, value);
|
| 265 |
+
}
|
| 266 |
+
if (final_output) {
|
| 267 |
+
set_results_to_output<${output_vec_size}>(value, base_offsets);
|
| 268 |
+
} else {
|
| 269 |
+
#pragma unroll
|
| 270 |
+
for (int i = 0; i < ${output_vec_size}; i++) {
|
| 271 |
+
*(out[i]) = get_accumulated_output(out[i], value[i]);
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
} else {
|
| 275 |
+
if (accumulate) {
|
| 276 |
+
#pragma unroll
|
| 277 |
+
for (int i = 0; i < ${output_vec_size}; i++) {
|
| 278 |
+
value[i] = reducer::combine((*acc)[i], value[i]);
|
| 279 |
+
}
|
| 280 |
+
}
|
| 281 |
+
if (final_output) {
|
| 282 |
+
set_results_to_output<${output_vec_size}>(value, base_offsets);
|
| 283 |
+
} else {
|
| 284 |
+
*acc = value;
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
}
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
template <int output_vec_size>
|
| 291 |
+
C10_DEVICE Array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
|
| 292 |
+
if (config.vectorize_input) {
|
| 293 |
+
assert(output_vec_size == 1);
|
| 294 |
+
// reduce at the header of input_slice where memory is not aligned,
|
| 295 |
+
// so that thread_reduce will have an aligned memory to work on.
|
| 296 |
+
return {input_vectorized_thread_reduce_impl(data)};
|
| 297 |
+
} else {
|
| 298 |
+
uint32_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t);
|
| 299 |
+
bool is_contiguous = (input_calc.dims == 1 && element_stride == 1);
|
| 300 |
+
if (is_contiguous) {
|
| 301 |
+
return thread_reduce_impl<output_vec_size>(data, [](uint32_t idx) { return idx; });
|
| 302 |
+
} else if (input_calc.dims == 1) {
|
| 303 |
+
return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return idx * element_stride; });
|
| 304 |
+
} else {
|
| 305 |
+
return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
|
| 306 |
+
}
|
| 307 |
+
}
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const {
|
| 311 |
+
uint32_t end = config.num_inputs;
|
| 312 |
+
|
| 313 |
+
// Handle the head of input slice where data is not aligned
|
| 314 |
+
arg_t value = ident;
|
| 315 |
+
constexpr int align_bytes = alignof(aligned_vector<scalar_t, input_vec_size>);
|
| 316 |
+
constexpr int align_elements = align_bytes / sizeof(scalar_t);
|
| 317 |
+
int shift = ((int64_t)data) % align_bytes / sizeof(scalar_t);
|
| 318 |
+
if (shift > 0) {
|
| 319 |
+
data -= shift;
|
| 320 |
+
end += shift;
|
| 321 |
+
if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){
|
| 322 |
+
value = reducer::reduce(value, data[threadIdx.x], threadIdx.x - shift);
|
| 323 |
+
}
|
| 324 |
+
end -= align_elements;
|
| 325 |
+
data += align_elements;
|
| 326 |
+
shift = align_elements - shift;
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
// Do the vectorized reduction
|
| 330 |
+
using load_t = aligned_vector<scalar_t, input_vec_size>;
|
| 331 |
+
|
| 332 |
+
uint32_t idx = config.input_idx();
|
| 333 |
+
const uint32_t stride = config.step_input;
|
| 334 |
+
|
| 335 |
+
// Multiple accumulators to remove dependency between unrolled loops.
|
| 336 |
+
arg_t value_list[input_vec_size];
|
| 337 |
+
value_list[0] = value;
|
| 338 |
+
|
| 339 |
+
#pragma unroll
|
| 340 |
+
for (int i = 1; i < input_vec_size; i++) {
|
| 341 |
+
value_list[i] = ident;
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
scalar_t values[input_vec_size];
|
| 345 |
+
|
| 346 |
+
load_t *values_vector = reinterpret_cast<load_t*>(&values[0]);
|
| 347 |
+
|
| 348 |
+
while (idx * input_vec_size + input_vec_size - 1 < end) {
|
| 349 |
+
*values_vector = reinterpret_cast<const load_t*>(data)[idx];
|
| 350 |
+
#pragma unroll
|
| 351 |
+
for (uint32_t i = 0; i < input_vec_size; i++) {
|
| 352 |
+
value_list[i] = reducer::reduce(value_list[i], values[i], shift + idx * input_vec_size + i);
|
| 353 |
+
}
|
| 354 |
+
idx += stride;
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
// tail
|
| 358 |
+
uint32_t tail_start = end - end % input_vec_size;
|
| 359 |
+
if (config.should_reduce_tail()) {
|
| 360 |
+
int idx = tail_start + threadIdx.x;
|
| 361 |
+
if (idx < end) {
|
| 362 |
+
value_list[0] = reducer::reduce(value_list[0], data[idx], idx + shift);
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
// combine accumulators
|
| 367 |
+
#pragma unroll
|
| 368 |
+
for (int i = 1; i < input_vec_size; i++) {
|
| 369 |
+
value_list[0] = reducer::combine(value_list[0], value_list[i]);
|
| 370 |
+
}
|
| 371 |
+
return value_list[0];
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
template <int output_vec_size, typename offset_calc_t>
|
| 375 |
+
C10_DEVICE Array<arg_t, output_vec_size> thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
|
| 376 |
+
uint32_t idx = config.input_idx();
|
| 377 |
+
const uint32_t end = config.num_inputs;
|
| 378 |
+
const uint32_t stride = config.step_input;
|
| 379 |
+
const int vt0=${vt0};
|
| 380 |
+
|
| 381 |
+
using arg_vec_t = Array<arg_t, output_vec_size>;
|
| 382 |
+
using load_t = aligned_vector<scalar_t, output_vec_size>;
|
| 383 |
+
const load_t* data = reinterpret_cast<const load_t*>(data_);
|
| 384 |
+
|
| 385 |
+
// Multiple accumulators to remove dependency between unrolled loops.
|
| 386 |
+
arg_vec_t value_list[vt0];
|
| 387 |
+
|
| 388 |
+
#pragma unroll
|
| 389 |
+
for (int i = 0; i < vt0; i++) {
|
| 390 |
+
#pragma unroll
|
| 391 |
+
for (int j = 0; j < output_vec_size; j++) {
|
| 392 |
+
value_list[i][j] = ident;
|
| 393 |
+
}
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
load_t values[vt0];
|
| 397 |
+
|
| 398 |
+
while (idx + (vt0 - 1) * stride < end) {
|
| 399 |
+
#pragma unroll
|
| 400 |
+
for (uint32_t i = 0; i < vt0; i++) {
|
| 401 |
+
values[i] = data[calc(idx + i * stride) / output_vec_size];
|
| 402 |
+
}
|
| 403 |
+
#pragma unroll
|
| 404 |
+
for (uint32_t i = 0; i < vt0; i++) {
|
| 405 |
+
#pragma unroll
|
| 406 |
+
for (uint32_t j = 0; j < output_vec_size; j++) {
|
| 407 |
+
value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx + i * stride);
|
| 408 |
+
}
|
| 409 |
+
}
|
| 410 |
+
idx += stride * vt0;
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
// tail
|
| 414 |
+
int idx_ = idx;
|
| 415 |
+
#pragma unroll
|
| 416 |
+
for (uint32_t i = 0; i < vt0; i++) {
|
| 417 |
+
if (idx >= end) {
|
| 418 |
+
break;
|
| 419 |
+
}
|
| 420 |
+
values[i] = data[calc(idx) / output_vec_size];
|
| 421 |
+
idx += stride;
|
| 422 |
+
}
|
| 423 |
+
idx = idx_;
|
| 424 |
+
#pragma unroll
|
| 425 |
+
for (uint32_t i = 0; i < vt0; i++) {
|
| 426 |
+
if (idx >= end) {
|
| 427 |
+
break;
|
| 428 |
+
}
|
| 429 |
+
#pragma unroll
|
| 430 |
+
for (uint32_t j = 0; j < output_vec_size; j++) {
|
| 431 |
+
value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx);
|
| 432 |
+
}
|
| 433 |
+
idx += stride;
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
// combine accumulators
|
| 437 |
+
#pragma unroll
|
| 438 |
+
for (int i = 1; i < vt0; i++) {
|
| 439 |
+
#pragma unroll
|
| 440 |
+
for (uint32_t j = 0; j < output_vec_size; j++) {
|
| 441 |
+
value_list[0][j] = reducer::combine(value_list[0][j], value_list[i][j]);
|
| 442 |
+
}
|
| 443 |
+
}
|
| 444 |
+
return value_list[0];
|
| 445 |
+
}
|
| 446 |
+
template <int output_vec_size>
|
| 447 |
+
C10_DEVICE Array<arg_t, output_vec_size> block_x_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
|
| 448 |
+
using args_vec_t = Array<arg_t, output_vec_size>;
|
| 449 |
+
int dim_x = blockDim.x;
|
| 450 |
+
args_vec_t* shared = (args_vec_t*)shared_memory;
|
| 451 |
+
if (dim_x > warpSize) {
|
| 452 |
+
int address_base = threadIdx.x + threadIdx.y*blockDim.x;
|
| 453 |
+
shared[address_base] = value;
|
| 454 |
+
for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
|
| 455 |
+
__syncthreads();
|
| 456 |
+
if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
|
| 457 |
+
args_vec_t other = shared[address_base + offset];
|
| 458 |
+
#pragma unroll
|
| 459 |
+
for (int i = 0; i < output_vec_size; i++) {
|
| 460 |
+
value[i] = reducer::combine(value[i], other[i]);
|
| 461 |
+
}
|
| 462 |
+
shared[address_base] = value;
|
| 463 |
+
}
|
| 464 |
+
}
|
| 465 |
+
dim_x = warpSize;
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
__syncthreads();
|
| 469 |
+
|
| 470 |
+
for (int offset = 1; offset < dim_x; offset <<= 1) {
|
| 471 |
+
#pragma unroll
|
| 472 |
+
for (int i = 0; i < output_vec_size; i++) {
|
| 473 |
+
arg_t other = reducer::warp_shfl_down(value[i], offset);
|
| 474 |
+
value[i] = reducer::combine(value[i], other);
|
| 475 |
+
}
|
| 476 |
+
}
|
| 477 |
+
return value;
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
template <int output_vec_size>
|
| 481 |
+
C10_DEVICE Array<arg_t, output_vec_size> block_y_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
|
| 482 |
+
using args_vec_t = Array<arg_t, output_vec_size>;
|
| 483 |
+
args_vec_t* shared = (args_vec_t*)shared_memory;
|
| 484 |
+
shared[config.shared_memory_offset(0)] = value;
|
| 485 |
+
for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
|
| 486 |
+
__syncthreads();
|
| 487 |
+
if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
|
| 488 |
+
args_vec_t other = shared[config.shared_memory_offset(offset)];
|
| 489 |
+
#pragma unroll
|
| 490 |
+
for (int i = 0; i < output_vec_size; i++) {
|
| 491 |
+
value[i] = reducer::combine(value[i], other[i]);
|
| 492 |
+
}
|
| 493 |
+
shared[config.shared_memory_offset(0)] = value;
|
| 494 |
+
}
|
| 495 |
+
}
|
| 496 |
+
return value;
|
| 497 |
+
}
|
| 498 |
+
)ESCAPE";
|
| 499 |
+
|
| 500 |
+
const std::string reduction_template_1 = R"ESCAPE(
|
| 501 |
+
|
| 502 |
+
C10_DEVICE bool mark_block_finished() const {
|
| 503 |
+
__shared__ bool is_last_block_done_shared;
|
| 504 |
+
|
| 505 |
+
__syncthreads();
|
| 506 |
+
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
| 507 |
+
int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
|
| 508 |
+
is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1);
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
__syncthreads();
|
| 512 |
+
|
| 513 |
+
return is_last_block_done_shared;
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
template <int output_vec_size>
|
| 517 |
+
C10_DEVICE Array<arg_t, output_vec_size> accumulate_in_output(
|
| 518 |
+
Array<out_scalar_t*, output_vec_size> out,
|
| 519 |
+
Array<arg_t, output_vec_size> value
|
| 520 |
+
) const {
|
| 521 |
+
Array<arg_t, output_vec_size> ret;
|
| 522 |
+
#pragma unroll
|
| 523 |
+
for (int i = 0; i < output_vec_size; i++) {
|
| 524 |
+
ret[i] = reducer::combine(*(out[i]), value[i]);
|
| 525 |
+
}
|
| 526 |
+
return ret;
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
C10_DEVICE out_scalar_t get_accumulated_output(
|
| 531 |
+
out_scalar_t* out, arg_t value
|
| 532 |
+
) const {
|
| 533 |
+
assert(!final_output);
|
| 534 |
+
return (out_scalar_t)value;
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
template<class T>
|
| 538 |
+
C10_DEVICE void set_results(const T x, const uint32_t base_offset) const {
|
| 539 |
+
assert(noutputs == 1);
|
| 540 |
+
auto res = (out_scalar_t*)((char*)dst[0] + base_offset);
|
| 541 |
+
*res = x;
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
//TODO - multi-output reduction - we won't be able to use thrust::pair
|
| 545 |
+
//just explicitly specify typed output reads/writes
|
| 546 |
+
//Currently implemented for max of two outputs
|
| 547 |
+
// template<class T1, class T2>
|
| 548 |
+
// C10_DEVICE void set_results(const thrust::pair<T1, T2> x, const index_t base_offset) const {
|
| 549 |
+
// if (noutputs >= 1) {
|
| 550 |
+
// auto res0 = (T1*)((char*)dst[0] + base_offset);
|
| 551 |
+
// *res0 = x.first;
|
| 552 |
+
// }
|
| 553 |
+
// if (noutputs >= 2) {
|
| 554 |
+
// // base offset is computed assuming element size being sizeof(T1), so we need to make a
|
| 555 |
+
// // correction to obtain the correct base offset
|
| 556 |
+
// auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2));
|
| 557 |
+
// *res1 = x.second;
|
| 558 |
+
// }
|
| 559 |
+
// }
|
| 560 |
+
|
| 561 |
+
template <int output_vec_size>
|
| 562 |
+
C10_DEVICE void set_results_to_output(Array<arg_t, output_vec_size> value, Array<uint32_t, output_vec_size> base_offset) const {
|
| 563 |
+
assert(final_output);
|
| 564 |
+
#pragma unroll
|
| 565 |
+
for (int i = 0; i < output_vec_size; i++) {
|
| 566 |
+
set_results(reducer::project(value[i]), base_offset[i]);
|
| 567 |
+
}
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
template <int output_vec_size>
|
| 571 |
+
C10_DEVICE Array<arg_t, output_vec_size> global_reduce(Array<arg_t, output_vec_size> value, Array<arg_t, output_vec_size> *acc, char* shared_memory) const {
|
| 572 |
+
using arg_vec_t = Array<arg_t, output_vec_size>;
|
| 573 |
+
using out_ptr_vec_t = Array<out_scalar_t*, output_vec_size>;
|
| 574 |
+
using offset_vec_t = Array<uint32_t, output_vec_size>;
|
| 575 |
+
|
| 576 |
+
arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf;
|
| 577 |
+
uint32_t output_idx = config.output_idx<output_vec_size>();
|
| 578 |
+
offset_vec_t base_offsets;
|
| 579 |
+
out_ptr_vec_t out;
|
| 580 |
+
|
| 581 |
+
#pragma unroll
|
| 582 |
+
for (int i = 0; i < output_vec_size; i++) {
|
| 583 |
+
base_offsets[i] = output_calc.get(output_idx + i)[0];
|
| 584 |
+
out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
bool should_store = config.should_store(output_idx);
|
| 588 |
+
if (should_store) {
|
| 589 |
+
uint32_t offset = config.staging_memory_offset(blockIdx.y);
|
| 590 |
+
reduce_buffer[offset] = value;
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
__threadfence(); // make sure writes are globally visible
|
| 594 |
+
__syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done
|
| 595 |
+
bool is_last_block_done = mark_block_finished();
|
| 596 |
+
|
| 597 |
+
if (is_last_block_done) {
|
| 598 |
+
value = ident;
|
| 599 |
+
if (config.should_block_x_reduce()) {
|
| 600 |
+
uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
|
| 601 |
+
uint32_t step = blockDim.x * blockDim.y;
|
| 602 |
+
for (; input_offset < config.ctas_per_output; input_offset += step) {
|
| 603 |
+
uint32_t idx = config.staging_memory_offset(input_offset);
|
| 604 |
+
arg_vec_t next = reduce_buffer[idx];
|
| 605 |
+
#pragma unroll
|
| 606 |
+
for (int i = 0; i < output_vec_size; i++) {
|
| 607 |
+
value[i] = reducer::combine(value[i], next[i]);
|
| 608 |
+
}
|
| 609 |
+
}
|
| 610 |
+
} else {
|
| 611 |
+
uint32_t input_offset = threadIdx.y;
|
| 612 |
+
uint32_t step = blockDim.y;
|
| 613 |
+
for (; input_offset < config.ctas_per_output; input_offset += step) {
|
| 614 |
+
uint32_t idx = config.staging_memory_offset(input_offset);
|
| 615 |
+
arg_vec_t next = reduce_buffer[idx];
|
| 616 |
+
#pragma unroll
|
| 617 |
+
for (int i = 0; i < output_vec_size; i++) {
|
| 618 |
+
value[i] = reducer::combine(value[i], next[i]);
|
| 619 |
+
}
|
| 620 |
+
}
|
| 621 |
+
}
|
| 622 |
+
value = block_y_reduce(value, shared_memory);
|
| 623 |
+
if (config.should_block_x_reduce()) {
|
| 624 |
+
value = block_x_reduce<output_vec_size>(value, shared_memory);
|
| 625 |
+
}
|
| 626 |
+
if (should_store) {
|
| 627 |
+
if (accumulate) {
|
| 628 |
+
#pragma unroll
|
| 629 |
+
for (int i = 0; i < output_vec_size; i++) {
|
| 630 |
+
value[i] = reducer::translate_idx(value[i], base_idx);
|
| 631 |
+
}
|
| 632 |
+
}
|
| 633 |
+
|
| 634 |
+
if (acc == nullptr) {
|
| 635 |
+
if (accumulate) {
|
| 636 |
+
value = accumulate_in_output<output_vec_size>(out, value);
|
| 637 |
+
}
|
| 638 |
+
if (final_output) {
|
| 639 |
+
set_results_to_output<output_vec_size>(value, base_offsets);
|
| 640 |
+
} else {
|
| 641 |
+
#pragma unroll
|
| 642 |
+
for (int i = 0; i < output_vec_size; i++) {
|
| 643 |
+
*(out[i]) = get_accumulated_output(out[i], value[i]);
|
| 644 |
+
}
|
| 645 |
+
}
|
| 646 |
+
} else {
|
| 647 |
+
if (accumulate) {
|
| 648 |
+
#pragma unroll
|
| 649 |
+
for (int i = 0; i < output_vec_size; i++) {
|
| 650 |
+
value[i] = reducer::combine((*acc)[i], value[i]);
|
| 651 |
+
}
|
| 652 |
+
}
|
| 653 |
+
if (final_output) {
|
| 654 |
+
set_results_to_output<output_vec_size>(value, base_offsets);
|
| 655 |
+
} else {
|
| 656 |
+
*acc = value;
|
| 657 |
+
}
|
| 658 |
+
}
|
| 659 |
+
}
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
return value;
|
| 663 |
+
}
|
| 664 |
+
};
|
| 665 |
+
|
| 666 |
+
extern "C"
|
| 667 |
+
__launch_bounds__(${max_threads_lb}, 4)
|
| 668 |
+
__global__ void reduction_${name}_kernel(ReduceJitOp r){
|
| 669 |
+
r.run();
|
| 670 |
+
}
|
| 671 |
+
)ESCAPE";
|
| 672 |
+
|
| 673 |
+
const std::string reduction_template = reduction_template_0 + reduction_template_1;
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
const std::string &get_reduction_template() {
|
| 677 |
+
return reduction_template;
|
| 678 |
+
}
|
| 679 |
+
|
| 680 |
+
}}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/thread_constants.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/macros/Macros.h>
|
| 3 |
+
|
| 4 |
+
// Marks a lambda as executable on both the host and device. The __host__
|
| 5 |
+
// attribute is important so that we can access static type information from
|
| 6 |
+
// the host, even if the function is typically only executed on the device.
|
| 7 |
+
#ifndef GPU_LAMBDA
|
| 8 |
+
#define GPU_LAMBDA __host__ __device__
|
| 9 |
+
#endif
|
| 10 |
+
|
| 11 |
+
#if defined(USE_ROCM)
|
| 12 |
+
constexpr int num_threads() {
|
| 13 |
+
return 256;
|
| 14 |
+
}
|
| 15 |
+
#else
|
| 16 |
+
constexpr uint32_t num_threads() {
|
| 17 |
+
return C10_WARP_SIZE * 4;
|
| 18 |
+
}
|
| 19 |
+
#endif
|
| 20 |
+
|
| 21 |
+
constexpr int thread_work_size() { return 4; }
|
| 22 |
+
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/OperationUtils.h
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
| 6 |
+
#include <ATen/Tensor.h>
|
| 7 |
+
#include <ATen/Utils.h>
|
| 8 |
+
#include <ATen/mps/MPSStream.h>
|
| 9 |
+
#include <ATen/native/mps/TensorFactory.h>
|
| 10 |
+
#include <c10/util/Optional.h>
|
| 11 |
+
#include <c10/core/ScalarType.h>
|
| 12 |
+
#include <torch/library.h>
|
| 13 |
+
#include <exception>
|
| 14 |
+
#include <unordered_map>
|
| 15 |
+
|
| 16 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 17 |
+
#include <ATen/Functions.h>
|
| 18 |
+
#include <ATen/NativeFunctions.h>
|
| 19 |
+
#else
|
| 20 |
+
#include <ATen/ops/empty.h>
|
| 21 |
+
#include <ATen/ops/empty_like.h>
|
| 22 |
+
#include <ATen/ops/zeros.h>
|
| 23 |
+
#include <ATen/ops/zeros_like.h>
|
| 24 |
+
#endif
|
| 25 |
+
|
| 26 |
+
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
| 27 |
+
|
| 28 |
+
// Fwd declarations
|
| 29 |
+
namespace at {
|
| 30 |
+
struct TensorIteratorBase;
|
| 31 |
+
}
|
| 32 |
+
using namespace at::mps;
|
| 33 |
+
|
| 34 |
+
namespace at::native::mps {
|
| 35 |
+
|
| 36 |
+
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
|
| 37 |
+
|
| 38 |
+
struct MPSScalar {
|
| 39 |
+
id<MTLBuffer> getMTLBuffer() const { return __builtin_bit_cast(id<MTLBuffer>, buffer.get()); }
|
| 40 |
+
|
| 41 |
+
size_t size = 0;
|
| 42 |
+
ScalarType type = ScalarType::Undefined;
|
| 43 |
+
c10::DataPtr buffer; // stores MTLBuffer (frees buffer if MPSScalar instance goes out of scope)
|
| 44 |
+
union {
|
| 45 |
+
float f; // MPS doesn't support 'double'
|
| 46 |
+
at::Half h;
|
| 47 |
+
int64_t i;
|
| 48 |
+
bool b;
|
| 49 |
+
c10::complex<float> cf;
|
| 50 |
+
c10::complex<at::Half> ch;
|
| 51 |
+
at::BFloat16 bf16;
|
| 52 |
+
} value {};
|
| 53 |
+
};
|
| 54 |
+
|
| 55 |
+
void runMPSGraph(MPSStream* mpsStream,
|
| 56 |
+
MPSGraph* mpsGraph,
|
| 57 |
+
NSDictionary* feeds,
|
| 58 |
+
NSDictionary* results);
|
| 59 |
+
|
| 60 |
+
MPSDataType getMPSDataType(ScalarType scalar_type);
|
| 61 |
+
static inline MPSDataType getMPSDataType(const Tensor& t) {
|
| 62 |
+
return getMPSDataType(t.scalar_type());
|
| 63 |
+
}
|
| 64 |
+
MPSDataType getMPSScalarType(ScalarType scalar_type);
|
| 65 |
+
static inline MPSDataType getMPSScalarType(const Tensor& t) {
|
| 66 |
+
return getMPSScalarType(t.scalar_type());
|
| 67 |
+
}
|
| 68 |
+
MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type);
|
| 69 |
+
std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false);
|
| 70 |
+
static inline std::string getMPSTypeString(const Tensor& t, bool short_name = false) {
|
| 71 |
+
return getMPSTypeString(t.scalar_type(), short_name);
|
| 72 |
+
}
|
| 73 |
+
std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type);
|
| 74 |
+
NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
|
| 75 |
+
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
|
| 76 |
+
std::string getMPSShapeString(MPSShape* shape);
|
| 77 |
+
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true);
|
| 78 |
+
std::string getArrayRefString(const IntArrayRef s);
|
| 79 |
+
// use has_storage() on the returned tensor to determine if src actually is a view
|
| 80 |
+
Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst);
|
| 81 |
+
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output);
|
| 82 |
+
bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape);
|
| 83 |
+
MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType);
|
| 84 |
+
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
|
| 85 |
+
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
|
| 86 |
+
|
| 87 |
+
// The MPSShape could vary based on memory format
|
| 88 |
+
MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
|
| 89 |
+
MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
|
| 90 |
+
|
| 91 |
+
static inline id<MTLBuffer> getMTLBufferStorage(const at::Tensor& tensor) {
|
| 92 |
+
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
class Placeholder {
|
| 96 |
+
public:
|
| 97 |
+
Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {}
|
| 98 |
+
Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {}
|
| 99 |
+
Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr,
|
| 100 |
+
bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid);
|
| 101 |
+
MPSGraphTensor* getMPSGraphTensor() {
|
| 102 |
+
return _placeholder;
|
| 103 |
+
}
|
| 104 |
+
MPSGraphTensorData* getMPSGraphTensorData() {
|
| 105 |
+
return _value;
|
| 106 |
+
}
|
| 107 |
+
bool isIntermediate() {
|
| 108 |
+
return _value == nullptr;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
private:
|
| 112 |
+
MPSGraphTensor* _placeholder;
|
| 113 |
+
MPSGraphTensorData* _value;
|
| 114 |
+
Tensor _tensor;
|
| 115 |
+
};
|
| 116 |
+
|
| 117 |
+
void resize_tensor(Tensor* output);
|
| 118 |
+
Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device);
|
| 119 |
+
MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
|
| 120 |
+
MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor);
|
| 121 |
+
MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType);
|
| 122 |
+
MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType);
|
| 123 |
+
MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor);
|
| 124 |
+
MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar);
|
| 125 |
+
|
| 126 |
+
MPSGraph* make_mps_graph();
|
| 127 |
+
void printTensorNDArray(const Tensor& t);
|
| 128 |
+
MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType mpsType);
|
| 129 |
+
|
| 130 |
+
MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
|
| 131 |
+
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape);
|
| 132 |
+
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor);
|
| 133 |
+
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
|
| 134 |
+
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar);
|
| 135 |
+
|
| 136 |
+
string get_mem_format_string(c10::MemoryFormat memory_format);
|
| 137 |
+
|
| 138 |
+
using MPSCacheKey = uint64_t;
|
| 139 |
+
|
| 140 |
+
// derive this class to cache a graph and its inputs/outputs
|
| 141 |
+
// can be used to store any NSObject
|
| 142 |
+
struct MPSCachedGraph
|
| 143 |
+
{
|
| 144 |
+
MPSCachedGraph(NSObject *object) : _object([object retain]) {}
|
| 145 |
+
virtual ~MPSCachedGraph() {
|
| 146 |
+
[_object release];
|
| 147 |
+
_object = nullptr;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
template<typename T>
|
| 151 |
+
inline T* as() {
|
| 152 |
+
return static_cast<T*>(this);
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
MPSGraph *graph() const { return (MPSGraph *)_object; }
|
| 156 |
+
NSObject *object() const { return _object; }
|
| 157 |
+
private:
|
| 158 |
+
NSObject *_object = nullptr;
|
| 159 |
+
};
|
| 160 |
+
|
| 161 |
+
struct MPSUnaryCachedGraph : public MPSCachedGraph
|
| 162 |
+
{
|
| 163 |
+
MPSUnaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
| 164 |
+
MPSGraphTensor *inputTensor_ = nil;
|
| 165 |
+
MPSGraphTensor *outputTensor_ = nil;
|
| 166 |
+
};
|
| 167 |
+
|
| 168 |
+
struct MPSUnaryGradCachedGraph : public MPSCachedGraph
|
| 169 |
+
{
|
| 170 |
+
MPSUnaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
| 171 |
+
MPSGraphTensor *gradOutputTensor_ = nil;
|
| 172 |
+
MPSGraphTensor *inputTensor_ = nil;
|
| 173 |
+
MPSGraphTensor *outputTensor_ = nil; // some backward input is actually the forward's output
|
| 174 |
+
MPSGraphTensor *gradInputTensor_ = nil;
|
| 175 |
+
};
|
| 176 |
+
|
| 177 |
+
struct MPSBinaryCachedGraph : public MPSCachedGraph
|
| 178 |
+
{
|
| 179 |
+
MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
| 180 |
+
MPSGraphTensor *inputTensor_ = nil;
|
| 181 |
+
MPSGraphTensor *otherTensor_ = nil;
|
| 182 |
+
MPSGraphTensor *outputTensor_ = nil;
|
| 183 |
+
};
|
| 184 |
+
|
| 185 |
+
struct MPSBinaryGradCachedGraph : public MPSCachedGraph
|
| 186 |
+
{
|
| 187 |
+
MPSBinaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
| 188 |
+
MPSGraphTensor *gradOutputTensor_ = nil;
|
| 189 |
+
MPSGraphTensor *inputTensor_ = nil;
|
| 190 |
+
MPSGraphTensor *otherTensor_ = nil;
|
| 191 |
+
MPSGraphTensor *gradInputTensor_ = nil;
|
| 192 |
+
};
|
| 193 |
+
|
| 194 |
+
// TODO: Improve the overall design of MPSGraphCache.
|
| 195 |
+
// https://github.com/pytorch/pytorch/issues/77176
|
| 196 |
+
// Cache holding various keys mapped to graphs
|
| 197 |
+
struct MPSGraphCache
|
| 198 |
+
{
|
| 199 |
+
typedef MPSCachedGraph * (^CreateCachedGraphBlock)();
|
| 200 |
+
|
| 201 |
+
struct CacheEntry {
|
| 202 |
+
CacheEntry(const std::string& key, MPSCachedGraph *cachedGraph) : cachedGraph_(cachedGraph), key_(key) {}
|
| 203 |
+
MPSCachedGraph* cachedGraph_ = nullptr;
|
| 204 |
+
std::string key_;
|
| 205 |
+
};
|
| 206 |
+
|
| 207 |
+
public:
|
| 208 |
+
|
| 209 |
+
static MPSGraphCache* getInstance() {
|
| 210 |
+
if(_instance_cache == nullptr) {
|
| 211 |
+
_instance_cache = new MPSGraphCache();
|
| 212 |
+
}
|
| 213 |
+
return _instance_cache;
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
~MPSGraphCache() {
|
| 217 |
+
dispatch_release(serialQueue_);
|
| 218 |
+
|
| 219 |
+
for (const auto& i : cache_) {
|
| 220 |
+
delete i.second.cachedGraph_;
|
| 221 |
+
}
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
// Disallow the copy constructor and operator= functions
|
| 225 |
+
MPSGraphCache(const MPSGraphCache&) = delete;
|
| 226 |
+
void operator=(const MPSGraphCache&) = delete;
|
| 227 |
+
|
| 228 |
+
MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
|
| 229 |
+
|
| 230 |
+
__block MPSCachedGraph* cachedGraph = nil;
|
| 231 |
+
|
| 232 |
+
MPSCacheKey hash = std::hash<std::string>{}(key);
|
| 233 |
+
|
| 234 |
+
dispatch_sync_with_rethrow(serialQueue_, ^() {
|
| 235 |
+
// verify the cached entry doesn't already exist
|
| 236 |
+
if (cache_.count(hash) != 0) {
|
| 237 |
+
auto& entry = cache_.at(hash);
|
| 238 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
|
| 239 |
+
cachedGraph = entry.cachedGraph_;
|
| 240 |
+
} else {
|
| 241 |
+
cachedGraph = createCacheBlock();
|
| 242 |
+
CacheEntry entry(key, cachedGraph);
|
| 243 |
+
cache_.emplace(hash, entry);
|
| 244 |
+
profileCachedGraph(entry);
|
| 245 |
+
}
|
| 246 |
+
});
|
| 247 |
+
return cachedGraph;
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
template<typename T>
|
| 251 |
+
inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
|
| 252 |
+
return static_cast<T *>(CreateCachedGraph(key, createCacheBlock));
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
MPSCachedGraph* LookUp(const std::string& key) const {
|
| 256 |
+
|
| 257 |
+
__block MPSCachedGraph* cachedGraph = nullptr;
|
| 258 |
+
|
| 259 |
+
MPSCacheKey hash = std::hash<std::string>{}(key);
|
| 260 |
+
|
| 261 |
+
dispatch_sync(serialQueue_, ^() {
|
| 262 |
+
|
| 263 |
+
if (cache_.count(hash) != 0) {
|
| 264 |
+
auto& entry = cache_.at(hash);
|
| 265 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
|
| 266 |
+
cachedGraph = entry.cachedGraph_;
|
| 267 |
+
profileCachedGraph(entry);
|
| 268 |
+
}
|
| 269 |
+
});
|
| 270 |
+
return cachedGraph;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
template<typename T>
|
| 274 |
+
inline T* LookUpAs(const std::string& key) const {
|
| 275 |
+
return static_cast<T *>(LookUp(key));
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
private:
|
| 279 |
+
MPSGraphCache() {
|
| 280 |
+
serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL);
|
| 281 |
+
}
|
| 282 |
+
// this is defined in OperationUtils.mm to not include
|
| 283 |
+
// MPSProfiler.h in header OperationUtils.h
|
| 284 |
+
void profileCachedGraph(const CacheEntry& cacheEntry) const;
|
| 285 |
+
|
| 286 |
+
static MPSGraphCache* _instance_cache;
|
| 287 |
+
std::unordered_map<MPSCacheKey, CacheEntry> cache_;
|
| 288 |
+
dispatch_queue_t serialQueue_ = nullptr;
|
| 289 |
+
|
| 290 |
+
};
|
| 291 |
+
|
| 292 |
+
// Common template for creating graph with a specified cache if missing
|
| 293 |
+
template<typename T>
|
| 294 |
+
inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function<void(MPSGraph*, T*)> instantiate) {
|
| 295 |
+
auto cache_ = MPSGraphCache::getInstance();
|
| 296 |
+
if (auto rc = cache_->LookUpAs<T>(key)) {
|
| 297 |
+
return rc;
|
| 298 |
+
}
|
| 299 |
+
return cache_->CreateCachedGraphAs<T>(key, ^mps::MPSCachedGraph*() {
|
| 300 |
+
T* newCachedGraph = nil;
|
| 301 |
+
@autoreleasepool {
|
| 302 |
+
// Initialize graph
|
| 303 |
+
auto mpsGraph = mps::make_mps_graph();
|
| 304 |
+
newCachedGraph = new T(mpsGraph);
|
| 305 |
+
instantiate(mpsGraph, newCachedGraph);
|
| 306 |
+
}
|
| 307 |
+
return newCachedGraph;
|
| 308 |
+
});
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
// Common math operations
|
| 312 |
+
MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
|
| 313 |
+
|
| 314 |
+
#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
|
| 315 |
+
if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
|
| 316 |
+
TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, \
|
| 317 |
+
", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
/**
|
| 321 |
+
* Returns distance from lowest to highest element offset in given tensor.
|
| 322 |
+
*/
|
| 323 |
+
size_t compute_storage_numel_distance(const at::Tensor& t);
|
| 324 |
+
|
| 325 |
+
/**
|
| 326 |
+
* Checks whether tensor is mapped to a contiguous area in the storage.
|
| 327 |
+
*/
|
| 328 |
+
inline bool is_dense_in_storage(const at::Tensor& t) {
|
| 329 |
+
return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
static inline void mtl_setBuffer(id<MTLComputeCommandEncoder> encoder, const Tensor& t, unsigned idx) {
|
| 333 |
+
[encoder setBuffer:getMTLBufferStorage(t)
|
| 334 |
+
offset:t.storage_offset() * t.element_size()
|
| 335 |
+
atIndex:idx];
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
static inline void mtl_dispatch1DJob(id<MTLComputeCommandEncoder> encoder,
|
| 339 |
+
id<MTLComputePipelineState> cplState,
|
| 340 |
+
uint32_t length) {
|
| 341 |
+
const uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup];
|
| 342 |
+
auto size = MTLSizeMake(length, 1, 1);
|
| 343 |
+
auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1);
|
| 344 |
+
[encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize];
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder, const TensorIteratorBase& iter, bool use_64bit_index = false);
|
| 348 |
+
|
| 349 |
+
inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) {
|
| 350 |
+
return @{ p1.getMPSGraphTensor(): p1.getMPSGraphTensorData() };
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2) {
|
| 354 |
+
return @{
|
| 355 |
+
p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
|
| 356 |
+
p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
|
| 357 |
+
};
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3) {
|
| 361 |
+
return @{
|
| 362 |
+
p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
|
| 363 |
+
p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
|
| 364 |
+
p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
|
| 365 |
+
};
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3, Placeholder& p4) {
|
| 369 |
+
return @{
|
| 370 |
+
p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
|
| 371 |
+
p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
|
| 372 |
+
p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
|
| 373 |
+
p4.getMPSGraphTensor(): p4.getMPSGraphTensorData(),
|
| 374 |
+
};
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds, Placeholder& result) {
|
| 378 |
+
runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result));
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
inline bool supportsComplex() {
|
| 382 |
+
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS);
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
// MPS yet to support double types, but starting from MacOS 14, supports bfloat16
|
| 386 |
+
inline bool supportedFloatingType(ScalarType dtype) {
|
| 387 |
+
return dtype == kFloat || dtype == kHalf || dtype == kBFloat16;
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
inline bool supportedFloatingType(const Tensor& t) {
|
| 391 |
+
return supportedFloatingType(t.scalar_type());
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
} // namespace at::native::mps
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/TensorFactory.h
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#define AT_DISPATCH_MPS_TYPES(TYPE, NAME, ...) \
|
| 4 |
+
AT_DISPATCH_SWITCH( \
|
| 5 |
+
TYPE, NAME, \
|
| 6 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
| 7 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
| 8 |
+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
|
| 9 |
+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
| 10 |
+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
| 11 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
| 12 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Transformer-specific NestedTensor utility functions.
|
| 3 |
+
*
|
| 4 |
+
* Not co-located with NestedTensor core code yet because they only
|
| 5 |
+
* support specific cases needed in transformers.
|
| 6 |
+
*/
|
| 7 |
+
#pragma once
|
| 8 |
+
|
| 9 |
+
#include <vector>
|
| 10 |
+
|
| 11 |
+
#include <c10/macros/Macros.h>
|
| 12 |
+
#include <c10/util/Optional.h>
|
| 13 |
+
|
| 14 |
+
namespace c10 {
|
| 15 |
+
class Scalar;
|
| 16 |
+
} // namespace c10
|
| 17 |
+
|
| 18 |
+
namespace at {
|
| 19 |
+
class Tensor;
|
| 20 |
+
namespace native {
|
| 21 |
+
struct NestedTensorImpl;
|
| 22 |
+
|
| 23 |
+
// Requires that self is a contiguous NestedTensor, other is not a
|
| 24 |
+
// NestedTensor, self.dim() == 3, and other.dim() == 2. Also, self
|
| 25 |
+
// must have a consistent last dimension across its included Tensors
|
| 26 |
+
// and that dimension must match other.size(0).
|
| 27 |
+
Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other);
|
| 28 |
+
|
| 29 |
+
// Requires that mat1 is a contiguous NestedTensor, self & mat2 are
|
| 30 |
+
// not NestedTensors, mat1.dim() == 3, mat2.dim() == 2, and that mat1
|
| 31 |
+
// has a consistent last dimension across its included Tensors that
|
| 32 |
+
// matches mat2.size(0).
|
| 33 |
+
Tensor NestedTensor_times_Tensor_plus_Tensor_addmm(
|
| 34 |
+
const Tensor& self,
|
| 35 |
+
const Tensor& mat1,
|
| 36 |
+
const Tensor& mat2,
|
| 37 |
+
const c10::Scalar& beta,
|
| 38 |
+
const c10::Scalar& alpha,
|
| 39 |
+
c10::optional<bool> use_gelu = c10::nullopt);
|
| 40 |
+
|
| 41 |
+
Tensor NestedTensor_add_NestedTensor_in_place(
|
| 42 |
+
const Tensor& self,
|
| 43 |
+
const Tensor& other);
|
| 44 |
+
|
| 45 |
+
TORCH_API Tensor NestedTensor_batch_offsets_from_size_tensor(
|
| 46 |
+
const Tensor& sizes,
|
| 47 |
+
int64_t extra_elements);
|
| 48 |
+
|
| 49 |
+
Tensor NestedTensor_from_padded_tensor_cpu(
|
| 50 |
+
const Tensor& padded,
|
| 51 |
+
const NestedTensorImpl& nt);
|
| 52 |
+
|
| 53 |
+
Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim, c10::optional<int64_t> mask_dim_length);
|
| 54 |
+
|
| 55 |
+
template <typename T>
|
| 56 |
+
void remove_padding_kernelLauncher(
|
| 57 |
+
const T* input,
|
| 58 |
+
T* output,
|
| 59 |
+
const int* offsets,
|
| 60 |
+
const int* input_sizes,
|
| 61 |
+
const int* output_sizes,
|
| 62 |
+
int output_dim,
|
| 63 |
+
const int batch_size);
|
| 64 |
+
|
| 65 |
+
template <typename T>
|
| 66 |
+
void remove_padding_transform0213_kernelLauncher(
|
| 67 |
+
const T* input,
|
| 68 |
+
T* output,
|
| 69 |
+
const int* offsets,
|
| 70 |
+
const int* input_sizes,
|
| 71 |
+
const int* output_sizes,
|
| 72 |
+
int output_dim,
|
| 73 |
+
const int batch_size);
|
| 74 |
+
|
| 75 |
+
template <typename T>
|
| 76 |
+
void add_padding_kernelLauncher(
|
| 77 |
+
T* input,
|
| 78 |
+
T* output,
|
| 79 |
+
T padding_value,
|
| 80 |
+
const int* offsets,
|
| 81 |
+
const int* input_sizes,
|
| 82 |
+
int input_dim,
|
| 83 |
+
const std::vector<int64_t>& output_sizes,
|
| 84 |
+
const int batch_size,
|
| 85 |
+
const int output_batch_size);
|
| 86 |
+
|
| 87 |
+
TORCH_API Tensor flash_attention_helper(
|
| 88 |
+
const Tensor& query,
|
| 89 |
+
const Tensor& key,
|
| 90 |
+
const Tensor& value,
|
| 91 |
+
double dropout_p,
|
| 92 |
+
bool need_attn_weights,
|
| 93 |
+
bool is_causal);
|
| 94 |
+
|
| 95 |
+
TORCH_API std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked(
|
| 96 |
+
const Tensor& query,
|
| 97 |
+
const Tensor& key,
|
| 98 |
+
const Tensor& value,
|
| 99 |
+
double dropout_p,
|
| 100 |
+
bool need_attn_weights,
|
| 101 |
+
bool is_causal);
|
| 102 |
+
} // namespace native
|
| 103 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/Dispatch.h>
|
| 5 |
+
#include <ATen/native/DispatchStub.h>
|
| 6 |
+
#include <ATen/native/quantized/AffineQuantizerBase.h>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
namespace native {
|
| 10 |
+
|
| 11 |
+
Tensor& quantize_tensor_per_tensor_affine(
|
| 12 |
+
const Tensor& rtensor,
|
| 13 |
+
Tensor& qtensor,
|
| 14 |
+
double scale,
|
| 15 |
+
int64_t zero_point);
|
| 16 |
+
Tensor& quantize_tensor_per_channel_affine(
|
| 17 |
+
const Tensor& rtensor,
|
| 18 |
+
Tensor& qtensor,
|
| 19 |
+
Tensor scales,
|
| 20 |
+
Tensor zero_points,
|
| 21 |
+
int64_t axis);
|
| 22 |
+
|
| 23 |
+
Tensor& quantize_tensor_per_channel_float_qparams(
|
| 24 |
+
const Tensor& rtensor,
|
| 25 |
+
Tensor& qtensor,
|
| 26 |
+
Tensor scales,
|
| 27 |
+
Tensor zero_points,
|
| 28 |
+
int64_t axis);
|
| 29 |
+
|
| 30 |
+
Tensor& dequantize_tensor_per_tensor_affine(
|
| 31 |
+
const Tensor& qtensor,
|
| 32 |
+
Tensor& rtensor,
|
| 33 |
+
double scale,
|
| 34 |
+
int64_t zero_point);
|
| 35 |
+
Tensor& dequantize_tensor_per_channel_affine(
|
| 36 |
+
const Tensor& qtensor,
|
| 37 |
+
Tensor& rtensor,
|
| 38 |
+
Tensor scales,
|
| 39 |
+
Tensor zero_points,
|
| 40 |
+
int64_t axis);
|
| 41 |
+
Tensor& dequantize_tensor_per_channel_float_qparams(
|
| 42 |
+
const Tensor& qtensor,
|
| 43 |
+
Tensor& rtensor,
|
| 44 |
+
Tensor scales,
|
| 45 |
+
Tensor zero_points,
|
| 46 |
+
int64_t axis);
|
| 47 |
+
|
| 48 |
+
using quantize_tensor_per_tensor_affine_fn =
|
| 49 |
+
void (*)(const Tensor& rtensor, Tensor& qtensor, double scale, int64_t zero_point);
|
| 50 |
+
|
| 51 |
+
using quantize_tensor_per_channel_affine_fn = void (*)(
|
| 52 |
+
const Tensor& rtensor,
|
| 53 |
+
Tensor& qtensor,
|
| 54 |
+
const Tensor& scales,
|
| 55 |
+
const Tensor& zero_points,
|
| 56 |
+
int64_t axis);
|
| 57 |
+
|
| 58 |
+
using quantize_tensor_per_channel_float_qparams_fn = void (*)(
|
| 59 |
+
const Tensor& rtensor,
|
| 60 |
+
Tensor& qtensor,
|
| 61 |
+
const Tensor& scales,
|
| 62 |
+
const Tensor& zero_points,
|
| 63 |
+
int64_t axis);
|
| 64 |
+
|
| 65 |
+
using dequantize_tensor_per_tensor_affine_fn =
|
| 66 |
+
void (*)(const Tensor& qtensor, Tensor& rtensor, double scale, int64_t zero_point);
|
| 67 |
+
|
| 68 |
+
using dequantize_tensor_per_channel_affine_fn = void (*)(
|
| 69 |
+
const Tensor& qtensor,
|
| 70 |
+
Tensor& rtensor,
|
| 71 |
+
const Tensor& scales,
|
| 72 |
+
const Tensor& zero_points,
|
| 73 |
+
int64_t axis);
|
| 74 |
+
|
| 75 |
+
using dequantize_tensor_per_channel_float_qparams_fn = void (*)(
|
| 76 |
+
const Tensor& qtensor,
|
| 77 |
+
Tensor& rtensor,
|
| 78 |
+
const Tensor& scales,
|
| 79 |
+
const Tensor& zero_points,
|
| 80 |
+
int64_t axis);
|
| 81 |
+
|
| 82 |
+
using quantize_tensor_per_tensor_affine_sub_byte_fn =
|
| 83 |
+
void (*)(const Tensor& rtensor, Tensor& qtensor, float scale, float zero_point);
|
| 84 |
+
|
| 85 |
+
using dequantize_tensor_per_tensor_affine_sub_byte_fn =
|
| 86 |
+
void (*)(const Tensor& qtensor, Tensor& rtensor, float scale, float zero_point);
|
| 87 |
+
|
| 88 |
+
DECLARE_DISPATCH(
|
| 89 |
+
quantize_tensor_per_tensor_affine_fn,
|
| 90 |
+
quantize_tensor_per_tensor_affine_stub);
|
| 91 |
+
DECLARE_DISPATCH(
|
| 92 |
+
quantize_tensor_per_channel_affine_fn,
|
| 93 |
+
quantize_tensor_per_channel_affine_stub);
|
| 94 |
+
DECLARE_DISPATCH(
|
| 95 |
+
quantize_tensor_per_channel_float_qparams_fn,
|
| 96 |
+
quantize_tensor_per_channel_float_qparams_stub);
|
| 97 |
+
|
| 98 |
+
DECLARE_DISPATCH(
|
| 99 |
+
dequantize_tensor_per_tensor_affine_fn,
|
| 100 |
+
dequantize_tensor_per_tensor_affine_stub);
|
| 101 |
+
DECLARE_DISPATCH(
|
| 102 |
+
dequantize_tensor_per_channel_affine_fn,
|
| 103 |
+
dequantize_tensor_per_channel_affine_stub);
|
| 104 |
+
DECLARE_DISPATCH(
|
| 105 |
+
dequantize_tensor_per_channel_float_qparams_fn,
|
| 106 |
+
dequantize_tensor_per_channel_float_qparams_stub);
|
| 107 |
+
|
| 108 |
+
DECLARE_DISPATCH(
|
| 109 |
+
quantize_tensor_per_tensor_affine_sub_byte_fn,
|
| 110 |
+
quantize_tensor_per_tensor_affine_sub_byte_stub);
|
| 111 |
+
|
| 112 |
+
DECLARE_DISPATCH(
|
| 113 |
+
dequantize_tensor_per_tensor_affine_sub_byte_fn,
|
| 114 |
+
dequantize_tensor_per_tensor_affine_sub_byte_stub);
|
| 115 |
+
|
| 116 |
+
template <typename T>
|
| 117 |
+
TORCH_API Tensor quantize_tensor(
|
| 118 |
+
Tensor rtensor,
|
| 119 |
+
Tensor qtensor,
|
| 120 |
+
double scale,
|
| 121 |
+
int64_t zero_point);
|
| 122 |
+
template <typename T>
|
| 123 |
+
TORCH_API Tensor dequantize_tensor(
|
| 124 |
+
Tensor qtensor,
|
| 125 |
+
Tensor rtensor,
|
| 126 |
+
double scale,
|
| 127 |
+
int64_t zero_point);
|
| 128 |
+
|
| 129 |
+
} // namespace native
|
| 130 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/ConvUtils.h
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/List.h>
|
| 3 |
+
#include <ATen/native/ConvUtils.h>
|
| 4 |
+
|
| 5 |
+
namespace at::native::quantized {
|
| 6 |
+
namespace {
|
| 7 |
+
// MakeConvOutputShape used from both CPU and CUDA libraries
|
| 8 |
+
// and exporting symbol from torch_cpu would probably take more storage
|
| 9 |
+
// than duplicating implementation which likely be inlined away
|
| 10 |
+
template <int kSpatialDim>
|
| 11 |
+
at::SmallVector<int64_t, kSpatialDim + 2> MakeConvOutputShape(
|
| 12 |
+
int N, // mini-batch
|
| 13 |
+
int M, // output channels
|
| 14 |
+
const std::array<int64_t, kSpatialDim>& input_image_shape,
|
| 15 |
+
const std::vector<int64_t>& kernel,
|
| 16 |
+
const torch::List<int64_t>& stride,
|
| 17 |
+
const torch::List<int64_t>& padding,
|
| 18 |
+
const torch::List<int64_t>& dilation);
|
| 19 |
+
|
| 20 |
+
#if defined(USE_CUDA) || defined(USE_PYTORCH_QNNPACK)
|
| 21 |
+
template <>
|
| 22 |
+
at::SmallVector<int64_t, 4> MakeConvOutputShape<2>(
|
| 23 |
+
int N, // mini-batch
|
| 24 |
+
int M, // output channels
|
| 25 |
+
const std::array<int64_t, 2>& input_image_shape,
|
| 26 |
+
const std::vector<int64_t>& kernel,
|
| 27 |
+
const at::List<int64_t>& stride,
|
| 28 |
+
const at::List<int64_t>& padding,
|
| 29 |
+
const at::List<int64_t>& dilation) {
|
| 30 |
+
const int H = input_image_shape[0];
|
| 31 |
+
const int W = input_image_shape[1];
|
| 32 |
+
const int64_t Y_H =
|
| 33 |
+
(H + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1;
|
| 34 |
+
const int64_t Y_W =
|
| 35 |
+
(W + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1;
|
| 36 |
+
return {N, M, Y_H, Y_W};
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
template <>
|
| 40 |
+
at::SmallVector<int64_t, 5> MakeConvOutputShape<3>(
|
| 41 |
+
int N, // mini-batch
|
| 42 |
+
int M, // output channels
|
| 43 |
+
const std::array<int64_t, 3>& input_image_shape,
|
| 44 |
+
const std::vector<int64_t>& kernel,
|
| 45 |
+
const at::List<int64_t>& stride,
|
| 46 |
+
const at::List<int64_t>& padding,
|
| 47 |
+
const torch::List<int64_t>& dilation) {
|
| 48 |
+
const int D = input_image_shape[0];
|
| 49 |
+
const int H = input_image_shape[1];
|
| 50 |
+
const int W = input_image_shape[2];
|
| 51 |
+
const int64_t Y_D =
|
| 52 |
+
(D + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1;
|
| 53 |
+
const int64_t Y_H =
|
| 54 |
+
(H + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1;
|
| 55 |
+
const int64_t Y_W =
|
| 56 |
+
(W + 2 * padding[2] - dilation[2] * (kernel[2] - 1) - 1) / stride[2] + 1;
|
| 57 |
+
return {N, M, Y_D, Y_H, Y_W};
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
#endif
|
| 61 |
+
} // anonymous namespace
|
| 62 |
+
} // namespace at::native::quantized
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/IndexKernel.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/TensorIterator.h>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
namespace native {
|
| 6 |
+
using masked_fill_kernel_quantized_fn = void(*)(TensorIterator& iter, const Scalar& value, double scale, int zero_point);
|
| 7 |
+
using index_put_kernel_quantized_fn = void(*)(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate, double scale, int zero_point);
|
| 8 |
+
|
| 9 |
+
DECLARE_DISPATCH(masked_fill_kernel_quantized_fn, masked_fill_kernel_quantized_stub);
|
| 10 |
+
DECLARE_DISPATCH(index_put_kernel_quantized_fn, index_put_kernel_quantized_stub);
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
} // native
|
| 14 |
+
} // at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/PackedParams.h
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/core/ivalue.h>
|
| 5 |
+
|
| 6 |
+
struct LinearPackedParamsBase : public torch::jit::CustomClassHolder {
|
| 7 |
+
virtual at::Tensor apply(
|
| 8 |
+
at::Tensor input,
|
| 9 |
+
double output_scale,
|
| 10 |
+
int64_t output_zero_point) = 0;
|
| 11 |
+
virtual at::Tensor apply_relu(
|
| 12 |
+
at::Tensor input,
|
| 13 |
+
double output_scale,
|
| 14 |
+
int64_t output_zero_point) = 0;
|
| 15 |
+
|
| 16 |
+
// out variant of LinearPackedParamsBase::apply
|
| 17 |
+
virtual at::Tensor& apply_out(
|
| 18 |
+
const at::Tensor& /*input*/,
|
| 19 |
+
double /*output_scale*/,
|
| 20 |
+
int64_t /*output_zero_point*/,
|
| 21 |
+
at::Tensor& output) {
|
| 22 |
+
throw std::runtime_error(
|
| 23 |
+
"apply_out is not implemented for this packed "
|
| 24 |
+
"parameter type");
|
| 25 |
+
return output;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
virtual at::Tensor& apply_relu_out(
|
| 29 |
+
const at::Tensor& /*input*/,
|
| 30 |
+
double /*output_scale*/,
|
| 31 |
+
int64_t /*output_zero_point*/,
|
| 32 |
+
at::Tensor& output) {
|
| 33 |
+
throw std::runtime_error(
|
| 34 |
+
"apply_relu_out is not implemented for this packed "
|
| 35 |
+
"parameter type");
|
| 36 |
+
return output;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
// Corresponding pattern (the ops with `*` are part of the pattern that
|
| 40 |
+
// represents the computation of quantized::linear_with_input_q_dq_qweight_dq_output_fp32):
|
| 41 |
+
// input -> q* -> dq* -> linear* ->
|
| 42 |
+
// qweight -> dq* /
|
| 43 |
+
//
|
| 44 |
+
// After fusion:
|
| 45 |
+
// input -> quantized::linear_with_input_q_dq_qweight_dq_output_fp32* ->
|
| 46 |
+
// qweight /
|
| 47 |
+
//
|
| 48 |
+
// Additional Note: the weight is packed as well
|
| 49 |
+
// Params:
|
| 50 |
+
// X: float32 Tensor, will be quantized to quint8 in the op
|
| 51 |
+
// W_prepack: packed qint8 quantized weight and bias
|
| 52 |
+
// Returns:
|
| 53 |
+
// Y: float32 Tensor
|
| 54 |
+
virtual at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32(
|
| 55 |
+
at::Tensor input,
|
| 56 |
+
double input_scale,
|
| 57 |
+
int64_t input_zero_point) {
|
| 58 |
+
throw std::runtime_error(
|
| 59 |
+
"apply_with_input_q_dq_qweight_dq_output_fp32 is not implemented for this packed "
|
| 60 |
+
"parameter type");
|
| 61 |
+
return {};
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
// Corresponding pattern (the ops with `*` are part of the pattern that
|
| 65 |
+
// represents the computation of quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32):
|
| 66 |
+
// input -> q* -> dq* -> linear* -> relu* ->
|
| 67 |
+
// qweight -> dq* /
|
| 68 |
+
//
|
| 69 |
+
// After fusion:
|
| 70 |
+
// input -> quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32* ->
|
| 71 |
+
// qweight /
|
| 72 |
+
//
|
| 73 |
+
// Additional Note: the weight is packed as well
|
| 74 |
+
// Params:
|
| 75 |
+
// input: float32 Tensor, will be quantized to quint8 in the op
|
| 76 |
+
// Returns:
|
| 77 |
+
// float32 Tensor
|
| 78 |
+
virtual at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32(
|
| 79 |
+
at::Tensor input,
|
| 80 |
+
double input_scale,
|
| 81 |
+
int64_t input_zero_point) {
|
| 82 |
+
throw std::runtime_error(
|
| 83 |
+
"apply_with_input_q_dq_qweight_dq_relu_output_fp32 is not implemented for this packed "
|
| 84 |
+
"parameter type");
|
| 85 |
+
return {};
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
virtual at::Tensor apply_dynamic(
|
| 89 |
+
at::Tensor input,
|
| 90 |
+
bool reduce_range = false) = 0;
|
| 91 |
+
virtual at::Tensor apply_dynamic_relu(
|
| 92 |
+
at::Tensor input,
|
| 93 |
+
bool reduce_range = false) = 0;
|
| 94 |
+
|
| 95 |
+
virtual at::Tensor& apply_dynamic_out(
|
| 96 |
+
const at::Tensor& /* input */,
|
| 97 |
+
at::Tensor& output,
|
| 98 |
+
bool /* reduce_range */) {
|
| 99 |
+
throw std::runtime_error(
|
| 100 |
+
"apply_dynamic_out is not implemented for this packed "
|
| 101 |
+
"parameter type");
|
| 102 |
+
return output;
|
| 103 |
+
}
|
| 104 |
+
virtual at::Tensor& apply_dynamic_relu_out(
|
| 105 |
+
const at::Tensor& /* input */,
|
| 106 |
+
at::Tensor& output,
|
| 107 |
+
bool /* reduce_range */) {
|
| 108 |
+
throw std::runtime_error(
|
| 109 |
+
"apply_dynamic_relu_out is not implemented for this packed "
|
| 110 |
+
"parameter type");
|
| 111 |
+
return output;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
virtual std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() = 0;
|
| 115 |
+
|
| 116 |
+
virtual c10::optional<at::Tensor> bias() = 0;
|
| 117 |
+
|
| 118 |
+
virtual void set_bias(c10::optional<at::Tensor> /*bias*/) {
|
| 119 |
+
throw std::runtime_error(
|
| 120 |
+
"set_bias is not implemented for this packed "
|
| 121 |
+
"parameter type");
|
| 122 |
+
}
|
| 123 |
+
};
|
| 124 |
+
|
| 125 |
+
template <int kSpatialDim = 2>
|
| 126 |
+
struct ConvPackedParamsBase : public torch::jit::CustomClassHolder {
|
| 127 |
+
virtual at::Tensor apply(
|
| 128 |
+
const at::Tensor& input,
|
| 129 |
+
double output_scale,
|
| 130 |
+
int64_t output_zero_point) = 0;
|
| 131 |
+
virtual at::Tensor apply_relu(
|
| 132 |
+
const at::Tensor& input,
|
| 133 |
+
double output_scale,
|
| 134 |
+
int64_t output_zero_point) = 0;
|
| 135 |
+
virtual at::Tensor apply_dynamic(
|
| 136 |
+
const at::Tensor& input,
|
| 137 |
+
bool reduce_range) = 0;
|
| 138 |
+
|
| 139 |
+
virtual std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() = 0;
|
| 140 |
+
|
| 141 |
+
virtual torch::List<int64_t> stride() const = 0;
|
| 142 |
+
virtual torch::List<int64_t> padding() const = 0;
|
| 143 |
+
virtual torch::List<int64_t> output_padding() const = 0;
|
| 144 |
+
virtual torch::List<int64_t> dilation() const = 0;
|
| 145 |
+
virtual int64_t groups() const = 0;
|
| 146 |
+
virtual bool transpose() const = 0;
|
| 147 |
+
};
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/core/ivalue.h>
|
| 5 |
+
|
| 6 |
+
struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder {
|
| 7 |
+
virtual at::Tensor embeddingbag_byte(
|
| 8 |
+
const at::Tensor& indices,
|
| 9 |
+
const c10::optional<at::Tensor>& offsets,
|
| 10 |
+
bool pruned_weights,
|
| 11 |
+
const c10::optional<at::Tensor>& per_sample_weights_,
|
| 12 |
+
const c10::optional<at::Tensor>& compressed_indices_mapping,
|
| 13 |
+
bool include_last_offset,
|
| 14 |
+
bool is_embedding_op) = 0;
|
| 15 |
+
|
| 16 |
+
virtual at::Tensor embeddingbag_4bit(
|
| 17 |
+
const at::Tensor& indices,
|
| 18 |
+
const c10::optional<at::Tensor>& offsets,
|
| 19 |
+
bool pruned_weights,
|
| 20 |
+
const c10::optional<at::Tensor>& per_sample_weights_,
|
| 21 |
+
const c10::optional<at::Tensor>& compressed_indices_mapping,
|
| 22 |
+
bool include_last_offset,
|
| 23 |
+
bool is_embedding_op) = 0;
|
| 24 |
+
|
| 25 |
+
virtual at::Tensor unpack() = 0;
|
| 26 |
+
|
| 27 |
+
virtual int64_t bit_rate() const = 0;
|
| 28 |
+
virtual int64_t version() const = 0;
|
| 29 |
+
};
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifdef USE_PYTORCH_QNNPACK
|
| 4 |
+
#include <ATen/core/Tensor.h>
|
| 5 |
+
#include <c10/util/irange.h>
|
| 6 |
+
#include <pytorch_qnnpack.h>
|
| 7 |
+
#include <qnnpack_func.h>
|
| 8 |
+
#include <ATen/native/quantized/cpu/XnnpackUtils.h>
|
| 9 |
+
#include <ATen/native/quantized/PackedParams.h>
|
| 10 |
+
#include <ATen/native/utils/Factory.h>
|
| 11 |
+
|
| 12 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 13 |
+
#include <ATen/Functions.h>
|
| 14 |
+
#else
|
| 15 |
+
#include <ATen/ops/empty.h>
|
| 16 |
+
#endif
|
| 17 |
+
|
| 18 |
+
#include <utility>
|
| 19 |
+
inline int kPaddingChannels = 8;
|
| 20 |
+
struct QnnpackOperatorDeleter {
|
| 21 |
+
void operator()(pytorch_qnnp_operator_t op) {
|
| 22 |
+
pytorch_qnnp_delete_operator(op);
|
| 23 |
+
}
|
| 24 |
+
};
|
| 25 |
+
|
| 26 |
+
// PackedWeight struct for QNNPACK stores the original Weight and Bias as
|
| 27 |
+
// QNNPACK currently does not support an unpack function.
|
| 28 |
+
// For PyTorch Mobile, once the model is scripted and serialized we don't need
|
| 29 |
+
// to call unpack, so we can save some memory by checking for this case and free
|
| 30 |
+
// the original weights after packing.
|
| 31 |
+
// Input scale is set to null in pre-pack step. QNNPACK needs bias quantized
|
| 32 |
+
// with input scale which is available at runtime in pytorch. During runtime if
|
| 33 |
+
// input scale value changes then we requantize bias with the updated scale. For
|
| 34 |
+
// inference we expect the graph to be static so the input scale should not
|
| 35 |
+
// change across consecutive inference calls.
|
| 36 |
+
struct PackedLinearWeightsQnnp : public LinearPackedParamsBase {
|
| 37 |
+
PackedLinearWeightsQnnp(
|
| 38 |
+
std::unique_ptr<qnnpack::PackBMatrix> w,
|
| 39 |
+
at::Tensor orig_weight,
|
| 40 |
+
at::Tensor bias,
|
| 41 |
+
c10::optional<double> input_scale,
|
| 42 |
+
at::Tensor w_scales,
|
| 43 |
+
std::vector<uint8_t>&& w_zps)
|
| 44 |
+
: w(std::move(w)),
|
| 45 |
+
orig_weight(std::move(orig_weight)),
|
| 46 |
+
bias_(at::native::mobile::allocate_padded_contiguous_if_needed(
|
| 47 |
+
bias, bias.suggest_memory_format())),
|
| 48 |
+
per_channel_(this->orig_weight.qscheme() == at::kPerChannelAffine),
|
| 49 |
+
input_scale(std::move(input_scale)),
|
| 50 |
+
w_scales(std::move(w_scales)),
|
| 51 |
+
w_zero_points(std::move(w_zps)),
|
| 52 |
+
q_scheme(this->orig_weight.qscheme()) {
|
| 53 |
+
weight_sizes = this->orig_weight.sizes().vec();
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
std::unique_ptr<qnnpack::PackBMatrix> w;
|
| 57 |
+
at::Tensor orig_weight;
|
| 58 |
+
at::Tensor bias_;
|
| 59 |
+
bool per_channel_;
|
| 60 |
+
c10::optional<double> input_scale;
|
| 61 |
+
at::Tensor w_scales;
|
| 62 |
+
std::vector<uint8_t> w_zero_points;
|
| 63 |
+
std::vector<float> requantization_scales;
|
| 64 |
+
std::vector<int64_t> weight_sizes;
|
| 65 |
+
c10::QScheme q_scheme;
|
| 66 |
+
|
| 67 |
+
at::Tensor apply(
|
| 68 |
+
at::Tensor input,
|
| 69 |
+
double output_scale,
|
| 70 |
+
int64_t output_zero_point) override;
|
| 71 |
+
at::Tensor apply_relu(
|
| 72 |
+
at::Tensor input,
|
| 73 |
+
double output_scale,
|
| 74 |
+
int64_t output_zero_point) override;
|
| 75 |
+
|
| 76 |
+
at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override;
|
| 77 |
+
at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override;
|
| 78 |
+
|
| 79 |
+
std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
|
| 80 |
+
|
| 81 |
+
c10::optional<at::Tensor> bias() override {
|
| 82 |
+
return bias_;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
|
| 86 |
+
at::Tensor weight,
|
| 87 |
+
c10::optional<at::Tensor> bias);
|
| 88 |
+
|
| 89 |
+
bool per_channel() const {
|
| 90 |
+
return per_channel_;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
private:
|
| 94 |
+
std::mutex qnnp_mutex_;
|
| 95 |
+
|
| 96 |
+
#ifdef USE_XNNPACK
|
| 97 |
+
xnnpack_operator xnnp_linear_op;
|
| 98 |
+
|
| 99 |
+
template <typename scalar_t, bool kReluFused>
|
| 100 |
+
at::Tensor apply_impl_xnnp(
|
| 101 |
+
const at::Tensor& input,
|
| 102 |
+
double output_scale,
|
| 103 |
+
int64_t output_zero_point);
|
| 104 |
+
#endif // USE_XNNPACK
|
| 105 |
+
|
| 106 |
+
template <bool ReluFused>
|
| 107 |
+
at::Tensor apply_impl(
|
| 108 |
+
at::Tensor input,
|
| 109 |
+
double output_scale,
|
| 110 |
+
int64_t output_zero_point);
|
| 111 |
+
|
| 112 |
+
template <bool ReluFused>
|
| 113 |
+
at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range);
|
| 114 |
+
};
|
| 115 |
+
|
| 116 |
+
template <int kSpatialDim = 2>
|
| 117 |
+
struct PackedConvWeightsQnnp : public ConvPackedParamsBase<kSpatialDim> {
|
| 118 |
+
PackedConvWeightsQnnp(
|
| 119 |
+
std::unique_ptr<qnnpack::PrePackConvWeights> w,
|
| 120 |
+
at::Tensor orig_weight,
|
| 121 |
+
at::Tensor bias,
|
| 122 |
+
torch::List<int64_t> stride,
|
| 123 |
+
torch::List<int64_t> padding,
|
| 124 |
+
torch::List<int64_t> output_padding,
|
| 125 |
+
torch::List<int64_t> dilation,
|
| 126 |
+
int64_t groups,
|
| 127 |
+
bool transpose,
|
| 128 |
+
c10::optional<double> input_scale,
|
| 129 |
+
std::vector<int64_t> kernel,
|
| 130 |
+
at::Tensor w_scale,
|
| 131 |
+
std::vector<uint8_t>&& w_zps,
|
| 132 |
+
bool is_per_channel)
|
| 133 |
+
: w(std::move(w)),
|
| 134 |
+
orig_weight(std::move(orig_weight)),
|
| 135 |
+
bias(std::move(bias)),
|
| 136 |
+
stride_(std::move(stride)),
|
| 137 |
+
padding_(std::move(padding)),
|
| 138 |
+
output_padding_(std::move(output_padding)),
|
| 139 |
+
dilation_(std::move(dilation)),
|
| 140 |
+
groups_(groups),
|
| 141 |
+
transpose_(transpose),
|
| 142 |
+
is_per_channel_(is_per_channel),
|
| 143 |
+
input_scale(input_scale),
|
| 144 |
+
kernel_(std::move(kernel)),
|
| 145 |
+
w_scales(std::move(w_scale)),
|
| 146 |
+
w_zero_points(std::move(w_zps)) {
|
| 147 |
+
const bool any_padding = std::any_of(
|
| 148 |
+
padding_.begin(), padding_.end(), [](const auto& e) { return e != 0; });
|
| 149 |
+
const size_t kernel_size =
|
| 150 |
+
std::accumulate(kernel_.begin(), kernel_.end(), 1, std::multiplies<>());
|
| 151 |
+
|
| 152 |
+
const size_t group_input_channels = transpose
|
| 153 |
+
? this->orig_weight.size(0) / groups
|
| 154 |
+
: this->orig_weight.size(1);
|
| 155 |
+
const size_t group_output_channels = transpose
|
| 156 |
+
? this->orig_weight.size(1)
|
| 157 |
+
: this->orig_weight.size(0) / groups;
|
| 158 |
+
|
| 159 |
+
const size_t kernel_depth = kSpatialDim == 3 ? kernel_[0] : 1;
|
| 160 |
+
const size_t kernel_height = kernel_[kSpatialDim - 2];
|
| 161 |
+
const size_t kernel_width = kernel_[kSpatialDim - 1];
|
| 162 |
+
|
| 163 |
+
pytorch_qnnp_ukernel_type ukernel_type;
|
| 164 |
+
if (transpose_) {
|
| 165 |
+
ukernel_type = pytorch_qnnp_ukernel_type_conv;
|
| 166 |
+
} else {
|
| 167 |
+
ukernel_type = pytorch_qnnp_ukernel_type_none;
|
| 168 |
+
|
| 169 |
+
const bool has_depthwise_dimensions =
|
| 170 |
+
(kSpatialDim == 2 &&
|
| 171 |
+
((kernel_height == 3 && kernel_width == 3) ||
|
| 172 |
+
(kernel_height == 5 && kernel_width == 5))) ||
|
| 173 |
+
(kSpatialDim == 3 && kernel_height == 3 && kernel_width == 3 &&
|
| 174 |
+
kernel_depth == 3);
|
| 175 |
+
const bool has_depthwise_grouping =
|
| 176 |
+
group_input_channels == 1 && group_output_channels == 1 && groups > 1;
|
| 177 |
+
|
| 178 |
+
if (has_depthwise_dimensions && has_depthwise_grouping) {
|
| 179 |
+
ukernel_type = pytorch_qnnp_ukernel_type_dwconv;
|
| 180 |
+
} else if (
|
| 181 |
+
kernel_size == 1 &&
|
| 182 |
+
std::all_of(
|
| 183 |
+
stride_.begin(),
|
| 184 |
+
stride_.end(),
|
| 185 |
+
[](const auto& e) { return e == 1; }) &&
|
| 186 |
+
!any_padding) {
|
| 187 |
+
ukernel_type = group_input_channels >= SIZE_MAX
|
| 188 |
+
? pytorch_qnnp_ukernel_type_xzp_gemm
|
| 189 |
+
: pytorch_qnnp_ukernel_type_gemm;
|
| 190 |
+
} else {
|
| 191 |
+
ukernel_type = pytorch_qnnp_ukernel_type_conv;
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
if (is_per_channel && ukernel_type == pytorch_qnnp_ukernel_type_xzp_gemm) {
|
| 196 |
+
TORCH_INTERNAL_ASSERT(
|
| 197 |
+
false, "Per channel quantized weights are not supported for XZP kernels");
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
pytorch_qnnp_operator_t convolution{nullptr};
|
| 201 |
+
// Initially all the params are set to zero.
|
| 202 |
+
convolution = static_cast<pytorch_qnnp_operator_t>(
|
| 203 |
+
calloc(1, sizeof(struct pytorch_qnnp_operator)));
|
| 204 |
+
if (convolution == nullptr) {
|
| 205 |
+
TORCH_INTERNAL_ASSERT(
|
| 206 |
+
false, "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
|
| 207 |
+
sizeof(struct pytorch_qnnp_operator));
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
convolution_op =
|
| 211 |
+
std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>(
|
| 212 |
+
convolution);
|
| 213 |
+
|
| 214 |
+
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
|
| 215 |
+
convolution->ukernel_type = ukernel_type;
|
| 216 |
+
convolution->groups = groups;
|
| 217 |
+
convolution->group_input_channels = group_input_channels;
|
| 218 |
+
convolution->group_output_channels = group_output_channels;
|
| 219 |
+
convolution->kernel_depth = kernel_depth;
|
| 220 |
+
convolution->kernel_height = kernel_height;
|
| 221 |
+
convolution->kernel_width = kernel_width;
|
| 222 |
+
convolution->stride_depth = kSpatialDim == 3 ? stride_[0] : 1;
|
| 223 |
+
convolution->stride_height = stride_[kSpatialDim - 2];
|
| 224 |
+
convolution->stride_width = stride_[kSpatialDim - 1];
|
| 225 |
+
convolution->dilation_depth = kSpatialDim == 3 ? dilation_[0] : 1;
|
| 226 |
+
convolution->dilation_height = dilation_[kSpatialDim - 2];
|
| 227 |
+
convolution->dilation_width = dilation_[kSpatialDim - 1];
|
| 228 |
+
convolution->input_padding_height = padding_[kSpatialDim - 2];
|
| 229 |
+
convolution->input_padding_width = padding_[kSpatialDim - 1];
|
| 230 |
+
convolution->input_padding_depth = kSpatialDim == 3 ? padding_[0] : 0;
|
| 231 |
+
convolution->per_channel = is_per_channel_;
|
| 232 |
+
convolution->transpose = transpose_;
|
| 233 |
+
|
| 234 |
+
const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
|
| 235 |
+
const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
|
| 236 |
+
|
| 237 |
+
size_t zero_size = sizeof(uint8_t) * k_stride;
|
| 238 |
+
size_t zero_offset = 0;
|
| 239 |
+
|
| 240 |
+
if (transpose_) {
|
| 241 |
+
convolution->adjustment_width = output_padding_[1];
|
| 242 |
+
convolution->adjustment_height = output_padding_[0];
|
| 243 |
+
if (group_input_channels < 8) {
|
| 244 |
+
zero_size += 8;
|
| 245 |
+
zero_offset = 8;
|
| 246 |
+
}
|
| 247 |
+
} else {
|
| 248 |
+
zero_buffer_size = 0;
|
| 249 |
+
if (any_padding) {
|
| 250 |
+
zero_size = 0;
|
| 251 |
+
zero_offset = 0;
|
| 252 |
+
if (ukernel_type == pytorch_qnnp_ukernel_type_dwconv) {
|
| 253 |
+
const uint32_t cr = pytorch_qnnp_params.q8dw9.cr;
|
| 254 |
+
const size_t group_stride = (groups + (cr - 1)) & -cr;
|
| 255 |
+
if (groups >= 8) {
|
| 256 |
+
zero_size = sizeof(uint8_t) * group_stride;
|
| 257 |
+
zero_offset = 0;
|
| 258 |
+
} else {
|
| 259 |
+
zero_size = sizeof(uint8_t) * group_stride + 8;
|
| 260 |
+
zero_offset = sizeof(uint8_t) * 8;
|
| 261 |
+
}
|
| 262 |
+
} else if (
|
| 263 |
+
ukernel_type == pytorch_qnnp_ukernel_type_conv ||
|
| 264 |
+
ukernel_type == pytorch_qnnp_ukernel_type_gemm) {
|
| 265 |
+
if (group_input_channels >= 8) {
|
| 266 |
+
zero_size = sizeof(uint8_t) * k_stride;
|
| 267 |
+
zero_offset = 0;
|
| 268 |
+
} else {
|
| 269 |
+
zero_size = sizeof(uint8_t) * k_stride + 8;
|
| 270 |
+
zero_offset = 8;
|
| 271 |
+
}
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
// NOLINTNEXTLINE(clang-analyzer-optin.portability.UnixAPI)
|
| 277 |
+
void* zero_buffer = malloc(zero_size);
|
| 278 |
+
if (zero_buffer == nullptr) {
|
| 279 |
+
pytorch_qnnp_delete_operator(convolution);
|
| 280 |
+
TORCH_INTERNAL_ASSERT(
|
| 281 |
+
false, "failed to allocate %zu bytes for zero padding",
|
| 282 |
+
zero_size);
|
| 283 |
+
}
|
| 284 |
+
// Need to set to input zero point
|
| 285 |
+
// memset(zero_buffer, input_zero_point, zero_size);
|
| 286 |
+
zero_buffer_size = zero_size;
|
| 287 |
+
convolution->zero_buffer = zero_buffer;
|
| 288 |
+
convolution->zero_pointer = (void*)((uintptr_t)zero_buffer + zero_offset);
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter> convolution_op;
|
| 292 |
+
#ifdef USE_XNNPACK
|
| 293 |
+
xnnpack_operator xnnp_convolution_op;
|
| 294 |
+
#endif // USE_XNNPACK
|
| 295 |
+
std::unique_ptr<qnnpack::PrePackConvWeights> w;
|
| 296 |
+
at::Tensor orig_weight;
|
| 297 |
+
at::Tensor bias;
|
| 298 |
+
torch::List<int64_t> stride_;
|
| 299 |
+
torch::List<int64_t> padding_;
|
| 300 |
+
torch::List<int64_t> output_padding_;
|
| 301 |
+
torch::List<int64_t> dilation_;
|
| 302 |
+
int64_t groups_;
|
| 303 |
+
bool transpose_;
|
| 304 |
+
bool is_per_channel_;
|
| 305 |
+
c10::optional<double> input_scale;
|
| 306 |
+
std::vector<int64_t> kernel_;
|
| 307 |
+
at::Tensor w_scales;
|
| 308 |
+
std::vector<uint8_t> w_zero_points;
|
| 309 |
+
std::vector<float> requantization_scales;
|
| 310 |
+
size_t zero_buffer_size;
|
| 311 |
+
|
| 312 |
+
at::Tensor apply(
|
| 313 |
+
const at::Tensor& input,
|
| 314 |
+
double output_scale,
|
| 315 |
+
int64_t output_zero_point) override;
|
| 316 |
+
|
| 317 |
+
at::Tensor apply_relu(
|
| 318 |
+
const at::Tensor& input,
|
| 319 |
+
double output_scale,
|
| 320 |
+
int64_t output_zero_point) override;
|
| 321 |
+
|
| 322 |
+
at::Tensor apply_dynamic(
|
| 323 |
+
const at::Tensor& input,
|
| 324 |
+
bool reduce_range=false) override;
|
| 325 |
+
|
| 326 |
+
std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
|
| 327 |
+
|
| 328 |
+
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
|
| 329 |
+
at::Tensor weight,
|
| 330 |
+
c10::optional<at::Tensor> bias,
|
| 331 |
+
torch::List<int64_t> stride,
|
| 332 |
+
torch::List<int64_t> padding,
|
| 333 |
+
torch::List<int64_t> output_padding,
|
| 334 |
+
torch::List<int64_t> dilation,
|
| 335 |
+
int64_t groups,
|
| 336 |
+
bool transpose);
|
| 337 |
+
|
| 338 |
+
torch::List<int64_t> stride() const override {
|
| 339 |
+
return stride_;
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
torch::List<int64_t> padding() const override {
|
| 343 |
+
return padding_;
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
torch::List<int64_t> output_padding() const override {
|
| 347 |
+
return output_padding_;
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
torch::List<int64_t> dilation() const override {
|
| 351 |
+
return dilation_;
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
int64_t groups() const override {
|
| 355 |
+
return groups_;
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
bool transpose() const override {
|
| 359 |
+
return transpose_;
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
bool per_channel() const {
|
| 363 |
+
return is_per_channel_;
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
private:
|
| 367 |
+
std::mutex qnnp_mutex_;
|
| 368 |
+
template <bool ReluFused>
|
| 369 |
+
at::Tensor apply_impl(
|
| 370 |
+
const at::Tensor& input,
|
| 371 |
+
double output_scale,
|
| 372 |
+
int64_t output_zero_point);
|
| 373 |
+
|
| 374 |
+
#ifdef USE_XNNPACK
|
| 375 |
+
template <typename scalar_t, bool ReluFused>
|
| 376 |
+
at::Tensor apply_impl_xnnp(
|
| 377 |
+
const at::Tensor& input,
|
| 378 |
+
double output_scale,
|
| 379 |
+
int64_t output_zero_point);
|
| 380 |
+
#endif // USE_XNNPACK
|
| 381 |
+
};
|
| 382 |
+
|
| 383 |
+
enum class Activation : uint8_t { NONE = 0, RELU = 1 };
|
| 384 |
+
|
| 385 |
+
#if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
|
| 386 |
+
template <class T>
|
| 387 |
+
inline float Round(const float x) {
|
| 388 |
+
return ::nearbyintf(x);
|
| 389 |
+
}
|
| 390 |
+
inline double Round(const double x) {
|
| 391 |
+
return ::nearbyint(x);
|
| 392 |
+
}
|
| 393 |
+
#else
|
| 394 |
+
template <class T>
|
| 395 |
+
inline T Round(const T x) {
|
| 396 |
+
return std::nearbyint(x);
|
| 397 |
+
}
|
| 398 |
+
#endif
|
| 399 |
+
|
| 400 |
+
template<typename T>
|
| 401 |
+
inline T QuantizeValue(float scale, int32_t zero_point, float value) {
|
| 402 |
+
const int32_t qmin = std::numeric_limits<T>::min();
|
| 403 |
+
const int32_t qmax = std::numeric_limits<T>::max();
|
| 404 |
+
auto r = zero_point + static_cast<int32_t>(Round(value / scale));
|
| 405 |
+
r = std::max(r, qmin);
|
| 406 |
+
r = std::min(r, qmax);
|
| 407 |
+
return static_cast<T>(r);
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
template<typename T>
|
| 411 |
+
inline std::pair<T, T> activationLimits(
|
| 412 |
+
float scale,
|
| 413 |
+
int32_t zero_point,
|
| 414 |
+
Activation Ac) {
|
| 415 |
+
switch (Ac) {
|
| 416 |
+
case Activation::NONE:
|
| 417 |
+
return {std::numeric_limits<T>::min(),
|
| 418 |
+
std::numeric_limits<T>::max()};
|
| 419 |
+
case Activation::RELU:
|
| 420 |
+
return {QuantizeValue<T>(scale, zero_point, 0.0),
|
| 421 |
+
std::numeric_limits<T>::max()};
|
| 422 |
+
default:
|
| 423 |
+
#ifdef _MSC_VER
|
| 424 |
+
__assume(0);
|
| 425 |
+
#else
|
| 426 |
+
__builtin_unreachable();
|
| 427 |
+
#endif
|
| 428 |
+
}
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
namespace at {
|
| 432 |
+
namespace native {
|
| 433 |
+
namespace qnnp_avgpool_helper {
|
| 434 |
+
Tensor qnnpack_avg_pool2d(
|
| 435 |
+
Tensor input,
|
| 436 |
+
IntArrayRef kernel_size,
|
| 437 |
+
IntArrayRef stride,
|
| 438 |
+
IntArrayRef padding,
|
| 439 |
+
bool ceil_mode,
|
| 440 |
+
bool count_include_pad,
|
| 441 |
+
c10::optional<int64_t> divisor_override);
|
| 442 |
+
} // qnnp_avgpool_helper
|
| 443 |
+
} // namespace native
|
| 444 |
+
} // namespace at
|
| 445 |
+
|
| 446 |
+
namespace {
|
| 447 |
+
C10_UNUSED std::vector<float> generate_requantization_scales(
|
| 448 |
+
const at::Tensor& weight_scales,
|
| 449 |
+
const float input_scale,
|
| 450 |
+
const float output_scale,
|
| 451 |
+
std::vector<float>& requant_scales) {
|
| 452 |
+
// Since weight scale is allocated with padding
|
| 453 |
+
// weight_scales.numel() gives us padded num elements.
|
| 454 |
+
const auto num_output_channels_padded = weight_scales.numel();
|
| 455 |
+
float *const weight_scales_data = weight_scales.data_ptr<float>();
|
| 456 |
+
if (static_cast<int64_t>(requant_scales.size()) < num_output_channels_padded) {
|
| 457 |
+
requant_scales.resize(num_output_channels_padded);
|
| 458 |
+
}
|
| 459 |
+
for (const auto i : c10::irange(num_output_channels_padded)) {
|
| 460 |
+
const auto inverse_output_scale = 1.f /output_scale;
|
| 461 |
+
requant_scales[i] = (weight_scales_data[i] * input_scale) * inverse_output_scale;
|
| 462 |
+
TORCH_CHECK(
|
| 463 |
+
(requant_scales[i] > 0.0f && std::isnormal(requant_scales[i])),
|
| 464 |
+
"failed to create op with requantization scale: ",
|
| 465 |
+
requant_scales[i],
|
| 466 |
+
": requantization scale must be finite and positive");
|
| 467 |
+
}
|
| 468 |
+
return requant_scales;
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
C10_UNUSED std::pair<std::vector<uint8_t>, at::Tensor> make_zero_points_and_scales_tensor(
|
| 472 |
+
const at::Tensor& weight_contig,
|
| 473 |
+
bool transpose = false,
|
| 474 |
+
uint32_t groups = 1
|
| 475 |
+
) {
|
| 476 |
+
const int out_ch_idx = transpose ? 1 : 0;
|
| 477 |
+
const auto num_output_channels = weight_contig.size(out_ch_idx) * (transpose ? groups : 1);
|
| 478 |
+
// Add 8 to account for bufferring needed by QNNPACK.
|
| 479 |
+
const auto num_output_channels_padded = num_output_channels + kPaddingChannels;
|
| 480 |
+
const auto qtype = weight_contig.qscheme();
|
| 481 |
+
std::vector<uint8_t> weight_zp(num_output_channels_padded, 0);
|
| 482 |
+
// Adjust weight zero point, similar to weight data.
|
| 483 |
+
if (qtype == at::kPerTensorAffine) {
|
| 484 |
+
for (const auto i : c10::irange(num_output_channels)) {
|
| 485 |
+
weight_zp[i] = (uint8_t)(weight_contig.q_zero_point() + 128);
|
| 486 |
+
}
|
| 487 |
+
} else if (qtype == at::kPerChannelAffine) {
|
| 488 |
+
TORCH_CHECK(
|
| 489 |
+
weight_contig.q_per_channel_zero_points().scalar_type() == at::kLong,
|
| 490 |
+
"Per channel zero points dtype must be long int.");
|
| 491 |
+
const int64_t* per_channel_zero_points =
|
| 492 |
+
weight_contig.q_per_channel_zero_points().data_ptr<int64_t>();
|
| 493 |
+
for (const auto i : c10::irange(num_output_channels)) {
|
| 494 |
+
weight_zp[i] = (uint8_t)(per_channel_zero_points[i] + 128);
|
| 495 |
+
}
|
| 496 |
+
} else {
|
| 497 |
+
TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme.");
|
| 498 |
+
}
|
| 499 |
+
at:: Tensor weight_scales =
|
| 500 |
+
at::empty(
|
| 501 |
+
{num_output_channels_padded},
|
| 502 |
+
at::device(at::kCPU).dtype(at::kFloat));
|
| 503 |
+
float *const weight_scales_data = weight_scales.data_ptr<float>();
|
| 504 |
+
if (qtype == at::kPerTensorAffine) {
|
| 505 |
+
for (const auto i : c10::irange(num_output_channels)) {
|
| 506 |
+
weight_scales_data[i] = weight_contig.q_scale();
|
| 507 |
+
}
|
| 508 |
+
} else if (qtype == at::kPerChannelAffine) {
|
| 509 |
+
TORCH_CHECK(
|
| 510 |
+
weight_contig.q_per_channel_scales().scalar_type() == at::kDouble,
|
| 511 |
+
"Per channel scales dtype must be double.");
|
| 512 |
+
const double *const per_channel_scales =
|
| 513 |
+
weight_contig.q_per_channel_scales().data_ptr<double>();
|
| 514 |
+
for (const auto i : c10::irange(num_output_channels)) {
|
| 515 |
+
weight_scales_data[i] = static_cast<float>(per_channel_scales[i]);
|
| 516 |
+
}
|
| 517 |
+
} else {
|
| 518 |
+
TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme.");
|
| 519 |
+
}
|
| 520 |
+
for (const auto i : c10::irange(num_output_channels, num_output_channels_padded)) {
|
| 521 |
+
weight_scales_data[i] = 1.f;
|
| 522 |
+
}
|
| 523 |
+
return {weight_zp, weight_scales};
|
| 524 |
+
}
|
| 525 |
+
} // namespace
|
| 526 |
+
|
| 527 |
+
#endif
|